Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion common/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package common

import "math/big"
import (
"math/big"
"math/rand"
"reflect"
)

const (
hashLength = 32
Expand Down Expand Up @@ -48,6 +52,15 @@ func (h *Hash) Set(other Hash) {
}
}

// Generate implements testing/quick.Generator.
func (h Hash) Generate(rand *rand.Rand, size int) reflect.Value {
m := rand.Intn(len(h))
for i := len(h) - 1; i > m; i-- {
h[i] = byte(rand.Uint32())
}
return reflect.ValueOf(h)
}

/////////// Address
func BytesToAddress(b []byte) Address {
var a Address
Expand Down
4 changes: 2 additions & 2 deletions eth/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ func (s *Ethereum) NodeInfo() *NodeInfo {
NodeUrl: node.String(),
NodeID: node.ID.String(),
IP: node.IP.String(),
DiscPort: node.DiscPort,
TCPPort: node.TCPPort,
DiscPort: int(node.UDP),
TCPPort: int(node.TCP),
ListenAddr: s.net.ListenAddr,
Td: s.ChainManager().Td().String(),
}
Expand Down
2 changes: 2 additions & 0 deletions p2p/discover/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"time"

"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/rlp"
Expand Down Expand Up @@ -167,6 +168,7 @@ func (db *nodeDB) node(id NodeID) *Node {
glog.V(logger.Warn).Infof("failed to decode node RLP: %v", err)
return nil
}
node.sha = crypto.Sha3Hash(node.ID[:])
return node
}

Expand Down
74 changes: 43 additions & 31 deletions p2p/discover/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"os"
"path/filepath"
"reflect"
"testing"
"time"
)
Expand Down Expand Up @@ -85,11 +86,12 @@ func TestNodeDBInt64(t *testing.T) {
}

func TestNodeDBFetchStore(t *testing.T) {
node := &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.IP([]byte{192, 168, 0, 1}),
TCPPort: 30303,
}
node := newNode(
MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
net.IP{192, 168, 0, 1},
30303,
30303,
)
inst := time.Now()

db, _ := newNodeDB("", Version)
Expand Down Expand Up @@ -124,34 +126,40 @@ func TestNodeDBFetchStore(t *testing.T) {
}
if stored := db.node(node.ID); stored == nil {
t.Errorf("node: not found")
} else if !bytes.Equal(stored.ID[:], node.ID[:]) || !stored.IP.Equal(node.IP) || stored.TCPPort != node.TCPPort {
} else if !reflect.DeepEqual(stored, node) {
t.Errorf("node: data mismatch: have %v, want %v", stored, node)
}
}

var nodeDBSeedQueryNodes = []struct {
node Node
node *Node
pong time.Time
}{
{
node: Node{
ID: MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: []byte{127, 0, 0, 1},
},
node: newNode(
MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
net.IP{127, 0, 0, 1},
30303,
30303,
),
pong: time.Now().Add(-2 * time.Second),
},
{
node: Node{
ID: MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: []byte{127, 0, 0, 2},
},
node: newNode(
MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
net.IP{127, 0, 0, 2},
30303,
30303,
),
pong: time.Now().Add(-3 * time.Second),
},
{
node: Node{
ID: MustHexID("0x03d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: []byte{127, 0, 0, 3},
},
node: newNode(
MustHexID("0x03d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
net.IP{127, 0, 0, 3},
30303,
30303,
),
pong: time.Now().Add(-1 * time.Second),
},
}
Expand All @@ -162,7 +170,7 @@ func TestNodeDBSeedQuery(t *testing.T) {

// Insert a batch of nodes for querying
for i, seed := range nodeDBSeedQueryNodes {
if err := db.updateNode(&seed.node); err != nil {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
}
Expand Down Expand Up @@ -202,7 +210,7 @@ func TestNodeDBSeedQueryContinuation(t *testing.T) {

// Insert a batch of nodes for querying
for i, seed := range nodeDBSeedQueryNodes {
if err := db.updateNode(&seed.node); err != nil {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
}
Expand Down Expand Up @@ -266,22 +274,26 @@ func TestNodeDBPersistency(t *testing.T) {
}

var nodeDBExpirationNodes = []struct {
node Node
node *Node
pong time.Time
exp bool
}{
{
node: Node{
ID: MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: []byte{127, 0, 0, 1},
},
node: newNode(
MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
net.IP{127, 0, 0, 1},
30303,
30303,
),
pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute),
exp: false,
}, {
node: Node{
ID: MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: []byte{127, 0, 0, 2},
},
node: newNode(
MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
net.IP{127, 0, 0, 2},
30303,
30303,
),
pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute),
exp: true,
},
Expand All @@ -293,7 +305,7 @@ func TestNodeDBExpiration(t *testing.T) {

// Add all the test nodes and set their last pong time
for i, seed := range nodeDBExpirationNodes {
if err := db.updateNode(&seed.node); err != nil {
if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
Expand Down
98 changes: 48 additions & 50 deletions p2p/discover/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,62 @@ import (
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
"math/rand"
"net"
"net/url"
"strconv"
"strings"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/rlp"
)

const nodeIDBits = 512

// Node represents a host on the network.
type Node struct {
ID NodeID
IP net.IP
IP net.IP // len 4 for IPv4 or 16 for IPv6
UDP, TCP uint16 // port numbers
ID NodeID // the node's public key

DiscPort int // UDP listening port for discovery protocol
TCPPort int // TCP listening port for RLPx
// This is a cached copy of sha3(ID) which is used for node
// distance calculations. This is part of Node in order to make it
// possible to write tests that need a node at a certain distance.
// In those tests, the content of sha will not actually correspond
// with ID.
sha common.Hash
}

func newNode(id NodeID, addr *net.UDPAddr) *Node {
func newNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
return &Node{
ID: id,
IP: addr.IP,
DiscPort: addr.Port,
TCPPort: addr.Port,
IP: ip,
UDP: udpPort,
TCP: tcpPort,
ID: id,
sha: crypto.Sha3Hash(id[:]),
}
}

func (n *Node) isValid() bool {
// TODO: don't accept localhost, LAN addresses from internet hosts
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
}

func (n *Node) addr() *net.UDPAddr {
return &net.UDPAddr{IP: n.IP, Port: n.DiscPort}
return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
}

// The string representation of a Node is a URL.
// Please see ParseNode for a description of the format.
func (n *Node) String() string {
addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
u := url.URL{
Scheme: "enode",
User: url.User(fmt.Sprintf("%x", n.ID[:])),
Host: addr.String(),
}
if n.DiscPort != n.TCPPort {
u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
if n.UDP != n.TCP {
u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
}
return u.String()
}
Expand All @@ -80,36 +83,47 @@ func (n *Node) String() string {
//
// enode://<hex node id>@10.3.58.6:30303?discport=30301
func ParseNode(rawurl string) (*Node, error) {
var n Node
var (
id NodeID
ip net.IP
tcpPort, udpPort uint64
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps use uint16 for consistency with Node.UDP/TCP and to catch potential future overflow bugs quicker

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nevermind, I see it's used due to strconv.ParseUInt returning uint64.

)
u, err := url.Parse(rawurl)
if u.Scheme != "enode" {
return nil, errors.New("invalid URL scheme, want \"enode\"")
}
// Parse the Node ID from the user portion.
if u.User == nil {
return nil, errors.New("does not contain node ID")
}
if n.ID, err = HexID(u.User.String()); err != nil {
if id, err = HexID(u.User.String()); err != nil {
return nil, fmt.Errorf("invalid node ID (%v)", err)
}
ip, port, err := net.SplitHostPort(u.Host)
// Parse the IP address.
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
return nil, fmt.Errorf("invalid host: %v", err)
}
if n.IP = net.ParseIP(ip); n.IP == nil {
if ip = net.ParseIP(host); ip == nil {
return nil, errors.New("invalid IP address")
}
if n.TCPPort, err = strconv.Atoi(port); err != nil {
// Ensure the IP is 4 bytes long for IPv4 addresses.
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
// Parse the port numbers.
if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
return nil, errors.New("invalid port")
}
udpPort = tcpPort
qv := u.Query()
if qv.Get("discport") == "" {
n.DiscPort = n.TCPPort
} else {
if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
if qv.Get("discport") != "" {
udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
if err != nil {
return nil, errors.New("invalid discport in query")
}
}
return &n, nil
return newNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
}

// MustParseNode parses a node URL. It panics if the URL is not valid.
Expand All @@ -121,22 +135,6 @@ func MustParseNode(rawurl string) *Node {
return n
}

func (n Node) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
}
func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
var ext rpcNode
if err = s.Decode(&ext); err == nil {
n.TCPPort = int(ext.Port)
n.DiscPort = int(ext.Port)
n.ID = ext.ID
if n.IP = net.ParseIP(ext.IP); n.IP == nil {
return errors.New("invalid IP string")
}
}
return err
}

// NodeID is a unique identifier for each node.
// The node identifier is a marshaled elliptic curve public key.
type NodeID [nodeIDBits / 8]byte
Expand Down Expand Up @@ -221,7 +219,7 @@ func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
// distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal.
func distcmp(target, a, b NodeID) int {
func distcmp(target, a, b common.Hash) int {
for i := range target {
da := a[i] ^ target[i]
db := b[i] ^ target[i]
Expand Down Expand Up @@ -271,7 +269,7 @@ var lzcount = [256]int{
}

// logdist returns the logarithmic distance between a and b, log2(a ^ b).
func logdist(a, b NodeID) int {
func logdist(a, b common.Hash) int {
lz := 0
for i := range a {
x := a[i] ^ b[i]
Expand All @@ -285,8 +283,8 @@ func logdist(a, b NodeID) int {
return len(a)*8 - lz
}

// randomID returns a random NodeID such that logdist(a, b) == n
func randomID(a NodeID, n int) (b NodeID) {
// hashAtDistance returns a random hash such that logdist(a, b) == n
func hashAtDistance(a common.Hash, n int) (b common.Hash) {
if n == 0 {
return a
}
Expand Down
Loading