diff --git a/client/cmd/bootnode/main.go b/client/cmd/bootnode/main.go index 189cc9c0e..12ea1d748 100644 --- a/client/cmd/bootnode/main.go +++ b/client/cmd/bootnode/main.go @@ -4,6 +4,7 @@ import ( "crypto/ecdsa" "flag" "fmt" + "net" "os" "github.com/kowala-tech/kcoin/client/cmd/utils" @@ -79,12 +80,37 @@ func main() { } } + addr, err := net.ResolveUDPAddr("udp", *listenAddr) + if err != nil { + utils.Fatalf("-ResolveUDPAddr: %v", err) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + utils.Fatalf("-ListenUDP: %v", err) + } + + realaddr := conn.LocalAddr().(*net.UDPAddr) + if natm != nil { + if !realaddr.IP.IsLoopback() { + go nat.Map(natm, nil, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") + } + // TODO: react to external IP changes over time. + if ext, err := natm.ExternalIP(); err == nil { + realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} + } + } + if *runv5 { - if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil { + if _, err := discv5.ListenUDP(nodeKey, conn, realaddr, "", restrictList); err != nil { utils.Fatalf("%v", err) } } else { - if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil { + cfg := discover.Config{ + PrivateKey: nodeKey, + AnnounceAddr: realaddr, + NetRestrict: restrictList, + } + if _, err := discover.ListenUDP(conn, cfg); err != nil { utils.Fatalf("%v", err) } } diff --git a/client/crypto/crypto.go b/client/crypto/crypto.go index e5bd6df90..a881e8ef7 100644 --- a/client/crypto/crypto.go +++ b/client/crypto/crypto.go @@ -35,8 +35,8 @@ import ( ) var ( - secp256k1_N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) - secp256k1_halfN = new(big.Int).Div(secp256k1_N, big.NewInt(2)) + secp256k1N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) + secp256k1halfN = new(big.Int).Div(secp256k1N, big.NewInt(2)) ) // Keccak256 calculates and returns the Keccak256 hash of the input data. @@ -68,7 +68,7 @@ func Keccak512(data ...[]byte) []byte { return d.Sum(nil) } -// Creates an ethereum address given the bytes and the nonce +// CreateAddress creates an ethereum address given the bytes and the nonce func CreateAddress(b common.Address, nonce uint64) common.Address { data, _ := rlp.EncodeToBytes([]interface{}{b, nonce}) return common.BytesToAddress(Keccak256(data)[12:]) @@ -79,7 +79,7 @@ func ToECDSA(d []byte) (*ecdsa.PrivateKey, error) { return toECDSA(d, true) } -// ToECDSAUnsafe blidly converts a binary blob to a private key. It should almost +// ToECDSAUnsafe blindly converts a binary blob to a private key. It should almost // never be used unless you are sure the input is valid and want to avoid hitting // errors due to bad origin encoding (0 prefixes cut off). func ToECDSAUnsafe(d []byte) *ecdsa.PrivateKey { @@ -97,7 +97,20 @@ func toECDSA(d []byte, strict bool) (*ecdsa.PrivateKey, error) { return nil, fmt.Errorf("invalid length, need %d bits", priv.Params().BitSize) } priv.D = new(big.Int).SetBytes(d) + + // The priv.D must < N + if priv.D.Cmp(secp256k1N) >= 0 { + return nil, fmt.Errorf("invalid private key, >=N") + } + // The priv.D must not be zero or negative. + if priv.D.Sign() <= 0 { + return nil, fmt.Errorf("invalid private key, zero or negative") + } + priv.PublicKey.X, priv.PublicKey.Y = priv.PublicKey.Curve.ScalarBaseMult(d) + if priv.PublicKey.X == nil { + return nil, errors.New("invalid private key") + } return priv, nil } @@ -171,11 +184,11 @@ func ValidateSignatureValues(v byte, r, s *big.Int, homestead bool) bool { } // reject upper range of s values (ECDSA malleability) // see discussion in secp256k1/libsecp256k1/include/secp256k1.h - if homestead && s.Cmp(secp256k1_halfN) > 0 { + if homestead && s.Cmp(secp256k1halfN) > 0 { return false } // Frontier: allow s to be in full N range - return r.Cmp(secp256k1_N) < 0 && s.Cmp(secp256k1_N) < 0 && (v == 0 || v == 1) + return r.Cmp(secp256k1N) < 0 && s.Cmp(secp256k1N) < 0 && (v == 0 || v == 1) } func PubkeyToAddress(p ecdsa.PublicKey) common.Address { diff --git a/client/crypto/crypto_test.go b/client/crypto/crypto_test.go index b8cde049a..e2cf7c9fe 100644 --- a/client/crypto/crypto_test.go +++ b/client/crypto/crypto_test.go @@ -20,12 +20,10 @@ import ( "bytes" "crypto/ecdsa" "encoding/hex" - "fmt" "io/ioutil" "math/big" "os" "testing" - "time" "github.com/kowala-tech/kcoin/client/common" ) @@ -42,15 +40,20 @@ func TestKeccak256Hash(t *testing.T) { checkhash(t, "Sha3-256-array", func(in []byte) []byte { h := Keccak256Hash(in); return h[:] }, msg, exp) } +func TestToECDSAErrors(t *testing.T) { + if _, err := HexToECDSA("0000000000000000000000000000000000000000000000000000000000000000"); err == nil { + t.Fatal("HexToECDSA should've returned error") + } + if _, err := HexToECDSA("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"); err == nil { + t.Fatal("HexToECDSA should've returned error") + } +} + func BenchmarkSha3(b *testing.B) { a := []byte("hello world") - amount := 1000000 - start := time.Now() - for i := 0; i < amount; i++ { + for i := 0; i < b.N; i++ { Keccak256(a) } - - fmt.Println(amount, ":", time.Since(start)) } func TestSign(t *testing.T) { @@ -151,7 +154,7 @@ func TestValidateSignatureValues(t *testing.T) { minusOne := big.NewInt(-1) one := common.Big1 zero := common.Big0 - secp256k1nMinus1 := new(big.Int).Sub(secp256k1_N, common.Big1) + secp256k1nMinus1 := new(big.Int).Sub(secp256k1N, common.Big1) // correct v,r,s check(true, 0, one, one) @@ -178,9 +181,9 @@ func TestValidateSignatureValues(t *testing.T) { // correct sig with max r,s check(true, 0, secp256k1nMinus1, secp256k1nMinus1) // correct v, combinations of incorrect r,s at upper limit - check(false, 0, secp256k1_N, secp256k1nMinus1) - check(false, 0, secp256k1nMinus1, secp256k1_N) - check(false, 0, secp256k1_N, secp256k1_N) + check(false, 0, secp256k1N, secp256k1nMinus1) + check(false, 0, secp256k1nMinus1, secp256k1N) + check(false, 0, secp256k1N, secp256k1N) // current callers ensures r,s cannot be negative, but let's test for that too // as crypto package could be used stand-alone diff --git a/client/crypto/ecies/ecies.go b/client/crypto/ecies/ecies.go index 1d5f96ed2..147418148 100644 --- a/client/crypto/ecies/ecies.go +++ b/client/crypto/ecies/ecies.go @@ -230,7 +230,7 @@ func symEncrypt(rand io.Reader, params *ECIESParams, key, m []byte) (ct []byte, // symDecrypt carries out CTR decryption using the block cipher specified in // the parameters -func symDecrypt(rand io.Reader, params *ECIESParams, key, ct []byte) (m []byte, err error) { +func symDecrypt(params *ECIESParams, key, ct []byte) (m []byte, err error) { c, err := params.Cipher(key) if err != nil { return @@ -292,7 +292,7 @@ func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err e } // Decrypt decrypts an ECIES ciphertext. -func (prv *PrivateKey) Decrypt(rand io.Reader, c, s1, s2 []byte) (m []byte, err error) { +func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) { if len(c) == 0 { return nil, ErrInvalidMessage } @@ -314,7 +314,7 @@ func (prv *PrivateKey) Decrypt(rand io.Reader, c, s1, s2 []byte) (m []byte, err switch c[0] { case 2, 3, 4: - rLen = ((prv.PublicKey.Curve.Params().BitSize + 7) / 4) + rLen = (prv.PublicKey.Curve.Params().BitSize + 7) / 4 if len(c) < (rLen + hLen + 1) { err = ErrInvalidMessage return @@ -361,6 +361,6 @@ func (prv *PrivateKey) Decrypt(rand io.Reader, c, s1, s2 []byte) (m []byte, err return } - m, err = symDecrypt(rand, params, Ke, c[mStart:mEnd]) + m, err = symDecrypt(params, Ke, c[mStart:mEnd]) return } diff --git a/client/crypto/secp256k1/curve.go b/client/crypto/secp256k1/curve.go index 7024e1169..ffdcaf244 100644 --- a/client/crypto/secp256k1/curve.go +++ b/client/crypto/secp256k1/curve.go @@ -34,7 +34,6 @@ package secp256k1 import ( "crypto/elliptic" "math/big" - "sync" "unsafe" "github.com/kowala-tech/kcoin/client/common/math" @@ -42,7 +41,7 @@ import ( /* #include "libsecp256k1/include/secp256k1.h" -extern int secp256k1_pubkey_scalar_mul(const secp256k1_context* ctx, const unsigned char *point, const unsigned char *scalar); +extern int secp256k1_ext_scalar_mul(const secp256k1_context* ctx, const unsigned char *point, const unsigned char *scalar); */ import "C" @@ -78,7 +77,7 @@ func (BitCurve *BitCurve) Params() *elliptic.CurveParams { } } -// IsOnBitCurve returns true if the given (x,y) lies on the BitCurve. +// IsOnCurve returns true if the given (x,y) lies on the BitCurve. func (BitCurve *BitCurve) IsOnCurve(x, y *big.Int) bool { // y² = x³ + b y2 := new(big.Int).Mul(y, y) //y² @@ -236,7 +235,7 @@ func (BitCurve *BitCurve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, math.ReadBits(By, point[32:]) pointPtr := (*C.uchar)(unsafe.Pointer(&point[0])) scalarPtr := (*C.uchar)(unsafe.Pointer(&scalar[0])) - res := C.secp256k1_pubkey_scalar_mul(context, pointPtr, scalarPtr) + res := C.secp256k1_ext_scalar_mul(context, pointPtr, scalarPtr) // Unpack the result and clear temporaries. x := new(big.Int).SetBytes(point[:32]) @@ -263,14 +262,10 @@ func (BitCurve *BitCurve) ScalarBaseMult(k []byte) (*big.Int, *big.Int) { // X9.62. func (BitCurve *BitCurve) Marshal(x, y *big.Int) []byte { byteLen := (BitCurve.BitSize + 7) >> 3 - ret := make([]byte, 1+2*byteLen) - ret[0] = 4 // uncompressed point - - xBytes := x.Bytes() - copy(ret[1+byteLen-len(xBytes):], xBytes) - yBytes := y.Bytes() - copy(ret[1+2*byteLen-len(yBytes):], yBytes) + ret[0] = 4 // uncompressed point flag + math.ReadBits(x, ret[1:1+byteLen]) + math.ReadBits(y, ret[1+byteLen:]) return ret } @@ -289,24 +284,21 @@ func (BitCurve *BitCurve) Unmarshal(data []byte) (x, y *big.Int) { return } -var ( - initonce sync.Once - theCurve *BitCurve -) +var theCurve = new(BitCurve) + +func init() { + // See SEC 2 section 2.7.1 + // curve parameters taken from: + // http://www.secg.org/collateral/sec2_final.pdf + theCurve.P = math.MustParseBig256("0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F") + theCurve.N = math.MustParseBig256("0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141") + theCurve.B = math.MustParseBig256("0x0000000000000000000000000000000000000000000000000000000000000007") + theCurve.Gx = math.MustParseBig256("0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798") + theCurve.Gy = math.MustParseBig256("0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8") + theCurve.BitSize = 256 +} -// S256 returns a BitCurve which implements secp256k1 (see SEC 2 section 2.7.1) +// S256 returns a BitCurve which implements secp256k1. func S256() *BitCurve { - initonce.Do(func() { - // See SEC 2 section 2.7.1 - // curve parameters taken from: - // http://www.secg.org/collateral/sec2_final.pdf - theCurve = new(BitCurve) - theCurve.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F", 16) - theCurve.N, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", 16) - theCurve.B, _ = new(big.Int).SetString("0000000000000000000000000000000000000000000000000000000000000007", 16) - theCurve.Gx, _ = new(big.Int).SetString("79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798", 16) - theCurve.Gy, _ = new(big.Int).SetString("483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8", 16) - theCurve.BitSize = 256 - }) return theCurve } diff --git a/client/crypto/secp256k1/ext.h b/client/crypto/secp256k1/ext.h index ee759fde6..9b043c724 100644 --- a/client/crypto/secp256k1/ext.h +++ b/client/crypto/secp256k1/ext.h @@ -19,7 +19,7 @@ static secp256k1_context* secp256k1_context_create_sign_verify() { return secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); } -// secp256k1_ecdsa_recover_pubkey recovers the public key of an encoded compact signature. +// secp256k1_ext_ecdsa_recover recovers the public key of an encoded compact signature. // // Returns: 1: recovery was successful // 0: recovery was not successful @@ -27,7 +27,7 @@ static secp256k1_context* secp256k1_context_create_sign_verify() { // Out: pubkey_out: the serialized 65-byte public key of the signer (cannot be NULL) // In: sigdata: pointer to a 65-byte signature with the recovery id at the end (cannot be NULL) // msgdata: pointer to a 32-byte message (cannot be NULL) -static int secp256k1_ecdsa_recover_pubkey( +static int secp256k1_ext_ecdsa_recover( const secp256k1_context* ctx, unsigned char *pubkey_out, const unsigned char *sigdata, @@ -46,7 +46,62 @@ static int secp256k1_ecdsa_recover_pubkey( return secp256k1_ec_pubkey_serialize(ctx, pubkey_out, &outputlen, &pubkey, SECP256K1_EC_UNCOMPRESSED); } -// secp256k1_pubkey_scalar_mul multiplies a point by a scalar in constant time. +// secp256k1_ext_ecdsa_verify verifies an encoded compact signature. +// +// Returns: 1: signature is valid +// 0: signature is invalid +// Args: ctx: pointer to a context object (cannot be NULL) +// In: sigdata: pointer to a 64-byte signature (cannot be NULL) +// msgdata: pointer to a 32-byte message (cannot be NULL) +// pubkeydata: pointer to public key data (cannot be NULL) +// pubkeylen: length of pubkeydata +static int secp256k1_ext_ecdsa_verify( + const secp256k1_context* ctx, + const unsigned char *sigdata, + const unsigned char *msgdata, + const unsigned char *pubkeydata, + size_t pubkeylen +) { + secp256k1_ecdsa_signature sig; + secp256k1_pubkey pubkey; + + if (!secp256k1_ecdsa_signature_parse_compact(ctx, &sig, sigdata)) { + return 0; + } + if (!secp256k1_ec_pubkey_parse(ctx, &pubkey, pubkeydata, pubkeylen)) { + return 0; + } + return secp256k1_ecdsa_verify(ctx, &sig, msgdata, &pubkey); +} + +// secp256k1_ext_reencode_pubkey decodes then encodes a public key. It can be used to +// convert between public key formats. The input/output formats are chosen depending on the +// length of the input/output buffers. +// +// Returns: 1: conversion successful +// 0: conversion unsuccessful +// Args: ctx: pointer to a context object (cannot be NULL) +// Out: out: output buffer that will contain the reencoded key (cannot be NULL) +// In: outlen: length of out (33 for compressed keys, 65 for uncompressed keys) +// pubkeydata: the input public key (cannot be NULL) +// pubkeylen: length of pubkeydata +static int secp256k1_ext_reencode_pubkey( + const secp256k1_context* ctx, + unsigned char *out, + size_t outlen, + const unsigned char *pubkeydata, + size_t pubkeylen +) { + secp256k1_pubkey pubkey; + + if (!secp256k1_ec_pubkey_parse(ctx, &pubkey, pubkeydata, pubkeylen)) { + return 0; + } + unsigned int flag = (outlen == 33) ? SECP256K1_EC_COMPRESSED : SECP256K1_EC_UNCOMPRESSED; + return secp256k1_ec_pubkey_serialize(ctx, out, &outlen, &pubkey, flag); +} + +// secp256k1_ext_scalar_mul multiplies a point by a scalar in constant time. // // Returns: 1: multiplication was successful // 0: scalar was invalid (zero or overflow) @@ -55,7 +110,7 @@ static int secp256k1_ecdsa_recover_pubkey( // In: point: pointer to a 64-byte public point, // encoded as two 256bit big-endian numbers. // scalar: a 32-byte scalar with which to multiply the point -int secp256k1_pubkey_scalar_mul(const secp256k1_context* ctx, unsigned char *point, const unsigned char *scalar) { +int secp256k1_ext_scalar_mul(const secp256k1_context* ctx, unsigned char *point, const unsigned char *scalar) { int ret = 0; int overflow = 0; secp256k1_fe feX, feY; diff --git a/client/crypto/secp256k1/secp256.go b/client/crypto/secp256k1/secp256.go index 0ffa04fe0..eefbb99ee 100644 --- a/client/crypto/secp256k1/secp256.go +++ b/client/crypto/secp256k1/secp256.go @@ -38,6 +38,7 @@ import "C" import ( "errors" + "math/big" "unsafe" ) @@ -55,6 +56,7 @@ var ( ErrInvalidSignatureLen = errors.New("invalid signature length") ErrInvalidRecoveryID = errors.New("invalid signature recovery id") ErrInvalidKey = errors.New("invalid private key") + ErrInvalidPubkey = errors.New("invalid public key") ErrSignFailed = errors.New("signing failed") ErrRecoverFailed = errors.New("recovery failed") ) @@ -113,12 +115,59 @@ func RecoverPubkey(msg []byte, sig []byte) ([]byte, error) { sigdata = (*C.uchar)(unsafe.Pointer(&sig[0])) msgdata = (*C.uchar)(unsafe.Pointer(&msg[0])) ) - if C.secp256k1_ecdsa_recover_pubkey(context, (*C.uchar)(unsafe.Pointer(&pubkey[0])), sigdata, msgdata) == 0 { + if C.secp256k1_ext_ecdsa_recover(context, (*C.uchar)(unsafe.Pointer(&pubkey[0])), sigdata, msgdata) == 0 { return nil, ErrRecoverFailed } return pubkey, nil } +// VerifySignature checks that the given pubkey created signature over message. +// The signature should be in [R || S] format. +func VerifySignature(pubkey, msg, signature []byte) bool { + if len(msg) != 32 || len(signature) != 64 || len(pubkey) == 0 { + return false + } + sigdata := (*C.uchar)(unsafe.Pointer(&signature[0])) + msgdata := (*C.uchar)(unsafe.Pointer(&msg[0])) + keydata := (*C.uchar)(unsafe.Pointer(&pubkey[0])) + return C.secp256k1_ext_ecdsa_verify(context, sigdata, msgdata, keydata, C.size_t(len(pubkey))) != 0 +} + +// DecompressPubkey parses a public key in the 33-byte compressed format. +// It returns non-nil coordinates if the public key is valid. +func DecompressPubkey(pubkey []byte) (x, y *big.Int) { + if len(pubkey) != 33 { + return nil, nil + } + var ( + pubkeydata = (*C.uchar)(unsafe.Pointer(&pubkey[0])) + pubkeylen = C.size_t(len(pubkey)) + out = make([]byte, 65) + outdata = (*C.uchar)(unsafe.Pointer(&out[0])) + outlen = C.size_t(len(out)) + ) + if C.secp256k1_ext_reencode_pubkey(context, outdata, outlen, pubkeydata, pubkeylen) == 0 { + return nil, nil + } + return new(big.Int).SetBytes(out[1:33]), new(big.Int).SetBytes(out[33:]) +} + +// CompressPubkey encodes a public key to 33-byte compressed format. +func CompressPubkey(x, y *big.Int) []byte { + var ( + pubkey = S256().Marshal(x, y) + pubkeydata = (*C.uchar)(unsafe.Pointer(&pubkey[0])) + pubkeylen = C.size_t(len(pubkey)) + out = make([]byte, 33) + outdata = (*C.uchar)(unsafe.Pointer(&out[0])) + outlen = C.size_t(len(out)) + ) + if C.secp256k1_ext_reencode_pubkey(context, outdata, outlen, pubkeydata, pubkeylen) == 0 { + panic("libsecp256k1 error") + } + return out +} + func checkSignature(sig []byte) error { if len(sig) != 65 { return ErrInvalidSignatureLen diff --git a/client/crypto/secp256k1/secp256_test.go b/client/crypto/secp256k1/secp256_test.go index ebd7bdf92..360e0fa03 100644 --- a/client/crypto/secp256k1/secp256_test.go +++ b/client/crypto/secp256k1/secp256_test.go @@ -49,7 +49,7 @@ func randSig() []byte { // tests for malleability // highest bit of signature ECDSA s value must be 0, in the 33th byte func compactSigCheck(t *testing.T, sig []byte) { - var b int = int(sig[32]) + var b = int(sig[32]) if b < 0 { t.Errorf("highest bit is negative: %d", b) } diff --git a/client/crypto/signature_cgo.go b/client/crypto/signature_cgo.go index 11c1e0212..faf8e09d7 100644 --- a/client/crypto/signature_cgo.go +++ b/client/crypto/signature_cgo.go @@ -27,10 +27,12 @@ import ( "github.com/kowala-tech/kcoin/client/crypto/secp256k1" ) +// Ecrecover returns the uncompressed public key that created the given signature. func Ecrecover(hash, sig []byte) ([]byte, error) { return secp256k1.RecoverPubkey(hash, sig) } +// SigToPub returns the public key that created the given signature. func SigToPub(hash, sig []byte) (*ecdsa.PublicKey, error) { s, err := Ecrecover(hash, sig) if err != nil { @@ -58,6 +60,27 @@ func Sign(hash []byte, prv *ecdsa.PrivateKey) (sig []byte, err error) { return secp256k1.Sign(hash, seckey) } +// VerifySignature checks that the given public key created signature over hash. +// The public key should be in compressed (33 bytes) or uncompressed (65 bytes) format. +// The signature should have the 64 byte [R || S] format. +func VerifySignature(pubkey, hash, signature []byte) bool { + return secp256k1.VerifySignature(pubkey, hash, signature) +} + +// DecompressPubkey parses a public key in the 33-byte compressed format. +func DecompressPubkey(pubkey []byte) (*ecdsa.PublicKey, error) { + x, y := secp256k1.DecompressPubkey(pubkey) + if x == nil { + return nil, fmt.Errorf("invalid public key") + } + return &ecdsa.PublicKey{X: x, Y: y, Curve: S256()}, nil +} + +// CompressPubkey encodes a public key to the 33-byte compressed format. +func CompressPubkey(pubkey *ecdsa.PublicKey) []byte { + return secp256k1.CompressPubkey(pubkey.X, pubkey.Y) +} + // S256 returns an instance of the secp256k1 curve. func S256() elliptic.Curve { return secp256k1.S256() diff --git a/client/crypto/signature_nocgo.go b/client/crypto/signature_nocgo.go index a022eef9a..e8fa18ed4 100644 --- a/client/crypto/signature_nocgo.go +++ b/client/crypto/signature_nocgo.go @@ -21,11 +21,14 @@ package crypto import ( "crypto/ecdsa" "crypto/elliptic" + "errors" "fmt" + "math/big" "github.com/btcsuite/btcd/btcec" ) +// Ecrecover returns the uncompressed public key that created the given signature. func Ecrecover(hash, sig []byte) ([]byte, error) { pub, err := SigToPub(hash, sig) if err != nil { @@ -35,6 +38,7 @@ func Ecrecover(hash, sig []byte) ([]byte, error) { return bytes, err } +// SigToPub returns the public key that created the given signature. func SigToPub(hash, sig []byte) (*ecdsa.PublicKey, error) { // Convert to btcec input format with 'recovery id' v at the beginning. btcsig := make([]byte, 65) @@ -71,6 +75,42 @@ func Sign(hash []byte, prv *ecdsa.PrivateKey) ([]byte, error) { return sig, nil } +// VerifySignature checks that the given public key created signature over hash. +// The public key should be in compressed (33 bytes) or uncompressed (65 bytes) format. +// The signature should have the 64 byte [R || S] format. +func VerifySignature(pubkey, hash, signature []byte) bool { + if len(signature) != 64 { + return false + } + sig := &btcec.Signature{R: new(big.Int).SetBytes(signature[:32]), S: new(big.Int).SetBytes(signature[32:])} + key, err := btcec.ParsePubKey(pubkey, btcec.S256()) + if err != nil { + return false + } + // Reject malleable signatures. libsecp256k1 does this check but btcec doesn't. + if sig.S.Cmp(secp256k1halfN) > 0 { + return false + } + return sig.Verify(hash, key) +} + +// DecompressPubkey parses a public key in the 33-byte compressed format. +func DecompressPubkey(pubkey []byte) (*ecdsa.PublicKey, error) { + if len(pubkey) != 33 { + return nil, errors.New("invalid compressed public key length") + } + key, err := btcec.ParsePubKey(pubkey, btcec.S256()) + if err != nil { + return nil, err + } + return key.ToECDSA(), nil +} + +// CompressPubkey encodes a public key to the 33-byte compressed format. +func CompressPubkey(pubkey *ecdsa.PublicKey) []byte { + return (*btcec.PublicKey)(pubkey).SerializeCompressed() +} + // S256 returns an instance of the secp256k1 curve. func S256() elliptic.Curve { return btcec.S256() diff --git a/client/crypto/signature_test.go b/client/crypto/signature_test.go index aefd9e38d..fb6897ef4 100644 --- a/client/crypto/signature_test.go +++ b/client/crypto/signature_test.go @@ -18,19 +18,143 @@ package crypto import ( "bytes" - "encoding/hex" + "crypto/ecdsa" + "reflect" "testing" + + "github.com/kowala-tech/kcoin/client/common" + "github.com/kowala-tech/kcoin/client/common/hexutil" + "github.com/kowala-tech/kcoin/client/common/math" +) + +var ( + testmsg = hexutil.MustDecode("0xce0677bb30baa8cf067c88db9811f4333d131bf8bcf12fe7065d211dce971008") + testsig = hexutil.MustDecode("0x90f27b8b488db00b00606796d2987f6a5f59ae62ea05effe84fef5b8b0e549984a691139ad57a3f0b906637673aa2f63d1f55cb1a69199d4009eea23ceaddc9301") + testpubkey = hexutil.MustDecode("0x04e32df42865e97135acfb65f3bae71bdc86f4d49150ad6a440b6f15878109880a0a2b2667f7e725ceea70c673093bf67663e0312623c8e091b13cf2c0f11ef652") + testpubkeyc = hexutil.MustDecode("0x02e32df42865e97135acfb65f3bae71bdc86f4d49150ad6a440b6f15878109880a") ) -func TestRecoverSanity(t *testing.T) { - msg, _ := hex.DecodeString("ce0677bb30baa8cf067c88db9811f4333d131bf8bcf12fe7065d211dce971008") - sig, _ := hex.DecodeString("90f27b8b488db00b00606796d2987f6a5f59ae62ea05effe84fef5b8b0e549984a691139ad57a3f0b906637673aa2f63d1f55cb1a69199d4009eea23ceaddc9301") - pubkey1, _ := hex.DecodeString("04e32df42865e97135acfb65f3bae71bdc86f4d49150ad6a440b6f15878109880a0a2b2667f7e725ceea70c673093bf67663e0312623c8e091b13cf2c0f11ef652") - pubkey2, err := Ecrecover(msg, sig) +func TestEcrecover(t *testing.T) { + pubkey, err := Ecrecover(testmsg, testsig) if err != nil { t.Fatalf("recover error: %s", err) } - if !bytes.Equal(pubkey1, pubkey2) { - t.Errorf("pubkey mismatch: want: %x have: %x", pubkey1, pubkey2) + if !bytes.Equal(pubkey, testpubkey) { + t.Errorf("pubkey mismatch: want: %x have: %x", testpubkey, pubkey) + } +} + +func TestVerifySignature(t *testing.T) { + sig := testsig[:len(testsig)-1] // remove recovery id + if !VerifySignature(testpubkey, testmsg, sig) { + t.Errorf("can't verify signature with uncompressed key") + } + if !VerifySignature(testpubkeyc, testmsg, sig) { + t.Errorf("can't verify signature with compressed key") + } + + if VerifySignature(nil, testmsg, sig) { + t.Errorf("signature valid with no key") + } + if VerifySignature(testpubkey, nil, sig) { + t.Errorf("signature valid with no message") + } + if VerifySignature(testpubkey, testmsg, nil) { + t.Errorf("nil signature valid") + } + if VerifySignature(testpubkey, testmsg, append(common.CopyBytes(sig), 1, 2, 3)) { + t.Errorf("signature valid with extra bytes at the end") + } + if VerifySignature(testpubkey, testmsg, sig[:len(sig)-2]) { + t.Errorf("signature valid even though it's incomplete") + } + wrongkey := common.CopyBytes(testpubkey) + wrongkey[10]++ + if VerifySignature(wrongkey, testmsg, sig) { + t.Errorf("signature valid with with wrong public key") + } +} + +// This test checks that VerifySignature rejects malleable signatures with s > N/2. +func TestVerifySignatureMalleable(t *testing.T) { + sig := hexutil.MustDecode("0x638a54215d80a6713c8d523a6adc4e6e73652d859103a36b700851cb0e61b66b8ebfc1a610c57d732ec6e0a8f06a9a7a28df5051ece514702ff9cdff0b11f454") + key := hexutil.MustDecode("0x03ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138") + msg := hexutil.MustDecode("0xd301ce462d3e639518f482c7f03821fec1e602018630ce621e1e7851c12343a6") + if VerifySignature(key, msg, sig) { + t.Error("VerifySignature returned true for malleable signature") + } +} + +func TestDecompressPubkey(t *testing.T) { + key, err := DecompressPubkey(testpubkeyc) + if err != nil { + t.Fatal(err) + } + if uncompressed := FromECDSAPub(key); !bytes.Equal(uncompressed, testpubkey) { + t.Errorf("wrong public key result: got %x, want %x", uncompressed, testpubkey) + } + if _, err := DecompressPubkey(nil); err == nil { + t.Errorf("no error for nil pubkey") + } + if _, err := DecompressPubkey(testpubkeyc[:5]); err == nil { + t.Errorf("no error for incomplete pubkey") + } + if _, err := DecompressPubkey(append(common.CopyBytes(testpubkeyc), 1, 2, 3)); err == nil { + t.Errorf("no error for pubkey with extra bytes at the end") + } +} + +func TestCompressPubkey(t *testing.T) { + key := &ecdsa.PublicKey{ + Curve: S256(), + X: math.MustParseBig256("0xe32df42865e97135acfb65f3bae71bdc86f4d49150ad6a440b6f15878109880a"), + Y: math.MustParseBig256("0x0a2b2667f7e725ceea70c673093bf67663e0312623c8e091b13cf2c0f11ef652"), + } + compressed := CompressPubkey(key) + if !bytes.Equal(compressed, testpubkeyc) { + t.Errorf("wrong public key result: got %x, want %x", compressed, testpubkeyc) + } +} + +func TestPubkeyRandom(t *testing.T) { + const runs = 200 + + for i := 0; i < runs; i++ { + key, err := GenerateKey() + if err != nil { + t.Fatalf("iteration %d: %v", i, err) + } + pubkey2, err := DecompressPubkey(CompressPubkey(&key.PublicKey)) + if err != nil { + t.Fatalf("iteration %d: %v", i, err) + } + if !reflect.DeepEqual(key.PublicKey, *pubkey2) { + t.Fatalf("iteration %d: keys not equal", i) + } + } +} + +func BenchmarkEcrecoverSignature(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := Ecrecover(testmsg, testsig); err != nil { + b.Fatal("ecrecover error", err) + } + } +} + +func BenchmarkVerifySignature(b *testing.B) { + sig := testsig[:len(testsig)-1] // remove recovery id + for i := 0; i < b.N; i++ { + if !VerifySignature(testpubkey, testmsg, sig) { + b.Fatal("verify error") + } + } +} + +func BenchmarkDecompressPubkey(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := DecompressPubkey(testpubkeyc); err != nil { + b.Fatal(err) + } } } diff --git a/client/knode/downloader/downloader.go b/client/knode/downloader/downloader.go index ceebc1924..4741ac8b3 100644 --- a/client/knode/downloader/downloader.go +++ b/client/knode/downloader/downloader.go @@ -304,8 +304,13 @@ func (d *Downloader) Synchronise(id string, head common.Hash, blockNumber *big.I errEmptyHeaderSet, errPeersUnavailable, errInvalidAncestor, errInvalidChain: log.Warn("Synchronisation failed, dropping peer", "peer", id, "err", err) - d.dropPeer(id) - + if d.dropPeer == nil { + // The dropPeer method is nil when `--copydb` is used for a local copy. + // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored + log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", id) + } else { + d.dropPeer(id) + } default: log.Warn("Synchronisation failed, retrying", "err", err) } @@ -551,6 +556,11 @@ func (d *Downloader) fetchHeight(p *peerConnection) (*types.Header, error) { } // Make sure the peer actually gave something valid headers := packet.(*headerPack).headers + if len(headers) == 0 { + log.Debug("Received empty headers from peer", "peer", packet.PeerId()) + break + } + if len(headers) != 1 { p.log.Debug("Multiple headers for single request", "headers", len(headers)) return nil, errBadPeer diff --git a/client/p2p/dial.go b/client/p2p/dial.go index 3642fd5dc..abc424f08 100644 --- a/client/p2p/dial.go +++ b/client/p2p/dial.go @@ -138,10 +138,13 @@ func (s *dialstate) addStatic(n *discover.Node) { func (s *dialstate) removeStatic(n *discover.Node) { // This removes a task so future attempts to connect will not be made. delete(s.static, n.ID) + // This removes a previous dial timestamp so that application + // can force a server to reconnect with chosen peer immediately. + s.hist.remove(n.ID) } func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { - if s.start == (time.Time{}) { + if s.start.IsZero() { s.start = now } @@ -275,11 +278,14 @@ func (t *dialTask) Do(srv *Server) { return } } - success := t.dial(srv, t.dest) - // Try resolving the ID of static nodes if dialing failed. - if !success && t.flags&staticDialedConn != 0 { - if t.resolve(srv) { - t.dial(srv, t.dest) + err := t.dial(srv, t.dest) + if err != nil { + log.Trace("Dial error", "task", t, "err", err) + // Try resolving the ID of static nodes if dialing failed. + if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { + if t.resolve(srv) { + t.dial(srv, t.dest) + } } } } @@ -318,16 +324,19 @@ func (t *dialTask) resolve(srv *Server) bool { return true } +type dialError struct { + error +} + // dial performs the actual connection attempt. -func (t *dialTask) dial(srv *Server, dest *discover.Node) bool { +func (t *dialTask) dial(srv *Server, dest *discover.Node) error { fd, err := srv.Dialer.Dial(dest) if err != nil { log.Trace("Dial error", "task", t, "err", err) - return false + return &dialError{err} } mfd := newMeteredConn(fd, false) - srv.SetupConn(mfd, t.flags, dest) - return true + return srv.SetupConn(mfd, t.flags, dest) } func (t *dialTask) String() string { @@ -369,6 +378,16 @@ func (h dialHistory) min() pastDial { } func (h *dialHistory) add(id discover.NodeID, exp time.Time) { heap.Push(h, pastDial{id, exp}) + +} +func (h *dialHistory) remove(id discover.NodeID) bool { + for i, v := range *h { + if v.id == id { + heap.Remove(h, i) + return true + } + } + return false } func (h dialHistory) contains(id discover.NodeID) bool { for _, v := range h { diff --git a/client/p2p/discover/database.go b/client/p2p/discover/database.go index 3284d0402..7cdb72458 100644 --- a/client/p2p/discover/database.go +++ b/client/p2p/discover/database.go @@ -210,14 +210,14 @@ func (db *nodeDB) ensureExpirer() { // expirer should be started in a go routine, and is responsible for looping ad // infinitum and dropping stale data from the database. func (db *nodeDB) expirer() { - tick := time.Tick(nodeDBCleanupCycle) + tick := time.NewTicker(nodeDBCleanupCycle) + defer tick.Stop() for { select { - case <-tick: + case <-tick.C: if err := db.expireNodes(); err != nil { log.Error("Failed to expire nodedb items", "err", err) } - case <-db.quit: return } @@ -241,7 +241,7 @@ func (db *nodeDB) expireNodes() error { } // Skip the node if not expired yet (and not self) if !bytes.Equal(id[:], db.self[:]) { - if seen := db.lastPong(id); seen.After(threshold) { + if seen := db.bondTime(id); seen.After(threshold) { continue } } @@ -262,13 +262,18 @@ func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) } -// lastPong retrieves the time of the last successful contact from remote node. -func (db *nodeDB) lastPong(id NodeID) time.Time { +// bondTime retrieves the time of the last successful pong from remote node. +func (db *nodeDB) bondTime(id NodeID) time.Time { return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) } -// updateLastPong updates the last time a remote node successfully contacted. -func (db *nodeDB) updateLastPong(id NodeID, instance time.Time) error { +// hasBond reports whether the given node is considered bonded. +func (db *nodeDB) hasBond(id NodeID) bool { + return time.Since(db.bondTime(id)) < nodeDBNodeExpiration +} + +// updateBondTime updates the last pong time of a node. +func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) } @@ -311,7 +316,7 @@ seek: if n.ID == db.self { continue seek } - if now.Sub(db.lastPong(n.ID)) > maxAge { + if now.Sub(db.bondTime(n.ID)) > maxAge { continue seek } for i := range nodes { diff --git a/client/p2p/discover/node.go b/client/p2p/discover/node.go index 49479d3ec..c61b3c630 100644 --- a/client/p2p/discover/node.go +++ b/client/p2p/discover/node.go @@ -13,6 +13,7 @@ import ( "regexp" "strconv" "strings" + "time" "github.com/kowala-tech/kcoin/client/common" "github.com/kowala-tech/kcoin/client/crypto" @@ -35,9 +36,8 @@ type Node struct { // with ID. sha common.Hash - // whether this node is currently being pinged in order to replace - // it in a bucket - contested bool + // Time when the node was added to the table. + addedAt time.Time } // NewNode creates a new node. It is mostly meant to be used for diff --git a/client/p2p/discover/table.go b/client/p2p/discover/table.go index d66107192..70a7e4e34 100644 --- a/client/p2p/discover/table.go +++ b/client/p2p/discover/table.go @@ -7,10 +7,11 @@ package discover import ( - "crypto/rand" + crand "crypto/rand" "encoding/binary" "errors" "fmt" + mrand "math/rand" "net" "sort" "sync" @@ -19,29 +20,45 @@ import ( "github.com/kowala-tech/kcoin/client/common" "github.com/kowala-tech/kcoin/client/crypto" "github.com/kowala-tech/kcoin/client/log" + "github.com/kowala-tech/kcoin/client/p2p/netutil" ) const ( - alpha = 3 // Kademlia concurrency factor - bucketSize = 16 // Kademlia bucket size - hashBits = len(common.Hash{}) * 8 - nBuckets = hashBits + 1 // Number of buckets - - maxBondingPingPongs = 16 - maxFindnodeFailures = 5 - - autoRefreshInterval = 1 * time.Hour - seedCount = 30 - seedMaxAge = 5 * 24 * time.Hour + alpha = 3 // Kademlia concurrency factor + bucketSize = 16 // Kademlia bucket size + maxReplacements = 10 // Size of per-bucket replacement list + + // We keep buckets for the upper 1/15 of distances because + // it's very unlikely we'll ever encounter a node that's closer. + hashBits = len(common.Hash{}) * 8 + nBuckets = hashBits / 15 // Number of buckets + bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket + + // IP address limits. + bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 + tableIPLimit, tableSubnet = 10, 24 + + maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions + maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped + + refreshInterval = 30 * time.Minute + revalidateInterval = 10 * time.Second + copyNodesInterval = 30 * time.Second + seedMinTableTime = 5 * time.Minute + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) type Table struct { - mutex sync.Mutex // protects buckets, their content, and nursery + mutex sync.Mutex // protects buckets, bucket content, nursery, rand buckets [nBuckets]*bucket // index of known nodes by distance nursery []*Node // bootstrap nodes - db *nodeDB // database of known nodes + rand *mrand.Rand // source of randomness, periodically reseeded + ips netutil.DistinctNetSet + db *nodeDB // database of known nodes refreshReq chan chan struct{} + initDone chan struct{} closeReq chan struct{} closed chan struct{} @@ -73,9 +90,13 @@ type transport interface { // bucket contains nodes, ordered by their last activity. the entry // that was most recently active is the first element in entries. -type bucket struct{ entries []*Node } +type bucket struct { + entries []*Node // live entries, sorted by time of last contact + replacements []*Node // recently seen nodes to be used if revalidation fails + ips netutil.DistinctNetSet +} -func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) { +func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) { // If no node database was given, use an in-memory one db, err := newNodeDB(nodeDBPath, Version, ourID) if err != nil { @@ -88,19 +109,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string bonding: make(map[NodeID]*bondproc), bondslots: make(chan struct{}, maxBondingPingPongs), refreshReq: make(chan chan struct{}), + initDone: make(chan struct{}), closeReq: make(chan struct{}), closed: make(chan struct{}), + rand: mrand.New(mrand.NewSource(0)), + ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, + } + if err := tab.setFallbackNodes(bootnodes); err != nil { + return nil, err } for i := 0; i < cap(tab.bondslots); i++ { tab.bondslots <- struct{}{} } for i := range tab.buckets { - tab.buckets[i] = new(bucket) + tab.buckets[i] = &bucket{ + ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, + } } - go tab.refreshLoop() + tab.seedRand() + tab.loadSeedNodes(false) + // Start the background expiration goroutine after loading seeds so that the search for + // seed nodes also considers older nodes that would otherwise be removed by the + // expiration. + tab.db.ensureExpirer() + go tab.loop() return tab, nil } +func (tab *Table) seedRand() { + var b [8]byte + crand.Read(b[:]) + + tab.mutex.Lock() + tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:]))) + tab.mutex.Unlock() +} + // Self returns the local node. // The returned node should not be modified by the caller. func (tab *Table) Self() *Node { @@ -111,9 +155,12 @@ func (tab *Table) Self() *Node { // table. It will not write the same node more than once. The nodes in // the slice are copies and can be modified by the caller. func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { + if !tab.isInitDone() { + return 0 + } tab.mutex.Lock() defer tab.mutex.Unlock() - // TODO: tree-based buckets would help here + // Find all non-empty buckets and get a fresh slice of their entries. var buckets [][]*Node for _, b := range tab.buckets { @@ -125,8 +172,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { return 0 } // Shuffle the buckets. - for i := uint32(len(buckets)) - 1; i > 0; i-- { - j := randUint(i) + for i := len(buckets) - 1; i > 0; i-- { + j := tab.rand.Intn(len(buckets)) buckets[i], buckets[j] = buckets[j], buckets[i] } // Move head of each bucket into buf, removing buckets that become empty. @@ -145,15 +192,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { return i + 1 } -func randUint(max uint32) uint32 { - if max == 0 { - return 0 - } - var b [4]byte - rand.Read(b[:]) - return binary.BigEndian.Uint32(b[:]) % max -} - // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { select { @@ -164,16 +202,15 @@ func (tab *Table) Close() { } } -// SetFallbackNodes sets the initial points of contact. These nodes +// setFallbackNodes sets the initial points of contact. These nodes // are used to connect to the network if the table is empty and there // are no known nodes in the database. -func (tab *Table) SetFallbackNodes(nodes []*Node) error { +func (tab *Table) setFallbackNodes(nodes []*Node) error { for _, n := range nodes { if err := n.validateComplete(); err != nil { return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err) } } - tab.mutex.Lock() tab.nursery = make([]*Node, 0, len(nodes)) for _, n := range nodes { cpy := *n @@ -182,11 +219,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error { cpy.sha = crypto.Keccak256Hash(n.ID[:]) tab.nursery = append(tab.nursery, &cpy) } - tab.mutex.Unlock() - tab.refresh() return nil } +// isInitDone returns whether the table's initial seeding procedure has completed. +func (tab *Table) isInitDone() bool { + select { + case <-tab.initDone: + return true + default: + return false + } +} + // Resolve searches for a specific node with the given ID. // It returns nil if the node could not be found. func (tab *Table) Resolve(targetID NodeID) *Node { @@ -298,33 +343,49 @@ func (tab *Table) refresh() <-chan struct{} { return done } -// refreshLoop schedules doRefresh runs and coordinates shutdown. -func (tab *Table) refreshLoop() { +// loop schedules refresh, revalidate runs and coordinates shutdown. +func (tab *Table) loop() { var ( - timer = time.NewTicker(autoRefreshInterval) - waiting []chan struct{} // accumulates waiting callers while doRefresh runs - done chan struct{} // where doRefresh reports completion + revalidate = time.NewTimer(tab.nextRevalidateTime()) + refresh = time.NewTicker(refreshInterval) + copyNodes = time.NewTicker(copyNodesInterval) + revalidateDone = make(chan struct{}) + refreshDone = make(chan struct{}) // where doRefresh reports completion + waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs ) + defer refresh.Stop() + defer revalidate.Stop() + defer copyNodes.Stop() + + // Start initial refresh. + go tab.doRefresh(refreshDone) + loop: for { select { - case <-timer.C: - if done == nil { - done = make(chan struct{}) - go tab.doRefresh(done) + case <-refresh.C: + tab.seedRand() + if refreshDone == nil { + refreshDone = make(chan struct{}) + go tab.doRefresh(refreshDone) } case req := <-tab.refreshReq: waiting = append(waiting, req) - if done == nil { - done = make(chan struct{}) - go tab.doRefresh(done) + if refreshDone == nil { + refreshDone = make(chan struct{}) + go tab.doRefresh(refreshDone) } - case <-done: + case <-refreshDone: for _, ch := range waiting { close(ch) } - waiting = nil - done = nil + waiting, refreshDone = nil, nil + case <-revalidate.C: + go tab.doRevalidate(revalidateDone) + case <-revalidateDone: + revalidate.Reset(tab.nextRevalidateTime()) + case <-copyNodes.C: + go tab.copyBondedNodes() case <-tab.closeReq: break loop } @@ -333,8 +394,8 @@ loop: if tab.net != nil { tab.net.close() } - if done != nil { - <-done + if refreshDone != nil { + <-refreshDone } for _, ch := range waiting { close(ch) @@ -349,38 +410,109 @@ loop: func (tab *Table) doRefresh(done chan struct{}) { defer close(done) + // Load nodes from the database and insert + // them. This should yield a few previously seen nodes that are + // (hopefully) still alive. + tab.loadSeedNodes(true) + + // Run self lookup to discover new neighbor nodes. + tab.lookup(tab.self.ID, false) + // The Kademlia paper specifies that the bucket refresh should // perform a lookup in the least recently used bucket. We cannot // adhere to this because the findnode target is a 512bit value // (not hash-sized) and it is not easily possible to generate a // sha3 preimage that falls into a chosen bucket. - // We perform a lookup with a random target instead. - var target NodeID - rand.Read(target[:]) - result := tab.lookup(target, false) - if len(result) > 0 { - return + // We perform a few lookups with a random target instead. + for i := 0; i < 3; i++ { + var target NodeID + crand.Read(target[:]) + tab.lookup(target, false) } +} - // The table is empty. Load nodes from the database and insert - // them. This should yield a few previously seen nodes that are - // (hopefully) still alive. +func (tab *Table) loadSeedNodes(bond bool) { seeds := tab.db.querySeeds(seedCount, seedMaxAge) - seeds = tab.bondall(append(seeds, tab.nursery...)) + seeds = append(seeds, tab.nursery...) + if bond { + seeds = tab.bondall(seeds) + } + for i := range seeds { + seed := seeds[i] + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} + log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) + tab.add(seed) + } +} - if len(seeds) == 0 { - log.Debug("No discv4 seed nodes found") +// doRevalidate checks that the last node in a random bucket is still live +// and replaces or deletes the node if it isn't. +func (tab *Table) doRevalidate(done chan<- struct{}) { + defer func() { done <- struct{}{} }() + + last, bi := tab.nodeToRevalidate() + if last == nil { + // No non-empty bucket found. + return + } + + // Ping the selected node and wait for a pong. + err := tab.ping(last.ID, last.addr()) + + tab.mutex.Lock() + defer tab.mutex.Unlock() + b := tab.buckets[bi] + if err == nil { + // The node responded, move it to the front. + log.Debug("Revalidated node", "b", bi, "id", last.ID) + b.bump(last) + return } - for _, n := range seeds { - age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }} - log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age) + // No reply received, pick a replacement or delete the node if there aren't + // any replacements. + if r := tab.replace(b, last); r != nil { + log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP) + } else { + log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP) } +} + +// nodeToRevalidate returns the last node in a random, non-empty bucket. +func (tab *Table) nodeToRevalidate() (n *Node, bi int) { tab.mutex.Lock() - tab.stuff(seeds) - tab.mutex.Unlock() + defer tab.mutex.Unlock() - // Finally, do a self lookup to fill up the buckets. - tab.lookup(tab.self.ID, false) + for _, bi = range tab.rand.Perm(len(tab.buckets)) { + b := tab.buckets[bi] + if len(b.entries) > 0 { + last := b.entries[len(b.entries)-1] + return last, bi + } + } + return nil, 0 +} + +func (tab *Table) nextRevalidateTime() time.Duration { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) +} + +// copyBondedNodes adds nodes from the table to the database if they have been in the table +// longer then minTableTime. +func (tab *Table) copyBondedNodes() { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + now := time.Now() + for _, b := range tab.buckets { + for _, n := range b.entries { + if now.Sub(n.addedAt) >= seedMinTableTime { + tab.db.updateNode(n) + } + } + } } // closest returns the n nodes in the table that are closest to the @@ -443,15 +575,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 if id == tab.self.ID { return nil, errors.New("is self") } - // Retrieve a previously known node and any recent findnode failures - node, fails := tab.db.node(id), 0 - if node != nil { - fails = tab.db.findFails(id) + if pinged && !tab.isInitDone() { + return nil, errors.New("still initializing") } - // If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch + // Start bonding if we haven't seen this node for a while or if it failed findnode too often. + node, fails := tab.db.node(id), tab.db.findFails(id) + age := time.Since(tab.db.bondTime(id)) var result error - age := time.Since(tab.db.lastPong(id)) - if node == nil || fails > 0 || age > nodeDBNodeExpiration { + if fails > 0 || age > nodeDBNodeExpiration { log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) tab.bondmu.Lock() @@ -478,10 +609,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 node = w.n } } + // Add the node to the table even if the bonding ping/pong + // fails. It will be relaced quickly if it continues to be + // unresponsive. if node != nil { - // Add the node to the table even if the bonding ping/pong - // fails. It will be relaced quickly if it continues to be - // unresponsive. tab.add(node) tab.db.updateFindFails(id, 0) } @@ -506,7 +637,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd } // Bonding succeeded, update the node database. w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort) - tab.db.updateNode(w.n) close(w.done) } @@ -517,17 +647,19 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { if err := tab.net.ping(id, addr); err != nil { return err } - tab.db.updateLastPong(id, time.Now()) - - // Start the background expiration goroutine after the first - // successful communication. Subsequent calls have no effect if it - // is already running. We do this here instead of somewhere else - // so that the search for seed nodes also considers older nodes - // that would otherwise be removed by the expiration. - tab.db.ensureExpirer() + tab.db.updateBondTime(id, time.Now()) return nil } +// bucket returns the bucket for the given node ID hash. +func (tab *Table) bucket(sha common.Hash) *bucket { + d := logdist(tab.self.sha, sha) + if d <= bucketMinDistance { + return tab.buckets[0] + } + return tab.buckets[d-bucketMinDistance-1] +} + // add attempts to add the given node its corresponding bucket. If the // bucket has space available, adding the node succeeds immediately. // Otherwise, the node is added if the least recently active node in @@ -535,57 +667,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { // // The caller must not hold tab.mutex. func (tab *Table) add(new *Node) { - b := tab.buckets[logdist(tab.self.sha, new.sha)] tab.mutex.Lock() defer tab.mutex.Unlock() - if b.bump(new) { - return - } - var oldest *Node - if len(b.entries) == bucketSize { - oldest = b.entries[bucketSize-1] - if oldest.contested { - // The node is already being replaced, don't attempt - // to replace it. - return - } - oldest.contested = true - // Let go of the mutex so other goroutines can access - // the table while we ping the least recently active node. - tab.mutex.Unlock() - err := tab.ping(oldest.ID, oldest.addr()) - tab.mutex.Lock() - oldest.contested = false - if err == nil { - // The node responded, don't replace it. - return - } - } - added := b.replace(new, oldest) - if added && tab.nodeAddedHook != nil { - tab.nodeAddedHook(new) + + b := tab.bucket(new.sha) + if !tab.bumpOrAdd(b, new) { + // Node is not in table. Add it to the replacement list. + tab.addReplacement(b, new) } } // stuff adds nodes the table to the end of their corresponding bucket -// if the bucket is not full. The caller must hold tab.mutex. +// if the bucket is not full. The caller must not hold tab.mutex. func (tab *Table) stuff(nodes []*Node) { -outer: + tab.mutex.Lock() + defer tab.mutex.Unlock() + for _, n := range nodes { if n.ID == tab.self.ID { continue // don't add self } - bucket := tab.buckets[logdist(tab.self.sha, n.sha)] - for i := range bucket.entries { - if bucket.entries[i].ID == n.ID { - continue outer // already in bucket - } - } - if len(bucket.entries) < bucketSize { - bucket.entries = append(bucket.entries, n) - if tab.nodeAddedHook != nil { - tab.nodeAddedHook(n) - } + b := tab.bucket(n.sha) + if len(b.entries) < bucketSize { + tab.bumpOrAdd(b, n) } } } @@ -595,36 +699,72 @@ outer: func (tab *Table) delete(node *Node) { tab.mutex.Lock() defer tab.mutex.Unlock() - bucket := tab.buckets[logdist(tab.self.sha, node.sha)] - for i := range bucket.entries { - if bucket.entries[i].ID == node.ID { - bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...) - return - } - } + + tab.deleteInBucket(tab.bucket(node.sha), node) } -func (b *bucket) replace(n *Node, last *Node) bool { - // Don't add if b already contains n. - for i := range b.entries { - if b.entries[i].ID == n.ID { - return false - } +func (tab *Table) addIP(b *bucket, ip net.IP) bool { + if netutil.IsLAN(ip) { + return true } - // Replace last if it is still the last entry or just add n if b - // isn't full. If is no longer the last entry, it has either been - // replaced with someone else or became active. - if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) { + if !tab.ips.Add(ip) { + log.Debug("IP exceeds table limit", "ip", ip) return false } - if len(b.entries) < bucketSize { - b.entries = append(b.entries, nil) + if !b.ips.Add(ip) { + log.Debug("IP exceeds bucket limit", "ip", ip) + tab.ips.Remove(ip) + return false } - copy(b.entries[1:], b.entries) - b.entries[0] = n return true } +func (tab *Table) removeIP(b *bucket, ip net.IP) { + if netutil.IsLAN(ip) { + return + } + tab.ips.Remove(ip) + b.ips.Remove(ip) +} + +func (tab *Table) addReplacement(b *bucket, n *Node) { + for _, e := range b.replacements { + if e.ID == n.ID { + return // already in list + } + } + if !tab.addIP(b, n.IP) { + return + } + var removed *Node + b.replacements, removed = pushNode(b.replacements, n, maxReplacements) + if removed != nil { + tab.removeIP(b, removed.IP) + } +} + +// replace removes n from the replacement list and replaces 'last' with it if it is the +// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced +// with someone else or became active. +func (tab *Table) replace(b *bucket, last *Node) *Node { + if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID { + // Entry has moved, don't replace it. + return nil + } + // Still the last entry. + if len(b.replacements) == 0 { + tab.deleteInBucket(b, last) + return nil + } + r := b.replacements[tab.rand.Intn(len(b.replacements))] + b.replacements = deleteNode(b.replacements, r) + b.entries[len(b.entries)-1] = r + tab.removeIP(b, last.IP) + return r +} + +// bump moves the given node to the front of the bucket entry list +// if it is contained in that list. func (b *bucket) bump(n *Node) bool { for i := range b.entries { if b.entries[i].ID == n.ID { @@ -637,6 +777,50 @@ func (b *bucket) bump(n *Node) bool { return false } +// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't +// full. The return value is true if n is in the bucket. +func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool { + if b.bump(n) { + return true + } + if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) { + return false + } + b.entries, _ = pushNode(b.entries, n, bucketSize) + b.replacements = deleteNode(b.replacements, n) + n.addedAt = time.Now() + if tab.nodeAddedHook != nil { + tab.nodeAddedHook(n) + } + return true +} + +func (tab *Table) deleteInBucket(b *bucket, n *Node) { + b.entries = deleteNode(b.entries, n) + tab.removeIP(b, n.IP) +} + +// pushNode adds n to the front of list, keeping at most max items. +func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) { + if len(list) < max { + list = append(list, nil) + } + removed := list[len(list)-1] + copy(list[1:], list) + list[0] = n + return list, removed +} + +// deleteNode removes n from list. +func deleteNode(list []*Node, n *Node) []*Node { + for i := range list { + if list[i].ID == n.ID { + return append(list[:i], list[i+1:]...) + } + } + return list +} + // nodesByDistance is a list of nodes, ordered by // distance to target. type nodesByDistance struct { diff --git a/client/p2p/discover/udp.go b/client/p2p/discover/udp.go index f77d4126a..0262c0a94 100644 --- a/client/p2p/discover/udp.go +++ b/client/p2p/discover/udp.go @@ -33,7 +33,6 @@ var ( // Timeouts const ( respTimeout = 500 * time.Millisecond - sendTimeout = 500 * time.Millisecond expiration = 20 * time.Second ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP @@ -194,55 +193,58 @@ type reply struct { matched chan<- bool } +// ReadPacket is sent to the unhandled channel when it could not be processed +type ReadPacket struct { + Data []byte + Addr *net.UDPAddr +} + +// Config holds Table-related settings. +type Config struct { + // These settings are required and configure the UDP listener: + PrivateKey *ecdsa.PrivateKey + + // These settings are optional: + AnnounceAddr *net.UDPAddr // local address announced in the DHT + NodeDBPath string // if set, the node database is stored at this filesystem location + NetRestrict *netutil.Netlist // network whitelist + Bootnodes []*Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel +} + // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { - addr, err := net.ResolveUDPAddr("udp", laddr) +func ListenUDP(c conn, cfg Config) (*Table, error) { + tab, _, err := newUDP(c, cfg) if err != nil { return nil, err } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - return nil, err - } - tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) - if err != nil { - return nil, err - } - - enode := tab.self.String() - log.Warn("UDP listener up for DiscoveryV4. Enode: " + enode) + log.Info("UDP listener up", "self", tab.self) return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { +func newUDP(c conn, cfg Config) (*Table, *udp, error) { udp := &udp{ conn: c, - priv: priv, - netrestrict: netrestrict, + priv: cfg.PrivateKey, + netrestrict: cfg.NetRestrict, closing: make(chan struct{}), gotreply: make(chan reply), addpending: make(chan *pending), } realaddr := c.LocalAddr().(*net.UDPAddr) - if natm != nil { - if !realaddr.IP.IsLoopback() { - go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") - } - // TODO: react to external IP changes over time. - if ext, err := natm.ExternalIP(); err == nil { - realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} - } + if cfg.AnnounceAddr != nil { + realaddr = cfg.AnnounceAddr } // TODO: separate TCP port udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) - tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) + tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) if err != nil { return nil, nil, err } udp.Table = tab go udp.loop() - go udp.readLoop() + go udp.readLoop(cfg.Unhandled) return udp.Table, udp, nil } @@ -254,14 +256,20 @@ func (t *udp) close() { // ping sends a ping message to the given node and waits for a reply. func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { - // TODO: maybe check for ReplyTo field in callback to measure RTT - errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) - t.send(toaddr, pingPacket, &ping{ + req := &ping{ Version: Version, From: t.ourEndpoint, To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), + } + packet, hash, err := encodePacket(t.priv, pingPacket, req) + if err != nil { + return err + } + errc := t.pending(toid, pongPacket, func(p interface{}) bool { + return bytes.Equal(p.(*pong).ReplyTok, hash) }) + t.write(toaddr, req.name(), packet) return <-errc } @@ -445,41 +453,49 @@ func init() { } } -func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error { - packet, err := encodePacket(t.priv, ptype, req) +func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { + packet, hash, err := encodePacket(t.priv, ptype, req) if err != nil { - return err + return hash, err } - _, err = t.conn.WriteToUDP(packet, toaddr) - log.Trace(">> "+req.name(), "addr", toaddr, "err", err) + return hash, t.write(toaddr, req.name(), packet) +} + +func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { + _, err := t.conn.WriteToUDP(packet, toaddr) + log.Trace(">> "+what, "addr", toaddr, "err", err) return err } -func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { +func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { b := new(bytes.Buffer) b.Write(headSpace) b.WriteByte(ptype) if err := rlp.Encode(b, req); err != nil { log.Error("Can't encode discv4 packet", "err", err) - return nil, err + return nil, nil, err } - packet := b.Bytes() + packet = b.Bytes() sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) if err != nil { log.Error("Can't sign discv4 packet", "err", err) - return nil, err + return nil, nil, err } copy(packet[macSize:], sig) // add the hash to the front. Note: this doesn't protect the // packet in any way. Our public key will be part of this hash in // The future. - copy(packet, crypto.Keccak256(packet[macSize:])) - return packet, nil + hash = crypto.Keccak256(packet[macSize:]) + copy(packet, hash) + return packet, hash, nil } // readLoop runs in its own goroutine. it handles incoming UDP packets. -func (t *udp) readLoop() { +func (t *udp) readLoop(unhandled chan<- ReadPacket) { defer t.conn.Close() + if unhandled != nil { + defer close(unhandled) + } // Discovery packets are defined to be no larger than 1280 bytes. // Packets larger than this size will be cut at the end and treated // as invalid because their hash won't match. @@ -495,7 +511,12 @@ func (t *udp) readLoop() { log.Debug("UDP read error", "err", err) return } - t.handlePacket(from, buf[:nbytes]) + if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { + select { + case unhandled <- ReadPacket{buf[:nbytes], from}: + default: + } + } } } @@ -575,7 +596,7 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte if expired(req.Expiration) { return errExpired } - if t.db.node(fromID) == nil { + if !t.db.hasBond(fromID) { // No bond exists, we don't process the packet. This prevents // an attack vector where the discovery protocol could be used // to amplify traffic in a DDOS attack. A malicious actor @@ -591,18 +612,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte t.mutex.Unlock() p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} + var sent bool // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. - for i, n := range closest { - if netutil.CheckRelayIP(from.IP, n.IP) != nil { - continue + for _, n := range closest { + if netutil.CheckRelayIP(from.IP, n.IP) == nil { + p.Nodes = append(p.Nodes, nodeToRPC(n)) } - p.Nodes = append(p.Nodes, nodeToRPC(n)) - if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { + if len(p.Nodes) == maxNeighbors { t.send(from, neighborsPacket, &p) p.Nodes = p.Nodes[:0] + sent = true } } + if len(p.Nodes) > 0 || !sent { + t.send(from, neighborsPacket, &p) + } return nil } diff --git a/client/p2p/discv5/database.go b/client/p2p/discv5/database.go index 3e70c4e9e..63c72421a 100644 --- a/client/p2p/discv5/database.go +++ b/client/p2p/discv5/database.go @@ -223,14 +223,14 @@ func (db *nodeDB) ensureExpirer() { // expirer should be started in a go routine, and is responsible for looping ad // infinitum and dropping stale data from the database. func (db *nodeDB) expirer() { - tick := time.Tick(nodeDBCleanupCycle) + tick := time.NewTicker(nodeDBCleanupCycle) + defer tick.Stop() for { select { - case <-tick: + case <-tick.C: if err := db.expireNodes(); err != nil { log.Error(fmt.Sprintf("Failed to expire nodedb items: %v", err)) } - case <-db.quit: return } diff --git a/client/p2p/discv5/net.go b/client/p2p/discv5/net.go index d515948a8..8d4af2eca 100644 --- a/client/p2p/discv5/net.go +++ b/client/p2p/discv5/net.go @@ -13,7 +13,6 @@ import ( "github.com/kowala-tech/kcoin/client/crypto" "github.com/kowala-tech/kcoin/client/crypto/sha3" "github.com/kowala-tech/kcoin/client/log" - "github.com/kowala-tech/kcoin/client/p2p/nat" "github.com/kowala-tech/kcoin/client/p2p/netutil" "github.com/kowala-tech/kcoin/client/rlp" ) @@ -21,7 +20,6 @@ import ( var ( errInvalidEvent = errors.New("invalid in current state") errNoQuery = errors.New("no pending query") - errWrongAddress = errors.New("unknown sender address") ) const ( @@ -35,16 +33,9 @@ const ( const testTopic = "foo" const ( - printDebugLogs = false printTestImgLogs = false ) -func debugLog(s string) { - if printDebugLogs { - fmt.Println(s) - } -} - // Network manages the table and all protocol interaction. type Network struct { db *nodeDB // database of known nodes @@ -125,7 +116,7 @@ type timeoutEvent struct { node *Node } -func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { +func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { ourID := PubkeyID(&ourPubkey) var db *nodeDB @@ -363,17 +354,26 @@ func (net *Network) loop() { // Tracking the next ticket to register. var ( nextTicket *ticketRef - nextRegisterTimer = time.NewTicker(1*time.Second) - nextRegisterTime = nextRegisterTimer.C + nextRegisterTimer *time.Timer + nextRegisterTime <-chan time.Time ) defer func() { - nextRegisterTimer.Stop() + if nextRegisterTimer != nil { + nextRegisterTimer.Stop() + } }() - resetNextTicket := func() { - t, _ := net.ticketStore.nextFilteredTicket() - if t != nextTicket { - nextTicket = t + ticket, timeout := net.ticketStore.nextFilteredTicket() + if nextTicket != ticket { + nextTicket = ticket + if nextRegisterTimer != nil { + nextRegisterTimer.Stop() + nextRegisterTime = nil + } + if ticket != nil { + nextRegisterTimer = time.NewTimer(timeout) + nextRegisterTime = nextRegisterTimer.C + } } } @@ -394,12 +394,11 @@ func (net *Network) loop() { go func() { for range nextRegisterTime { + log.Trace("<-nextRegisterTime") if nextTicket == nil { continue } - debugLog("<-nextRegisterTime") net.ticketStore.ticketRegistered(*nextTicket) - //fmt.Println("sendTopicRegister", nextTicket.t.node.addr().String(), nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong) net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong) } }() @@ -410,13 +409,13 @@ loop: select { case <-net.closeReq: - debugLog("<-net.closeReq") + log.Trace("<-net.closeReq") break loop // Ingress packet handling. case pkt := <-net.read: //fmt.Println("read", pkt.ev) - debugLog("<-net.read") + log.Trace("<-net.read") n := net.internNode(&pkt) prestate := n.state status := "ok" @@ -431,7 +430,7 @@ loop: // State transition timeouts. case timeout := <-net.timeout: - debugLog("<-net.timeout") + log.Trace("<-net.timeout") if net.timeoutTimers[timeout] == nil { // Stale timer (was aborted). continue @@ -449,20 +448,20 @@ loop: // Querying. case q := <-net.queryReq: - debugLog("<-net.queryReq") + log.Trace("<-net.queryReq") if !q.start(net) { q.remote.deferQuery(q) } // Interacting with the table. case f := <-net.tableOpReq: - debugLog("<-net.tableOpReq") + log.Trace("<-net.tableOpReq") f() net.tableOpResp <- struct{}{} // Topic registration stuff. case req := <-net.topicRegisterReq: - debugLog("<-net.topicRegisterReq") + log.Trace("<-net.topicRegisterReq") if !req.add { net.ticketStore.removeRegisterTopic(req.topic) continue @@ -473,7 +472,7 @@ loop: // determination for new topics. // if topicRegisterLookupDone == nil { if topicRegisterLookupTarget.target == (common.Hash{}) { - debugLog("topicRegisterLookupTarget == null") + log.Trace("topicRegisterLookupTarget == null") if topicRegisterLookupTick.Stop() { <-topicRegisterLookupTick.C } @@ -483,7 +482,7 @@ loop: } case nodes := <-topicRegisterLookupDone: - debugLog("<-topicRegisterLookupDone") + log.Trace("<-topicRegisterLookupDone") net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte { net.ping(n, n.addr()) return n.pingEcho @@ -494,7 +493,7 @@ loop: topicRegisterLookupDone = nil case <-topicRegisterLookupTick.C: - debugLog("<-topicRegisterLookupTick") + log.Trace("<-topicRegisterLookupTick") if (topicRegisterLookupTarget.target == common.Hash{}) { target, delay := net.ticketStore.nextRegisterLookup() topicRegisterLookupTarget = target @@ -508,7 +507,7 @@ loop: case req := <-net.topicSearchReq: if refreshDone == nil { - debugLog("<-net.topicSearchReq") + log.Trace("<-net.topicSearchReq") info, ok := searchInfo[req.topic] if ok { if req.delay == time.Duration(0) { @@ -552,13 +551,11 @@ loop: case res := <-topicSearchLookupDone: activeSearchCount-- if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil { - lookupChn <- net.ticketStore.radius[res.target.topic].converged + rad, _ := net.ticketStore.radius.get(res.target.topic) + lookupChn <- rad.converged } - net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node) []byte { - net.ping(n, n.addr()) - return n.pingEcho - }, func(n *Node, topic Topic) []byte { - if n.state == known { + net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte { + if n.state != nil && n.state.canQuery { return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration } else { if n.state == unknown { @@ -569,7 +566,7 @@ loop: }) case <-statsDump.C: - debugLog("<-statsDump.C") + log.Trace("<-statsDump.C") /*r, ok := net.ticketStore.radius[testTopic] if !ok { fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now()) @@ -581,7 +578,9 @@ loop: }*/ tm := mclock.Now() - for topic, r := range net.ticketStore.radius { + + net.ticketStore.radius.RLock() + for topic, r := range net.ticketStore.radius.m { if printTestImgLogs { rad := r.radius / (maxRadius/1000000 + 1) minrad := r.minRadius / (maxRadius/1000000 + 1) @@ -589,6 +588,8 @@ loop: fmt.Printf("*MR %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], minrad) } } + net.ticketStore.radius.RUnlock() + for topic, t := range net.topictab.topics { wp := t.wcl.nextWaitPeriod(tm) if printTestImgLogs { @@ -598,7 +599,7 @@ loop: // Periodic / lookup-initiated bucket refresh. case <-refreshTimer.C: - debugLog("<-refreshTimer.C") + log.Trace("<-refreshTimer.C") // TODO: ideally we would start the refresh timer after // fallback nodes have been set for the first time. if refreshDone == nil { @@ -612,7 +613,7 @@ loop: bucketRefreshTimer.Reset(bucketRefreshInterval) }() case newNursery := <-net.refreshReq: - debugLog("<-net.refreshReq") + log.Trace("<-net.refreshReq") if newNursery != nil { net.nursery = newNursery } @@ -622,18 +623,23 @@ loop: } net.refreshResp <- refreshDone case <-refreshDone: - debugLog("<-net.refreshDone") - refreshDone = nil - list := searchReqWhenRefreshDone - searchReqWhenRefreshDone = nil - go func() { - for _, req := range list { - net.topicSearchReq <- req - } - }() + log.Trace("<-net.refreshDone", "table size", net.tab.count) + if net.tab.count != 0 { + refreshDone = nil + list := searchReqWhenRefreshDone + searchReqWhenRefreshDone = nil + go func() { + for _, req := range list { + net.topicSearchReq <- req + } + }() + } else { + refreshDone = make(chan struct{}) + net.refresh(refreshDone) + } } } - debugLog("loop stopped") + log.Trace("loop stopped") log.Debug(fmt.Sprintf("shutting down")) if net.conn != nil { @@ -665,7 +671,7 @@ func (net *Network) refresh(done chan<- struct{}) { seeds = net.nursery } if len(seeds) == 0 { - log.Trace(fmt.Sprint("no seed nodes found")) + log.Trace("no seed nodes found") close(done) return } @@ -740,7 +746,15 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n return n, err } if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP { - err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) + if n.state == known { + // reject address change if node is known by us + err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) + } else { + // accept otherwise; this will be handled nicer with signed ENRs + n.IP = rn.IP + n.UDP = rn.UDP + n.TCP = rn.TCP + } } return n, err } @@ -807,11 +821,10 @@ type nodeEvent uint //go:generate stringer -type=nodeEvent const ( - invalidEvent nodeEvent = iota // zero is reserved // Packet type events. // These correspond to packet types in the UDP protocol. - pingPacket + pingPacket = iota + 1 pongPacket findnodePacket neighborsPacket @@ -1090,14 +1103,14 @@ func (net *Network) ping(n *Node, addr *net.UDPAddr) { //fmt.Println(" not sent") return } - debugLog(fmt.Sprintf("ping(node = %x)", n.ID[:8])) + log.Trace("Pinging remote node", "node", n.ID) n.pingTopics = net.ticketStore.regTopicSet() n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics) net.timedEvent(respTimeout, n, pongTimeout) } func (net *Network) handlePing(n *Node, pkt *ingressPacket) { - debugLog(fmt.Sprintf("handlePing(node = %x)", n.ID[:8])) + log.Trace("Handling remote ping", "node", n.ID) ping := pkt.data.(*ping) n.TCP = ping.From.TCP t := net.topictab.getTicket(n, ping.Topics) @@ -1112,7 +1125,7 @@ func (net *Network) handlePing(n *Node, pkt *ingressPacket) { } func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error { - debugLog(fmt.Sprintf("handleKnownPong(node = %x)", n.ID[:8])) + log.Trace("Handling known pong", "node", n.ID) net.abortTimedEvent(n, pongTimeout) now := mclock.Now() ticket, err := pongToTicket(now, n.pingTopics, n, pkt) @@ -1120,9 +1133,8 @@ func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error { // fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data) net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket) } else { - debugLog(fmt.Sprintf(" error: %v", err)) + log.Trace("Failed to convert pong to ticket", "err", err) } - n.pingEcho = nil n.pingTopics = nil return err @@ -1169,7 +1181,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) // TODO: handle expiration topic := pkt.data.(*topicQuery).Topic results := net.topictab.getEntries(topic) - if _, ok := net.ticketStore.tickets[topic]; ok { + if _, ok := net.ticketStore.tickets.get(topic); ok { results = append(results, net.tab.self) // we're not registering in our own table but if we're advertising, return ourselves too } if len(results) > 10 { diff --git a/client/p2p/discv5/node.go b/client/p2p/discv5/node.go index a374add1e..b7aefc5d2 100644 --- a/client/p2p/discv5/node.go +++ b/client/p2p/discv5/node.go @@ -257,6 +257,11 @@ func (n NodeID) GoString() string { return fmt.Sprintf("discover.HexID(\"%x\")", n[:]) } +// TerminalString returns a shortened hex string for terminal logging. +func (n NodeID) TerminalString() string { + return hex.EncodeToString(n[:8]) +} + // HexID converts a hex string to a NodeID. // The string may be prefixed with 0x. func HexID(in string) (NodeID, error) { @@ -294,11 +299,11 @@ func PubkeyID(pub *ecdsa.PublicKey) NodeID { // Pubkey returns the public key represented by the node ID. // It returns an error if the ID is not a point on the curve. -func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) { +func (n NodeID) Pubkey() (*ecdsa.PublicKey, error) { p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)} - half := len(id) / 2 - p.X.SetBytes(id[:half]) - p.Y.SetBytes(id[half:]) + half := len(n) / 2 + p.X.SetBytes(n[:half]) + p.Y.SetBytes(n[half:]) if !p.Curve.IsOnCurve(p.X, p.Y) { return nil, errors.New("id is invalid secp256k1 curve point") } diff --git a/client/p2p/discv5/node_tickets.go b/client/p2p/discv5/node_tickets.go new file mode 100644 index 000000000..9fa4803de --- /dev/null +++ b/client/p2p/discv5/node_tickets.go @@ -0,0 +1,42 @@ +package discv5 + +import ( + "sync" + "github.com/kowala-tech/kcoin/client/log" +) + +type nodeTickets struct { + m map[*Node]*ticket + sync.RWMutex +} + +func newNodeTickets() *nodeTickets { + return &nodeTickets{m: make(map[*Node]*ticket)} +} + +func (r *nodeTickets) get(key *Node) (*ticket, bool) { + r.RLock() + v, ok := r.m[key] + r.RUnlock() + + switch ok { + case true: + log.Trace("Retrieving node ticket", "node", key.ID, "serial", v.serial) + case false: + log.Trace("Retrieving node ticket", "node", key.ID, "serial", nil) + } + + return v, ok +} + +func (r *nodeTickets) set(key *Node, v *ticket) { + r.Lock() + r.m[key] = v + r.Unlock() +} + +func (r *nodeTickets) delete(key *Node) { + r.Lock() + delete(r.m, key) + r.Unlock() +} diff --git a/client/p2p/discv5/nodeevent_string.go b/client/p2p/discv5/nodeevent_string.go index fde9045c5..38c1993ba 100644 --- a/client/p2p/discv5/nodeevent_string.go +++ b/client/p2p/discv5/nodeevent_string.go @@ -1,27 +1,17 @@ -// Code generated by "stringer -type nodeEvent"; DO NOT EDIT +// Code generated by "stringer -type=nodeEvent"; DO NOT EDIT. package discv5 -import "fmt" +import "strconv" -const ( - _nodeEvent_name_0 = "invalidEventpingPacketpongPacketfindnodePacketneighborsPacketfindnodeHashPackettopicRegisterPackettopicQueryPackettopicNodesPacket" - _nodeEvent_name_1 = "pongTimeoutpingTimeoutneighboursTimeout" -) +const _nodeEvent_name = "pongTimeoutpingTimeoutneighboursTimeout" -var ( - _nodeEvent_index_0 = [...]uint8{0, 12, 22, 32, 46, 61, 79, 98, 114, 130} - _nodeEvent_index_1 = [...]uint8{0, 11, 22, 39} -) +var _nodeEvent_index = [...]uint8{0, 11, 22, 39} func (i nodeEvent) String() string { - switch { - case 0 <= i && i <= 8: - return _nodeEvent_name_0[_nodeEvent_index_0[i]:_nodeEvent_index_0[i+1]] - case 265 <= i && i <= 267: - i -= 265 - return _nodeEvent_name_1[_nodeEvent_index_1[i]:_nodeEvent_index_1[i+1]] - default: - return fmt.Sprintf("nodeEvent(%d)", i) + i -= 264 + if i >= nodeEvent(len(_nodeEvent_index)-1) { + return "nodeEvent(" + strconv.FormatInt(int64(i+264), 10) + ")" } + return _nodeEvent_name[_nodeEvent_index[i]:_nodeEvent_index[i+1]] } diff --git a/client/p2p/discv5/ntp.go b/client/p2p/discv5/ntp.go index 4b62abb92..12b7c32e5 100644 --- a/client/p2p/discv5/ntp.go +++ b/client/p2p/discv5/ntp.go @@ -38,10 +38,10 @@ func checkClockDrift() { howtofix := fmt.Sprintf("Please enable network time synchronisation in system settings") separator := strings.Repeat("-", len(warning)) - log.Warn(fmt.Sprint(separator)) - log.Warn(fmt.Sprint(warning)) - log.Warn(fmt.Sprint(howtofix)) - log.Warn(fmt.Sprint(separator)) + log.Warn(separator) + log.Warn(warning) + log.Warn(howtofix) + log.Warn(separator) } else { log.Debug(fmt.Sprintf("Sanity NTP check reported %v drift, all ok", drift)) } diff --git a/client/p2p/discv5/radius.go b/client/p2p/discv5/radius.go new file mode 100644 index 000000000..ec280d865 --- /dev/null +++ b/client/p2p/discv5/radius.go @@ -0,0 +1,31 @@ +package discv5 + +import "sync" + +type radius struct { + m map[Topic]*topicRadius + sync.RWMutex +} + +func newRadius() *radius { + return &radius{m: make(map[Topic]*topicRadius)} +} + +func (r *radius) get(key Topic) (*topicRadius, bool) { + r.RLock() + v, ok := r.m[key] + r.RUnlock() + return v, ok +} + +func (r *radius) set(key Topic, v *topicRadius) { + r.Lock() + r.m[key] = v + r.Unlock() +} + +func (r *radius) delete(key Topic) { + r.Lock() + delete(r.m, key) + r.Unlock() +} diff --git a/client/p2p/discv5/request_info.go b/client/p2p/discv5/request_info.go new file mode 100644 index 000000000..d6c6b2508 --- /dev/null +++ b/client/p2p/discv5/request_info.go @@ -0,0 +1,31 @@ +package discv5 + +import "sync" + +type requestInfo struct { + m map[*Node]reqInfo + sync.RWMutex +} + +func newRequestInfo() *requestInfo { + return &requestInfo{m: make(map[*Node]reqInfo)} +} + +func (r *requestInfo) get(key *Node) (reqInfo, bool) { + r.RLock() + v, ok := r.m[key] + r.RUnlock() + return v, ok +} + +func (r *requestInfo) set(key *Node, req reqInfo) { + r.Lock() + r.m[key] = req + r.Unlock() +} + +func (r *requestInfo) delete(key *Node) { + r.Lock() + delete(r.m, key) + r.Unlock() +} diff --git a/client/p2p/discv5/request_set.go b/client/p2p/discv5/request_set.go new file mode 100644 index 000000000..a6769da66 --- /dev/null +++ b/client/p2p/discv5/request_set.go @@ -0,0 +1,38 @@ +package discv5 + +import "sync" + +type requestSet struct { + m map[Topic]struct{} + sync.RWMutex +} + +func newRequestSet() *requestSet { + return &requestSet{m: make(map[Topic]struct{})} +} + +func (r *requestSet) get(key Topic) bool { + r.RLock() + _, ok := r.m[key] + r.RUnlock() + return ok +} + +func (r *requestSet) set(key Topic) { + r.Lock() + r.m[key] = struct{}{} + r.Unlock() +} + +func (r *requestSet) len() int { + r.RLock() + l := len(r.m) + r.RUnlock() + return l +} + +func (r *requestSet) delete(key Topic) { + r.Lock() + delete(r.m, key) + r.Unlock() +} diff --git a/client/p2p/discv5/search_ticket.go b/client/p2p/discv5/search_ticket.go new file mode 100644 index 000000000..6d82854a6 --- /dev/null +++ b/client/p2p/discv5/search_ticket.go @@ -0,0 +1,45 @@ +package discv5 + +import "sync" + +type searchTopics struct { + m map[Topic]searchTopic + sync.RWMutex +} + +func newSearchTopics() *searchTopics { + return &searchTopics{m: make(map[Topic]searchTopic)} +} + +func (r *searchTopics) get(key Topic) (searchTopic, bool) { + r.RLock() + v, ok := r.m[key] + r.RUnlock() + return v, ok +} + +func (r *searchTopics) has(key Topic) bool { + r.RLock() + v, ok := r.m[key] + r.RUnlock() + return ok && v.foundChn != nil +} + +func (r *searchTopics) set(key Topic, v searchTopic) { + r.Lock() + r.m[key] = v + r.Unlock() +} + +func (r *searchTopics) len() int { + r.RLock() + l := len(r.m) + r.RUnlock() + return l +} + +func (r *searchTopics) delete(key Topic) { + r.Lock() + delete(r.m, key) + r.Unlock() +} diff --git a/client/p2p/discv5/table.go b/client/p2p/discv5/table.go index 43d70e55d..d220b841a 100644 --- a/client/p2p/discv5/table.go +++ b/client/p2p/discv5/table.go @@ -22,7 +22,6 @@ const ( hashBits = len(common.Hash{}) * 8 nBuckets = hashBits + 1 // Number of buckets - maxBondingPingPongs = 16 maxFindnodeFailures = 5 ) diff --git a/client/p2p/discv5/ticket.go b/client/p2p/discv5/ticket.go index 0a70ca47e..78d046b14 100644 --- a/client/p2p/discv5/ticket.go +++ b/client/p2p/discv5/ticket.go @@ -12,6 +12,7 @@ import ( "github.com/kowala-tech/kcoin/client/common" "github.com/kowala-tech/kcoin/client/common/mclock" "github.com/kowala-tech/kcoin/client/crypto" + "github.com/kowala-tech/kcoin/client/log" ) const ( @@ -107,21 +108,24 @@ func ticketToPong(t *ticket, pong *pong) { type ticketStore struct { // radius detector and target address generator // exists for both searched and registered topics - radius map[Topic]*topicRadius + radius *radius // Contains buckets (for each absolute minute) of tickets // that can be used in that minute. // This is only set if the topic is being registered. - tickets map[Topic]topicTickets - regtopics []Topic - nodes map[*Node]*ticket - nodeLastReq map[*Node]reqInfo + tickets *tickets + + regQueue []Topic // Topic registration queue for round robin attempts + regSet *requestSet // Topic registration queue contents for fast filling + + nodes *nodeTickets + nodeLastReq *requestInfo lastBucketFetched timeBucket nextTicketCached *ticketRef nextTicketReg mclock.AbsTime - searchTopicMap map[Topic]searchTopic + searchTopicMap *searchTopics nextTopicQueryCleanup mclock.AbsTime queriesSent map[*Node]map[common.Hash]sentQuery } @@ -136,91 +140,120 @@ type sentQuery struct { } type topicTickets struct { - buckets map[timeBucket][]ticketRef - nextLookup, nextReg mclock.AbsTime + buckets map[timeBucket][]ticketRef + nextLookup mclock.AbsTime + nextReg mclock.AbsTime } func newTicketStore() *ticketStore { return &ticketStore{ - radius: make(map[Topic]*topicRadius), - tickets: make(map[Topic]topicTickets), - nodes: make(map[*Node]*ticket), - nodeLastReq: make(map[*Node]reqInfo), - searchTopicMap: make(map[Topic]searchTopic), + radius: newRadius(), + tickets: newTickets(), + regSet: newRequestSet(), + nodes: newNodeTickets(), + nodeLastReq: newRequestInfo(), + searchTopicMap: newSearchTopics(), queriesSent: make(map[*Node]map[common.Hash]sentQuery), } } // addTopic starts tracking a topic. If register is true, // the local node will register the topic and tickets will be collected. -func (s *ticketStore) addTopic(t Topic, register bool) { - debugLog(fmt.Sprintf(" addTopic(%v, %v)", t, register)) - if s.radius[t] == nil { - s.radius[t] = newTopicRadius(t) +func (s *ticketStore) addTopic(topic Topic, register bool) { + log.Trace("Adding discovery topic", "topic", topic, "register", register) + if _, ok := s.radius.get(topic); !ok { + s.radius.set(topic, newTopicRadius(topic)) } - if register && s.tickets[t].buckets == nil { - s.tickets[t] = topicTickets{buckets: make(map[timeBucket][]ticketRef)} + _, ok := s.tickets.get(topic) + if register && !ok { + s.tickets.set(topic, &topicTickets{buckets: make(map[timeBucket][]ticketRef)}) } } func (s *ticketStore) addSearchTopic(t Topic, foundChn chan<- *Node) { s.addTopic(t, false) - if s.searchTopicMap[t].foundChn == nil { - s.searchTopicMap[t] = searchTopic{foundChn: foundChn} + if !s.searchTopicMap.has(t) { + s.searchTopicMap.set(t, searchTopic{foundChn: foundChn}) } } func (s *ticketStore) removeSearchTopic(t Topic) { - if st := s.searchTopicMap[t]; st.foundChn != nil { - delete(s.searchTopicMap, t) + if s.searchTopicMap.has(t) { + s.searchTopicMap.delete(t) } } // removeRegisterTopic deletes all tickets for the given topic. func (s *ticketStore) removeRegisterTopic(topic Topic) { - debugLog(fmt.Sprintf(" removeRegisterTopic(%v)", topic)) - for _, list := range s.tickets[topic].buckets { + log.Trace("Removing discovery topic", "topic", topic) + + ticket, ok := s.tickets.get(topic) + if !ok { + log.Warn("Removing non-existent discovery topic", "topic", topic) + return + } + for _, list := range ticket.buckets { for _, ref := range list { ref.t.refCnt-- if ref.t.refCnt == 0 { - delete(s.nodes, ref.t.node) - delete(s.nodeLastReq, ref.t.node) + s.nodes.delete(ref.t.node) + s.nodeLastReq.delete(ref.t.node) } } } - delete(s.tickets, topic) + s.tickets.delete(topic) } func (s *ticketStore) regTopicSet() []Topic { - topics := make([]Topic, 0, len(s.tickets)) - for topic := range s.tickets { + topics := make([]Topic, 0, s.tickets.len()) + + s.tickets.RLock() + for topic := range s.tickets.m { topics = append(topics, topic) } + s.tickets.RUnlock() return topics } // nextRegisterLookup returns the target of the next lookup for ticket collection. -func (s *ticketStore) nextRegisterLookup() (lookup lookupInfo, delay time.Duration) { - debugLog("nextRegisterLookup()") - firstTopic, ok := s.iterRegTopics() - for topic := firstTopic; ok; { - debugLog(fmt.Sprintf(" checking topic %v, len(s.tickets[topic]) = %d", topic, len(s.tickets[topic].buckets))) - if s.tickets[topic].buckets != nil && s.needMoreTickets(topic) { - next := s.radius[topic].nextTarget(false) - debugLog(fmt.Sprintf(" %x 1s", next.target[:8])) - return next, 100 * time.Millisecond +func (s *ticketStore) nextRegisterLookup() (lookupInfo, time.Duration) { + // Queue up any new topics (or discarded ones), preserving iteration order + s.tickets.RLock() + for topic := range s.tickets.m { + if ok := s.regSet.get(topic); !ok { + s.regQueue = append(s.regQueue, topic) + s.regSet.set(topic) } - topic, ok = s.iterRegTopics() - if topic == firstTopic { - break // We have checked all topics. + } + s.tickets.RUnlock() + + // Iterate over the set of all topics and look up the next suitable one + for len(s.regQueue) > 0 { + // Fetch the next topic from the queue, and ensure it still exists + topic := s.regQueue[0] + s.regQueue = s.regQueue[1:] + s.regSet.delete(topic) + + ticket, ok := s.tickets.get(topic) + if !ok { + continue + } + // If the topic needs more tickets, return it + if ticket.nextLookup < mclock.Now() { + rad, _ := s.radius.get(topic) + next, delay := rad.nextTarget(false), 100*time.Millisecond + log.Trace("Found discovery topic to register", "topic", topic, "target", next.target, "delay", delay) + return next, delay } } - debugLog(" null, 40s") - return lookupInfo{}, 40 * time.Second + // No registration topics found or all exhausted, sleep + delay := 40 * time.Second + log.Trace("No topic found to register", "delay", delay) + return lookupInfo{}, delay } func (s *ticketStore) nextSearchLookup(topic Topic) lookupInfo { - tr := s.radius[topic] + tr, _ := s.radius.get(topic) target := tr.nextTarget(tr.radiusLookupCnt >= searchForceQuery) if target.radiusLookup { tr.radiusLookupCnt++ @@ -230,40 +263,23 @@ func (s *ticketStore) nextSearchLookup(topic Topic) lookupInfo { return target } -// iterRegTopics returns topics to register in arbitrary order. -// The second return value is false if there are no topics. -func (s *ticketStore) iterRegTopics() (Topic, bool) { - debugLog("iterRegTopics()") - if len(s.regtopics) == 0 { - if len(s.tickets) == 0 { - debugLog(" false") - return "", false - } - // Refill register list. - for t := range s.tickets { - s.regtopics = append(s.regtopics, t) - } +// ticketsInWindow returns the tickets of a given topic in the registration window. +func (s *ticketStore) ticketsInWindow(topic Topic) []ticketRef { + // Sanity check that the topic still exists before operating on it + ticket, ok := s.tickets.get(topic) + if !ok { + log.Warn("Listing non-existing discovery tickets", "topic", topic) + return nil } - topic := s.regtopics[len(s.regtopics)-1] - s.regtopics = s.regtopics[:len(s.regtopics)-1] - debugLog(" " + string(topic) + " true") - return topic, true -} - -func (s *ticketStore) needMoreTickets(t Topic) bool { - return s.tickets[t].nextLookup < mclock.Now() -} + // Gather all the tickers in the next time window + var tickets []ticketRef -// ticketsInWindow returns the tickets of a given topic in the registration window. -func (s *ticketStore) ticketsInWindow(t Topic) []ticketRef { - ltBucket := s.lastBucketFetched - var res []ticketRef - tickets := s.tickets[t].buckets - for g := ltBucket; g < ltBucket+timeWindow; g++ { - res = append(res, tickets[g]...) + buckets := ticket.buckets + for idx := timeBucket(0); idx < timeWindow; idx++ { + tickets = append(tickets, buckets[s.lastBucketFetched+idx]...) } - debugLog(fmt.Sprintf("ticketsInWindow(%v) = %v", t, len(res))) - return res + log.Trace("Retrieved discovery registration tickets", "topic", topic, "from", s.lastBucketFetched, "tickets", len(tickets)) + return tickets } func (s *ticketStore) removeExcessTickets(t Topic) { @@ -284,8 +300,8 @@ func (s ticketRefByWaitTime) Len() int { return len(s) } -func (r ticketRef) waitTime() mclock.AbsTime { - return r.t.regTime[r.idx] - r.t.issueTime +func (ref ticketRef) waitTime() mclock.AbsTime { + return ref.t.regTime[ref.idx] - ref.t.issueTime } // Less reports whether the element with @@ -301,53 +317,57 @@ func (s ticketRefByWaitTime) Swap(i, j int) { func (s *ticketStore) addTicketRef(r ticketRef) { topic := r.t.topics[r.idx] - t := s.tickets[topic] - if t.buckets == nil { + tickets, ok := s.tickets.get(topic) + if !ok { + log.Warn("Adding ticket to non-existent topic", "topic", topic) return } bucket := timeBucket(r.t.regTime[r.idx] / mclock.AbsTime(ticketTimeBucketLen)) - t.buckets[bucket] = append(t.buckets[bucket], r) + tickets.buckets[bucket] = append(tickets.buckets[bucket], r) r.t.refCnt++ min := mclock.Now() - mclock.AbsTime(collectFrequency)*maxCollectDebt - if t.nextLookup < min { - t.nextLookup = min + if tickets.nextLookup < min { + tickets.nextLookup = min } - t.nextLookup += mclock.AbsTime(collectFrequency) - s.tickets[topic] = t + tickets.nextLookup += mclock.AbsTime(collectFrequency) //s.removeExcessTickets(topic) } -func (s *ticketStore) nextFilteredTicket() (t *ticketRef, wait time.Duration) { +func (s *ticketStore) nextFilteredTicket() (*ticketRef, time.Duration) { now := mclock.Now() for { - t, wait = s.nextRegisterableTicket() - if t == nil { - return + ticket, wait := s.nextRegisterableTicket() + if ticket == nil { + return ticket, wait } + log.Trace("Found discovery ticket to register", "node", ticket.t.node, "serial", ticket.t.serial, "wait", wait) + regTime := now + mclock.AbsTime(wait) - topic := t.t.topics[t.idx] - if regTime >= s.tickets[topic].nextReg { - return + topic := ticket.t.topics[ticket.idx] + + tic, ok := s.tickets.get(topic) + if ok && regTime >= tic.nextReg { + return ticket, wait } - s.removeTicketRef(*t) + s.removeTicketRef(*ticket) } } -func (s *ticketStore) ticketRegistered(t ticketRef) { +func (s *ticketStore) ticketRegistered(ref ticketRef) { now := mclock.Now() - topic := t.t.topics[t.idx] - tt := s.tickets[topic] + topic := ref.t.topics[ref.idx] + tickets, _ := s.tickets.get(topic) min := now - mclock.AbsTime(registerFrequency)*maxRegisterDebt - if min > tt.nextReg { - tt.nextReg = min + if min > tickets.nextReg { + tickets.nextReg = min } - tt.nextReg += mclock.AbsTime(registerFrequency) - s.tickets[topic] = tt + tickets.nextReg += mclock.AbsTime(registerFrequency) + s.tickets.set(topic, tickets) - s.removeTicketRef(t) + s.removeTicketRef(ref) } // nextRegisterableTicket returns the next ticket that can be used @@ -358,16 +378,7 @@ func (s *ticketStore) ticketRegistered(t ticketRef) { // // A ticket can be returned more than once with <= zero wait time in case // the ticket contains multiple topics. -func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration) { - defer func() { - if t == nil { - debugLog(" nil") - } else { - debugLog(fmt.Sprintf(" node = %x sn = %v wait = %v", t.t.node.ID[:8], t.t.serial, wait)) - } - }() - - debugLog("nextRegisterableTicket()") +func (s *ticketStore) nextRegisterableTicket() (*ticketRef, time.Duration) { now := mclock.Now() if s.nextTicketCached != nil { return s.nextTicketCached, time.Duration(s.nextTicketCached.topicRegTime() - now) @@ -378,27 +389,30 @@ func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration empty = true // true if there are no tickets nextTicket ticketRef // uninitialized if this bucket is empty ) - for _, tickets := range s.tickets { + + s.tickets.RLock() + for _, tickets := range s.tickets.m { //s.removeExcessTickets(topic) if len(tickets.buckets) != 0 { empty = false - if list := tickets.buckets[bucket]; list != nil { - for _, ref := range list { - //debugLog(fmt.Sprintf(" nrt bucket = %d node = %x sn = %v wait = %v", bucket, ref.t.node.ID[:8], ref.t.serial, time.Duration(ref.topicRegTime()-now))) - if nextTicket.t == nil || ref.topicRegTime() < nextTicket.topicRegTime() { - nextTicket = ref - } + + list := tickets.buckets[bucket] + for _, ref := range list { + //debugLog(fmt.Sprintf(" nrt bucket = %d node = %x sn = %v wait = %v", bucket, ref.t.node.ID[:8], ref.t.serial, time.Duration(ref.topicRegTime()-now))) + if nextTicket.t == nil || ref.topicRegTime() < nextTicket.topicRegTime() { + nextTicket = ref } } } } + s.tickets.RUnlock() + if empty { return nil, 0 } if nextTicket.t != nil { - wait = time.Duration(nextTicket.topicRegTime() - now) s.nextTicketCached = &nextTicket - return &nextTicket, wait + return &nextTicket, time.Duration(nextTicket.topicRegTime() - now) } s.lastBucketFetched = bucket } @@ -406,14 +420,21 @@ func (s *ticketStore) nextRegisterableTicket() (t *ticketRef, wait time.Duration // removeTicket removes a ticket from the ticket store func (s *ticketStore) removeTicketRef(ref ticketRef) { - debugLog(fmt.Sprintf("removeTicketRef(node = %x sn = %v)", ref.t.node.ID[:8], ref.t.serial)) + log.Trace("Removing discovery ticket reference", "node", ref.t.node.ID, "serial", ref.t.serial) + + // Make nextRegisterableTicket return the next available ticket. + s.nextTicketCached = nil + topic := ref.topic() - tickets := s.tickets[topic].buckets - if tickets == nil { + + tickets, ok := s.tickets.get(topic) + if !ok { + log.Trace("Removing tickets from unknown topic", "topic", topic) return } + bucket := timeBucket(ref.t.regTime[ref.idx] / mclock.AbsTime(ticketTimeBucketLen)) - list := tickets[bucket] + list := tickets.buckets[bucket] idx := -1 for i, bt := range list { if bt.t == ref.t { @@ -422,22 +443,19 @@ func (s *ticketStore) removeTicketRef(ref ticketRef) { } } if idx == -1 { - panic(nil) + panic("") } list = append(list[:idx], list[idx+1:]...) if len(list) != 0 { - tickets[bucket] = list + tickets.buckets[bucket] = list } else { - delete(tickets, bucket) + delete(tickets.buckets, bucket) } ref.t.refCnt-- if ref.t.refCnt == 0 { - delete(s.nodes, ref.t.node) - delete(s.nodeLastReq, ref.t.node) + s.nodes.delete(ref.t.node) + s.nodeLastReq.delete(ref.t.node) } - - // Make nextRegisterableTicket return the next available ticket. - s.nextTicketCached = nil } type lookupInfo struct { @@ -465,27 +483,29 @@ func (t *ticket) findIdx(topic Topic) int { func (s *ticketStore) registerLookupDone(lookup lookupInfo, nodes []*Node, ping func(n *Node) []byte) { now := mclock.Now() for i, n := range nodes { - if i == 0 || (binary.BigEndian.Uint64(n.sha[:8])^binary.BigEndian.Uint64(lookup.target[:8])) < s.radius[lookup.topic].minRadius { + rad, _ := s.radius.get(lookup.topic) + if i == 0 || (binary.BigEndian.Uint64(n.sha[:8])^binary.BigEndian.Uint64(lookup.target[:8])) < rad.minRadius { if lookup.radiusLookup { - if lastReq, ok := s.nodeLastReq[n]; !ok || time.Duration(now-lastReq.time) > radiusTC { - s.nodeLastReq[n] = reqInfo{pingHash: ping(n), lookup: lookup, time: now} + if lastReq, ok := s.nodeLastReq.get(n); !ok || time.Duration(now-lastReq.time) > radiusTC { + s.nodeLastReq.set(n, reqInfo{pingHash: ping(n), lookup: lookup, time: now}) } } else { - if s.nodes[n] == nil { - s.nodeLastReq[n] = reqInfo{pingHash: ping(n), lookup: lookup, time: now} + if _, ok := s.nodes.get(n); !ok { + s.nodeLastReq.set(n, reqInfo{pingHash: ping(n), lookup: lookup, time: now}) } } } } } -func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, ping func(n *Node) []byte, query func(n *Node, topic Topic) []byte) { +func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, query func(n *Node, topic Topic) []byte) { now := mclock.Now() for i, n := range nodes { - if i == 0 || (binary.BigEndian.Uint64(n.sha[:8])^binary.BigEndian.Uint64(lookup.target[:8])) < s.radius[lookup.topic].minRadius { + rad, _ := s.radius.get(lookup.topic) + if i == 0 || (binary.BigEndian.Uint64(n.sha[:8])^binary.BigEndian.Uint64(lookup.target[:8])) < rad.minRadius { if lookup.radiusLookup { - if lastReq, ok := s.nodeLastReq[n]; !ok || time.Duration(now-lastReq.time) > radiusTC { - s.nodeLastReq[n] = reqInfo{pingHash: ping(n), lookup: lookup, time: now} + if lastReq, ok := s.nodeLastReq.get(n); !ok || time.Duration(now-lastReq.time) > radiusTC { + s.nodeLastReq.set(n, reqInfo{pingHash: nil, lookup: lookup, time: now}) } } // else { if s.canQueryTopic(n, lookup.topic) { @@ -501,27 +521,28 @@ func (s *ticketStore) searchLookupDone(lookup lookupInfo, nodes []*Node, ping fu func (s *ticketStore) adjustWithTicket(now mclock.AbsTime, targetHash common.Hash, t *ticket) { for i, topic := range t.topics { - if tt, ok := s.radius[topic]; ok { + if tt, ok := s.radius.get(topic); ok { tt.adjustWithTicket(now, targetHash, ticketRef{t, i}) } } } -func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, t *ticket) { - debugLog(fmt.Sprintf("add(node = %x sn = %v)", t.node.ID[:8], t.serial)) +func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, ticket *ticket) { + log.Trace("Adding discovery ticket", "node", ticket.node.ID, "serial", ticket.serial) - lastReq, ok := s.nodeLastReq[t.node] + lastReq, ok := s.nodeLastReq.get(ticket.node) if !(ok && bytes.Equal(pingHash, lastReq.pingHash)) { return } - s.adjustWithTicket(localTime, lastReq.lookup.target, t) + s.adjustWithTicket(localTime, lastReq.lookup.target, ticket) - if lastReq.lookup.radiusLookup || s.nodes[t.node] != nil { + _, ok = s.nodes.get(ticket.node) + if lastReq.lookup.radiusLookup || ok { return } topic := lastReq.lookup.topic - topicIdx := t.findIdx(topic) + topicIdx := ticket.findIdx(topic) if topicIdx == -1 { return } @@ -531,32 +552,23 @@ func (s *ticketStore) addTicket(localTime mclock.AbsTime, pingHash []byte, t *ti s.lastBucketFetched = bucket } - if _, ok := s.tickets[topic]; ok { - wait := t.regTime[topicIdx] - localTime + if _, ok := s.tickets.get(topic); ok { + wait := ticket.regTime[topicIdx] - localTime rnd := rand.ExpFloat64() if rnd > 10 { rnd = 10 } if float64(wait) < float64(keepTicketConst)+float64(keepTicketExp)*rnd { // use the ticket to register this topic - //fmt.Println("addTicket", t.node.ID[:8], t.node.addr().String(), t.serial, t.pong) - s.addTicketRef(ticketRef{t, topicIdx}) + //fmt.Println("addTicket", ticket.node.ID[:8], ticket.node.addr().String(), ticket.serial, ticket.pong) + s.addTicketRef(ticketRef{ticket, topicIdx}) } } - if t.refCnt > 0 { + if ticket.refCnt > 0 { s.nextTicketCached = nil - s.nodes[t.node] = t - } -} - -func (s *ticketStore) getNodeTicket(node *Node) *ticket { - if s.nodes[node] == nil { - debugLog(fmt.Sprintf("getNodeTicket(%x) sn = nil", node.ID[:8])) - } else { - debugLog(fmt.Sprintf("getNodeTicket(%x) sn = %v", node.ID[:8], s.nodes[node].serial)) + s.nodes.set(ticket.node, ticket) } - return s.nodes[node] } func (s *ticketStore) canQueryTopic(node *Node, topic Topic) bool { @@ -616,20 +628,25 @@ func (s *ticketStore) gotTopicNodes(from *Node, hash common.Hash, nodes []rpcNod if len(nodes) > 0 { inside = 1 } - s.radius[q.lookup.topic].adjust(now, q.lookup.target, from.sha, inside) - chn := s.searchTopicMap[q.lookup.topic].foundChn - if chn == nil { + + rad, _ := s.radius.get(q.lookup.topic) + rad.adjust(now, q.lookup.target, from.sha, inside) + + if !s.searchTopicMap.has(q.lookup.topic) { //fmt.Println("no channel") return false } + + searchTopics, _ := s.searchTopicMap.get(q.lookup.topic) + for _, node := range nodes { ip := node.IP if ip.IsUnspecified() || ip.IsLoopback() { ip = from.IP } - n := NewNode(node.ID, ip, node.UDP-1, node.TCP-1) // subtract one from port while discv5 is running in test mode on UDPport+1 + n := NewNode(node.ID, ip, node.UDP, node.TCP) select { - case chn <- n: + case searchTopics.foundChn <- n: default: return false } diff --git a/client/p2p/discv5/tickets.go b/client/p2p/discv5/tickets.go new file mode 100644 index 000000000..33fae530c --- /dev/null +++ b/client/p2p/discv5/tickets.go @@ -0,0 +1,38 @@ +package discv5 + +import "sync" + +type tickets struct { + m map[Topic]*topicTickets + sync.RWMutex +} + +func newTickets() *tickets { + return &tickets{m: make(map[Topic]*topicTickets)} +} + +func (r *tickets) get(key Topic) (*topicTickets, bool) { + r.RLock() + v, ok := r.m[key] + r.RUnlock() + return v, ok +} + +func (r *tickets) set(key Topic, v *topicTickets) { + r.Lock() + r.m[key] = v + r.Unlock() +} + +func (r *tickets) len() int { + r.RLock() + l := len(r.m) + r.RUnlock() + return l +} + +func (r *tickets) delete(key Topic) { + r.Lock() + delete(r.m, key) + r.Unlock() +} diff --git a/client/p2p/discv5/topic.go b/client/p2p/discv5/topic.go index c69b98c84..0fb6ab683 100644 --- a/client/p2p/discv5/topic.go +++ b/client/p2p/discv5/topic.go @@ -9,6 +9,7 @@ import ( "time" "github.com/kowala-tech/kcoin/client/common" + "github.com/kowala-tech/kcoin/client/log" "github.com/kowala-tech/kcoin/client/common/mclock" ) @@ -221,7 +222,7 @@ func (t *topicTable) deleteEntry(e *topicEntry) { // It is assumed that topics and waitPeriods have the same length. func (t *topicTable) useTicket(node *Node, serialNo uint32, topics []Topic, idx int, issueTime uint64, waitPeriods []uint32) (registered bool) { - debugLog(fmt.Sprintf("useTicket %v %v %v", serialNo, topics, waitPeriods)) + log.Trace("Using discovery ticket", "serial", serialNo, "topics", topics, "waits", waitPeriods) //fmt.Println("useTicket", serialNo, topics, waitPeriods) t.collectGarbage() @@ -243,7 +244,7 @@ func (t *topicTable) useTicket(node *Node, serialNo uint32, topics []Topic, idx currTime := uint64(tm / mclock.AbsTime(time.Second)) regTime := issueTime + uint64(waitPeriods[idx]) relTime := int64(currTime - regTime) - if relTime <= regTimeWindow+1 { // give clients a little security margin on both ends + if relTime >= -1 && relTime <= regTimeWindow+1 { // give clients a little security margin on both ends if e := n.entries[topics[idx]]; e == nil { t.addEntry(node, topics[idx]) } else { @@ -256,15 +257,15 @@ func (t *topicTable) useTicket(node *Node, serialNo uint32, topics []Topic, idx return false } -func (topictab *topicTable) getTicket(node *Node, topics []Topic) *ticket { - topictab.collectGarbage() +func (t *topicTable) getTicket(node *Node, topics []Topic) *ticket { + t.collectGarbage() now := mclock.Now() - n := topictab.getOrNewNode(node) + n := t.getOrNewNode(node) n.lastIssuedTicket++ - topictab.storeTicketCounters(node) + t.storeTicketCounters(node) - t := &ticket{ + tic := &ticket{ issueTime: now, topics: topics, serial: n.lastIssuedTicket, @@ -272,15 +273,15 @@ func (topictab *topicTable) getTicket(node *Node, topics []Topic) *ticket { } for i, topic := range topics { var waitPeriod time.Duration - if topic := topictab.topics[topic]; topic != nil { + if topic := t.topics[topic]; topic != nil { waitPeriod = topic.wcl.waitPeriod } else { waitPeriod = minWaitPeriod } - t.regTime[i] = now + mclock.AbsTime(waitPeriod) + tic.regTime[i] = now + mclock.AbsTime(waitPeriod) } - return t + return tic } const gcInterval = time.Minute diff --git a/client/p2p/discv5/udp.go b/client/p2p/discv5/udp.go index eae72f80b..5c0015e51 100644 --- a/client/p2p/discv5/udp.go +++ b/client/p2p/discv5/udp.go @@ -20,25 +20,17 @@ const Version = 4 // Errors var ( - errPacketTooSmall = errors.New("too small") - errBadHash = errors.New("bad hash") - errExpired = errors.New("expired") - errUnsolicitedReply = errors.New("unsolicited reply") - errUnknownNode = errors.New("unknown node") - errTimeout = errors.New("RPC timeout") - errClockWarp = errors.New("reply deadline too far in the future") - errClosed = errors.New("socket closed") + errPacketTooSmall = errors.New("too small") + errBadPrefix = errors.New("bad prefix") + errTimeout = errors.New("RPC timeout") ) // Timeouts const ( respTimeout = 500 * time.Millisecond - sendTimeout = 500 * time.Millisecond expiration = 20 * time.Second - ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP - ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning - driftThreshold = 10 * time.Second // Allowed clock drift before warning user + driftThreshold = 10 * time.Second // Allowed clock drift before warning user ) // RPC request structures @@ -129,10 +121,11 @@ type ( } ) -const ( - macSize = 256 / 8 - sigSize = 520 / 8 - headSize = macSize + sigSize // space of packet frame data +var ( + versionPrefix = []byte("temporary discovery v5") + versionPrefixSize = len(versionPrefix) + sigSize = 520 / 8 + headSize = versionPrefixSize + sigSize // space of packet frame data ) // Neighbors replies are sent across multiple packets to @@ -221,15 +214,16 @@ type udp struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { - transport, err := listenUDP(priv, laddr) +func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { + transport, err := listenUDP(priv, conn, realaddr) if err != nil { return nil, err } - net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict) + net, err := newNetwork(transport, priv.PublicKey, nodeDBPath, netrestrict) if err != nil { return nil, err } + log.Info("UDP listener up", "net", net.tab.self) transport.net = net go transport.readLoop() @@ -238,16 +232,8 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP return net, nil } -func listenUDP(priv *ecdsa.PrivateKey, laddr string) (*udp, error) { - addr, err := net.ResolveUDPAddr("udp", laddr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - return nil, err - } - return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(addr, uint16(addr.Port))}, nil +func listenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) { + return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil } func (t *udp) localAddr() *net.UDPAddr { @@ -311,20 +297,20 @@ func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []by func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) { p := topicNodes{Echo: queryHash} - if len(nodes) == 0 { - t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) - return - } - for i, result := range nodes { - if netutil.CheckRelayIP(remote.IP, result.IP) != nil { - continue + var sent bool + for _, result := range nodes { + if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil { + p.Nodes = append(p.Nodes, nodeToRPC(result)) } - p.Nodes = append(p.Nodes, nodeToRPC(result)) - if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { + if len(p.Nodes) == maxTopicNodes { t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) p.Nodes = p.Nodes[:0] + sent = true } } + if !sent || len(p.Nodes) > 0 { + t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) + } } func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) { @@ -335,7 +321,7 @@ func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req inter return hash, err } log.Trace(fmt.Sprintf(">>> %v to %x@%v", nodeEvent(ptype), toid[:8], toaddr)) - if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { + if _, err := t.conn.WriteToUDP(packet, toaddr); err != nil { log.Trace(fmt.Sprint("UDP send failed:", err)) } //fmt.Println(err) @@ -359,11 +345,9 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash log.Error(fmt.Sprint("could not sign packet:", err)) return nil, nil, err } - copy(packet[macSize:], sig) - // add the hash to the front. Note: this doesn't protect the - // packet in any way. - hash = crypto.Keccak256(packet[macSize:]) - copy(packet, hash) + copy(packet, versionPrefix) + copy(packet[versionPrefixSize:], sig) + hash = crypto.Keccak256(packet[versionPrefixSize:]) return packet, hash, nil } @@ -407,17 +391,16 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error { } buf := make([]byte, len(buffer)) copy(buf, buffer) - hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] - shouldhash := crypto.Keccak256(buf[macSize:]) - if !bytes.Equal(hash, shouldhash) { - return errBadHash + prefix, sig, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:headSize], buf[headSize:] + if !bytes.Equal(prefix, versionPrefix) { + return errBadPrefix } fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) if err != nil { return err } pkt.rawData = buf - pkt.hash = hash + pkt.hash = crypto.Keccak256(buf[versionPrefixSize:]) pkt.remoteID = fromID switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev { case pingPacket: diff --git a/client/p2p/enr/enr.go b/client/p2p/enr/enr.go new file mode 100644 index 000000000..e385c1f6e --- /dev/null +++ b/client/p2p/enr/enr.go @@ -0,0 +1,252 @@ +// Package enr supports the "secp256k1-keccak" identity scheme. +package enr + +import ( + "bytes" + "errors" + "fmt" + "io" + "sort" + + "github.com/kowala-tech/kcoin/client/rlp" +) + +const SizeLimit = 300 // maximum encoded size of a node record in bytes + +var ( + errNoID = errors.New("unknown or unspecified identity scheme") + errInvalidSig = errors.New("invalid signature") + errNotSorted = errors.New("record key/value pairs are not sorted by key") + errDuplicateKey = errors.New("record contains duplicate key") + errIncompletePair = errors.New("record contains incomplete k/v pair") + errTooBig = fmt.Errorf("record bigger than %d bytes", SizeLimit) + errEncodeUnsigned = errors.New("can't encode unsigned record") + errNotFound = errors.New("no such key in record") +) + +// Record represents a node record. The zero value is an empty record. +type Record struct { + seq uint64 // sequence number + signature []byte // the signature + raw []byte // RLP encoded record + pairs []pair // sorted list of all key/value pairs +} + +// pair is a key/value pair in a record. +type pair struct { + k string + v rlp.RawValue +} + +// Signed reports whether the record has a valid signature. +func (r *Record) Signed() bool { + return r.signature != nil +} + +// Seq returns the sequence number. +func (r *Record) Seq() uint64 { + return r.seq +} + +// SetSeq updates the record sequence number. This invalidates any signature on the record. +// Calling SetSeq is usually not required because setting any key in a signed record +// increments the sequence number. +func (r *Record) SetSeq(s uint64) { + r.signature = nil + r.raw = nil + r.seq = s +} + +// Load retrieves the value of a key/value pair. The given Entry must be a pointer and will +// be set to the value of the entry in the record. +// +// Errors returned by Load are wrapped in KeyError. You can distinguish decoding errors +// from missing keys using the IsNotFound function. +func (r *Record) Load(e Entry) error { + i := sort.Search(len(r.pairs), func(i int) bool { return r.pairs[i].k >= e.ENRKey() }) + if i < len(r.pairs) && r.pairs[i].k == e.ENRKey() { + if err := rlp.DecodeBytes(r.pairs[i].v, e); err != nil { + return &KeyError{Key: e.ENRKey(), Err: err} + } + return nil + } + return &KeyError{Key: e.ENRKey(), Err: errNotFound} +} + +// Set adds or updates the given entry in the record. It panics if the value can't be +// encoded. If the record is signed, Set increments the sequence number and invalidates +// the sequence number. +func (r *Record) Set(e Entry) { + blob, err := rlp.EncodeToBytes(e) + if err != nil { + panic(fmt.Errorf("enr: can't encode %s: %v", e.ENRKey(), err)) + } + r.invalidate() + + pairs := make([]pair, len(r.pairs)) + copy(pairs, r.pairs) + i := sort.Search(len(pairs), func(i int) bool { return pairs[i].k >= e.ENRKey() }) + switch { + case i < len(pairs) && pairs[i].k == e.ENRKey(): + // element is present at r.pairs[i] + pairs[i].v = blob + case i < len(r.pairs): + // insert pair before i-th elem + el := pair{e.ENRKey(), blob} + pairs = append(pairs, pair{}) + copy(pairs[i+1:], pairs[i:]) + pairs[i] = el + default: + // element should be placed at the end of r.pairs + pairs = append(pairs, pair{e.ENRKey(), blob}) + } + r.pairs = pairs +} + +func (r *Record) invalidate() { + if r.signature == nil { + r.seq++ + } + r.signature = nil + r.raw = nil +} + +// EncodeRLP implements rlp.Encoder. Encoding fails if +// the record is unsigned. +func (r Record) EncodeRLP(w io.Writer) error { + if !r.Signed() { + return errEncodeUnsigned + } + _, err := w.Write(r.raw) + return err +} + +// DecodeRLP implements rlp.Decoder. Decoding verifies the signature. +func (r *Record) DecodeRLP(s *rlp.Stream) error { + raw, err := s.Raw() + if err != nil { + return err + } + if len(raw) > SizeLimit { + return errTooBig + } + + // Decode the RLP container. + dec := Record{raw: raw} + s = rlp.NewStream(bytes.NewReader(raw), 0) + if _, err := s.List(); err != nil { + return err + } + if err = s.Decode(&dec.signature); err != nil { + return err + } + if err = s.Decode(&dec.seq); err != nil { + return err + } + // The rest of the record contains sorted k/v pairs. + var prevkey string + for i := 0; ; i++ { + var kv pair + if err := s.Decode(&kv.k); err != nil { + if err == rlp.EOL { + break + } + return err + } + if err := s.Decode(&kv.v); err != nil { + if err == rlp.EOL { + return errIncompletePair + } + return err + } + if i > 0 { + if kv.k == prevkey { + return errDuplicateKey + } + if kv.k < prevkey { + return errNotSorted + } + } + dec.pairs = append(dec.pairs, kv) + prevkey = kv.k + } + if err := s.ListEnd(); err != nil { + return err + } + + _, scheme := dec.idScheme() + if scheme == nil { + return errNoID + } + if err := scheme.Verify(&dec, dec.signature); err != nil { + return err + } + *r = dec + return nil +} + +// NodeAddr returns the node address. The return value will be nil if the record is +// unsigned or uses an unknown identity scheme. +func (r *Record) NodeAddr() []byte { + _, scheme := r.idScheme() + if scheme == nil { + return nil + } + return scheme.NodeAddr(r) +} + +// SetSig sets the record signature. It returns an error if the encoded record is larger +// than the size limit or if the signature is invalid according to the passed scheme. +func (r *Record) SetSig(idscheme string, sig []byte) error { + // Check that "id" is set and matches the given scheme. This panics because + // inconsitencies here are always implementation bugs in the signing function calling + // this method. + id, s := r.idScheme() + if s == nil { + panic(errNoID) + } + if id != idscheme { + panic(fmt.Errorf("identity scheme mismatch in Sign: record has %s, want %s", id, idscheme)) + } + + // Verify against the scheme. + if err := s.Verify(r, sig); err != nil { + return err + } + raw, err := r.encode(sig) + if err != nil { + return err + } + r.signature, r.raw = sig, raw + return nil +} + +// AppendElements appends the sequence number and entries to the given slice. +func (r *Record) AppendElements(list []interface{}) []interface{} { + list = append(list, r.seq) + for _, p := range r.pairs { + list = append(list, p.k, p.v) + } + return list +} + +func (r *Record) encode(sig []byte) (raw []byte, err error) { + list := make([]interface{}, 1, 2*len(r.pairs)+1) + list[0] = sig + list = r.AppendElements(list) + if raw, err = rlp.EncodeToBytes(list); err != nil { + return nil, err + } + if len(raw) > SizeLimit { + return nil, errTooBig + } + return raw, nil +} + +func (r *Record) idScheme() (string, IdentityScheme) { + var id ID + if err := r.Load(&id); err != nil { + return "", nil + } + return string(id), FindIdentityScheme(string(id)) +} diff --git a/client/p2p/enr/enr_test.go b/client/p2p/enr/enr_test.go new file mode 100644 index 000000000..ca518be3d --- /dev/null +++ b/client/p2p/enr/enr_test.go @@ -0,0 +1,302 @@ +package enr + +import ( + "bytes" + "encoding/hex" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/kowala-tech/kcoin/client/crypto" + "github.com/kowala-tech/kcoin/client/rlp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + pubkey = &privkey.PublicKey +) + +var rnd = rand.New(rand.NewSource(time.Now().UnixNano())) + +func randomString(strlen int) string { + b := make([]byte, strlen) + rnd.Read(b) + return string(b) +} + +// TestGetSetID tests encoding/decoding and setting/getting of the ID key. +func TestGetSetID(t *testing.T) { + id := ID("someid") + var r Record + r.Set(id) + + var id2 ID + require.NoError(t, r.Load(&id2)) + assert.Equal(t, id, id2) +} + +// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key. +func TestGetSetIP4(t *testing.T) { + ip := IP{192, 168, 0, 3} + var r Record + r.Set(ip) + + var ip2 IP + require.NoError(t, r.Load(&ip2)) + assert.Equal(t, ip, ip2) +} + +// TestGetSetIP6 tests encoding/decoding and setting/getting of the IP key. +func TestGetSetIP6(t *testing.T) { + ip := IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68} + var r Record + r.Set(ip) + + var ip2 IP + require.NoError(t, r.Load(&ip2)) + assert.Equal(t, ip, ip2) +} + +// TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key. +func TestGetSetUDP(t *testing.T) { + port := UDP(30309) + var r Record + r.Set(port) + + var port2 UDP + require.NoError(t, r.Load(&port2)) + assert.Equal(t, port, port2) +} + +// TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key. +func TestGetSetSecp256k1(t *testing.T) { + var r Record + if err := SignV4(&r, privkey); err != nil { + t.Fatal(err) + } + + var pk Secp256k1 + require.NoError(t, r.Load(&pk)) + assert.EqualValues(t, pubkey, &pk) +} + +func TestLoadErrors(t *testing.T) { + var r Record + ip4 := IP{127, 0, 0, 1} + r.Set(ip4) + + // Check error for missing keys. + var udp UDP + err := r.Load(&udp) + if !IsNotFound(err) { + t.Error("IsNotFound should return true for missing key") + } + assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err) + + // Check error for invalid keys. + var list []uint + err = r.Load(WithEntry(ip4.ENRKey(), &list)) + kerr, ok := err.(*KeyError) + if !ok { + t.Fatalf("expected KeyError, got %T", err) + } + assert.Equal(t, kerr.Key, ip4.ENRKey()) + assert.Error(t, kerr.Err) + if IsNotFound(err) { + t.Error("IsNotFound should return false for decoding errors") + } +} + +// TestSortedGetAndSet tests that Set produced a sorted pairs slice. +func TestSortedGetAndSet(t *testing.T) { + type pair struct { + k string + v uint32 + } + + for _, tt := range []struct { + input []pair + want []pair + }{ + { + input: []pair{{"a", 1}, {"c", 2}, {"b", 3}}, + want: []pair{{"a", 1}, {"b", 3}, {"c", 2}}, + }, + { + input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, + want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, + }, + { + input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, + want: []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, + }, + } { + var r Record + for _, i := range tt.input { + r.Set(WithEntry(i.k, &i.v)) + } + for i, w := range tt.want { + // set got's key from r.pair[i], so that we preserve order of pairs + got := pair{k: r.pairs[i].k} + assert.NoError(t, r.Load(WithEntry(w.k, &got.v))) + assert.Equal(t, w, got) + } + } +} + +// TestDirty tests record signature removal on setting of new key/value pair in record. +func TestDirty(t *testing.T) { + var r Record + + if r.Signed() { + t.Error("Signed returned true for zero record") + } + if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { + t.Errorf("expected errEncodeUnsigned, got %#v", err) + } + + require.NoError(t, SignV4(&r, privkey)) + if !r.Signed() { + t.Error("Signed return false for signed record") + } + _, err := rlp.EncodeToBytes(r) + assert.NoError(t, err) + + r.SetSeq(3) + if r.Signed() { + t.Error("Signed returned true for modified record") + } + if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { + t.Errorf("expected errEncodeUnsigned, got %#v", err) + } +} + +// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record. +func TestGetSetOverwrite(t *testing.T) { + var r Record + + ip := IP{192, 168, 0, 3} + r.Set(ip) + + ip2 := IP{192, 168, 0, 4} + r.Set(ip2) + + var ip3 IP + require.NoError(t, r.Load(&ip3)) + assert.Equal(t, ip2, ip3) +} + +// TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record. +func TestSignEncodeAndDecode(t *testing.T) { + var r Record + r.Set(UDP(30303)) + r.Set(IP{127, 0, 0, 1}) + require.NoError(t, SignV4(&r, privkey)) + + blob, err := rlp.EncodeToBytes(r) + require.NoError(t, err) + + var r2 Record + require.NoError(t, rlp.DecodeBytes(blob, &r2)) + assert.Equal(t, r, r2) + + blob2, err := rlp.EncodeToBytes(r2) + require.NoError(t, err) + assert.Equal(t, blob, blob2) +} + +func TestNodeAddr(t *testing.T) { + var r Record + if addr := r.NodeAddr(); addr != nil { + t.Errorf("wrong address on empty record: got %v, want %v", addr, nil) + } + + require.NoError(t, SignV4(&r, privkey)) + expected := "a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" + assert.Equal(t, expected, hex.EncodeToString(r.NodeAddr())) +} + +var pyRecord, _ = hex.DecodeString("f884b8407098ad865b00a582051940cb9cf36836572411a47278783077011599ed5cd16b76f2635f4e234738f30813a89eb9137e3e3df5266e3a1f11df72ecf1145ccb9c01826964827634826970847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31388375647082765f") + +// TestPythonInterop checks that we can decode and verify a record produced by the Python +// implementation. +func TestPythonInterop(t *testing.T) { + var r Record + if err := rlp.DecodeBytes(pyRecord, &r); err != nil { + t.Fatalf("can't decode: %v", err) + } + + var ( + wantAddr, _ = hex.DecodeString("a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7") + wantSeq = uint64(1) + wantIP = IP{127, 0, 0, 1} + wantUDP = UDP(30303) + ) + if r.Seq() != wantSeq { + t.Errorf("wrong seq: got %d, want %d", r.Seq(), wantSeq) + } + if addr := r.NodeAddr(); !bytes.Equal(addr, wantAddr) { + t.Errorf("wrong addr: got %x, want %x", addr, wantAddr) + } + want := map[Entry]interface{}{new(IP): &wantIP, new(UDP): &wantUDP} + for k, v := range want { + desc := fmt.Sprintf("loading key %q", k.ENRKey()) + if assert.NoError(t, r.Load(k), desc) { + assert.Equal(t, k, v, desc) + } + } +} + +// TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed. +func TestRecordTooBig(t *testing.T) { + var r Record + key := randomString(10) + + // set a big value for random key, expect error + r.Set(WithEntry(key, randomString(SizeLimit))) + if err := SignV4(&r, privkey); err != errTooBig { + t.Fatalf("expected to get errTooBig, got %#v", err) + } + + // set an acceptable value for random key, expect no error + r.Set(WithEntry(key, randomString(100))) + require.NoError(t, SignV4(&r, privkey)) +} + +// TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs. +func TestSignEncodeAndDecodeRandom(t *testing.T) { + var r Record + + // random key/value pairs for testing + pairs := map[string]uint32{} + for i := 0; i < 10; i++ { + key := randomString(7) + value := rnd.Uint32() + pairs[key] = value + r.Set(WithEntry(key, &value)) + } + + require.NoError(t, SignV4(&r, privkey)) + _, err := rlp.EncodeToBytes(r) + require.NoError(t, err) + + for k, v := range pairs { + desc := fmt.Sprintf("key %q", k) + var got uint32 + buf := WithEntry(k, &got) + require.NoError(t, r.Load(buf), desc) + require.Equal(t, v, got, desc) + } +} + +func BenchmarkDecode(b *testing.B) { + var r Record + for i := 0; i < b.N; i++ { + rlp.DecodeBytes(pyRecord, &r) + } + b.StopTimer() + r.NodeAddr() +} diff --git a/client/p2p/enr/entries.go b/client/p2p/enr/entries.go new file mode 100644 index 000000000..da7d72bdf --- /dev/null +++ b/client/p2p/enr/entries.go @@ -0,0 +1,128 @@ +package enr + +import ( + "crypto/ecdsa" + "fmt" + "io" + "net" + + "github.com/kowala-tech/kcoin/client/crypto" + "github.com/kowala-tech/kcoin/client/rlp" +) + +// Entry is implemented by known node record entry types. +// +// To define a new entry that is to be included in a node record, +// create a Go type that satisfies this interface. The type should +// also implement rlp.Decoder if additional checks are needed on the value. +type Entry interface { + ENRKey() string +} + +type generic struct { + key string + value interface{} +} + +func (g generic) ENRKey() string { return g.key } + +func (g generic) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, g.value) +} + +func (g *generic) DecodeRLP(s *rlp.Stream) error { + return s.Decode(g.value) +} + +// WithEntry wraps any value with a key name. It can be used to set and load arbitrary values +// in a record. The value v must be supported by rlp. To use WithEntry with Load, the value +// must be a pointer. +func WithEntry(k string, v interface{}) Entry { + return &generic{key: k, value: v} +} + +// TCP is the "tcp" key, which holds the TCP port of the node. +type TCP uint16 + +func (v TCP) ENRKey() string { return "tcp" } + +// UDP is the "udp" key, which holds the UDP port of the node. +type UDP uint16 + +func (v UDP) ENRKey() string { return "udp" } + +// ID is the "id" key, which holds the name of the identity scheme. +type ID string + +const IDv4 = ID("v4") // the default identity scheme + +func (v ID) ENRKey() string { return "id" } + +// IP is the "ip" key, which holds the IP address of the node. +type IP net.IP + +func (v IP) ENRKey() string { return "ip" } + +// EncodeRLP implements rlp.Encoder. +func (v IP) EncodeRLP(w io.Writer) error { + if ip4 := net.IP(v).To4(); ip4 != nil { + return rlp.Encode(w, ip4) + } + return rlp.Encode(w, net.IP(v)) +} + +// DecodeRLP implements rlp.Decoder. +func (v *IP) DecodeRLP(s *rlp.Stream) error { + if err := s.Decode((*net.IP)(v)); err != nil { + return err + } + if len(*v) != 4 && len(*v) != 16 { + return fmt.Errorf("invalid IP address, want 4 or 16 bytes: %v", *v) + } + return nil +} + +// Secp256k1 is the "secp256k1" key, which holds a public key. +type Secp256k1 ecdsa.PublicKey + +func (v Secp256k1) ENRKey() string { return "secp256k1" } + +// EncodeRLP implements rlp.Encoder. +func (v Secp256k1) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, crypto.CompressPubkey((*ecdsa.PublicKey)(&v))) +} + +// DecodeRLP implements rlp.Decoder. +func (v *Secp256k1) DecodeRLP(s *rlp.Stream) error { + buf, err := s.Bytes() + if err != nil { + return err + } + pk, err := crypto.DecompressPubkey(buf) + if err != nil { + return err + } + *v = (Secp256k1)(*pk) + return nil +} + +// KeyError is an error related to a key. +type KeyError struct { + Key string + Err error +} + +// Error implements error. +func (err *KeyError) Error() string { + if err.Err == errNotFound { + return fmt.Sprintf("missing ENR key %q", err.Key) + } + return fmt.Sprintf("ENR key %q: %v", err.Key, err.Err) +} + +// IsNotFound reports whether the given error means that a key/value pair is +// missing from a record. +func IsNotFound(err error) bool { + kerr, ok := err.(*KeyError) + return ok && kerr.Err == errNotFound +} diff --git a/client/p2p/enr/idscheme.go b/client/p2p/enr/idscheme.go new file mode 100644 index 000000000..1c2b55907 --- /dev/null +++ b/client/p2p/enr/idscheme.go @@ -0,0 +1,98 @@ +package enr + +import ( + "crypto/ecdsa" + "fmt" + "sync" + + "github.com/kowala-tech/kcoin/client/common/math" + "github.com/kowala-tech/kcoin/client/crypto" + "github.com/kowala-tech/kcoin/client/crypto/sha3" + "github.com/kowala-tech/kcoin/client/rlp" +) + +// Registry of known identity schemes. +var schemes sync.Map + +// An IdentityScheme is capable of verifying record signatures and +// deriving node addresses. +type IdentityScheme interface { + Verify(r *Record, sig []byte) error + NodeAddr(r *Record) []byte +} + +// RegisterIdentityScheme adds an identity scheme to the global registry. +func RegisterIdentityScheme(name string, scheme IdentityScheme) { + if _, loaded := schemes.LoadOrStore(name, scheme); loaded { + panic("identity scheme " + name + " already registered") + } +} + +// FindIdentityScheme resolves name to an identity scheme in the global registry. +func FindIdentityScheme(name string) IdentityScheme { + s, ok := schemes.Load(name) + if !ok { + return nil + } + return s.(IdentityScheme) +} + +// v4ID is the "v4" identity scheme. +type v4ID struct{} + +func init() { + RegisterIdentityScheme("v4", v4ID{}) +} + +// SignV4 signs a record using the v4 scheme. +func SignV4(r *Record, privkey *ecdsa.PrivateKey) error { + // Copy r to avoid modifying it if signing fails. + cpy := *r + cpy.Set(ID("v4")) + cpy.Set(Secp256k1(privkey.PublicKey)) + + h := sha3.NewKeccak256() + rlp.Encode(h, cpy.AppendElements(nil)) + sig, err := crypto.Sign(h.Sum(nil), privkey) + if err != nil { + return err + } + sig = sig[:len(sig)-1] // remove v + if err = cpy.SetSig("v4", sig); err == nil { + *r = cpy + } + return err +} + +// s256raw is an unparsed secp256k1 public key entry. +type s256raw []byte + +func (s256raw) ENRKey() string { return "secp256k1" } + +func (v4ID) Verify(r *Record, sig []byte) error { + var entry s256raw + if err := r.Load(&entry); err != nil { + return err + } else if len(entry) != 33 { + return fmt.Errorf("invalid public key") + } + + h := sha3.NewKeccak256() + rlp.Encode(h, r.AppendElements(nil)) + if !crypto.VerifySignature(entry, h.Sum(nil), sig) { + return errInvalidSig + } + return nil +} + +func (v4ID) NodeAddr(r *Record) []byte { + var pubkey Secp256k1 + err := r.Load(&pubkey) + if err != nil { + return nil + } + buf := make([]byte, 64) + math.ReadBits(pubkey.X, buf[:32]) + math.ReadBits(pubkey.Y, buf[32:]) + return crypto.Keccak256(buf) +} diff --git a/client/p2p/enr/idscheme_test.go b/client/p2p/enr/idscheme_test.go new file mode 100644 index 000000000..00c180a7d --- /dev/null +++ b/client/p2p/enr/idscheme_test.go @@ -0,0 +1,20 @@ +package enr + +import ( + "crypto/ecdsa" + "math/big" + "testing" +) + +// Checks that failure to sign leaves the record unmodified. +func TestSignError(t *testing.T) { + invalidKey := &ecdsa.PrivateKey{D: new(big.Int), PublicKey: *pubkey} + + var r Record + if err := SignV4(&r, invalidKey); err == nil { + t.Fatal("expected error from SignV4") + } + if len(r.pairs) > 0 { + t.Fatal("expected empty record, have", r.pairs) + } +} diff --git a/client/p2p/message.go b/client/p2p/message.go index ae5443dbc..05ed96648 100644 --- a/client/p2p/message.go +++ b/client/p2p/message.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "io/ioutil" - "net" - "sync" "sync/atomic" "time" @@ -96,30 +94,6 @@ func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error { return Send(w, msgcode, elems) } -// netWrapper wraps a MsgReadWriter with locks around -// ReadMsg/WriteMsg and applies read/write deadlines. -type netWrapper struct { - rmu, wmu sync.Mutex - - rtimeout, wtimeout time.Duration - conn net.Conn - wrapped MsgReadWriter -} - -func (rw *netWrapper) ReadMsg() (Msg, error) { - rw.rmu.Lock() - defer rw.rmu.Unlock() - rw.conn.SetReadDeadline(time.Now().Add(rw.rtimeout)) - return rw.wrapped.ReadMsg() -} - -func (rw *netWrapper) WriteMsg(msg Msg) error { - rw.wmu.Lock() - defer rw.wmu.Unlock() - rw.conn.SetWriteDeadline(time.Now().Add(rw.wtimeout)) - return rw.wrapped.WriteMsg(msg) -} - // eofSignal wraps a reader with eof signaling. the eof channel is // closed when the wrapped reader returns an error or when count bytes // have been read. @@ -239,21 +213,20 @@ func ExpectMsg(r MsgReader, code uint64, content interface{}) error { } if content == nil { return msg.Discard() - } else { - contentEnc, err := rlp.EncodeToBytes(content) - if err != nil { - panic("content encode error: " + err.Error()) - } - if int(msg.Size) != len(contentEnc) { - return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc)) - } - actualContent, err := ioutil.ReadAll(msg.Payload) - if err != nil { - return err - } - if !bytes.Equal(actualContent, contentEnc) { - return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) - } + } + contentEnc, err := rlp.EncodeToBytes(content) + if err != nil { + panic("content encode error: " + err.Error()) + } + if int(msg.Size) != len(contentEnc) { + return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc)) + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + if !bytes.Equal(actualContent, contentEnc) { + return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) } return nil } @@ -281,15 +254,15 @@ func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID discover.NodeID, p // ReadMsg reads a message from the underlying MsgReadWriter and emits a // "message received" event -func (self *msgEventer) ReadMsg() (Msg, error) { - msg, err := self.MsgReadWriter.ReadMsg() +func (ev *msgEventer) ReadMsg() (Msg, error) { + msg, err := ev.MsgReadWriter.ReadMsg() if err != nil { return msg, err } - self.feed.Send(&PeerEvent{ + ev.feed.Send(&PeerEvent{ Type: PeerEventTypeMsgRecv, - Peer: self.peerID, - Protocol: self.Protocol, + Peer: ev.peerID, + Protocol: ev.Protocol, MsgCode: &msg.Code, MsgSize: &msg.Size, }) @@ -298,15 +271,15 @@ func (self *msgEventer) ReadMsg() (Msg, error) { // WriteMsg writes a message to the underlying MsgReadWriter and emits a // "message sent" event -func (self *msgEventer) WriteMsg(msg Msg) error { - err := self.MsgReadWriter.WriteMsg(msg) +func (ev *msgEventer) WriteMsg(msg Msg) error { + err := ev.MsgReadWriter.WriteMsg(msg) if err != nil { return err } - self.feed.Send(&PeerEvent{ + ev.feed.Send(&PeerEvent{ Type: PeerEventTypeMsgSend, - Peer: self.peerID, - Protocol: self.Protocol, + Peer: ev.peerID, + Protocol: ev.Protocol, MsgCode: &msg.Code, MsgSize: &msg.Size, }) @@ -315,8 +288,8 @@ func (self *msgEventer) WriteMsg(msg Msg) error { // Close closes the underlying MsgReadWriter if it implements the io.Closer // interface -func (self *msgEventer) Close() error { - if v, ok := self.MsgReadWriter.(io.Closer); ok { +func (ev *msgEventer) Close() error { + if v, ok := ev.MsgReadWriter.(io.Closer); ok { return v.Close() } return nil diff --git a/client/p2p/netutil/net.go b/client/p2p/netutil/net.go index 1b0e79ebd..6d1ced744 100644 --- a/client/p2p/netutil/net.go +++ b/client/p2p/netutil/net.go @@ -2,8 +2,11 @@ package netutil import ( + "bytes" "errors" + "fmt" "net" + "sort" "strings" ) @@ -173,3 +176,131 @@ func CheckRelayIP(sender, addr net.IP) error { } return nil } + +// SameNet reports whether two IP addresses have an equal prefix of the given bit length. +func SameNet(bits uint, ip, other net.IP) bool { + ip4, other4 := ip.To4(), other.To4() + switch { + case (ip4 == nil) != (other4 == nil): + return false + case ip4 != nil: + return sameNet(bits, ip4, other4) + default: + return sameNet(bits, ip.To16(), other.To16()) + } +} + +func sameNet(bits uint, ip, other net.IP) bool { + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask { + return false + } + return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb]) +} + +// DistinctNetSet tracks IPs, ensuring that at most N of them +// fall into the same network range. +type DistinctNetSet struct { + Subnet uint // number of common prefix bits + Limit uint // maximum number of IPs in each subnet + + members map[string]uint + buf net.IP +} + +// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the +// number of existing IPs in the defined range exceeds the limit. +func (s *DistinctNetSet) Add(ip net.IP) bool { + key := s.key(ip) + n := s.members[string(key)] + if n < s.Limit { + s.members[string(key)] = n + 1 + return true + } + return false +} + +// Remove removes an IP from the set. +func (s *DistinctNetSet) Remove(ip net.IP) { + key := s.key(ip) + if n, ok := s.members[string(key)]; ok { + if n == 1 { + delete(s.members, string(key)) + } else { + s.members[string(key)] = n - 1 + } + } +} + +// Contains whether the given IP is contained in the set. +func (s DistinctNetSet) Contains(ip net.IP) bool { + key := s.key(ip) + _, ok := s.members[string(key)] + return ok +} + +// Len returns the number of tracked IPs. +func (s DistinctNetSet) Len() int { + n := uint(0) + for _, i := range s.members { + n += i + } + return int(n) +} + +// key encodes the map key for an address into a temporary buffer. +// +// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types. +// The remainder of the key is the IP, truncated to the number of bits. +func (s *DistinctNetSet) key(ip net.IP) net.IP { + // Lazily initialize storage. + if s.members == nil { + s.members = make(map[string]uint) + s.buf = make(net.IP, 17) + } + // Canonicalize ip and bits. + typ := byte('6') + if ip4 := ip.To4(); ip4 != nil { + typ, ip = '4', ip4 + } + bits := s.Subnet + if bits > uint(len(ip)*8) { + bits = uint(len(ip) * 8) + } + // Encode the prefix into s.buf. + nb := int(bits / 8) + mask := ^byte(0xFF >> (bits % 8)) + s.buf[0] = typ + buf := append(s.buf[:1], ip[:nb]...) + if nb < len(ip) && mask != 0 { + buf = append(buf, ip[nb]&mask) + } + return buf +} + +// String implements fmt.Stringer +func (s DistinctNetSet) String() string { + var buf bytes.Buffer + buf.WriteString("{") + keys := make([]string, 0, len(s.members)) + for k := range s.members { + keys = append(keys, k) + } + sort.Strings(keys) + for i, k := range keys { + var ip net.IP + if k[0] == '4' { + ip = make(net.IP, 4) + } else { + ip = make(net.IP, 16) + } + copy(ip, k[1:]) + fmt.Fprintf(&buf, "%v×%d", ip, s.members[k]) + if i != len(keys)-1 { + buf.WriteString(" ") + } + } + buf.WriteString("}") + return buf.String() +} diff --git a/client/p2p/peer.go b/client/p2p/peer.go index 2c42f3cc6..c44adc20e 100644 --- a/client/p2p/peer.go +++ b/client/p2p/peer.go @@ -31,8 +31,6 @@ const ( discMsg = 0x01 pingMsg = 0x02 pongMsg = 0x03 - getPeersMsg = 0x04 - peersMsg = 0x05 ) // protoHandshake is the RLP structure of the protocol handshake. @@ -144,6 +142,11 @@ func (p *Peer) String() string { return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr()) } +// Inbound returns true if the peer is an inbound connection +func (p *Peer) Inbound() bool { + return p.rw.flags&inboundConn != 0 +} + func newPeer(conn *conn, protocols []Protocol) *Peer { protomap := matchProtocols(protocols, conn.caps, conn) p := &Peer{ @@ -201,6 +204,7 @@ loop: reason = discReasonForError(err) break loop case err = <-p.disc: + reason = discReasonForError(err) break loop } } @@ -398,6 +402,9 @@ type PeerInfo struct { Network struct { LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection + Inbound bool `json:"inbound"` + Trusted bool `json:"trusted"` + Static bool `json:"static"` } `json:"network"` Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields } @@ -418,6 +425,9 @@ func (p *Peer) Info() *PeerInfo { } info.Network.LocalAddress = p.LocalAddr().String() info.Network.RemoteAddress = p.RemoteAddr().String() + info.Network.Inbound = p.rw.is(inboundConn) + info.Network.Trusted = p.rw.is(trustedConn) + info.Network.Static = p.rw.is(staticDialedConn) // Gather all the running protocol infos for _, proto := range p.running { diff --git a/client/p2p/peer_error.go b/client/p2p/peer_error.go index 8d9134d12..686ac8bc5 100644 --- a/client/p2p/peer_error.go +++ b/client/p2p/peer_error.go @@ -32,8 +32,8 @@ func newPeerError(code int, format string, v ...interface{}) *peerError { return err } -func (self *peerError) Error() string { - return self.message +func (pe *peerError) Error() string { + return pe.message } var errProtocolReturned = errors.New("protocol returned") diff --git a/client/p2p/protocols/protocol.go b/client/p2p/protocols/protocol.go new file mode 100644 index 000000000..77661d823 --- /dev/null +++ b/client/p2p/protocols/protocol.go @@ -0,0 +1,295 @@ +/* +Package protocols is an extension to p2p. It offers a user friendly simple way to define +devp2p subprotocols by abstracting away code standardly shared by protocols. + +* automate assigments of code indexes to messages +* automate RLP decoding/encoding based on reflecting +* provide the forever loop to read incoming messages +* standardise error handling related to communication +* standardised handshake negotiation +* TODO: automatic generation of wire protocol specification for peers + +*/ +package protocols + +import ( + "context" + "fmt" + "reflect" + "sync" + + "github.com/kowala-tech/kcoin/client/p2p" +) + +// error codes used by this protocol scheme +const ( + ErrMsgTooLong = iota + ErrDecode + ErrWrite + ErrInvalidMsgCode + ErrInvalidMsgType + ErrHandshake + ErrNoHandler + ErrHandler +) + +// error description strings associated with the codes +var errorToString = map[int]string{ + ErrMsgTooLong: "Message too long", + ErrDecode: "Invalid message (RLP error)", + ErrWrite: "Error sending message", + ErrInvalidMsgCode: "Invalid message code", + ErrInvalidMsgType: "Invalid message type", + ErrHandshake: "Handshake error", + ErrNoHandler: "No handler registered error", + ErrHandler: "Message handler error", +} + +/* +Error implements the standard go error interface. +Use: + + errorf(code, format, params ...interface{}) + +Prints as: + + :
+ +where description is given by code in errorToString +and details is fmt.Sprintf(format, params...) + +exported field Code can be checked +*/ +type Error struct { + Code int + message string + format string + params []interface{} +} + +func (e Error) Error() (message string) { + if len(e.message) == 0 { + name, ok := errorToString[e.Code] + if !ok { + panic("invalid message code") + } + e.message = name + if e.format != "" { + e.message += ": " + fmt.Sprintf(e.format, e.params...) + } + } + return e.message +} + +func errorf(code int, format string, params ...interface{}) *Error { + return &Error{ + Code: code, + format: format, + params: params, + } +} + +// Spec is a protocol specification including its name and version as well as +// the types of messages which are exchanged +type Spec struct { + // Name is the name of the protocol, often a three-letter word + Name string + + // Version is the version number of the protocol + Version uint + + // MaxMsgSize is the maximum accepted length of the message payload + MaxMsgSize uint32 + + // Messages is a list of message data types which this protocol uses, with + // each message type being sent with its array index as the code (so + // [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes + // 0, 1 and 2 respectively) + // each message must have a single unique data type + Messages []interface{} + + initOnce sync.Once + codes map[reflect.Type]uint64 + types map[uint64]reflect.Type +} + +func (s *Spec) init() { + s.initOnce.Do(func() { + s.codes = make(map[reflect.Type]uint64, len(s.Messages)) + s.types = make(map[uint64]reflect.Type, len(s.Messages)) + for i, msg := range s.Messages { + code := uint64(i) + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + s.codes[typ] = code + s.types[code] = typ + } + }) +} + +// Length returns the number of message types in the protocol +func (s *Spec) Length() uint64 { + return uint64(len(s.Messages)) +} + +// GetCode returns the message code of a type, and boolean second argument is +// false if the message type is not found +func (s *Spec) GetCode(msg interface{}) (uint64, bool) { + s.init() + typ := reflect.TypeOf(msg) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + code, ok := s.codes[typ] + return code, ok +} + +// NewMsg construct a new message type given the code +func (s *Spec) NewMsg(code uint64) (interface{}, bool) { + s.init() + typ, ok := s.types[code] + if !ok { + return nil, false + } + return reflect.New(typ).Interface(), true +} + +// Peer represents a remote peer or protocol instance that is running on a peer connection with +// a remote peer +type Peer struct { + *p2p.Peer // the p2p.Peer object representing the remote + rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from + spec *Spec +} + +// NewPeer constructs a new peer +// this constructor is called by the p2p.Protocol#Run function +// the first two arguments are the arguments passed to p2p.Protocol.Run function +// the third argument is the Spec describing the protocol +func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { + return &Peer{ + Peer: p, + rw: rw, + spec: spec, + } +} + +// Run starts the forever loop that handles incoming messages +// called within the p2p.Protocol#Run function +// the handler argument is a function which is called for each message received +// from the remote peer, a returned error causes the loop to exit +// resulting in disconnection +func (p *Peer) Run(handler func(msg interface{}) error) error { + for { + if err := p.handleIncoming(handler); err != nil { + return err + } + } +} + +// Drop disconnects a peer. +// TODO: may need to implement protocol drop only? don't want to kick off the peer +// if they are useful for other protocols +func (p *Peer) Drop(err error) { + p.Disconnect(p2p.DiscSubprotocolError) +} + +// Send takes a message, encodes it in RLP, finds the right message code and sends the +// message off to the peer +// this low level call will be wrapped by libraries providing routed or broadcast sends +// but often just used to forward and push messages to directly connected peers +func (p *Peer) Send(msg interface{}) error { + code, found := p.spec.GetCode(msg) + if !found { + return errorf(ErrInvalidMsgType, "%v", code) + } + return p2p.Send(p.rw, code, msg) +} + +// handleIncoming(code) +// is called each cycle of the main forever loop that dispatches incoming messages +// if this returns an error the loop returns and the peer is disconnected with the error +// this generic handler +// * checks message size, +// * checks for out-of-range message codes, +// * handles decoding with reflection, +// * call handlers as callbacks +func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + // make sure that the payload has been fully consumed + defer msg.Discard() + + if msg.Size > p.spec.MaxMsgSize { + return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) + } + + val, ok := p.spec.NewMsg(msg.Code) + if !ok { + return errorf(ErrInvalidMsgCode, "%v", msg.Code) + } + if err := msg.Decode(val); err != nil { + return errorf(ErrDecode, "<= %v: %v", msg, err) + } + + // call the registered handler callbacks + // a registered callback take the decoded message as argument as an interface + // which the handler is supposed to cast to the appropriate type + // it is entirely safe not to check the cast in the handler since the handler is + // chosen based on the proper type in the first place + if err := handle(val); err != nil { + return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) + } + return nil +} + +// Handshake negotiates a handshake on the peer connection +// * arguments +// * context +// * the local handshake to be sent to the remote peer +// * funcion to be called on the remote handshake (can be nil) +// * expects a remote handshake back of the same type +// * the dialing peer needs to send the handshake first and then waits for remote +// * the listening peer waits for the remote handshake and then sends it +// returns the remote handshake and an error +func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) { + if _, ok := p.spec.GetCode(hs); !ok { + return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) + } + errc := make(chan error, 2) + handle := func(msg interface{}) error { + rhs = msg + if verify != nil { + return verify(rhs) + } + return nil + } + send := func() { errc <- p.Send(hs) } + receive := func() { errc <- p.handleIncoming(handle) } + + go func() { + if p.Inbound() { + receive() + send() + } else { + send() + receive() + } + }() + + for i := 0; i < 2; i++ { + select { + case err = <-errc: + case <-ctx.Done(): + err = ctx.Err() + } + if err != nil { + return nil, errorf(ErrHandshake, err.Error()) + } + } + return rhs, nil +} diff --git a/client/p2p/protocols/protocol_test.go b/client/p2p/protocols/protocol_test.go new file mode 100644 index 000000000..053f537a6 --- /dev/null +++ b/client/p2p/protocols/protocol_test.go @@ -0,0 +1,389 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package protocols + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + p2ptest "github.com/ethereum/go-ethereum/p2p/testing" +) + +// handshake message type +type hs0 struct { + C uint +} + +// message to kill/drop the peer with nodeID +type kill struct { + C discover.NodeID +} + +// message to drop connection +type drop struct { +} + +/// protoHandshake represents module-independent aspects of the protocol and is +// the first message peers send and receive as part the initial exchange +type protoHandshake struct { + Version uint // local and remote peer should have identical version + NetworkID string // local and remote peer should have identical network id +} + +// checkProtoHandshake verifies local and remote protoHandshakes match +func checkProtoHandshake(testVersion uint, testNetworkID string) func(interface{}) error { + return func(rhs interface{}) error { + remote := rhs.(*protoHandshake) + if remote.NetworkID != testNetworkID { + return fmt.Errorf("%s (!= %s)", remote.NetworkID, testNetworkID) + } + + if remote.Version != testVersion { + return fmt.Errorf("%d (!= %d)", remote.Version, testVersion) + } + return nil + } +} + +// newProtocol sets up a protocol +// the run function here demonstrates a typical protocol using peerPool, handshake +// and messages registered to handlers +func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) error { + spec := &Spec{ + Name: "test", + Version: 42, + MaxMsgSize: 10 * 1024, + Messages: []interface{}{ + protoHandshake{}, + hs0{}, + kill{}, + drop{}, + }, + } + return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := NewPeer(p, rw, spec) + + // initiate one-off protohandshake and check validity + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + phs := &protoHandshake{42, "420"} + hsCheck := checkProtoHandshake(phs.Version, phs.NetworkID) + _, err := peer.Handshake(ctx, phs, hsCheck) + if err != nil { + return err + } + + lhs := &hs0{42} + // module handshake demonstrating a simple repeatable exchange of same-type message + hs, err := peer.Handshake(ctx, lhs, nil) + if err != nil { + return err + } + + if rmhs := hs.(*hs0); rmhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C) + } + + handle := func(msg interface{}) error { + switch msg := msg.(type) { + + case *protoHandshake: + return errors.New("duplicate handshake") + + case *hs0: + rhs := msg + if rhs.C > lhs.C { + return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C) + } + lhs.C += rhs.C + return peer.Send(lhs) + + case *kill: + // demonstrates use of peerPool, killing another peer connection as a response to a message + id := msg.C + pp.Get(id).Drop(errors.New("killed")) + return nil + + case *drop: + // for testing we can trigger self induced disconnect upon receiving drop message + return errors.New("dropped") + + default: + return fmt.Errorf("unknown message type: %T", msg) + } + } + + pp.Add(peer) + defer pp.Remove(peer) + return peer.Run(handle) + } +} + +func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { + conf := adapters.RandomNodeConfig() + return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) +} + +func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: id, + }, + }, + }, + { + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: proto, + Peer: id, + }, + }, + }, + } +} + +func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + // TODO: make this more than one handshake + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, proto)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestProtoHandshakeVersionMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error())) +} + +func TestProtoHandshakeNetworkIDMismatch(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "421"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 421 (!= 420)").Error())) +} + +func TestProtoHandshakeSuccess(t *testing.T) { + runProtoHandshake(t, &protoHandshake{42, "420"}) +} + +func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Expects: []p2ptest.Expect{ + { + Code: 1, + Msg: &hs0{42}, + Peer: id, + }, + }, + }, + { + Triggers: []p2ptest.Trigger{ + { + Code: 1, + Msg: &hs0{resp}, + Peer: id, + }, + }, + }, + } +} + +func runModuleHandshake(t *testing.T, resp uint, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + id := s.IDs[0] + if err := s.TestExchanges(protoHandshakeExchange(id, &protoHandshake{42, "420"})...); err != nil { + t.Fatal(err) + } + if err := s.TestExchanges(moduleHandshakeExchange(id, resp)...); err != nil { + t.Fatal(err) + } + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } +} + +func TestModuleHandshakeError(t *testing.T) { + runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42")) +} + +func TestModuleHandshakeSuccess(t *testing.T) { + runModuleHandshake(t, 42) +} + +// testing complex interactions over multiple peers, relaying, dropping +func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { + + return []p2ptest.Exchange{ + { + Label: "primary handshake", + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + }, + { + Label: "module handshake", + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: a, + }, + { + Code: 0, + Msg: &protoHandshake{42, "420"}, + Peer: b, + }, + }, + Expects: []p2ptest.Expect{ + { + Code: 1, + Msg: &hs0{42}, + Peer: a, + }, + { + Code: 1, + Msg: &hs0{42}, + Peer: b, + }, + }, + }, + + {Label: "alternative module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{41}, Peer: a}, + {Code: 1, Msg: &hs0{41}, Peer: b}}}, + {Label: "repeated module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{1}, Peer: a}}}, + {Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{{Code: 1, Msg: &hs0{43}, Peer: a}}}} +} + +func runMultiplePeers(t *testing.T, peer int, errs ...error) { + pp := p2ptest.NewTestPeerPool() + s := protocolTester(t, pp) + + if err := s.TestExchanges(testMultiPeerSetup(s.IDs[0], s.IDs[1])...); err != nil { + t.Fatal(err) + } + // after some exchanges of messages, we can test state changes + // here this is simply demonstrated by the peerPool + // after the handshake negotiations peers must be added to the pool + // time.Sleep(1) + tick := time.NewTicker(10 * time.Millisecond) + timeout := time.NewTimer(1 * time.Second) +WAIT: + for { + select { + case <-tick.C: + if pp.Has(s.IDs[0]) { + break WAIT + } + case <-timeout.C: + t.Fatal("timeout") + } + } + if !pp.Has(s.IDs[1]) { + t.Fatalf("missing peer test-1: %v (%v)", pp, s.IDs) + } + + // peer 0 sends kill request for peer with index + err := s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 2, + Msg: &kill{s.IDs[peer]}, + Peer: s.IDs[0], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // the peer not killed sends a drop request + err = s.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 3, + Msg: &drop{}, + Peer: s.IDs[(peer+1)%2], + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + // check the actual discconnect errors on the individual peers + var disconnects []*p2ptest.Disconnect + for i, err := range errs { + disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) + } + if err := s.TestDisconnected(disconnects...); err != nil { + t.Fatal(err) + } + // test if disconnected peers have been removed from peerPool + if pp.Has(s.IDs[peer]) { + t.Fatalf("peer test-%v not dropped: %v (%v)", peer, pp, s.IDs) + } + +} + +func TestMultiplePeersDropSelf(t *testing.T) { + runMultiplePeers(t, 0, + fmt.Errorf("subprotocol error"), + fmt.Errorf("Message handler error: (msg code 3): dropped"), + ) +} + +func TestMultiplePeersDropOther(t *testing.T) { + runMultiplePeers(t, 1, + fmt.Errorf("Message handler error: (msg code 3): dropped"), + fmt.Errorf("subprotocol error"), + ) +} diff --git a/client/p2p/rlpx.go b/client/p2p/rlpx.go index 41aefc5f7..87729adc5 100644 --- a/client/p2p/rlpx.go +++ b/client/p2p/rlpx.go @@ -92,8 +92,14 @@ func (t *rlpx) close(err error) { // Tell the remote end why we're disconnecting if possible. if t.rw != nil { if r, ok := err.(DiscReason); ok && r != DiscNetworkError { - t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)) - SendItems(t.rw, discMsg, r) + // rlpx tries to send DiscReason to disconnected peer + // if the connection is net.Pipe (in-memory simulation) + // it hangs forever, since net.Pipe does not implement + // a write deadline. Because of this only try to send + // the disconnect reason message if there is no error. + if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil { + SendItems(t.rw, discMsg, r) + } } } t.fd.Close() @@ -153,6 +159,10 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, return &hs, nil } +// doEncHandshake runs the protocol handshake using authenticated +// messages. the protocol handshake is the first authenticated message +// and also verifies whether the encryption handshake 'worked' and the +// remote side actually provided the right public key. func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) { var ( sec secrets @@ -469,7 +479,7 @@ func readHandshakeMsg(msg plainDecoder, plainSize int, prv *ecdsa.PrivateKey, r } // Attempt decoding pre-EIP-8 "plain" format. key := ecies.ImportECDSA(prv) - if dec, err := key.Decrypt(rand.Reader, buf, nil, nil); err == nil { + if dec, err := key.Decrypt(buf, nil, nil); err == nil { msg.decodePlain(dec) return buf, nil } @@ -483,7 +493,7 @@ func readHandshakeMsg(msg plainDecoder, plainSize int, prv *ecdsa.PrivateKey, r if _, err := io.ReadFull(r, buf[plainSize:]); err != nil { return buf, err } - dec, err := key.Decrypt(rand.Reader, buf[2:], nil, prefix) + dec, err := key.Decrypt(buf[2:], nil, prefix) if err != nil { return buf, err } diff --git a/client/p2p/server.go b/client/p2p/server.go index 879b0fa22..a4a5b3f82 100644 --- a/client/p2p/server.go +++ b/client/p2p/server.go @@ -20,15 +20,12 @@ import ( ) const ( - defaultDialTimeout = 15 * time.Second - refreshPeersInterval = 30 * time.Second - staticPeerCheckInterval = 15 * time.Second + defaultDialTimeout = 15 * time.Second - // Maximum number of concurrently handshaking inbound connections. - maxAcceptConns = 50 - - // Maximum number of concurrently dialing outbound connections. - maxActiveDialTasks = 16 + // Connectivity defaults. + maxActiveDialTasks = 16 + defaultMaxPendingPeers = 50 + defaultDialRatio = 3 // Maximum time allowed for reading a complete message. // This is effectively the amount of time a connection can be idle. @@ -54,6 +51,11 @@ type Config struct { // Zero defaults to preset values. MaxPendingPeers int `toml:",omitempty"` + // DialRatio controls the ratio of inbound to dialed connections. + // Example: a DialRatio of 2 allows 1/2 of connections to be dialed. + // Setting DialRatio to zero defaults it to 3. + DialRatio int `toml:",omitempty"` + // NoDiscovery can be used to disable the peer discovery mechanism. // Disabling is useful for protocol debugging (manual topology). NoDiscovery bool @@ -123,6 +125,9 @@ type Config struct { // If EnableMsgEvents is set then the server will emit PeerEvents // whenever a message is sent to or received from a peer EnableMsgEvents bool + + // Logger is a custom logger to use with the p2p.Server. + Logger log.Logger `toml:",omitempty"` } // Server manages all peer connections. @@ -156,6 +161,7 @@ type Server struct { delpeer chan peerDrop loopWG sync.WaitGroup // loop, listenLoop peerFeed event.Feed + log log.Logger } type peerOpFunc func(map[discover.NodeID]*Peer) @@ -334,6 +340,32 @@ func (srv *Server) Stop() { srv.loopWG.Wait() } +// sharedUDPConn implements a shared connection. Write sends messages to the underlying connection while read returns +// messages that were found unprocessable and sent to the unhandled channel by the primary listener. +type sharedUDPConn struct { + *net.UDPConn + unhandled chan discover.ReadPacket +} + +// ReadFromUDP implements discv5.conn +func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { + packet, ok := <-s.unhandled + if !ok { + return 0, nil, fmt.Errorf("Connection was closed") + } + l := len(packet.Data) + if l > len(b) { + l = len(b) + } + copy(b[:l], packet.Data[:l]) + return l, packet.Addr, nil +} + +// Close implements discv5.conn +func (s *sharedUDPConn) Close() error { + return nil +} + // Start starts running the server. // Servers can not be re-used after stopping. func (srv *Server) Start() (err error) { @@ -343,7 +375,11 @@ func (srv *Server) Start() (err error) { return errors.New("server already running") } srv.running = true - log.Info("Starting P2P networking") + srv.log = srv.Config.Logger + if srv.log == nil { + srv.log = log.New() + } + srv.log.Info("Starting P2P networking") // static fields if srv.PrivateKey == nil { @@ -364,20 +400,66 @@ func (srv *Server) Start() (err error) { srv.peerOp = make(chan peerOpFunc) srv.peerOpDone = make(chan struct{}) - // node table - if !srv.NoDiscovery { - ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict) + var ( + conn *net.UDPConn + sconn *sharedUDPConn + realaddr *net.UDPAddr + unhandled chan discover.ReadPacket + ) + + if !srv.NoDiscovery || srv.DiscoveryV5 { + addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr) + if err != nil { + return err + } + conn, err = net.ListenUDP("udp", addr) if err != nil { return err } - if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil { + realaddr = conn.LocalAddr().(*net.UDPAddr) + if srv.NAT != nil { + if !realaddr.IP.IsLoopback() { + go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") + } + // TODO: react to external IP changes over time. + if ext, err := srv.NAT.ExternalIP(); err == nil { + realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} + } + } + } + + if !srv.NoDiscovery && srv.DiscoveryV5 { + unhandled = make(chan discover.ReadPacket, 100) + sconn = &sharedUDPConn{conn, unhandled} + } + + // node table + if !srv.NoDiscovery { + cfg := discover.Config{ + PrivateKey: srv.PrivateKey, + AnnounceAddr: realaddr, + NodeDBPath: srv.NodeDatabase, + NetRestrict: srv.NetRestrict, + Bootnodes: srv.BootstrapNodes, + Unhandled: unhandled, + } + ntab, err := discover.ListenUDP(conn, cfg) + if err != nil { return err } srv.ntab = ntab } if srv.DiscoveryV5 { - ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase) + var ( + ntab *discv5.Network + err error + ) + if sconn != nil { + ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + } else { + ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) + } if err != nil { return err } @@ -387,10 +469,7 @@ func (srv *Server) Start() (err error) { srv.DiscV5 = ntab } - dynPeers := (srv.MaxPeers + 1) / 2 - if srv.NoDiscovery { - dynPeers = 0 - } + dynPeers := srv.maxDialedConns() dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) // handshake @@ -405,7 +484,7 @@ func (srv *Server) Start() (err error) { } } if srv.NoDial && srv.ListenAddr == "" { - log.Warn("P2P server will be useless, neither dialing nor listening") + srv.log.Warn("P2P server will be useless, neither dialing nor listening") } srv.loopWG.Add(1) @@ -447,6 +526,7 @@ func (srv *Server) run(dialstate dialer) { defer srv.loopWG.Done() var ( peers = make(map[discover.NodeID]*Peer) + inboundCount = 0 trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) taskdone = make(chan task, maxActiveDialTasks) runningTasks []task @@ -473,7 +553,7 @@ func (srv *Server) run(dialstate dialer) { i := 0 for ; len(runningTasks) < maxActiveDialTasks && i < len(ts); i++ { t := ts[i] - log.Trace("New dial task", "task", t) + srv.log.Trace("New dial task", "task", t) go func() { t.Do(srv); taskdone <- t }() runningTasks = append(runningTasks, t) } @@ -501,13 +581,13 @@ running: // This channel is used by AddPeer to add to the // ephemeral static peer list. Add it to the dialer, // it will keep the node connected. - log.Debug("Adding static node", "node", n) + srv.log.Debug("Adding static node", "node", n) dialstate.addStatic(n) case n := <-srv.removestatic: // This channel is used by RemovePeer to send a // disconnect request to a peer and begin the // stop keeping the node connected - log.Debug("Removing static node", "node", n) + srv.log.Debug("Removing static node", "node", n) dialstate.removeStatic(n) if p, ok := peers[n.ID]; ok { p.Disconnect(DiscRequested) @@ -520,7 +600,7 @@ running: // A task got done. Tell dialstate about it so it // can update its state and remove it from the active // tasks list. - log.Trace("Dial task done", "task", t) + srv.log.Trace("Dial task done", "task", t) dialstate.taskDone(t, time.Now()) delTask(t) case c := <-srv.posthandshake: @@ -532,14 +612,14 @@ running: } // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. select { - case c.cont <- srv.encHandshakeChecks(peers, c): + case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c): case <-srv.quit: break running } case c := <-srv.addpeer: // At this point the connection is past the protocol handshake. // Its capabilities are known and the remote identity is verified. - err := srv.protoHandshakeChecks(peers, c) + err := srv.protoHandshakeChecks(peers, inboundCount, c) if err == nil { // The handshakes are done and it passed all checks. p := newPeer(c, srv.Protocols) @@ -549,9 +629,12 @@ running: p.events = &srv.peerFeed } name := truncateName(c.name) - log.Debug("Adding p2p peer", "id", c.id, "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) - peers[c.id] = p + srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) go srv.runPeer(p) + peers[c.id] = p + if p.Inbound() { + inboundCount++ + } } // The dialer logic relies on the assumption that // dial tasks complete after the peer has been added or @@ -566,10 +649,13 @@ running: d := common.PrettyDuration(mclock.Now() - pd.created) pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err) delete(peers, pd.ID()) + if pd.Inbound() { + inboundCount-- + } } } - log.Trace("P2P networking is spinning down") + srv.log.Trace("P2P networking is spinning down") // Terminate discovery. If there is a running lookup it will terminate soon. if srv.ntab != nil { @@ -592,20 +678,22 @@ running: } } -func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { +func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { // Drop connections with no matching protocols. if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { return DiscUselessPeer } // Repeat the encryption handshake checks because the // peer set might have changed between the handshakes. - return srv.encHandshakeChecks(peers, c) + return srv.encHandshakeChecks(peers, inboundCount, c) } -func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { +func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { switch { case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers + case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): + return DiscTooManyPeers case peers[c.id] != nil: return DiscAlreadyConnected case c.id == srv.Self().ID: @@ -615,6 +703,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) } } +func (srv *Server) maxInboundConns() int { + return srv.MaxPeers - srv.maxDialedConns() +} + +func (srv *Server) maxDialedConns() int { + if srv.NoDiscovery || srv.NoDial { + return 0 + } + r := srv.DialRatio + if r == 0 { + r = defaultDialRatio + } + return srv.MaxPeers / r +} + type tempError interface { Temporary() bool } @@ -623,12 +726,9 @@ type tempError interface { // inbound connections. func (srv *Server) listenLoop() { defer srv.loopWG.Done() - log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) + srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) - // This channel acts as a semaphore limiting - // active inbound connections that are lingering pre-handshake. - // If all slots are taken, no further connections are accepted. - tokens := maxAcceptConns + tokens := defaultMaxPendingPeers if srv.MaxPendingPeers > 0 { tokens = srv.MaxPendingPeers } @@ -648,10 +748,10 @@ func (srv *Server) listenLoop() { for { fd, err = srv.listener.Accept() if tempErr, ok := err.(tempError); ok && tempErr.Temporary() { - log.Debug("Temporary read error", "err", err) + srv.log.Debug("Temporary read error", "err", err) continue } else if err != nil { - log.Debug("Read error", "err", err) + srv.log.Debug("Read error", "err", err) return } break @@ -660,7 +760,7 @@ func (srv *Server) listenLoop() { // Reject connections that do not match NetRestrict. if srv.NetRestrict != nil { if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { - log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr()) + srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr()) fd.Close() slots <- struct{}{} continue @@ -668,10 +768,7 @@ func (srv *Server) listenLoop() { } fd = newMeteredConn(fd, true) - log.Trace("Accepted connection", "addr", fd.RemoteAddr()) - - // Spawn the handler. It will give the slot back when the connection - // has been established. + srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) go func() { srv.SetupConn(fd, inboundConn, nil) slots <- struct{}{} @@ -682,55 +779,65 @@ func (srv *Server) listenLoop() { // SetupConn runs the handshakes and attempts to add the connection // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. -func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) { +func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) error { + self := srv.Self() + if self == nil { + return errors.New("shutdown") + } + c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)} + err := srv.setupConn(c, flags, dialDest) + if err != nil { + c.close(err) + srv.log.Trace("Setting up connection failed", "id", c.id, "err", err) + } + return err +} + +func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) error { // Prevent leftover pending conns from entering the handshake. srv.lock.Lock() running := srv.running srv.lock.Unlock() - c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)} if !running { - c.close(errServerStopped) - return + return errServerStopped } // Run the encryption handshake. var err error if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil { - log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) - c.close(err) - return + srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) + return err } - clog := log.New("id", c.id, "addr", c.fd.RemoteAddr(), "conn", c.flags) + clog := srv.log.New("id", c.id, "addr", c.fd.RemoteAddr(), "conn", c.flags) // For dialed connections, check that the remote public key matches. if dialDest != nil && c.id != dialDest.ID { - c.close(DiscUnexpectedIdentity) clog.Trace("Dialed identity mismatch", "want", c, dialDest.ID) - return + return DiscUnexpectedIdentity } - if err := srv.checkpoint(c, srv.posthandshake); err != nil { + err = srv.checkpoint(c, srv.posthandshake) + if err != nil { clog.Trace("Rejected peer before protocol handshake", "err", err) - c.close(err) - return + return err } // Run the protocol handshake phs, err := c.doProtoHandshake(srv.ourHandshake) if err != nil { clog.Trace("Failed proto handshake", "err", err) - c.close(err) - return + return err } if phs.ID != c.id { clog.Trace("Wrong devp2p handshake identity", "err", phs.ID) - c.close(DiscUnexpectedIdentity) - return + return DiscUnexpectedIdentity } c.caps, c.name = phs.Caps, phs.Name - if err := srv.checkpoint(c, srv.addpeer); err != nil { + err = srv.checkpoint(c, srv.addpeer) + if err != nil { clog.Trace("Rejected peer", "err", err) - c.close(err) - return + return err } // If the checks completed successfully, runPeer has now been // launched by run. + clog.Trace("connection set up", "inbound", dialDest == nil) + return nil } func truncateName(s string) string { diff --git a/e2e/impl/cluster.go b/e2e/impl/cluster.go index 9d4102013..4e64c8b6a 100644 --- a/e2e/impl/cluster.go +++ b/e2e/impl/cluster.go @@ -56,7 +56,10 @@ func (ctx *Context) InitCluster(logsToStdout bool) error { } func (ctx *Context) RunCluster() error { - ctx.nodeRunner.StopAll() + if err := ctx.nodeRunner.StopAll(); err != nil { + return err + } + if err := ctx.runNodes(); err != nil { return err }