diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go
index 254b1f7fb..c39c88bef 100644
--- a/accounts/abi/abi.go
+++ b/accounts/abi/abi.go
@@ -19,8 +19,13 @@ package abi
import (
"bytes"
"encoding/json"
+ "errors"
"fmt"
"io"
+ "math/big"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/crypto"
)
// The ABI holds information about a contract's context and available
@@ -30,6 +35,12 @@ type ABI struct {
Constructor Method
Methods map[string]Method
Events map[string]Event
+
+ // Additional "special" functions introduced in solidity v0.6.0.
+ // It's separated from the original default fallback. Each contract
+ // can only define one fallback and receive function.
+ Fallback Method // Note it's also used to represent legacy fallback before v0.6.0
+ Receive Method
}
// JSON returns a parsed ABI interface and error if it failed.
@@ -40,7 +51,6 @@ func JSON(reader io.Reader) (ABI, error) {
if err := dec.Decode(&abi); err != nil {
return ABI{}, err
}
-
return abi, nil
}
@@ -58,89 +68,257 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) {
return nil, err
}
return arguments, nil
-
}
method, exist := abi.Methods[name]
if !exist {
return nil, fmt.Errorf("method '%s' not found", name)
}
-
arguments, err := method.Inputs.Pack(args...)
if err != nil {
return nil, err
}
// Pack up the method ID too if not a constructor and return
- return append(method.Id(), arguments...), nil
+ return append(method.ID, arguments...), nil
}
-// Unpack output in v according to the abi specification
-func (abi ABI) Unpack(v interface{}, name string, output []byte) (err error) {
- if len(output) == 0 {
- return fmt.Errorf("abi: unmarshalling empty output")
- }
+func (abi ABI) getArguments(name string, data []byte) (Arguments, error) {
// since there can't be naming collisions with contracts and events,
// we need to decide whether we're calling a method or an event
+ var args Arguments
if method, ok := abi.Methods[name]; ok {
- if len(output)%32 != 0 {
- return fmt.Errorf("abi: improperly formatted output")
+ if len(data)%32 != 0 {
+ return nil, fmt.Errorf("abi: improperly formatted output: %s - Bytes: [%+v]", string(data), data)
}
- return method.Outputs.Unpack(v, output)
- } else if event, ok := abi.Events[name]; ok {
- return event.Inputs.Unpack(v, output)
+ args = method.Outputs
+ }
+ if event, ok := abi.Events[name]; ok {
+ args = event.Inputs
+ }
+ if args == nil {
+ return nil, errors.New("abi: could not locate named method or event")
+ }
+ return args, nil
+}
+
+// Unpack unpacks the output according to the abi specification.
+func (abi ABI) Unpack(name string, data []byte) ([]interface{}, error) {
+ args, err := abi.getArguments(name, data)
+ if err != nil {
+ return nil, err
}
- return fmt.Errorf("abi: could not locate named method or event")
+ return args.Unpack(data)
}
-// UnmarshalJSON implements json.Unmarshaler interface
+// UnpackIntoInterface unpacks the output in v according to the abi specification.
+// It performs an additional copy. Please only use, if you want to unpack into a
+// structure that does not strictly conform to the abi structure (e.g. has additional arguments)
+func (abi ABI) UnpackIntoInterface(v interface{}, name string, data []byte) error {
+ args, err := abi.getArguments(name, data)
+ if err != nil {
+ return err
+ }
+ unpacked, err := args.Unpack(data)
+ if err != nil {
+ return err
+ }
+ return args.Copy(v, unpacked)
+}
+
+// UnpackIntoMap unpacks a log into the provided map[string]interface{}.
+func (abi ABI) UnpackIntoMap(v map[string]interface{}, name string, data []byte) (err error) {
+ args, err := abi.getArguments(name, data)
+ if err != nil {
+ return err
+ }
+ return args.UnpackIntoMap(v, data)
+}
+
+// UnmarshalJSON implements json.Unmarshaler interface.
func (abi *ABI) UnmarshalJSON(data []byte) error {
var fields []struct {
- Type string
- Name string
- Constant bool
+ Type string
+ Name string
+ Inputs []Argument
+ Outputs []Argument
+
+ // Status indicator which can be: "pure", "view",
+ // "nonpayable" or "payable".
+ StateMutability string
+
+ // Deprecated Status indicators, but removed in v0.6.0.
+ Constant bool // True if function is either pure or view
+ Payable bool // True if function is payable
+
+ // Event relevant indicator represents the event is
+ // declared as anonymous.
Anonymous bool
- Inputs []Argument
- Outputs []Argument
}
-
if err := json.Unmarshal(data, &fields); err != nil {
return err
}
-
abi.Methods = make(map[string]Method)
abi.Events = make(map[string]Event)
for _, field := range fields {
switch field.Type {
case "constructor":
- abi.Constructor = Method{
- Inputs: field.Inputs,
+ abi.Constructor = NewMethod("", "", Constructor, field.StateMutability, field.Constant, field.Payable, field.Inputs, nil)
+ case "function":
+ name := abi.overloadedMethodName(field.Name)
+ abi.Methods[name] = NewMethod(name, field.Name, Function, field.StateMutability, field.Constant, field.Payable, field.Inputs, field.Outputs)
+ case "fallback":
+ // New introduced function type in v0.6.0, check more detail
+ // here https://solidity.readthedocs.io/en/v0.6.0/contracts.html#fallback-function
+ if abi.HasFallback() {
+ return errors.New("only single fallback is allowed")
}
- // empty defaults to function according to the abi spec
- case "function", "":
- abi.Methods[field.Name] = Method{
- Name: field.Name,
- Const: field.Constant,
- Inputs: field.Inputs,
- Outputs: field.Outputs,
+ abi.Fallback = NewMethod("", "", Fallback, field.StateMutability, field.Constant, field.Payable, nil, nil)
+ case "receive":
+ // New introduced function type in v0.6.0, check more detail
+ // here https://solidity.readthedocs.io/en/v0.6.0/contracts.html#fallback-function
+ if abi.HasReceive() {
+ return errors.New("only single receive is allowed")
}
- case "event":
- abi.Events[field.Name] = Event{
- Name: field.Name,
- Anonymous: field.Anonymous,
- Inputs: field.Inputs,
+ if field.StateMutability != "payable" {
+ return errors.New("the statemutability of receive can only be payable")
}
+ abi.Receive = NewMethod("", "", Receive, field.StateMutability, field.Constant, field.Payable, nil, nil)
+ case "event":
+ name := abi.overloadedEventName(field.Name)
+ abi.Events[name] = NewEvent(name, field.Name, field.Anonymous, field.Inputs)
+ default:
+ return fmt.Errorf("abi: could not recognize type %v of field %v", field.Type, field.Name)
}
}
-
return nil
}
-// MethodById looks up a method by the 4-byte id
-// returns nil if none found
+// overloadedMethodName returns the next available name for a given function.
+// Needed since solidity allows for function overload.
+//
+// e.g. if the abi contains Methods send, send1
+// overloadedMethodName would return send2 for input send.
+func (abi *ABI) overloadedMethodName(rawName string) string {
+ name := rawName
+ _, ok := abi.Methods[name]
+ for idx := 0; ok; idx++ {
+ name = fmt.Sprintf("%s%d", rawName, idx)
+ _, ok = abi.Methods[name]
+ }
+ return name
+}
+
+// overloadedEventName returns the next available name for a given event.
+// Needed since solidity allows for event overload.
+//
+// e.g. if the abi contains events received, received1
+// overloadedEventName would return received2 for input received.
+func (abi *ABI) overloadedEventName(rawName string) string {
+ name := rawName
+ _, ok := abi.Events[name]
+ for idx := 0; ok; idx++ {
+ name = fmt.Sprintf("%s%d", rawName, idx)
+ _, ok = abi.Events[name]
+ }
+ return name
+}
+
+// MethodById looks up a method by the 4-byte id,
+// returns nil if none found.
func (abi *ABI) MethodById(sigdata []byte) (*Method, error) {
+ if len(sigdata) < 4 {
+ return nil, fmt.Errorf("data too short (%d bytes) for abi method lookup", len(sigdata))
+ }
for _, method := range abi.Methods {
- if bytes.Equal(method.Id(), sigdata[:4]) {
+ if bytes.Equal(method.ID, sigdata[:4]) {
return &method, nil
}
}
return nil, fmt.Errorf("no method with id: %#x", sigdata[:4])
}
+
+// EventByID looks an event up by its topic hash in the
+// ABI and returns nil if none found.
+func (abi *ABI) EventByID(topic common.Hash) (*Event, error) {
+ for _, event := range abi.Events {
+ if bytes.Equal(event.ID.Bytes(), topic.Bytes()) {
+ return &event, nil
+ }
+ }
+ return nil, fmt.Errorf("no event with id: %#x", topic.Hex())
+}
+
+// HasFallback returns an indicator whether a fallback function is included.
+func (abi *ABI) HasFallback() bool {
+ return abi.Fallback.Type == Fallback
+}
+
+// HasReceive returns an indicator whether a receive function is included.
+func (abi *ABI) HasReceive() bool {
+ return abi.Receive.Type == Receive
+}
+
+// revertSelector is a special function selector for revert reason unpacking.
+var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4]
+
+// panicSelector is a special function selector for panic reason unpacking.
+var panicSelector = crypto.Keccak256([]byte("Panic(uint256)"))[:4]
+
+// panicReasons map is for readable panic codes
+// see this linkage for the deails
+// https://docs.soliditylang.org/en/v0.8.21/control-structures.html#panic-via-assert-and-error-via-require
+// the reason string list is copied from ether.js
+// https://github.com/ethers-io/ethers.js/blob/fa3a883ff7c88611ce766f58bdd4b8ac90814470/src.ts/abi/interface.ts#L207-L218
+var panicReasons = map[uint64]string{
+ 0x00: "generic panic",
+ 0x01: "assert(false)",
+ 0x11: "arithmetic underflow or overflow",
+ 0x12: "division or modulo by zero",
+ 0x21: "enum overflow",
+ 0x22: "invalid encoded storage byte array accessed",
+ 0x31: "out-of-bounds array access; popping on an empty array",
+ 0x32: "out-of-bounds access of an array or bytesN",
+ 0x41: "out of memory",
+ 0x51: "uninitialized function",
+}
+
+// UnpackRevert resolves the abi-encoded revert reason. According to the solidity
+// spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert,
+// the provided revert reason is abi-encoded as if it were a call to function
+// `Error(string)` or `Panic(uint256)`. So it's a special tool for it.
+func UnpackRevert(data []byte) (string, error) {
+ if len(data) < 4 {
+ return "", errors.New("invalid data for unpacking")
+ }
+ switch {
+ case bytes.Equal(data[:4], revertSelector):
+ typ, err := NewType("string", "", nil)
+ if err != nil {
+ return "", err
+ }
+ unpacked, err := (Arguments{{Type: typ}}).Unpack(data[4:])
+ if err != nil {
+ return "", err
+ }
+ return unpacked[0].(string), nil
+ case bytes.Equal(data[:4], panicSelector):
+ typ, err := NewType("uint256", "", nil)
+ if err != nil {
+ return "", err
+ }
+ unpacked, err := (Arguments{{Type: typ}}).Unpack(data[4:])
+ if err != nil {
+ return "", err
+ }
+ pCode := unpacked[0].(*big.Int)
+ // uint64 safety check for future
+ // but the code is not bigger than MAX(uint64) now
+ if pCode.IsUint64() {
+ if reason, ok := panicReasons[pCode.Uint64()]; ok {
+ return reason, nil
+ }
+ }
+ return fmt.Sprintf("unknown panic code: %#x", pCode), nil
+ default:
+ return "", errors.New("invalid data for unpacking")
+ }
+}
diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go
index 5a128bfe5..87eceb8b1 100644
--- a/accounts/abi/abi_test.go
+++ b/accounts/abi/abi_test.go
@@ -19,63 +19,113 @@ package abi
import (
"bytes"
"encoding/hex"
+ "errors"
"fmt"
- "log"
"math/big"
+ "reflect"
"strings"
"testing"
- "reflect"
-
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/crypto"
)
const jsondata = `
[
- { "type" : "function", "name" : "balance", "constant" : true },
- { "type" : "function", "name" : "send", "constant" : false, "inputs" : [ { "name" : "amount", "type" : "uint256" } ] }
+ { "type" : "function", "name" : ""},
+ { "type" : "function", "name" : "balance", "stateMutability" : "view" },
+ { "type" : "function", "name" : "send", "inputs" : [ { "name" : "amount", "type" : "uint256" } ] },
+ { "type" : "function", "name" : "test", "inputs" : [ { "name" : "number", "type" : "uint32" } ] },
+ { "type" : "function", "name" : "string", "inputs" : [ { "name" : "inputs", "type" : "string" } ] },
+ { "type" : "function", "name" : "bool", "inputs" : [ { "name" : "inputs", "type" : "bool" } ] },
+ { "type" : "function", "name" : "address", "inputs" : [ { "name" : "inputs", "type" : "address" } ] },
+ { "type" : "function", "name" : "uint64[2]", "inputs" : [ { "name" : "inputs", "type" : "uint64[2]" } ] },
+ { "type" : "function", "name" : "uint64[]", "inputs" : [ { "name" : "inputs", "type" : "uint64[]" } ] },
+ { "type" : "function", "name" : "int8", "inputs" : [ { "name" : "inputs", "type" : "int8" } ] },
+ { "type" : "function", "name" : "foo", "inputs" : [ { "name" : "inputs", "type" : "uint32" } ] },
+ { "type" : "function", "name" : "bar", "inputs" : [ { "name" : "inputs", "type" : "uint32" }, { "name" : "string", "type" : "uint16" } ] },
+ { "type" : "function", "name" : "slice", "inputs" : [ { "name" : "inputs", "type" : "uint32[2]" } ] },
+ { "type" : "function", "name" : "slice256", "inputs" : [ { "name" : "inputs", "type" : "uint256[2]" } ] },
+ { "type" : "function", "name" : "sliceAddress", "inputs" : [ { "name" : "inputs", "type" : "address[]" } ] },
+ { "type" : "function", "name" : "sliceMultiAddress", "inputs" : [ { "name" : "a", "type" : "address[]" }, { "name" : "b", "type" : "address[]" } ] },
+ { "type" : "function", "name" : "nestedArray", "inputs" : [ { "name" : "a", "type" : "uint256[2][2]" }, { "name" : "b", "type" : "address[]" } ] },
+ { "type" : "function", "name" : "nestedArray2", "inputs" : [ { "name" : "a", "type" : "uint8[][2]" } ] },
+ { "type" : "function", "name" : "nestedSlice", "inputs" : [ { "name" : "a", "type" : "uint8[][]" } ] },
+ { "type" : "function", "name" : "receive", "inputs" : [ { "name" : "memo", "type" : "bytes" }], "outputs" : [], "payable" : true, "stateMutability" : "payable" },
+ { "type" : "function", "name" : "fixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] },
+ { "type" : "function", "name" : "fixedArrBytes", "stateMutability" : "view", "inputs" : [ { "name" : "bytes", "type" : "bytes" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] },
+ { "type" : "function", "name" : "mixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type" : "uint256[2]" }, { "name" : "dynArr", "type" : "uint256[]" } ] },
+ { "type" : "function", "name" : "doubleFixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type" : "uint256[2]" }, { "name" : "fixedArr2", "type" : "uint256[3]" } ] },
+ { "type" : "function", "name" : "multipleMixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type" : "uint256[2]" }, { "name" : "dynArr", "type" : "uint256[]" }, { "name" : "fixedArr2", "type" : "uint256[3]" } ] },
+ { "type" : "function", "name" : "overloadedNames", "stateMutability" : "view", "inputs": [ { "components": [ { "internalType": "uint256", "name": "_f", "type": "uint256" }, { "internalType": "uint256", "name": "__f", "type": "uint256"}, { "internalType": "uint256", "name": "f", "type": "uint256"}],"internalType": "struct Overloader.F", "name": "f","type": "tuple"}]}
]`
-const jsondata2 = `
-[
- { "type" : "function", "name" : "balance", "constant" : true },
- { "type" : "function", "name" : "send", "constant" : false, "inputs" : [ { "name" : "amount", "type" : "uint256" } ] },
- { "type" : "function", "name" : "test", "constant" : false, "inputs" : [ { "name" : "number", "type" : "uint32" } ] },
- { "type" : "function", "name" : "string", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "string" } ] },
- { "type" : "function", "name" : "bool", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "bool" } ] },
- { "type" : "function", "name" : "address", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "address" } ] },
- { "type" : "function", "name" : "uint64[2]", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint64[2]" } ] },
- { "type" : "function", "name" : "uint64[]", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint64[]" } ] },
- { "type" : "function", "name" : "foo", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32" } ] },
- { "type" : "function", "name" : "bar", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32" }, { "name" : "string", "type" : "uint16" } ] },
- { "type" : "function", "name" : "slice", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32[2]" } ] },
- { "type" : "function", "name" : "slice256", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint256[2]" } ] },
- { "type" : "function", "name" : "sliceAddress", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "address[]" } ] },
- { "type" : "function", "name" : "sliceMultiAddress", "constant" : false, "inputs" : [ { "name" : "a", "type" : "address[]" }, { "name" : "b", "type" : "address[]" } ] }
-]`
+var (
+ Uint256, _ = NewType("uint256", "", nil)
+ Uint32, _ = NewType("uint32", "", nil)
+ Uint16, _ = NewType("uint16", "", nil)
+ String, _ = NewType("string", "", nil)
+ Bool, _ = NewType("bool", "", nil)
+ Bytes, _ = NewType("bytes", "", nil)
+ Address, _ = NewType("address", "", nil)
+ Uint64Arr, _ = NewType("uint64[]", "", nil)
+ AddressArr, _ = NewType("address[]", "", nil)
+ Int8, _ = NewType("int8", "", nil)
+ // Special types for testing
+ Uint32Arr2, _ = NewType("uint32[2]", "", nil)
+ Uint64Arr2, _ = NewType("uint64[2]", "", nil)
+ Uint256Arr, _ = NewType("uint256[]", "", nil)
+ Uint256Arr2, _ = NewType("uint256[2]", "", nil)
+ Uint256Arr3, _ = NewType("uint256[3]", "", nil)
+ Uint256ArrNested, _ = NewType("uint256[2][2]", "", nil)
+ Uint8ArrNested, _ = NewType("uint8[][2]", "", nil)
+ Uint8SliceNested, _ = NewType("uint8[][]", "", nil)
+ TupleF, _ = NewType("tuple", "struct Overloader.F", []ArgumentMarshaling{
+ {Name: "_f", Type: "uint256"},
+ {Name: "__f", Type: "uint256"},
+ {Name: "f", Type: "uint256"}})
+)
+
+var methods = map[string]Method{
+ "": NewMethod("", "", Function, "", false, false, nil, nil),
+ "balance": NewMethod("balance", "balance", Function, "view", false, false, nil, nil),
+ "send": NewMethod("send", "send", Function, "", false, false, []Argument{{"amount", Uint256, false}}, nil),
+ "test": NewMethod("test", "test", Function, "", false, false, []Argument{{"number", Uint32, false}}, nil),
+ "string": NewMethod("string", "string", Function, "", false, false, []Argument{{"inputs", String, false}}, nil),
+ "bool": NewMethod("bool", "bool", Function, "", false, false, []Argument{{"inputs", Bool, false}}, nil),
+ "address": NewMethod("address", "address", Function, "", false, false, []Argument{{"inputs", Address, false}}, nil),
+ "uint64[]": NewMethod("uint64[]", "uint64[]", Function, "", false, false, []Argument{{"inputs", Uint64Arr, false}}, nil),
+ "uint64[2]": NewMethod("uint64[2]", "uint64[2]", Function, "", false, false, []Argument{{"inputs", Uint64Arr2, false}}, nil),
+ "int8": NewMethod("int8", "int8", Function, "", false, false, []Argument{{"inputs", Int8, false}}, nil),
+ "foo": NewMethod("foo", "foo", Function, "", false, false, []Argument{{"inputs", Uint32, false}}, nil),
+ "bar": NewMethod("bar", "bar", Function, "", false, false, []Argument{{"inputs", Uint32, false}, {"string", Uint16, false}}, nil),
+ "slice": NewMethod("slice", "slice", Function, "", false, false, []Argument{{"inputs", Uint32Arr2, false}}, nil),
+ "slice256": NewMethod("slice256", "slice256", Function, "", false, false, []Argument{{"inputs", Uint256Arr2, false}}, nil),
+ "sliceAddress": NewMethod("sliceAddress", "sliceAddress", Function, "", false, false, []Argument{{"inputs", AddressArr, false}}, nil),
+ "sliceMultiAddress": NewMethod("sliceMultiAddress", "sliceMultiAddress", Function, "", false, false, []Argument{{"a", AddressArr, false}, {"b", AddressArr, false}}, nil),
+ "nestedArray": NewMethod("nestedArray", "nestedArray", Function, "", false, false, []Argument{{"a", Uint256ArrNested, false}, {"b", AddressArr, false}}, nil),
+ "nestedArray2": NewMethod("nestedArray2", "nestedArray2", Function, "", false, false, []Argument{{"a", Uint8ArrNested, false}}, nil),
+ "nestedSlice": NewMethod("nestedSlice", "nestedSlice", Function, "", false, false, []Argument{{"a", Uint8SliceNested, false}}, nil),
+ "receive": NewMethod("receive", "receive", Function, "payable", false, true, []Argument{{"memo", Bytes, false}}, []Argument{}),
+ "fixedArrStr": NewMethod("fixedArrStr", "fixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr", Uint256Arr2, false}}, nil),
+ "fixedArrBytes": NewMethod("fixedArrBytes", "fixedArrBytes", Function, "view", false, false, []Argument{{"bytes", Bytes, false}, {"fixedArr", Uint256Arr2, false}}, nil),
+ "mixedArrStr": NewMethod("mixedArrStr", "mixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr", Uint256Arr2, false}, {"dynArr", Uint256Arr, false}}, nil),
+ "doubleFixedArrStr": NewMethod("doubleFixedArrStr", "doubleFixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr1", Uint256Arr2, false}, {"fixedArr2", Uint256Arr3, false}}, nil),
+ "multipleMixedArrStr": NewMethod("multipleMixedArrStr", "multipleMixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr1", Uint256Arr2, false}, {"dynArr", Uint256Arr, false}, {"fixedArr2", Uint256Arr3, false}}, nil),
+ "overloadedNames": NewMethod("overloadedNames", "overloadedNames", Function, "view", false, false, []Argument{{"f", TupleF, false}}, nil),
+}
func TestReader(t *testing.T) {
- Uint256, _ := NewType("uint256")
- exp := ABI{
- Methods: map[string]Method{
- "balance": {
- "balance", true, nil, nil,
- },
- "send": {
- "send", false, []Argument{
- {"amount", Uint256, false},
- }, nil,
- },
- },
+ abi := ABI{
+ Methods: methods,
}
- abi, err := JSON(strings.NewReader(jsondata))
+ exp, err := JSON(strings.NewReader(jsondata))
if err != nil {
- t.Error(err)
+ t.Fatal(err)
}
- // deep equal fails for some reason
for name, expM := range exp.Methods {
gotM, exist := abi.Methods[name]
if !exist {
@@ -97,11 +147,58 @@ func TestReader(t *testing.T) {
}
}
-func TestTestNumbers(t *testing.T) {
- abi, err := JSON(strings.NewReader(jsondata2))
+func TestInvalidABI(t *testing.T) {
+ json := `[{ "type" : "function", "name" : "", "constant" : fals }]`
+ _, err := JSON(strings.NewReader(json))
+ if err == nil {
+ t.Fatal("invalid json should produce error")
+ }
+ json2 := `[{ "type" : "function", "name" : "send", "constant" : false, "inputs" : [ { "name" : "amount", "typ" : "uint256" } ] }]`
+ _, err = JSON(strings.NewReader(json2))
+ if err == nil {
+ t.Fatal("invalid json should produce error")
+ }
+}
+
+// TestConstructor tests a constructor function.
+// The test is based on the following contract:
+//
+// contract TestConstructor {
+// constructor(uint256 a, uint256 b) public{}
+// }
+func TestConstructor(t *testing.T) {
+ json := `[{ "inputs": [{"internalType": "uint256","name": "a","type": "uint256" },{ "internalType": "uint256","name": "b","type": "uint256"}],"stateMutability": "nonpayable","type": "constructor"}]`
+ method := NewMethod("", "", Constructor, "nonpayable", false, false, []Argument{{"a", Uint256, false}, {"b", Uint256, false}}, nil)
+ // Test from JSON
+ abi, err := JSON(strings.NewReader(json))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(abi.Constructor, method) {
+ t.Error("Missing expected constructor")
+ }
+ // Test pack/unpack
+ packed, err := abi.Pack("", big.NewInt(1), big.NewInt(2))
if err != nil {
t.Error(err)
- t.FailNow()
+ }
+ unpacked, err := abi.Constructor.Inputs.Unpack(packed)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !reflect.DeepEqual(unpacked[0], big.NewInt(1)) {
+ t.Error("Unable to pack/unpack from constructor")
+ }
+ if !reflect.DeepEqual(unpacked[1], big.NewInt(2)) {
+ t.Error("Unable to pack/unpack from constructor")
+ }
+}
+
+func TestTestNumbers(t *testing.T) {
+ abi, err := JSON(strings.NewReader(jsondata))
+ if err != nil {
+ t.Fatal(err)
}
if _, err := abi.Pack("balance"); err != nil {
@@ -135,73 +232,71 @@ func TestTestNumbers(t *testing.T) {
}
}
-func TestTestString(t *testing.T) {
- abi, err := JSON(strings.NewReader(jsondata2))
- if err != nil {
- t.Error(err)
- t.FailNow()
+func TestMethodSignature(t *testing.T) {
+ m := NewMethod("foo", "foo", Function, "", false, false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil)
+ exp := "foo(string,string)"
+ if m.Sig != exp {
+ t.Error("signature mismatch", exp, "!=", m.Sig)
}
- if _, err := abi.Pack("string", "hello world"); err != nil {
- t.Error(err)
+ idexp := crypto.Keccak256([]byte(exp))[:4]
+ if !bytes.Equal(m.ID, idexp) {
+ t.Errorf("expected ids to match %x != %x", m.ID, idexp)
}
-}
-func TestTestBool(t *testing.T) {
- abi, err := JSON(strings.NewReader(jsondata2))
- if err != nil {
- t.Error(err)
- t.FailNow()
+ m = NewMethod("foo", "foo", Function, "", false, false, []Argument{{"bar", Uint256, false}}, nil)
+ exp = "foo(uint256)"
+ if m.Sig != exp {
+ t.Error("signature mismatch", exp, "!=", m.Sig)
}
- if _, err := abi.Pack("bool", true); err != nil {
- t.Error(err)
+ // Method with tuple arguments
+ s, _ := NewType("tuple", "", []ArgumentMarshaling{
+ {Name: "a", Type: "int256"},
+ {Name: "b", Type: "int256[]"},
+ {Name: "c", Type: "tuple[]", Components: []ArgumentMarshaling{
+ {Name: "x", Type: "int256"},
+ {Name: "y", Type: "int256"},
+ }},
+ {Name: "d", Type: "tuple[2]", Components: []ArgumentMarshaling{
+ {Name: "x", Type: "int256"},
+ {Name: "y", Type: "int256"},
+ }},
+ })
+ m = NewMethod("foo", "foo", Function, "", false, false, []Argument{{"s", s, false}, {"bar", String, false}}, nil)
+ exp = "foo((int256,int256[],(int256,int256)[],(int256,int256)[2]),string)"
+ if m.Sig != exp {
+ t.Error("signature mismatch", exp, "!=", m.Sig)
}
}
-func TestTestSlice(t *testing.T) {
- abi, err := JSON(strings.NewReader(jsondata2))
+func TestOverloadedMethodSignature(t *testing.T) {
+ json := `[{"constant":true,"inputs":[{"name":"i","type":"uint256"},{"name":"j","type":"uint256"}],"name":"foo","outputs":[],"payable":false,"stateMutability":"pure","type":"function"},{"constant":true,"inputs":[{"name":"i","type":"uint256"}],"name":"foo","outputs":[],"payable":false,"stateMutability":"pure","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"i","type":"uint256"}],"name":"bar","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"i","type":"uint256"},{"indexed":false,"name":"j","type":"uint256"}],"name":"bar","type":"event"}]`
+ abi, err := JSON(strings.NewReader(json))
if err != nil {
- t.Error(err)
- t.FailNow()
- }
-
- slice := make([]uint64, 2)
- if _, err := abi.Pack("uint64[2]", slice); err != nil {
- t.Error(err)
- }
-
- if _, err := abi.Pack("uint64[]", slice); err != nil {
- t.Error(err)
- }
-}
-
-func TestMethodSignature(t *testing.T) {
- String, _ := NewType("string")
- m := Method{"foo", false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil}
- exp := "foo(string,string)"
- if m.Sig() != exp {
- t.Error("signature mismatch", exp, "!=", m.Sig())
- }
-
- idexp := crypto.Keccak256([]byte(exp))[:4]
- if !bytes.Equal(m.Id(), idexp) {
- t.Errorf("expected ids to match %x != %x", m.Id(), idexp)
+ t.Fatal(err)
}
-
- uintt, _ := NewType("uint256")
- m = Method{"foo", false, []Argument{{"bar", uintt, false}}, nil}
- exp = "foo(uint256)"
- if m.Sig() != exp {
- t.Error("signature mismatch", exp, "!=", m.Sig())
+ check := func(name string, expect string, method bool) {
+ if method {
+ if abi.Methods[name].Sig != expect {
+ t.Fatalf("The signature of overloaded method mismatch, want %s, have %s", expect, abi.Methods[name].Sig)
+ }
+ } else {
+ if abi.Events[name].Sig != expect {
+ t.Fatalf("The signature of overloaded event mismatch, want %s, have %s", expect, abi.Events[name].Sig)
+ }
+ }
}
+ check("foo", "foo(uint256,uint256)", true)
+ check("foo0", "foo(uint256)", true)
+ check("bar", "bar(uint256)", false)
+ check("bar0", "bar(uint256,uint256)", false)
}
func TestMultiPack(t *testing.T) {
- abi, err := JSON(strings.NewReader(jsondata2))
+ abi, err := JSON(strings.NewReader(jsondata))
if err != nil {
- t.Error(err)
- t.FailNow()
+ t.Fatal(err)
}
sig := crypto.Keccak256([]byte("bar(uint32,uint16)"))[:4]
@@ -211,10 +306,8 @@ func TestMultiPack(t *testing.T) {
packed, err := abi.Pack("bar", uint32(10), uint16(11))
if err != nil {
- t.Error(err)
- t.FailNow()
+ t.Fatal(err)
}
-
if !bytes.Equal(packed, sig) {
t.Errorf("expected %x got %x", sig, packed)
}
@@ -225,11 +318,11 @@ func ExampleJSON() {
abi, err := JSON(strings.NewReader(definition))
if err != nil {
- log.Fatalln(err)
+ panic(err)
}
out, err := abi.Pack("isBar", common.HexToAddress("01"))
if err != nil {
- log.Fatalln(err)
+ panic(err)
}
fmt.Printf("%x\n", out)
@@ -366,15 +459,7 @@ func TestInputVariableInputLength(t *testing.T) {
}
func TestInputFixedArrayAndVariableInputLength(t *testing.T) {
- const definition = `[
- { "type" : "function", "name" : "fixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] },
- { "type" : "function", "name" : "fixedArrBytes", "constant" : true, "inputs" : [ { "name" : "str", "type" : "bytes" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] },
- { "type" : "function", "name" : "mixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type": "uint256[2]" }, { "name" : "dynArr", "type": "uint256[]" } ] },
- { "type" : "function", "name" : "doubleFixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type": "uint256[2]" }, { "name" : "fixedArr2", "type": "uint256[3]" } ] },
- { "type" : "function", "name" : "multipleMixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type": "uint256[2]" }, { "name" : "dynArr", "type" : "uint256[]" }, { "name" : "fixedArr2", "type" : "uint256[3]" } ] }
- ]`
-
- abi, err := JSON(strings.NewReader(definition))
+ abi, err := JSON(strings.NewReader(jsondata))
if err != nil {
t.Error(err)
}
@@ -521,7 +606,7 @@ func TestInputFixedArrayAndVariableInputLength(t *testing.T) {
strvalue = common.RightPadBytes([]byte(strin), 32)
fixedarrin1value1 = common.LeftPadBytes(fixedarrin1[0].Bytes(), 32)
fixedarrin1value2 = common.LeftPadBytes(fixedarrin1[1].Bytes(), 32)
- dynarroffset = U256(big.NewInt(int64(256 + ((len(strin)/32)+1)*32)))
+ dynarroffset = math.U256Bytes(big.NewInt(int64(256 + ((len(strin)/32)+1)*32)))
dynarrlength = make([]byte, 32)
dynarrlength[31] = byte(len(dynarrin))
dynarrinvalue1 = common.LeftPadBytes(dynarrin[0].Bytes(), 32)
@@ -548,7 +633,7 @@ func TestInputFixedArrayAndVariableInputLength(t *testing.T) {
}
func TestDefaultFunctionParsing(t *testing.T) {
- const definition = `[{ "name" : "balance" }]`
+ const definition = `[{ "name" : "balance", "type" : "function" }]`
abi, err := JSON(strings.NewReader(definition))
if err != nil {
@@ -564,11 +649,11 @@ func TestBareEvents(t *testing.T) {
const definition = `[
{ "type" : "event", "name" : "balance" },
{ "type" : "event", "name" : "anon", "anonymous" : true},
- { "type" : "event", "name" : "args", "inputs" : [{ "indexed":false, "name":"arg0", "type":"uint256" }, { "indexed":true, "name":"arg1", "type":"address" }] }
+ { "type" : "event", "name" : "args", "inputs" : [{ "indexed":false, "name":"arg0", "type":"uint256" }, { "indexed":true, "name":"arg1", "type":"address" }] },
+ { "type" : "event", "name" : "tuple", "inputs" : [{ "indexed":false, "name":"t", "type":"tuple", "components":[{"name":"a", "type":"uint256"}] }, { "indexed":true, "name":"arg1", "type":"address" }] }
]`
- arg0, _ := NewType("uint256")
- arg1, _ := NewType("address")
+ tuple, _ := NewType("tuple", "", []ArgumentMarshaling{{Name: "a", Type: "uint256"}})
expectedEvents := map[string]struct {
Anonymous bool
@@ -577,8 +662,12 @@ func TestBareEvents(t *testing.T) {
"balance": {false, nil},
"anon": {true, nil},
"args": {false, []Argument{
- {Name: "arg0", Type: arg0, Indexed: false},
- {Name: "arg1", Type: arg1, Indexed: true},
+ {Name: "arg0", Type: Uint256, Indexed: false},
+ {Name: "arg1", Type: Address, Indexed: true},
+ }},
+ "tuple": {false, []Argument{
+ {Name: "t", Type: tuple, Indexed: false},
+ {Name: "arg1", Type: Address, Indexed: true},
}},
}
@@ -619,16 +708,19 @@ func TestBareEvents(t *testing.T) {
}
// TestUnpackEvent is based on this contract:
-// contract T {
-// event received(address sender, uint amount, bytes memo);
-// event receivedAddr(address sender);
-// function receive(bytes memo) external payable {
-// received(msg.sender, msg.value, memo);
-// receivedAddr(msg.sender);
-// }
-// }
+//
+// contract T {
+// event received(address sender, uint amount, bytes memo);
+// event receivedAddr(address sender);
+// function receive(bytes memo) external payable {
+// received(msg.sender, msg.value, memo);
+// receivedAddr(msg.sender);
+// }
+// }
+//
// When receive("X") is called with sender 0x00... and value 1, it produces this tx receipt:
-// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]}
+//
+// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]}
func TestUnpackEvent(t *testing.T) {
const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]`
abi, err := JSON(strings.NewReader(abiJSON))
@@ -646,70 +738,416 @@ func TestUnpackEvent(t *testing.T) {
}
type ReceivedEvent struct {
- Address common.Address
- Amount *big.Int
- Memo []byte
+ Sender common.Address
+ Amount *big.Int
+ Memo []byte
}
var ev ReceivedEvent
- err = abi.Unpack(&ev, "received", data)
+ err = abi.UnpackIntoInterface(&ev, "received", data)
if err != nil {
t.Error(err)
- } else {
- t.Logf("len(data): %d; received event: %+v", len(data), ev)
}
type ReceivedAddrEvent struct {
- Address common.Address
+ Sender common.Address
}
var receivedAddrEv ReceivedAddrEvent
- err = abi.Unpack(&receivedAddrEv, "receivedAddr", data)
+ err = abi.UnpackIntoInterface(&receivedAddrEv, "receivedAddr", data)
if err != nil {
t.Error(err)
- } else {
- t.Logf("len(data): %d; received event: %+v", len(data), receivedAddrEv)
}
}
-func TestABI_MethodById(t *testing.T) {
- const abiJSON = `[
- {"type":"function","name":"receive","constant":false,"inputs":[{"name":"memo","type":"bytes"}],"outputs":[],"payable":true,"stateMutability":"payable"},
- {"type":"event","name":"received","anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}]},
- {"type":"function","name":"fixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr","type":"uint256[2]"}]},
- {"type":"function","name":"fixedArrBytes","constant":true,"inputs":[{"name":"str","type":"bytes"},{"name":"fixedArr","type":"uint256[2]"}]},
- {"type":"function","name":"mixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr","type":"uint256[2]"},{"name":"dynArr","type":"uint256[]"}]},
- {"type":"function","name":"doubleFixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr1","type":"uint256[2]"},{"name":"fixedArr2","type":"uint256[3]"}]},
- {"type":"function","name":"multipleMixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr1","type":"uint256[2]"},{"name":"dynArr","type":"uint256[]"},{"name":"fixedArr2","type":"uint256[3]"}]},
- {"type":"function","name":"balance","constant":true},
- {"type":"function","name":"send","constant":false,"inputs":[{"name":"amount","type":"uint256"}]},
- {"type":"function","name":"test","constant":false,"inputs":[{"name":"number","type":"uint32"}]},
- {"type":"function","name":"string","constant":false,"inputs":[{"name":"inputs","type":"string"}]},
- {"type":"function","name":"bool","constant":false,"inputs":[{"name":"inputs","type":"bool"}]},
- {"type":"function","name":"address","constant":false,"inputs":[{"name":"inputs","type":"address"}]},
- {"type":"function","name":"uint64[2]","constant":false,"inputs":[{"name":"inputs","type":"uint64[2]"}]},
- {"type":"function","name":"uint64[]","constant":false,"inputs":[{"name":"inputs","type":"uint64[]"}]},
- {"type":"function","name":"foo","constant":false,"inputs":[{"name":"inputs","type":"uint32"}]},
- {"type":"function","name":"bar","constant":false,"inputs":[{"name":"inputs","type":"uint32"},{"name":"string","type":"uint16"}]},
- {"type":"function","name":"_slice","constant":false,"inputs":[{"name":"inputs","type":"uint32[2]"}]},
- {"type":"function","name":"__slice256","constant":false,"inputs":[{"name":"inputs","type":"uint256[2]"}]},
- {"type":"function","name":"sliceAddress","constant":false,"inputs":[{"name":"inputs","type":"address[]"}]},
- {"type":"function","name":"sliceMultiAddress","constant":false,"inputs":[{"name":"a","type":"address[]"},{"name":"b","type":"address[]"}]}
- ]
-`
+func TestUnpackEventIntoMap(t *testing.T) {
+ const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]`
abi, err := JSON(strings.NewReader(abiJSON))
if err != nil {
t.Fatal(err)
}
+
+ const hexdata = `000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158`
+ data, err := hex.DecodeString(hexdata)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(data)%32 == 0 {
+ t.Errorf("len(data) is %d, want a non-multiple of 32", len(data))
+ }
+
+ receivedMap := map[string]interface{}{}
+ expectedReceivedMap := map[string]interface{}{
+ "sender": common.HexToAddress("0x376c47978271565f56DEB45495afa69E59c16Ab2"),
+ "amount": big.NewInt(1),
+ "memo": []byte{88},
+ }
+ if err := abi.UnpackIntoMap(receivedMap, "received", data); err != nil {
+ t.Error(err)
+ }
+ if len(receivedMap) != 3 {
+ t.Error("unpacked `received` map expected to have length 3")
+ }
+ if receivedMap["sender"] != expectedReceivedMap["sender"] {
+ t.Error("unpacked `received` map does not match expected map")
+ }
+ if receivedMap["amount"].(*big.Int).Cmp(expectedReceivedMap["amount"].(*big.Int)) != 0 {
+ t.Error("unpacked `received` map does not match expected map")
+ }
+ if !bytes.Equal(receivedMap["memo"].([]byte), expectedReceivedMap["memo"].([]byte)) {
+ t.Error("unpacked `received` map does not match expected map")
+ }
+
+ receivedAddrMap := map[string]interface{}{}
+ if err = abi.UnpackIntoMap(receivedAddrMap, "receivedAddr", data); err != nil {
+ t.Error(err)
+ }
+ if len(receivedAddrMap) != 1 {
+ t.Error("unpacked `receivedAddr` map expected to have length 1")
+ }
+ if receivedAddrMap["sender"] != expectedReceivedMap["sender"] {
+ t.Error("unpacked `receivedAddr` map does not match expected map")
+ }
+}
+
+func TestUnpackMethodIntoMap(t *testing.T) {
+ const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[],"name":"send","outputs":[{"name":"amount","type":"uint256"}],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[{"name":"addr","type":"address"}],"name":"get","outputs":[{"name":"hash","type":"bytes"}],"payable":true,"stateMutability":"payable","type":"function"}]`
+ abi, err := JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+ const hexdata = `00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000015800000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000158000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000001580000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000015800000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000158`
+ data, err := hex.DecodeString(hexdata)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(data)%32 != 0 {
+ t.Errorf("len(data) is %d, want a multiple of 32", len(data))
+ }
+
+ // Tests a method with no outputs
+ receiveMap := map[string]interface{}{}
+ if err = abi.UnpackIntoMap(receiveMap, "receive", data); err != nil {
+ t.Error(err)
+ }
+ if len(receiveMap) > 0 {
+ t.Error("unpacked `receive` map expected to have length 0")
+ }
+
+ // Tests a method with only outputs
+ sendMap := map[string]interface{}{}
+ if err = abi.UnpackIntoMap(sendMap, "send", data); err != nil {
+ t.Error(err)
+ }
+ if len(sendMap) != 1 {
+ t.Error("unpacked `send` map expected to have length 1")
+ }
+ if sendMap["amount"].(*big.Int).Cmp(big.NewInt(1)) != 0 {
+ t.Error("unpacked `send` map expected `amount` value of 1")
+ }
+
+ // Tests a method with outputs and inputs
+ getMap := map[string]interface{}{}
+ if err = abi.UnpackIntoMap(getMap, "get", data); err != nil {
+ t.Error(err)
+ }
+ if len(getMap) != 1 {
+ t.Error("unpacked `get` map expected to have length 1")
+ }
+ expectedBytes := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0}
+ if !bytes.Equal(getMap["hash"].([]byte), expectedBytes) {
+ t.Errorf("unpacked `get` map expected `hash` value of %v", expectedBytes)
+ }
+}
+
+func TestUnpackIntoMapNamingConflict(t *testing.T) {
+ // Two methods have the same name
+ var abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"get","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[],"name":"send","outputs":[{"name":"amount","type":"uint256"}],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[{"name":"addr","type":"address"}],"name":"get","outputs":[{"name":"hash","type":"bytes"}],"payable":true,"stateMutability":"payable","type":"function"}]`
+ abi, err := JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+ var hexdata = `00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158`
+ data, err := hex.DecodeString(hexdata)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(data)%32 == 0 {
+ t.Errorf("len(data) is %d, want a non-multiple of 32", len(data))
+ }
+ getMap := map[string]interface{}{}
+ if err = abi.UnpackIntoMap(getMap, "get", data); err == nil {
+ t.Error("naming conflict between two methods; error expected")
+ }
+
+ // Two events have the same name
+ abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"received","type":"event"}]`
+ abi, err = JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+ hexdata = `000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158`
+ data, err = hex.DecodeString(hexdata)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(data)%32 == 0 {
+ t.Errorf("len(data) is %d, want a non-multiple of 32", len(data))
+ }
+ receivedMap := map[string]interface{}{}
+ if err = abi.UnpackIntoMap(receivedMap, "received", data); err != nil {
+ t.Error("naming conflict between two events; no error expected")
+ }
+
+ // Method and event have the same name
+ abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"received","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]`
+ abi, err = JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(data)%32 == 0 {
+ t.Errorf("len(data) is %d, want a non-multiple of 32", len(data))
+ }
+ if err = abi.UnpackIntoMap(receivedMap, "received", data); err == nil {
+ t.Error("naming conflict between an event and a method; error expected")
+ }
+
+ // Conflict is case sensitive
+ abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"received","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"Received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]`
+ abi, err = JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(data)%32 == 0 {
+ t.Errorf("len(data) is %d, want a non-multiple of 32", len(data))
+ }
+ expectedReceivedMap := map[string]interface{}{
+ "sender": common.HexToAddress("0x376c47978271565f56DEB45495afa69E59c16Ab2"),
+ "amount": big.NewInt(1),
+ "memo": []byte{88},
+ }
+ if err = abi.UnpackIntoMap(receivedMap, "Received", data); err != nil {
+ t.Error(err)
+ }
+ if len(receivedMap) != 3 {
+ t.Error("unpacked `received` map expected to have length 3")
+ }
+ if receivedMap["sender"] != expectedReceivedMap["sender"] {
+ t.Error("unpacked `received` map does not match expected map")
+ }
+ if receivedMap["amount"].(*big.Int).Cmp(expectedReceivedMap["amount"].(*big.Int)) != 0 {
+ t.Error("unpacked `received` map does not match expected map")
+ }
+ if !bytes.Equal(receivedMap["memo"].([]byte), expectedReceivedMap["memo"].([]byte)) {
+ t.Error("unpacked `received` map does not match expected map")
+ }
+}
+
+func TestABI_MethodById(t *testing.T) {
+ abi, err := JSON(strings.NewReader(jsondata))
+ if err != nil {
+ t.Fatal(err)
+ }
for name, m := range abi.Methods {
a := fmt.Sprintf("%v", m)
- m2, err := abi.MethodById(m.Id())
+ m2, err := abi.MethodById(m.ID)
if err != nil {
t.Fatalf("Failed to look up ABI method: %v", err)
}
b := fmt.Sprintf("%v", m2)
if a != b {
- t.Errorf("Method %v (id %v) not 'findable' by id in ABI", name, common.ToHex(m.Id()))
+ t.Errorf("Method %v (id %x) not 'findable' by id in ABI", name, m.ID)
}
}
+ // test unsuccessful lookups
+ if _, err = abi.MethodById(crypto.Keccak256()); err == nil {
+ t.Error("Expected error: no method with this id")
+ }
+ // Also test empty
+ if _, err := abi.MethodById([]byte{0x00}); err == nil {
+ t.Errorf("Expected error, too short to decode data")
+ }
+ if _, err := abi.MethodById([]byte{}); err == nil {
+ t.Errorf("Expected error, too short to decode data")
+ }
+ if _, err := abi.MethodById(nil); err == nil {
+ t.Errorf("Expected error, nil is short to decode data")
+ }
+}
+
+func TestABI_EventById(t *testing.T) {
+ tests := []struct {
+ name string
+ json string
+ event string
+ }{
+ {
+ name: "",
+ json: `[
+ {"type":"event","name":"received","anonymous":false,"inputs":[
+ {"indexed":false,"name":"sender","type":"address"},
+ {"indexed":false,"name":"amount","type":"uint256"},
+ {"indexed":false,"name":"memo","type":"bytes"}
+ ]
+ }]`,
+ event: "received(address,uint256,bytes)",
+ }, {
+ name: "",
+ json: `[
+ { "constant": true, "inputs": [], "name": "name", "outputs": [ { "name": "", "type": "string" } ], "payable": false, "stateMutability": "view", "type": "function" },
+ { "constant": false, "inputs": [ { "name": "_spender", "type": "address" }, { "name": "_value", "type": "uint256" } ], "name": "approve", "outputs": [ { "name": "", "type": "bool" } ], "payable": false, "stateMutability": "nonpayable", "type": "function" },
+ { "constant": true, "inputs": [], "name": "totalSupply", "outputs": [ { "name": "", "type": "uint256" } ], "payable": false, "stateMutability": "view", "type": "function" },
+ { "constant": false, "inputs": [ { "name": "_from", "type": "address" }, { "name": "_to", "type": "address" }, { "name": "_value", "type": "uint256" } ], "name": "transferFrom", "outputs": [ { "name": "", "type": "bool" } ], "payable": false, "stateMutability": "nonpayable", "type": "function" },
+ { "constant": true, "inputs": [], "name": "decimals", "outputs": [ { "name": "", "type": "uint8" } ], "payable": false, "stateMutability": "view", "type": "function" },
+ { "constant": true, "inputs": [ { "name": "_owner", "type": "address" } ], "name": "balanceOf", "outputs": [ { "name": "balance", "type": "uint256" } ], "payable": false, "stateMutability": "view", "type": "function" },
+ { "constant": true, "inputs": [], "name": "symbol", "outputs": [ { "name": "", "type": "string" } ], "payable": false, "stateMutability": "view", "type": "function" },
+ { "constant": false, "inputs": [ { "name": "_to", "type": "address" }, { "name": "_value", "type": "uint256" } ], "name": "transfer", "outputs": [ { "name": "", "type": "bool" } ], "payable": false, "stateMutability": "nonpayable", "type": "function" },
+ { "constant": true, "inputs": [ { "name": "_owner", "type": "address" }, { "name": "_spender", "type": "address" } ], "name": "allowance", "outputs": [ { "name": "", "type": "uint256" } ], "payable": false, "stateMutability": "view", "type": "function" },
+ { "payable": true, "stateMutability": "payable", "type": "fallback" },
+ { "anonymous": false, "inputs": [ { "indexed": true, "name": "owner", "type": "address" }, { "indexed": true, "name": "spender", "type": "address" }, { "indexed": false, "name": "value", "type": "uint256" } ], "name": "Approval", "type": "event" },
+ { "anonymous": false, "inputs": [ { "indexed": true, "name": "from", "type": "address" }, { "indexed": true, "name": "to", "type": "address" }, { "indexed": false, "name": "value", "type": "uint256" } ], "name": "Transfer", "type": "event" }
+ ]`,
+ event: "Transfer(address,address,uint256)",
+ },
+ }
+
+ for testnum, test := range tests {
+ abi, err := JSON(strings.NewReader(test.json))
+ if err != nil {
+ t.Error(err)
+ }
+
+ topic := test.event
+ topicID := crypto.Keccak256Hash([]byte(topic))
+ event, err := abi.EventByID(topicID)
+ if err != nil {
+ t.Fatalf("Failed to look up ABI method: %v, test #%d", err, testnum)
+ }
+ if event == nil {
+ t.Errorf("We should find a event for topic %s, test #%d", topicID.Hex(), testnum)
+ }
+
+ if event.ID != topicID {
+ t.Errorf("Event id %s does not match topic %s, test #%d", event.ID.Hex(), topicID.Hex(), testnum)
+ }
+
+ unknowntopicID := crypto.Keccak256Hash([]byte("unknownEvent"))
+ unknownEvent, err := abi.EventByID(unknowntopicID)
+ if err == nil {
+ t.Errorf("EventByID should return an error if a topic is not found, test #%d", testnum)
+ }
+ if unknownEvent != nil {
+ t.Errorf("We should not find any event for topic %s, test #%d", unknowntopicID.Hex(), testnum)
+ }
+ }
+}
+
+// TestDoubleDuplicateMethodNames checks that if transfer0 already exists, there won't be a name
+// conflict and that the second transfer method will be renamed transfer1.
+func TestDoubleDuplicateMethodNames(t *testing.T) {
+ abiJSON := `[{"constant":false,"inputs":[{"name":"to","type":"address"},{"name":"value","type":"uint256"}],"name":"transfer","outputs":[{"name":"ok","type":"bool"}],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"name":"to","type":"address"},{"name":"value","type":"uint256"},{"name":"data","type":"bytes"}],"name":"transfer0","outputs":[{"name":"ok","type":"bool"}],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"name":"to","type":"address"},{"name":"value","type":"uint256"},{"name":"data","type":"bytes"},{"name":"customFallback","type":"string"}],"name":"transfer","outputs":[{"name":"ok","type":"bool"}],"payable":false,"stateMutability":"nonpayable","type":"function"}]`
+ contractAbi, err := JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := contractAbi.Methods["transfer"]; !ok {
+ t.Fatalf("Could not find original method")
+ }
+ if _, ok := contractAbi.Methods["transfer0"]; !ok {
+ t.Fatalf("Could not find duplicate method")
+ }
+ if _, ok := contractAbi.Methods["transfer1"]; !ok {
+ t.Fatalf("Could not find duplicate method")
+ }
+ if _, ok := contractAbi.Methods["transfer2"]; ok {
+ t.Fatalf("Should not have found extra method")
+ }
+}
+
+// TestDoubleDuplicateEventNames checks that if send0 already exists, there won't be a name
+// conflict and that the second send event will be renamed send1.
+// The test runs the abi of the following contract.
+//
+// contract DuplicateEvent {
+// event send(uint256 a);
+// event send0();
+// event send();
+// }
+func TestDoubleDuplicateEventNames(t *testing.T) {
+ abiJSON := `[{"anonymous": false,"inputs": [{"indexed": false,"internalType": "uint256","name": "a","type": "uint256"}],"name": "send","type": "event"},{"anonymous": false,"inputs": [],"name": "send0","type": "event"},{ "anonymous": false, "inputs": [],"name": "send","type": "event"}]`
+ contractAbi, err := JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := contractAbi.Events["send"]; !ok {
+ t.Fatalf("Could not find original event")
+ }
+ if _, ok := contractAbi.Events["send0"]; !ok {
+ t.Fatalf("Could not find duplicate event")
+ }
+ if _, ok := contractAbi.Events["send1"]; !ok {
+ t.Fatalf("Could not find duplicate event")
+ }
+ if _, ok := contractAbi.Events["send2"]; ok {
+ t.Fatalf("Should not have found extra event")
+ }
+}
+
+// TestUnnamedEventParam checks that an event with unnamed parameters is
+// correctly handled.
+// The test runs the abi of the following contract.
+//
+// contract TestEvent {
+// event send(uint256, uint256);
+// }
+func TestUnnamedEventParam(t *testing.T) {
+ abiJSON := `[{ "anonymous": false, "inputs": [{ "indexed": false,"internalType": "uint256", "name": "","type": "uint256"},{"indexed": false,"internalType": "uint256","name": "","type": "uint256"}],"name": "send","type": "event"}]`
+ contractAbi, err := JSON(strings.NewReader(abiJSON))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ event, ok := contractAbi.Events["send"]
+ if !ok {
+ t.Fatalf("Could not find event")
+ }
+ if event.Inputs[0].Name != "arg0" {
+ t.Fatalf("Could not find input")
+ }
+ if event.Inputs[1].Name != "arg1" {
+ t.Fatalf("Could not find input")
+ }
+}
+
+func TestUnpackRevert(t *testing.T) {
+ t.Parallel()
+
+ var cases = []struct {
+ input string
+ expect string
+ expectErr error
+ }{
+ {"", "", errors.New("invalid data for unpacking")},
+ {"08c379a1", "", errors.New("invalid data for unpacking")},
+ {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil},
+ {"4e487b710000000000000000000000000000000000000000000000000000000000000000", "generic panic", nil},
+ {"4e487b7100000000000000000000000000000000000000000000000000000000000000ff", "unknown panic code: 0xff", nil},
+ }
+ for index, c := range cases {
+ t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) {
+ got, err := UnpackRevert(common.Hex2Bytes(c.input))
+ if c.expectErr != nil {
+ if err == nil {
+ t.Fatalf("Expected non-nil error")
+ }
+ if err.Error() != c.expectErr.Error() {
+ t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err)
+ }
+ return
+ }
+ if c.expect != got {
+ t.Fatalf("Output mismatch, want %v, got %v", c.expect, got)
+ }
+ })
+ }
}
diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go
index 512d8fdfa..2e48d539e 100644
--- a/accounts/abi/argument.go
+++ b/accounts/abi/argument.go
@@ -18,6 +18,7 @@ package abi
import (
"encoding/json"
+ "errors"
"fmt"
"reflect"
"strings"
@@ -33,41 +34,33 @@ type Argument struct {
type Arguments []Argument
-// UnmarshalJSON implements json.Unmarshaler interface
+type ArgumentMarshaling struct {
+ Name string
+ Type string
+ InternalType string
+ Components []ArgumentMarshaling
+ Indexed bool
+}
+
+// UnmarshalJSON implements json.Unmarshaler interface.
func (argument *Argument) UnmarshalJSON(data []byte) error {
- var extarg struct {
- Name string
- Type string
- Indexed bool
- }
- err := json.Unmarshal(data, &extarg)
+ var arg ArgumentMarshaling
+ err := json.Unmarshal(data, &arg)
if err != nil {
return fmt.Errorf("argument json err: %v", err)
}
- argument.Type, err = NewType(extarg.Type)
+ argument.Type, err = NewType(arg.Type, arg.InternalType, arg.Components)
if err != nil {
return err
}
- argument.Name = extarg.Name
- argument.Indexed = extarg.Indexed
+ argument.Name = arg.Name
+ argument.Indexed = arg.Indexed
return nil
}
-// LengthNonIndexed returns the number of arguments when not counting 'indexed' ones. Only events
-// can ever have 'indexed' arguments, it should always be false on arguments for method input/output
-func (arguments Arguments) LengthNonIndexed() int {
- out := 0
- for _, arg := range arguments {
- if !arg.Indexed {
- out++
- }
- }
- return out
-}
-
-// NonIndexed returns the arguments with indexed arguments filtered out
+// NonIndexed returns the arguments with indexed arguments filtered out.
func (arguments Arguments) NonIndexed() Arguments {
var ret []Argument
for _, arg := range arguments {
@@ -78,120 +71,126 @@ func (arguments Arguments) NonIndexed() Arguments {
return ret
}
-// isTuple returns true for non-atomic constructs, like (uint,uint) or uint[]
+// isTuple returns true for non-atomic constructs, like (uint,uint) or uint[].
func (arguments Arguments) isTuple() bool {
return len(arguments) > 1
}
-// Unpack performs the operation hexdata -> Go format
-func (arguments Arguments) Unpack(v interface{}, data []byte) error {
+// Unpack performs the operation hexdata -> Go format.
+func (arguments Arguments) Unpack(data []byte) ([]interface{}, error) {
+ if len(data) == 0 {
+ if len(arguments.NonIndexed()) != 0 {
+ return nil, errors.New("abi: attempting to unmarshall an empty string while arguments are expected")
+ }
+ return make([]interface{}, 0), nil
+ }
+ return arguments.UnpackValues(data)
+}
- // make sure the passed value is arguments pointer
- if reflect.Ptr != reflect.ValueOf(v).Kind() {
- return fmt.Errorf("abi: Unpack(non-pointer %T)", v)
+// UnpackIntoMap performs the operation hexdata -> mapping of argument name to argument value.
+func (arguments Arguments) UnpackIntoMap(v map[string]interface{}, data []byte) error {
+ // Make sure map is not nil
+ if v == nil {
+ return errors.New("abi: cannot unpack into a nil map")
+ }
+ if len(data) == 0 {
+ if len(arguments.NonIndexed()) != 0 {
+ return errors.New("abi: attempting to unmarshall an empty string while arguments are expected")
+ }
+ return nil // Nothing to unmarshal, return
}
marshalledValues, err := arguments.UnpackValues(data)
if err != nil {
return err
}
- if arguments.isTuple() {
- return arguments.unpackTuple(v, marshalledValues)
+ for i, arg := range arguments.NonIndexed() {
+ v[arg.Name] = marshalledValues[i]
}
- return arguments.unpackAtomic(v, marshalledValues)
+ return nil
}
-func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interface{}) error {
-
- var (
- value = reflect.ValueOf(v).Elem()
- typ = value.Type()
- kind = value.Kind()
- )
-
- if err := requireUnpackKind(value, typ, kind, arguments); err != nil {
- return err
+// Copy performs the operation go format -> provided struct.
+func (arguments Arguments) Copy(v interface{}, values []interface{}) error {
+ // make sure the passed value is arguments pointer
+ if reflect.Ptr != reflect.ValueOf(v).Kind() {
+ return fmt.Errorf("abi: Unpack(non-pointer %T)", v)
}
- // If the output interface is a struct, make sure names don't collide
- if kind == reflect.Struct {
- if err := requireUniqueStructFieldNames(arguments); err != nil {
- return err
+ if len(values) == 0 {
+ if len(arguments.NonIndexed()) != 0 {
+ return errors.New("abi: attempting to copy no values while arguments are expected")
}
+ return nil // Nothing to copy, return
}
- for i, arg := range arguments.NonIndexed() {
+ if arguments.isTuple() {
+ return arguments.copyTuple(v, values)
+ }
+ return arguments.copyAtomic(v, values[0])
+}
- reflectValue := reflect.ValueOf(marshalledValues[i])
+// unpackAtomic unpacks ( hexdata -> go ) a single value
+func (arguments Arguments) copyAtomic(v interface{}, marshalledValues interface{}) error {
+ dst := reflect.ValueOf(v).Elem()
+ src := reflect.ValueOf(marshalledValues)
- switch kind {
- case reflect.Struct:
- err := unpackStruct(value, reflectValue, arg)
- if err != nil {
- return err
- }
- case reflect.Slice, reflect.Array:
- if value.Len() < i {
- return fmt.Errorf("abi: insufficient number of arguments for unpack, want %d, got %d", len(arguments), value.Len())
+ if dst.Kind() == reflect.Struct {
+ return set(dst.Field(0), src)
+ }
+ return set(dst, src)
+}
+
+// copyTuple copies a batch of values from marshalledValues to v.
+func (arguments Arguments) copyTuple(v interface{}, marshalledValues []interface{}) error {
+ value := reflect.ValueOf(v).Elem()
+ nonIndexedArgs := arguments.NonIndexed()
+
+ switch value.Kind() {
+ case reflect.Struct:
+ argNames := make([]string, len(nonIndexedArgs))
+ for i, arg := range nonIndexedArgs {
+ argNames[i] = arg.Name
+ }
+ var err error
+ abi2struct, err := mapArgNamesToStructFields(argNames, value)
+ if err != nil {
+ return err
+ }
+ for i, arg := range nonIndexedArgs {
+ field := value.FieldByName(abi2struct[arg.Name])
+ if !field.IsValid() {
+ return fmt.Errorf("abi: field %s can't be found in the given value", arg.Name)
}
- v := value.Index(i)
- if err := requireAssignable(v, reflectValue); err != nil {
+ if err := set(field, reflect.ValueOf(marshalledValues[i])); err != nil {
return err
}
-
- if err := set(v.Elem(), reflectValue, arg); err != nil {
+ }
+ case reflect.Slice, reflect.Array:
+ if value.Len() < len(marshalledValues) {
+ return fmt.Errorf("abi: insufficient number of arguments for unpack, want %d, got %d", len(arguments), value.Len())
+ }
+ for i := range nonIndexedArgs {
+ if err := set(value.Index(i), reflect.ValueOf(marshalledValues[i])); err != nil {
return err
}
- default:
- return fmt.Errorf("abi:[2] cannot unmarshal tuple in to %v", typ)
}
+ default:
+ return fmt.Errorf("abi:[2] cannot unmarshal tuple in to %v", value.Type())
}
return nil
}
-// unpackAtomic unpacks ( hexdata -> go ) a single value
-func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues []interface{}) error {
- if len(marshalledValues) != 1 {
- return fmt.Errorf("abi: wrong length, expected single value, got %d", len(marshalledValues))
- }
- elem := reflect.ValueOf(v).Elem()
- kind := elem.Kind()
- reflectValue := reflect.ValueOf(marshalledValues[0])
-
- if kind == reflect.Struct {
- //make sure names don't collide
- if err := requireUniqueStructFieldNames(arguments); err != nil {
- return err
- }
-
- return unpackStruct(elem, reflectValue, arguments[0])
- }
-
- return set(elem, reflectValue, arguments.NonIndexed()[0])
-
-}
-
-// Computes the full size of an array;
-// i.e. counting nested arrays, which count towards size for unpacking.
-func getArraySize(arr *Type) int {
- size := arr.Size
- // Arrays can be nested, with each element being the same size
- arr = arr.Elem
- for arr.T == ArrayTy {
- // Keep multiplying by elem.Size while the elem is an array.
- size *= arr.Size
- arr = arr.Elem
- }
- // Now we have the full array size, including its children.
- return size
-}
-
// UnpackValues can be used to unpack ABI-encoded hexdata according to the ABI-specification,
// without supplying a struct to unpack into. Instead, this method returns a list containing the
// values. An atomic argument will be a list with one element.
func (arguments Arguments) UnpackValues(data []byte) ([]interface{}, error) {
- retval := make([]interface{}, 0, arguments.LengthNonIndexed())
+ nonIndexedArgs := arguments.NonIndexed()
+ retval := make([]interface{}, 0, len(nonIndexedArgs))
virtualArgs := 0
- for index, arg := range arguments.NonIndexed() {
+ for index, arg := range nonIndexedArgs {
marshalledValue, err := toGoType((index+virtualArgs)*32, arg.Type, data)
- if arg.Type.T == ArrayTy {
+ if err != nil {
+ return nil, err
+ }
+ if arg.Type.T == ArrayTy && !isDynamicType(arg.Type) {
// If we have a static array, like [3]uint256, these are coded as
// just like uint256,uint256,uint256.
// This means that we need to add two 'virtual' arguments when
@@ -202,28 +201,29 @@ func (arguments Arguments) UnpackValues(data []byte) ([]interface{}, error) {
//
// Calculate the full array size to get the correct offset for the next argument.
// Decrement it by 1, as the normal index increment is still applied.
- virtualArgs += getArraySize(&arg.Type) - 1
- }
- if err != nil {
- return nil, err
+ virtualArgs += getTypeSize(arg.Type)/32 - 1
+ } else if arg.Type.T == TupleTy && !isDynamicType(arg.Type) {
+ // If we have a static tuple, like (uint256, bool, uint256), these are
+ // coded as just like uint256,bool,uint256
+ virtualArgs += getTypeSize(arg.Type)/32 - 1
}
retval = append(retval, marshalledValue)
}
return retval, nil
}
-// PackValues performs the operation Go format -> Hexdata
-// It is the semantic opposite of UnpackValues
+// PackValues performs the operation Go format -> Hexdata.
+// It is the semantic opposite of UnpackValues.
func (arguments Arguments) PackValues(args []interface{}) ([]byte, error) {
return arguments.Pack(args...)
}
-// Pack performs the operation Go format -> Hexdata
+// Pack performs the operation Go format -> Hexdata.
func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) {
// Make sure arguments match up and pack them
abiArgs := arguments
if len(args) != len(abiArgs) {
- return nil, fmt.Errorf("argument count mismatch: %d for %d", len(args), len(abiArgs))
+ return nil, fmt.Errorf("argument count mismatch: got %d for %d", len(args), len(abiArgs))
}
// variable input is the output appended at the end of packed
// output. This is used for strings and bytes types input.
@@ -232,11 +232,7 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) {
// input offset is the bytes offset for packed output
inputOffset := 0
for _, abiArg := range abiArgs {
- if abiArg.Type.T == ArrayTy {
- inputOffset += 32 * abiArg.Type.Size
- } else {
- inputOffset += 32
- }
+ inputOffset += getTypeSize(abiArg.Type)
}
var ret []byte
for i, a := range args {
@@ -246,14 +242,13 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) {
if err != nil {
return nil, err
}
- // check for a slice type (string, bytes, slice)
- if input.Type.requiresLengthPrefix() {
- // calculate the offset
- offset := inputOffset + len(variableInput)
+ // check for dynamic types
+ if isDynamicType(input.Type) {
// set the offset
- ret = append(ret, packNum(reflect.ValueOf(offset))...)
- // Append the packed output to the variable input. The variable input
- // will be appended at the end of the input.
+ ret = append(ret, packNum(reflect.ValueOf(inputOffset))...)
+ // calculate next offset
+ inputOffset += len(packed)
+ // append to variable input
variableInput = append(variableInput, packed...)
} else {
// append the packed value to the input
@@ -266,29 +261,13 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) {
return ret, nil
}
-// capitalise makes the first character of a string upper case, also removing any
-// prefixing underscores from the variable names.
-func capitalise(input string) string {
- for len(input) > 0 && input[0] == '_' {
- input = input[1:]
- }
- if len(input) == 0 {
- return ""
- }
- return strings.ToUpper(input[:1]) + input[1:]
-}
-
-//unpackStruct extracts each argument into its corresponding struct field
-func unpackStruct(value, reflectValue reflect.Value, arg Argument) error {
- name := capitalise(arg.Name)
- typ := value.Type()
- for j := 0; j < typ.NumField(); j++ {
- // TODO read tags: `abi:"fieldName"`
- if typ.Field(j).Name == name {
- if err := set(value.Field(j), reflectValue, arg); err != nil {
- return err
- }
+// ToCamelCase converts an under-score string to a camel-case string
+func ToCamelCase(input string) string {
+ parts := strings.Split(input, "_")
+ for i, s := range parts {
+ if len(s) > 0 {
+ parts[i] = strings.ToUpper(s[:1]) + s[1:]
}
}
- return nil
+ return strings.Join(parts, "")
}
diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go
index 7411f492a..21ea9ec02 100644
--- a/accounts/abi/bind/backends/simulated.go
+++ b/accounts/abi/bind/backends/simulated.go
@@ -20,19 +20,21 @@ import (
"context"
"errors"
"fmt"
- "github.com/tomochain/tomochain/consensus"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"sync"
"time"
"github.com/tomochain/tomochain"
+ "github.com/tomochain/tomochain/accounts/abi"
"github.com/tomochain/tomochain/accounts/abi/bind"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/common/math"
+ "github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
@@ -107,7 +109,7 @@ func (b *SimulatedBackend) rollback() {
statedb, _ := b.blockchain.State()
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil)
}
// CodeAt returns the code associated with a certain account in the blockchain.
@@ -174,7 +176,7 @@ func (b *SimulatedBackend) ForEachStorageAt(ctx context.Context, contract common
// TransactionReceipt returns the receipt of a transaction.
func (b *SimulatedBackend) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) {
- receipt, _, _, _ := core.GetReceipt(b.database, txHash)
+ receipt, _, _, _ := rawdb.GetReceipt(b.database, txHash, b.config)
return receipt, nil
}
@@ -186,6 +188,36 @@ func (b *SimulatedBackend) PendingCodeAt(ctx context.Context, contract common.Ad
return b.pendingState.GetCode(contract), nil
}
+func newRevertError(result *core.ExecutionResult) *revertError {
+ reason, errUnpack := abi.UnpackRevert(result.Revert())
+ err := errors.New("execution reverted")
+ if errUnpack == nil {
+ err = fmt.Errorf("execution reverted: %v", reason)
+ }
+ return &revertError{
+ error: err,
+ reason: hexutil.Encode(result.Revert()),
+ }
+}
+
+// revertError is an API error that encompassas an EVM revertal with JSON error
+// code and a binary data blob.
+type revertError struct {
+ error
+ reason string // revert reason hex encoded
+}
+
+// ErrorCode returns the JSON error code for a revertal.
+// See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal
+func (e *revertError) ErrorCode() int {
+ return 3
+}
+
+// ErrorData returns the hex encoded revert reason.
+func (e *revertError) ErrorData() interface{} {
+ return e.reason
+}
+
// CallContract executes a contract call.
func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.CallMsg, blockNumber *big.Int) ([]byte, error) {
b.mu.Lock()
@@ -198,11 +230,19 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call
if err != nil {
return nil, err
}
- rval, _, _, err := b.callContract(ctx, call, b.blockchain.CurrentBlock(), state)
- return rval, err
+ res, err := b.callContract(ctx, call, b.blockchain.CurrentBlock(), state)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(res.Revert()) > 0 {
+ return nil, newRevertError(res)
+ }
+
+ return res.Return(), res.Err
}
-//FIXME: please use copyState for this function
+// FIXME: please use copyState for this function
// CallContractWithState executes a contract call at the given state.
func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) {
// Ensure message is initialized properly.
@@ -215,11 +255,19 @@ func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain c
call.Value = new(big.Int)
}
// Execute the call.
- msg := callmsg{call}
+ msg := &core.Message{
+ To: call.To,
+ From: call.From,
+ Value: call.Value,
+ GasLimit: call.Gas,
+ GasPrice: call.GasPrice,
+ Data: call.Data,
+ SkipAccountChecks: false,
+ }
feeCapacity := state.GetTRC21FeeCapacityFromState(statedb)
- if msg.To() != nil {
- if value, ok := feeCapacity[*msg.To()]; ok {
- msg.CallMsg.BalanceTokenFee = value
+ if msg.To != nil {
+ if value, ok := feeCapacity[*msg.To]; ok {
+ msg.BalanceTokenFee = value
}
}
evmContext := core.NewEVMContext(msg, chain.CurrentHeader(), chain, nil)
@@ -228,11 +276,11 @@ func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain c
vmenv := vm.NewEVM(evmContext, statedb, nil, chain.Config(), vm.Config{})
gaspool := new(core.GasPool).AddGas(1000000)
owner := common.Address{}
- rval, _, _, err := core.NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner)
+ result, err := core.NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner)
if err != nil {
return nil, err
}
- return rval, err
+ return result.Return(), nil
}
// PendingCallContract executes a contract call on the pending state.
@@ -241,8 +289,15 @@ func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call tomocha
defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
- rval, _, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
- return rval, err
+ res, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
+ if err != nil {
+ return nil, err
+ }
+ if len(res.Revert()) > 0 {
+ return nil, newRevertError(res)
+ }
+
+ return res.Return(), res.Err
}
// PendingNonceAt implements PendingStateReader.PendingNonceAt, retrieving
@@ -280,23 +335,32 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM
cap = hi
// Create a helper to check if a gas allowance results in an executable transaction
- executable := func(gas uint64) bool {
+ executable := func(gas uint64) (bool, *core.ExecutionResult, error) {
call.Gas = gas
snapshot := b.pendingState.Snapshot()
- _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
- fmt.Println("EstimateGas",err,failed)
+ res, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
b.pendingState.RevertToSnapshot(snapshot)
- if err != nil || failed {
- return false
+ if err != nil {
+ if err == core.ErrIntrinsicGas {
+ return true, nil, nil // Special case, raise gas limit
+ }
+ return true, nil, err
}
- return true
+ return res.Failed(), res, nil
}
// Execute the binary search and hone in on an executable gas limit
for lo+1 < hi {
mid := (hi + lo) / 2
- if !executable(mid) {
+ failed, _, err := executable(mid)
+ // If the error is not nil(consensus error), it means the provided message
+ // call or transaction will never be accepted no matter how much gas it is
+ // assigned. Return the error directly, don't struggle any more
+ if err != nil {
+ return 0, err
+ }
+ if failed {
lo = mid
} else {
hi = mid
@@ -304,8 +368,21 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM
}
// Reject the transaction as invalid if it still fails at the highest allowance
if hi == cap {
- if !executable(hi) {
- return 0, errGasEstimationFailed
+ failed, result, err := executable(hi)
+ if err != nil {
+ return 0, err
+ }
+ if failed {
+ if result != nil && result.Err != vm.ErrOutOfGas {
+
+ if len(result.Revert()) > 0 {
+ return 0, newRevertError(result)
+ }
+ return 0, result.Err
+ }
+
+ // Otherwise, the specified gas cap is too low
+ return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap)
}
}
return hi, nil
@@ -313,7 +390,7 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM
// callContract implements common code between normal and pending contract calls.
// state is modified during execution, make sure to copy it if necessary.
-func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.CallMsg, block *types.Block, statedb *state.StateDB) ([]byte, uint64, bool, error) {
+func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.CallMsg, block *types.Block, statedb *state.StateDB) (*core.ExecutionResult, error) {
// Ensure message is initialized properly.
if call.GasPrice == nil {
call.GasPrice = big.NewInt(1)
@@ -328,11 +405,19 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.Call
from := statedb.GetOrNewStateObject(call.From)
from.SetBalance(math.MaxBig256)
// Execute the call.
- msg := callmsg{call}
+ msg := &core.Message{
+ To: call.To,
+ From: call.From,
+ Value: call.Value,
+ GasLimit: call.Gas,
+ GasPrice: call.GasPrice,
+ Data: call.Data,
+ SkipAccountChecks: true,
+ }
feeCapacity := state.GetTRC21FeeCapacityFromState(statedb)
- if msg.To() != nil {
- if value, ok := feeCapacity[*msg.To()]; ok {
- msg.CallMsg.BalanceTokenFee = value
+ if msg.To != nil {
+ if value, ok := feeCapacity[*msg.To]; ok {
+ msg.BalanceTokenFee = value
}
}
evmContext := core.NewEVMContext(msg, block.Header(), b.blockchain, nil)
@@ -368,7 +453,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
statedb, _ := b.blockchain.State()
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil)
return nil
}
@@ -447,7 +532,7 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error {
statedb, _ := b.blockchain.State()
b.pendingBlock = blocks[0]
- b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
+ b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil)
return nil
}
@@ -485,11 +570,11 @@ func (fb *filterBackend) HeaderByNumber(ctx context.Context, block rpc.BlockNumb
}
func (fb *filterBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) {
- return core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)), nil
+ return rawdb.GetBlockReceipts(fb.db, hash, rawdb.GetBlockNumber(fb.db, hash), fb.bc.Config()), nil
}
func (fb *filterBackend) GetLogs(ctx context.Context, hash common.Hash) ([][]*types.Log, error) {
- receipts := core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash))
+ receipts := rawdb.GetBlockReceipts(fb.db, hash, rawdb.GetBlockNumber(fb.db, hash), fb.bc.Config())
if receipts == nil {
return nil, nil
}
diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go
index caf164049..55c13e15c 100644
--- a/accounts/abi/bind/base.go
+++ b/accounts/abi/bind/base.go
@@ -21,6 +21,8 @@ import (
"errors"
"fmt"
"math/big"
+ "strings"
+ "sync"
"github.com/tomochain/tomochain"
"github.com/tomochain/tomochain/accounts/abi"
@@ -30,6 +32,11 @@ import (
"github.com/tomochain/tomochain/event"
)
+var (
+ errNoEventSignature = errors.New("no event signature")
+ errEventSignatureMismatch = errors.New("event signature mismatch")
+)
+
// SignerFn is a signer function callback when a contract requires a method to
// sign the transaction before submission.
type SignerFn func(types.Signer, common.Address, *types.Transaction) (*types.Transaction, error)
@@ -72,6 +79,29 @@ type WatchOpts struct {
Context context.Context // Network context to support cancellation and timeouts (nil = no timeout)
}
+// MetaData collects all metadata for a bound contract.
+type MetaData struct {
+ mu sync.Mutex
+ Sigs map[string]string
+ Bin string
+ ABI string
+ ab *abi.ABI
+}
+
+func (m *MetaData) GetAbi() (*abi.ABI, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.ab != nil {
+ return m.ab, nil
+ }
+ if parsed, err := abi.JSON(strings.NewReader(m.ABI)); err != nil {
+ return nil, err
+ } else {
+ m.ab = &parsed
+ }
+ return m.ab, nil
+}
+
// BoundContract is the base wrapper object that reflects a contract on the
// Ethereum network. It contains a collection of methods that are used by the
// higher level contract bindings to operate.
@@ -149,7 +179,10 @@ func (c *BoundContract) Call(opts *CallOpts, result interface{}, method string,
}
} else {
output, err = c.caller.CallContract(ctx, msg, nil)
- if err == nil && len(output) == 0 {
+ if err != nil {
+ return err
+ }
+ if len(output) == 0 {
// Make sure we have a contract to operate on, and bail out otherwise.
if code, err = c.caller.CodeAt(ctx, c.address, nil); err != nil {
return err
@@ -161,7 +194,7 @@ func (c *BoundContract) Call(opts *CallOpts, result interface{}, method string,
if err != nil {
return err
}
- return c.abi.Unpack(result, method, output)
+ return c.abi.UnpackIntoInterface(result, method, output)
}
// Transact invokes the (paid) contract method with params as input values.
@@ -252,7 +285,7 @@ func (c *BoundContract) FilterLogs(opts *FilterOpts, name string, query ...[]int
opts = new(FilterOpts)
}
// Append the event selector to the query parameters and construct the topic set
- query = append([][]interface{}{{c.abi.Events[name].Id()}}, query...)
+ query = append([][]interface{}{{c.abi.Events[name].ID}}, query...)
topics, err := makeTopics(query...)
if err != nil {
@@ -301,7 +334,7 @@ func (c *BoundContract) WatchLogs(opts *WatchOpts, name string, query ...[]inter
opts = new(WatchOpts)
}
// Append the event selector to the query parameters and construct the topic set
- query = append([][]interface{}{{c.abi.Events[name].Id()}}, query...)
+ query = append([][]interface{}{{c.abi.Events[name].ID}}, query...)
topics, err := makeTopics(query...)
if err != nil {
@@ -326,8 +359,15 @@ func (c *BoundContract) WatchLogs(opts *WatchOpts, name string, query ...[]inter
// UnpackLog unpacks a retrieved log into the provided output structure.
func (c *BoundContract) UnpackLog(out interface{}, event string, log types.Log) error {
+ // Anonymous events are not supported.
+ if len(log.Topics) == 0 {
+ return errNoEventSignature
+ }
+ if log.Topics[0] != c.abi.Events[event].ID {
+ return errEventSignatureMismatch
+ }
if len(log.Data) > 0 {
- if err := c.abi.Unpack(out, event, log.Data); err != nil {
+ if err := c.abi.UnpackIntoInterface(out, event, log.Data); err != nil {
return err
}
}
diff --git a/accounts/abi/bind/bind.go b/accounts/abi/bind/bind.go
index efb24e4d8..9b73e7ef2 100644
--- a/accounts/abi/bind/bind.go
+++ b/accounts/abi/bind/bind.go
@@ -89,7 +89,7 @@ func Bind(types []string, abis []string, bytecodes []string, pkg string, lang La
}
}
// Append the methods to the call or transact lists
- if original.Const {
+ if original.IsConstant() {
calls[original.Name] = &tmplMethod{Original: original, Normalized: normalized, Structured: structured(original.Outputs)}
} else {
transacts[original.Name] = &tmplMethod{Original: original, Normalized: normalized, Structured: structured(original.Outputs)}
@@ -166,9 +166,10 @@ var bindType = map[Lang]func(kind abi.Type) string{
// Helper function for the binding generators.
// It reads the unmatched characters after the inner type-match,
-// (since the inner type is a prefix of the total type declaration),
-// looks for valid arrays (possibly a dynamic one) wrapping the inner type,
-// and returns the sizes of these arrays.
+//
+// (since the inner type is a prefix of the total type declaration),
+// looks for valid arrays (possibly a dynamic one) wrapping the inner type,
+// and returns the sizes of these arrays.
//
// Returned array sizes are in the same order as solidity signatures; inner array size first.
// Array sizes may also be "", indicating a dynamic array.
diff --git a/accounts/abi/bind/template.go b/accounts/abi/bind/template.go
index f49b0efd1..43985bfe6 100644
--- a/accounts/abi/bind/template.go
+++ b/accounts/abi/bind/template.go
@@ -22,17 +22,24 @@ import "github.com/tomochain/tomochain/accounts/abi"
type tmplData struct {
Package string // Name of the package to place the generated file in
Contracts map[string]*tmplContract // List of contracts to generate into this file
+ Libraries map[string]string // Map the bytecode's link pattern to the library name
+ Structs map[string]*tmplStruct // Contract struct type definitions
}
// tmplContract contains the data needed to generate an individual contract binding.
type tmplContract struct {
Type string // Type name of the main contract binding
InputABI string // JSON ABI used as the input to generate the binding from
- InputBin string // Optional EVM bytecode used to denetare deploy code from
+ InputBin string // Optional EVM bytecode used to generate deploy code from
+ FuncSigs map[string]string // Optional map: string signature -> 4-byte signature
Constructor abi.Method // Contract constructor for deploy parametrization
Calls map[string]*tmplMethod // Contract calls that only read state data
Transacts map[string]*tmplMethod // Contract calls that write state data
+ Fallback *tmplMethod // Additional special fallback function
+ Receive *tmplMethod // Additional special receive function
Events map[string]*tmplEvent // Contract events accessors
+ Libraries map[string]string // Same as tmplData, but filtered to only keep what the contract needs
+ Library bool // Indicator whether the contract is a library
}
// tmplMethod is a wrapper around an abi.Method that contains a few preprocessed
@@ -43,42 +50,120 @@ type tmplMethod struct {
Structured bool // Whether the returns should be accumulated into a struct
}
-// tmplEvent is a wrapper around an a
+// tmplEvent is a wrapper around an abi.Event that contains a few preprocessed
+// and cached data fields.
type tmplEvent struct {
Original abi.Event // Original event as parsed by the abi package
Normalized abi.Event // Normalized version of the parsed fields
}
+// tmplField is a wrapper around a struct field with binding language
+// struct type definition and relative filed name.
+type tmplField struct {
+ Type string // Field type representation depends on target binding language
+ Name string // Field name converted from the raw user-defined field name
+ SolKind abi.Type // Raw abi type information
+}
+
+// tmplStruct is a wrapper around an abi.tuple and contains an auto-generated
+// struct name.
+type tmplStruct struct {
+ Name string // Auto-generated struct name(before solidity v0.5.11) or raw name.
+ Fields []*tmplField // Struct fields definition depends on the binding language.
+}
+
// tmplSource is language to template mapping containing all the supported
// programming languages the package can generate to.
var tmplSource = map[Lang]string{
- LangGo: tmplSourceGo,
- LangJava: tmplSourceJava,
+ LangGo: tmplSourceGo,
}
-// tmplSourceGo is the Go source template use to generate the contract binding
-// based on.
+// tmplSourceGo is the Go source template that the generated Go contract binding
+// is based on.
const tmplSourceGo = `
// Code generated - DO NOT EDIT.
// This file is a generated binding and any manual changes will be lost.
package {{.Package}}
+import (
+ "math/big"
+ "strings"
+ "errors"
+
+ "github.com/tomochain/tomochain"
+ "github.com/tomochain/tomochain/accounts/abi"
+ "github.com/tomochain/tomochain/accounts/abi/bind"
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/event"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var (
+ _ = errors.New
+ _ = big.NewInt
+ _ = strings.NewReader
+ _ = tomochain.NotFound
+ _ = bind.Bind
+ _ = common.Big1
+ _ = types.BloomLookup
+ _ = event.NewSubscription
+ _ = abi.ConvertType
+)
+
+{{$structs := .Structs}}
+{{range $structs}}
+ // {{.Name}} is an auto generated low-level Go binding around an user-defined struct.
+ type {{.Name}} struct {
+ {{range $field := .Fields}}
+ {{$field.Name}} {{$field.Type}}{{end}}
+ }
+{{end}}
+
{{range $contract := .Contracts}}
+ // {{.Type}}MetaData contains all meta data concerning the {{.Type}} contract.
+ var {{.Type}}MetaData = &bind.MetaData{
+ ABI: "{{.InputABI}}",
+ {{if $contract.FuncSigs -}}
+ Sigs: map[string]string{
+ {{range $strsig, $binsig := .FuncSigs}}"{{$binsig}}": "{{$strsig}}",
+ {{end}}
+ },
+ {{end -}}
+ {{if .InputBin -}}
+ Bin: "0x{{.InputBin}}",
+ {{end}}
+ }
// {{.Type}}ABI is the input ABI used to generate the binding from.
- const {{.Type}}ABI = "{{.InputABI}}"
+ // Deprecated: Use {{.Type}}MetaData.ABI instead.
+ var {{.Type}}ABI = {{.Type}}MetaData.ABI
+
+ {{if $contract.FuncSigs}}
+ // Deprecated: Use {{.Type}}MetaData.Sigs instead.
+ // {{.Type}}FuncSigs maps the 4-byte function signature to its string representation.
+ var {{.Type}}FuncSigs = {{.Type}}MetaData.Sigs
+ {{end}}
{{if .InputBin}}
// {{.Type}}Bin is the compiled bytecode used for deploying new contracts.
- const {{.Type}}Bin = ` + "`" + `{{.InputBin}}` + "`" + `
+ // Deprecated: Use {{.Type}}MetaData.Bin instead.
+ var {{.Type}}Bin = {{.Type}}MetaData.Bin
- // Deploy{{.Type}} deploys a new Ethereum contract, binding an instance of {{.Type}} to it.
- func Deploy{{.Type}}(auth *bind.TransactOpts, backend bind.ContractBackend {{range .Constructor.Inputs}}, {{.Name}} {{bindtype .Type}}{{end}}) (common.Address, *types.Transaction, *{{.Type}}, error) {
- parsed, err := abi.JSON(strings.NewReader({{.Type}}ABI))
+ // Deploy{{.Type}} deploys a new Tomochain contract, binding an instance of {{.Type}} to it.
+ func Deploy{{.Type}}(auth *bind.TransactOpts, backend bind.ContractBackend {{range .Constructor.Inputs}}, {{.Name}} {{bindtype .Type $structs}}{{end}}) (common.Address, *types.Transaction, *{{.Type}}, error) {
+ parsed, err := {{.Type}}MetaData.GetAbi()
if err != nil {
return common.Address{}, nil, nil, err
}
- address, tx, contract, err := bind.DeployContract(auth, parsed, common.FromHex({{.Type}}Bin), backend {{range .Constructor.Inputs}}, {{.Name}}{{end}})
+ if parsed == nil {
+ return common.Address{}, nil, nil, errors.New("GetABI returned nil")
+ }
+ {{range $pattern, $name := .Libraries}}
+ {{decapitalise $name}}Addr, _, _, _ := Deploy{{capitalise $name}}(auth, backend)
+ {{$contract.Type}}Bin = strings.ReplaceAll({{$contract.Type}}Bin, "__${{$pattern}}$__", {{decapitalise $name}}Addr.String()[2:])
+ {{end}}
+ address, tx, contract, err := bind.DeployContract(auth, *parsed, common.FromHex({{.Type}}Bin), backend {{range .Constructor.Inputs}}, {{.Name}}{{end}})
if err != nil {
return common.Address{}, nil, nil, err
}
@@ -86,29 +171,29 @@ package {{.Package}}
}
{{end}}
- // {{.Type}} is an auto generated Go binding around an Ethereum contract.
+ // {{.Type}} is an auto generated Go binding around an Tomochain contract.
type {{.Type}} struct {
{{.Type}}Caller // Read-only binding to the contract
{{.Type}}Transactor // Write-only binding to the contract
- {{.Type}}Filterer // Log filterer for contract events
+ {{.Type}}Filterer // Log filterer for contract events
}
- // {{.Type}}Caller is an auto generated read-only Go binding around an Ethereum contract.
+ // {{.Type}}Caller is an auto generated read-only Go binding around an Tomochain contract.
type {{.Type}}Caller struct {
contract *bind.BoundContract // Generic contract wrapper for the low level calls
}
- // {{.Type}}Transactor is an auto generated write-only Go binding around an Ethereum contract.
+ // {{.Type}}Transactor is an auto generated write-only Go binding around an Tomochain contract.
type {{.Type}}Transactor struct {
contract *bind.BoundContract // Generic contract wrapper for the low level calls
}
- // {{.Type}}Filterer is an auto generated log filtering Go binding around an Ethereum contract events.
+ // {{.Type}}Filterer is an auto generated log filtering Go binding around an Tomochain contract events.
type {{.Type}}Filterer struct {
contract *bind.BoundContract // Generic contract wrapper for the low level calls
}
- // {{.Type}}Session is an auto generated Go binding around an Ethereum contract,
+ // {{.Type}}Session is an auto generated Go binding around an Tomochain contract,
// with pre-set call and transact options.
type {{.Type}}Session struct {
Contract *{{.Type}} // Generic contract binding to set the session for
@@ -116,31 +201,31 @@ package {{.Package}}
TransactOpts bind.TransactOpts // Transaction auth options to use throughout this session
}
- // {{.Type}}CallerSession is an auto generated read-only Go binding around an Ethereum contract,
+ // {{.Type}}CallerSession is an auto generated read-only Go binding around an Tomochain contract,
// with pre-set call options.
type {{.Type}}CallerSession struct {
Contract *{{.Type}}Caller // Generic contract caller binding to set the session for
CallOpts bind.CallOpts // Call options to use throughout this session
}
- // {{.Type}}TransactorSession is an auto generated write-only Go binding around an Ethereum contract,
+ // {{.Type}}TransactorSession is an auto generated write-only Go binding around an Tomochain contract,
// with pre-set transact options.
type {{.Type}}TransactorSession struct {
Contract *{{.Type}}Transactor // Generic contract transactor binding to set the session for
TransactOpts bind.TransactOpts // Transaction auth options to use throughout this session
}
- // {{.Type}}Raw is an auto generated low-level Go binding around an Ethereum contract.
+ // {{.Type}}Raw is an auto generated low-level Go binding around an Tomochain contract.
type {{.Type}}Raw struct {
Contract *{{.Type}} // Generic contract binding to access the raw methods on
}
- // {{.Type}}CallerRaw is an auto generated low-level read-only Go binding around an Ethereum contract.
+ // {{.Type}}CallerRaw is an auto generated low-level read-only Go binding around an Tomochain contract.
type {{.Type}}CallerRaw struct {
Contract *{{.Type}}Caller // Generic read-only contract binding to access the raw methods on
}
- // {{.Type}}TransactorRaw is an auto generated low-level write-only Go binding around an Ethereum contract.
+ // {{.Type}}TransactorRaw is an auto generated low-level write-only Go binding around an Tomochain contract.
type {{.Type}}TransactorRaw struct {
Contract *{{.Type}}Transactor // Generic write-only contract binding to access the raw methods on
}
@@ -183,18 +268,18 @@ package {{.Package}}
// bind{{.Type}} binds a generic wrapper to an already deployed contract.
func bind{{.Type}}(address common.Address, caller bind.ContractCaller, transactor bind.ContractTransactor, filterer bind.ContractFilterer) (*bind.BoundContract, error) {
- parsed, err := abi.JSON(strings.NewReader({{.Type}}ABI))
+ parsed, err := {{.Type}}MetaData.GetAbi()
if err != nil {
return nil, err
}
- return bind.NewBoundContract(address, parsed, caller, transactor, filterer), nil
+ return bind.NewBoundContract(address, *parsed, caller, transactor, filterer), nil
}
// Call invokes the (constant) contract method with params as input values and
// sets the output to result. The result type might be a single field for simple
// returns, a slice of interfaces for anonymous returns and a struct for named
// returns.
- func (_{{$contract.Type}} *{{$contract.Type}}Raw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error {
+ func (_{{$contract.Type}} *{{$contract.Type}}Raw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error {
return _{{$contract.Type}}.Contract.{{$contract.Type}}Caller.contract.Call(opts, result, method, params...)
}
@@ -213,7 +298,7 @@ package {{.Package}}
// sets the output to result. The result type might be a single field for simple
// returns, a slice of interfaces for anonymous returns and a struct for named
// returns.
- func (_{{$contract.Type}} *{{$contract.Type}}CallerRaw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error {
+ func (_{{$contract.Type}} *{{$contract.Type}}CallerRaw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error {
return _{{$contract.Type}}.Contract.contract.Call(opts, result, method, params...)
}
@@ -229,63 +314,116 @@ package {{.Package}}
}
{{range .Calls}}
- // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}.
+ // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}Caller) {{.Normalized.Name}}(opts *bind.CallOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} },{{else}}{{range .Normalized.Outputs}}{{bindtype .Type}},{{end}}{{end}} error) {
- {{if .Structured}}ret := new(struct{
- {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}}
- {{end}}
- }){{else}}var (
- {{range $i, $_ := .Normalized.Outputs}}ret{{$i}} = new({{bindtype .Type}})
- {{end}}
- ){{end}}
- out := {{if .Structured}}ret{{else}}{{if eq (len .Normalized.Outputs) 1}}ret0{{else}}&[]interface{}{
- {{range $i, $_ := .Normalized.Outputs}}ret{{$i}},
- {{end}}
- }{{end}}{{end}}
- err := _{{$contract.Type}}.contract.Call(opts, out, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}})
- return {{if .Structured}}*ret,{{else}}{{range $i, $_ := .Normalized.Outputs}}*ret{{$i}},{{end}}{{end}} err
+ func (_{{$contract.Type}} *{{$contract.Type}}Caller) {{.Normalized.Name}}(opts *bind.CallOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type $structs}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}};{{end}} },{{else}}{{range .Normalized.Outputs}}{{bindtype .Type $structs}},{{end}}{{end}} error) {
+ var out []interface{}
+ err := _{{$contract.Type}}.contract.Call(opts, &out, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}})
+ {{if .Structured}}
+ outstruct := new(struct{ {{range .Normalized.Outputs}} {{.Name}} {{bindtype .Type $structs}}; {{end}} })
+ if err != nil {
+ return *outstruct, err
+ }
+ {{range $i, $t := .Normalized.Outputs}}
+ outstruct.{{.Name}} = *abi.ConvertType(out[{{$i}}], new({{bindtype .Type $structs}})).(*{{bindtype .Type $structs}}){{end}}
+
+ return *outstruct, err
+ {{else}}
+ if err != nil {
+ return {{range $i, $_ := .Normalized.Outputs}}*new({{bindtype .Type $structs}}), {{end}} err
+ }
+ {{range $i, $t := .Normalized.Outputs}}
+ out{{$i}} := *abi.ConvertType(out[{{$i}}], new({{bindtype .Type $structs}})).(*{{bindtype .Type $structs}}){{end}}
+
+ return {{range $i, $t := .Normalized.Outputs}}out{{$i}}, {{end}} err
+ {{end}}
}
- // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}.
+ // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) {
+ func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type $structs}},{{end}} {{end}} error) {
return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.CallOpts {{range .Normalized.Inputs}}, {{.Name}}{{end}})
}
- // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}.
+ // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}CallerSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) {
+ func (_{{$contract.Type}} *{{$contract.Type}}CallerSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type $structs}},{{end}} {{end}} error) {
return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.CallOpts {{range .Normalized.Inputs}}, {{.Name}}{{end}})
}
{{end}}
{{range .Transacts}}
- // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}.
+ // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}Transactor) {{.Normalized.Name}}(opts *bind.TransactOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type}} {{end}}) (*types.Transaction, error) {
+ func (_{{$contract.Type}} *{{$contract.Type}}Transactor) {{.Normalized.Name}}(opts *bind.TransactOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type $structs}} {{end}}) (*types.Transaction, error) {
return _{{$contract.Type}}.contract.Transact(opts, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}})
}
- // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}.
+ // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) (*types.Transaction, error) {
+ func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) (*types.Transaction, error) {
return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.TransactOpts {{range $i, $_ := .Normalized.Inputs}}, {{.Name}}{{end}})
}
- // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}.
+ // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) (*types.Transaction, error) {
+ func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) (*types.Transaction, error) {
return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.TransactOpts {{range $i, $_ := .Normalized.Inputs}}, {{.Name}}{{end}})
}
{{end}}
+ {{if .Fallback}}
+ // Fallback is a paid mutator transaction binding the contract fallback function.
+ //
+ // Solidity: {{.Fallback.Original.String}}
+ func (_{{$contract.Type}} *{{$contract.Type}}Transactor) Fallback(opts *bind.TransactOpts, calldata []byte) (*types.Transaction, error) {
+ return _{{$contract.Type}}.contract.RawTransact(opts, calldata)
+ }
+
+ // Fallback is a paid mutator transaction binding the contract fallback function.
+ //
+ // Solidity: {{.Fallback.Original.String}}
+ func (_{{$contract.Type}} *{{$contract.Type}}Session) Fallback(calldata []byte) (*types.Transaction, error) {
+ return _{{$contract.Type}}.Contract.Fallback(&_{{$contract.Type}}.TransactOpts, calldata)
+ }
+
+ // Fallback is a paid mutator transaction binding the contract fallback function.
+ //
+ // Solidity: {{.Fallback.Original.String}}
+ func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) Fallback(calldata []byte) (*types.Transaction, error) {
+ return _{{$contract.Type}}.Contract.Fallback(&_{{$contract.Type}}.TransactOpts, calldata)
+ }
+ {{end}}
+
+ {{if .Receive}}
+ // Receive is a paid mutator transaction binding the contract receive function.
+ //
+ // Solidity: {{.Receive.Original.String}}
+ func (_{{$contract.Type}} *{{$contract.Type}}Transactor) Receive(opts *bind.TransactOpts) (*types.Transaction, error) {
+ return _{{$contract.Type}}.contract.RawTransact(opts, nil) // calldata is disallowed for receive function
+ }
+
+ // Receive is a paid mutator transaction binding the contract receive function.
+ //
+ // Solidity: {{.Receive.Original.String}}
+ func (_{{$contract.Type}} *{{$contract.Type}}Session) Receive() (*types.Transaction, error) {
+ return _{{$contract.Type}}.Contract.Receive(&_{{$contract.Type}}.TransactOpts)
+ }
+
+ // Receive is a paid mutator transaction binding the contract receive function.
+ //
+ // Solidity: {{.Receive.Original.String}}
+ func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) Receive() (*types.Transaction, error) {
+ return _{{$contract.Type}}.Contract.Receive(&_{{$contract.Type}}.TransactOpts)
+ }
+ {{end}}
+
{{range .Events}}
// {{$contract.Type}}{{.Normalized.Name}}Iterator is returned from Filter{{.Normalized.Name}} and is used to iterate over the raw logs and unpacked data for {{.Normalized.Name}} events raised by the {{$contract.Type}} contract.
type {{$contract.Type}}{{.Normalized.Name}}Iterator struct {
@@ -295,7 +433,7 @@ package {{.Package}}
event string // Event name to use for unpacking event data
logs chan types.Log // Log channel receiving the found contract events
- sub ethereum.Subscription // Subscription for errors, completion and termination
+ sub tomochain.Subscription // Subscription for errors, completion and termination
done bool // Whether the subscription completed delivering logs
fail error // Occurred error to stop iteration
}
@@ -353,14 +491,14 @@ package {{.Package}}
// {{$contract.Type}}{{.Normalized.Name}} represents a {{.Normalized.Name}} event raised by the {{$contract.Type}} contract.
type {{$contract.Type}}{{.Normalized.Name}} struct { {{range .Normalized.Inputs}}
- {{capitalise .Name}} {{if .Indexed}}{{bindtopictype .Type}}{{else}}{{bindtype .Type}}{{end}}; {{end}}
+ {{capitalise .Name}} {{if .Indexed}}{{bindtopictype .Type $structs}}{{else}}{{bindtype .Type $structs}}{{end}}; {{end}}
Raw types.Log // Blockchain specific contextual infos
}
- // Filter{{.Normalized.Name}} is a free log retrieval operation binding the contract event 0x{{printf "%x" .Original.Id}}.
+ // Filter{{.Normalized.Name}} is a free log retrieval operation binding the contract event 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Filter{{.Normalized.Name}}(opts *bind.FilterOpts{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type}}{{end}}{{end}}) (*{{$contract.Type}}{{.Normalized.Name}}Iterator, error) {
+ func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Filter{{.Normalized.Name}}(opts *bind.FilterOpts{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type $structs}}{{end}}{{end}}) (*{{$contract.Type}}{{.Normalized.Name}}Iterator, error) {
{{range .Normalized.Inputs}}
{{if .Indexed}}var {{.Name}}Rule []interface{}
for _, {{.Name}}Item := range {{.Name}} {
@@ -374,10 +512,10 @@ package {{.Package}}
return &{{$contract.Type}}{{.Normalized.Name}}Iterator{contract: _{{$contract.Type}}.contract, event: "{{.Original.Name}}", logs: logs, sub: sub}, nil
}
- // Watch{{.Normalized.Name}} is a free log subscription operation binding the contract event 0x{{printf "%x" .Original.Id}}.
+ // Watch{{.Normalized.Name}} is a free log subscription operation binding the contract event 0x{{printf "%x" .Original.ID}}.
//
// Solidity: {{.Original.String}}
- func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Watch{{.Normalized.Name}}(opts *bind.WatchOpts, sink chan<- *{{$contract.Type}}{{.Normalized.Name}}{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type}}{{end}}{{end}}) (event.Subscription, error) {
+ func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Watch{{.Normalized.Name}}(opts *bind.WatchOpts, sink chan<- *{{$contract.Type}}{{.Normalized.Name}}{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type $structs}}{{end}}{{end}}) (event.Subscription, error) {
{{range .Normalized.Inputs}}
{{if .Indexed}}var {{.Name}}Rule []interface{}
for _, {{.Name}}Item := range {{.Name}} {
@@ -415,108 +553,19 @@ package {{.Package}}
}
}), nil
}
- {{end}}
-{{end}}
-`
-
-// tmplSourceJava is the Java source template use to generate the contract binding
-// based on.
-const tmplSourceJava = `
-// This file is an automatically generated Java binding. Do not modify as any
-// change will likely be lost upon the next re-generation!
-package {{.Package}};
-
-import org.ethereum.geth.*;
-import org.ethereum.geth.internal.*;
-
-{{range $contract := .Contracts}}
- public class {{.Type}} {
- // ABI is the input ABI used to generate the binding from.
- public final static String ABI = "{{.InputABI}}";
-
- {{if .InputBin}}
- // BYTECODE is the compiled bytecode used for deploying new contracts.
- public final static byte[] BYTECODE = "{{.InputBin}}".getBytes();
-
- // deploy deploys a new Ethereum contract, binding an instance of {{.Type}} to it.
- public static {{.Type}} deploy(TransactOpts auth, EthereumClient client{{range .Constructor.Inputs}}, {{bindtype .Type}} {{.Name}}{{end}}) throws Exception {
- Interfaces args = Geth.newInterfaces({{(len .Constructor.Inputs)}});
- {{range $index, $element := .Constructor.Inputs}}
- args.set({{$index}}, Geth.newInterface()); args.get({{$index}}).set{{namedtype (bindtype .Type) .Type}}({{.Name}});
- {{end}}
- return new {{.Type}}(Geth.deployContract(auth, ABI, BYTECODE, client, args));
- }
-
- // Internal constructor used by contract deployment.
- private {{.Type}}(BoundContract deployment) {
- this.Address = deployment.getAddress();
- this.Deployer = deployment.getDeployer();
- this.Contract = deployment;
+ // Parse{{.Normalized.Name}} is a log parse operation binding the contract event 0x{{printf "%x" .Original.ID}}.
+ //
+ // Solidity: {{.Original.String}}
+ func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Parse{{.Normalized.Name}}(log types.Log) (*{{$contract.Type}}{{.Normalized.Name}}, error) {
+ event := new({{$contract.Type}}{{.Normalized.Name}})
+ if err := _{{$contract.Type}}.contract.UnpackLog(event, "{{.Original.Name}}", log); err != nil {
+ return nil, err
}
- {{end}}
-
- // Ethereum address where this contract is located at.
- public final Address Address;
-
- // Ethereum transaction in which this contract was deployed (if known!).
- public final Transaction Deployer;
-
- // Contract instance bound to a blockchain address.
- private final BoundContract Contract;
-
- // Creates a new instance of {{.Type}}, bound to a specific deployed contract.
- public {{.Type}}(Address address, EthereumClient client) throws Exception {
- this(Geth.bindContract(address, ABI, client));
+ event.Raw = log
+ return event, nil
}
- {{range .Calls}}
- {{if gt (len .Normalized.Outputs) 1}}
- // {{capitalise .Normalized.Name}}Results is the output of a call to {{.Normalized.Name}}.
- public class {{capitalise .Normalized.Name}}Results {
- {{range $index, $item := .Normalized.Outputs}}public {{bindtype .Type}} {{if ne .Name ""}}{{.Name}}{{else}}Return{{$index}}{{end}};
- {{end}}
- }
- {{end}}
-
- // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}.
- //
- // Solidity: {{.Original.String}}
- public {{if gt (len .Normalized.Outputs) 1}}{{capitalise .Normalized.Name}}Results{{else}}{{range .Normalized.Outputs}}{{bindtype .Type}}{{end}}{{end}} {{.Normalized.Name}}(CallOpts opts{{range .Normalized.Inputs}}, {{bindtype .Type}} {{.Name}}{{end}}) throws Exception {
- Interfaces args = Geth.newInterfaces({{(len .Normalized.Inputs)}});
- {{range $index, $item := .Normalized.Inputs}}args.set({{$index}}, Geth.newInterface()); args.get({{$index}}).set{{namedtype (bindtype .Type) .Type}}({{.Name}});
- {{end}}
-
- Interfaces results = Geth.newInterfaces({{(len .Normalized.Outputs)}});
- {{range $index, $item := .Normalized.Outputs}}Interface result{{$index}} = Geth.newInterface(); result{{$index}}.setDefault{{namedtype (bindtype .Type) .Type}}(); results.set({{$index}}, result{{$index}});
- {{end}}
-
- if (opts == null) {
- opts = Geth.newCallOpts();
- }
- this.Contract.call(opts, results, "{{.Original.Name}}", args);
- {{if gt (len .Normalized.Outputs) 1}}
- {{capitalise .Normalized.Name}}Results result = new {{capitalise .Normalized.Name}}Results();
- {{range $index, $item := .Normalized.Outputs}}result.{{if ne .Name ""}}{{.Name}}{{else}}Return{{$index}}{{end}} = results.get({{$index}}).get{{namedtype (bindtype .Type) .Type}}();
- {{end}}
- return result;
- {{else}}{{range .Normalized.Outputs}}return results.get(0).get{{namedtype (bindtype .Type) .Type}}();{{end}}
- {{end}}
- }
- {{end}}
-
- {{range .Transacts}}
- // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}.
- //
- // Solidity: {{.Original.String}}
- public Transaction {{.Normalized.Name}}(TransactOpts opts{{range .Normalized.Inputs}}, {{bindtype .Type}} {{.Name}}{{end}}) throws Exception {
- Interfaces args = Geth.newInterfaces({{(len .Normalized.Inputs)}});
- {{range $index, $item := .Normalized.Inputs}}args.set({{$index}}, Geth.newInterface()); args.get({{$index}}).set{{namedtype (bindtype .Type) .Type}}({{.Name}});
- {{end}}
-
- return this.Contract.transact(opts, "{{.Original.Name}}" , args);
- }
- {{end}}
- }
+ {{end}}
{{end}}
`
diff --git a/accounts/abi/error.go b/accounts/abi/error.go
index 9d8674ad0..f0f71b6c9 100644
--- a/accounts/abi/error.go
+++ b/accounts/abi/error.go
@@ -39,23 +39,21 @@ func formatSliceString(kind reflect.Kind, sliceSize int) string {
// type in t.
func sliceTypeCheck(t Type, val reflect.Value) error {
if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
- return typeErr(formatSliceString(t.Kind, t.Size), val.Type())
+ return typeErr(formatSliceString(t.GetType().Kind(), t.Size), val.Type())
}
if t.T == ArrayTy && val.Len() != t.Size {
- return typeErr(formatSliceString(t.Elem.Kind, t.Size), formatSliceString(val.Type().Elem().Kind(), val.Len()))
+ return typeErr(formatSliceString(t.Elem.GetType().Kind(), t.Size), formatSliceString(val.Type().Elem().Kind(), val.Len()))
}
- if t.Elem.T == SliceTy {
+ if t.Elem.T == SliceTy || t.Elem.T == ArrayTy {
if val.Len() > 0 {
return sliceTypeCheck(*t.Elem, val.Index(0))
}
- } else if t.Elem.T == ArrayTy {
- return sliceTypeCheck(*t.Elem, val.Index(0))
}
- if elemKind := val.Type().Elem().Kind(); elemKind != t.Elem.Kind {
- return typeErr(formatSliceString(t.Elem.Kind, t.Size), val.Type())
+ if val.Type().Elem().Kind() != t.Elem.GetType().Kind() {
+ return typeErr(formatSliceString(t.Elem.GetType().Kind(), t.Size), val.Type())
}
return nil
}
@@ -68,10 +66,10 @@ func typeCheck(t Type, value reflect.Value) error {
}
// Check base type validity. Element types will be checked later on.
- if t.Kind != value.Kind() {
- return typeErr(t.Kind, value.Kind())
+ if t.GetType().Kind() != value.Kind() {
+ return typeErr(t.GetType().Kind(), value.Kind())
} else if t.T == FixedBytesTy && t.Size != value.Len() {
- return typeErr(t.Type, value.Type())
+ return typeErr(t.GetType(), value.Type())
} else {
return nil
}
diff --git a/accounts/abi/event.go b/accounts/abi/event.go
index 082fd71ae..d427aac79 100644
--- a/accounts/abi/event.go
+++ b/accounts/abi/event.go
@@ -28,30 +28,76 @@ import (
// holds type information (inputs) about the yielded output. Anonymous events
// don't get the signature canonical representation as the first LOG topic.
type Event struct {
- Name string
+ // Name is the event name used for internal representation. It's derived from
+ // the raw name and a suffix will be added in the case of event overloading.
+ //
+ // e.g.
+ // These are two events that have the same name:
+ // * foo(int,int)
+ // * foo(uint,uint)
+ // The event name of the first one will be resolved as foo while the second one
+ // will be resolved as foo0.
+ Name string
+
+ // RawName is the raw event name parsed from ABI.
+ RawName string
Anonymous bool
Inputs Arguments
+ str string
+
+ // Sig contains the string signature according to the ABI spec.
+ // e.g. event foo(uint32 a, int b) = "foo(uint32,int256)"
+ // Please note that "int" is substitute for its canonical representation "int256"
+ Sig string
+
+ // ID returns the canonical representation of the event's signature used by the
+ // abi definition to identify event names and types.
+ ID common.Hash
}
-func (event Event) String() string {
- inputs := make([]string, len(event.Inputs))
- for i, input := range event.Inputs {
- inputs[i] = fmt.Sprintf("%v %v", input.Name, input.Type)
+// NewEvent creates a new Event.
+// It sanitizes the input arguments to remove unnamed arguments.
+// It also precomputes the id, signature and string representation
+// of the event.
+func NewEvent(name, rawName string, anonymous bool, inputs Arguments) Event {
+ // sanitize inputs to remove inputs without names
+ // and precompute string and sig representation.
+ names := make([]string, len(inputs))
+ types := make([]string, len(inputs))
+ for i, input := range inputs {
+ if input.Name == "" {
+ inputs[i] = Argument{
+ Name: fmt.Sprintf("arg%d", i),
+ Indexed: input.Indexed,
+ Type: input.Type,
+ }
+ } else {
+ inputs[i] = input
+ }
+ // string representation
+ names[i] = fmt.Sprintf("%v %v", input.Type, inputs[i].Name)
if input.Indexed {
- inputs[i] = fmt.Sprintf("%v indexed %v", input.Name, input.Type)
+ names[i] = fmt.Sprintf("%v indexed %v", input.Type, inputs[i].Name)
}
+ // sig representation
+ types[i] = input.Type.String()
}
- return fmt.Sprintf("event %v(%v)", event.Name, strings.Join(inputs, ", "))
-}
-// Id returns the canonical representation of the event's signature used by the
-// abi definition to identify event names and types.
-func (e Event) Id() common.Hash {
- types := make([]string, len(e.Inputs))
- i := 0
- for _, input := range e.Inputs {
- types[i] = input.Type.String()
- i++
+ str := fmt.Sprintf("event %v(%v)", rawName, strings.Join(names, ", "))
+ sig := fmt.Sprintf("%v(%v)", rawName, strings.Join(types, ","))
+ id := common.BytesToHash(crypto.Keccak256([]byte(sig)))
+
+ return Event{
+ Name: name,
+ RawName: rawName,
+ Anonymous: anonymous,
+ Inputs: inputs,
+ str: str,
+ Sig: sig,
+ ID: id,
}
- return common.BytesToHash(crypto.Keccak256([]byte(fmt.Sprintf("%v(%v)", e.Name, strings.Join(types, ",")))))
+}
+
+func (e Event) String() string {
+ return e.str
}
diff --git a/accounts/abi/event_test.go b/accounts/abi/event_test.go
index c39411d8f..3a39059a4 100644
--- a/accounts/abi/event_test.go
+++ b/accounts/abi/event_test.go
@@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/crypto"
)
@@ -58,12 +59,28 @@ var jsonEventPledge = []byte(`{
"type": "event"
}`)
+var jsonEventMixedCase = []byte(`{
+ "anonymous": false,
+ "inputs": [{
+ "indexed": false, "name": "value", "type": "uint256"
+ }, {
+ "indexed": false, "name": "_value", "type": "uint256"
+ }, {
+ "indexed": false, "name": "Value", "type": "uint256"
+ }],
+ "name": "MixedCase",
+ "type": "event"
+ }`)
+
// 1000000
var transferData1 = "00000000000000000000000000000000000000000000000000000000000f4240"
// "0x00Ce0d46d924CC8437c806721496599FC3FFA268", 2218516807680, "usd"
var pledgeData1 = "00000000000000000000000000ce0d46d924cc8437c806721496599fc3ffa2680000000000000000000000000000000000000000000000000000020489e800007573640000000000000000000000000000000000000000000000000000000000"
+// 1000000,2218516807680,1000001
+var mixedCaseData1 = "00000000000000000000000000000000000000000000000000000000000f42400000000000000000000000000000000000000000000000000000020489e8000000000000000000000000000000000000000000000000000000000000000f4241"
+
func TestEventId(t *testing.T) {
var table = []struct {
definition string
@@ -71,12 +88,45 @@ func TestEventId(t *testing.T) {
}{
{
definition: `[
- { "type" : "event", "name" : "balance", "inputs": [{ "name" : "in", "type": "uint256" }] },
- { "type" : "event", "name" : "check", "inputs": [{ "name" : "t", "type": "address" }, { "name": "b", "type": "uint256" }] }
+ { "type" : "event", "name" : "Balance", "inputs": [{ "name" : "in", "type": "uint256" }] },
+ { "type" : "event", "name" : "Check", "inputs": [{ "name" : "t", "type": "address" }, { "name": "b", "type": "uint256" }] }
]`,
expectations: map[string]common.Hash{
- "balance": crypto.Keccak256Hash([]byte("balance(uint256)")),
- "check": crypto.Keccak256Hash([]byte("check(address,uint256)")),
+ "Balance": crypto.Keccak256Hash([]byte("Balance(uint256)")),
+ "Check": crypto.Keccak256Hash([]byte("Check(address,uint256)")),
+ },
+ },
+ }
+
+ for _, test := range table {
+ abi, err := JSON(strings.NewReader(test.definition))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for name, event := range abi.Events {
+ if event.ID != test.expectations[name] {
+ t.Errorf("expected id to be %x, got %x", test.expectations[name], event.ID)
+ }
+ }
+ }
+}
+
+func TestEventString(t *testing.T) {
+ var table = []struct {
+ definition string
+ expectations map[string]string
+ }{
+ {
+ definition: `[
+ { "type" : "event", "name" : "Balance", "inputs": [{ "name" : "in", "type": "uint256" }] },
+ { "type" : "event", "name" : "Check", "inputs": [{ "name" : "t", "type": "address" }, { "name": "b", "type": "uint256" }] },
+ { "type" : "event", "name" : "Transfer", "inputs": [{ "name": "from", "type": "address", "indexed": true }, { "name": "to", "type": "address", "indexed": true }, { "name": "value", "type": "uint256" }] }
+ ]`,
+ expectations: map[string]string{
+ "Balance": "event Balance(uint256 in)",
+ "Check": "event Check(address t, uint256 b)",
+ "Transfer": "event Transfer(address indexed from, address indexed to, uint256 value)",
},
},
}
@@ -88,8 +138,8 @@ func TestEventId(t *testing.T) {
}
for name, event := range abi.Events {
- if event.Id() != test.expectations[name] {
- t.Errorf("expected id to be %x, got %x", test.expectations[name], event.Id())
+ if event.String() != test.expectations[name] {
+ t.Errorf("expected string to be %s, got %s", test.expectations[name], event.String())
}
}
}
@@ -98,10 +148,6 @@ func TestEventId(t *testing.T) {
// TestEventMultiValueWithArrayUnpack verifies that array fields will be counted after parsing array.
func TestEventMultiValueWithArrayUnpack(t *testing.T) {
definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": false, "name":"value1", "type":"uint8[2]"},{"indexed": false, "name":"value2", "type":"uint8"}]}]`
- type testStruct struct {
- Value1 [2]uint8
- Value2 uint8
- }
abi, err := JSON(strings.NewReader(definition))
require.NoError(t, err)
var b bytes.Buffer
@@ -109,10 +155,10 @@ func TestEventMultiValueWithArrayUnpack(t *testing.T) {
for ; i <= 3; i++ {
b.Write(packNum(reflect.ValueOf(i)))
}
- var rst testStruct
- require.NoError(t, abi.Unpack(&rst, "test", b.Bytes()))
- require.Equal(t, [2]uint8{1, 2}, rst.Value1)
- require.Equal(t, uint8(3), rst.Value2)
+ unpacked, err := abi.Unpack("test", b.Bytes())
+ require.NoError(t, err)
+ require.Equal(t, [2]uint8{1, 2}, unpacked[0])
+ require.Equal(t, uint8(3), unpacked[1])
}
func TestEventTupleUnpack(t *testing.T) {
@@ -121,6 +167,27 @@ func TestEventTupleUnpack(t *testing.T) {
Value *big.Int
}
+ type EventTransferWithTag struct {
+ // this is valid because `value` is not exportable,
+ // so value is only unmarshalled into `Value1`.
+ value *big.Int //lint:ignore U1000 unused field is part of test
+ Value1 *big.Int `abi:"value"`
+ }
+
+ type BadEventTransferWithSameFieldAndTag struct {
+ Value *big.Int
+ Value1 *big.Int `abi:"value"`
+ }
+
+ type BadEventTransferWithDuplicatedTag struct {
+ Value1 *big.Int `abi:"value"`
+ Value2 *big.Int `abi:"value"`
+ }
+
+ type BadEventTransferWithEmptyTag struct {
+ Value *big.Int `abi:""`
+ }
+
type EventPledge struct {
Who common.Address
Wad *big.Int
@@ -133,9 +200,16 @@ func TestEventTupleUnpack(t *testing.T) {
Currency [3]byte
}
+ type EventMixedCase struct {
+ Value1 *big.Int `abi:"value"`
+ Value2 *big.Int `abi:"_value"`
+ Value3 *big.Int `abi:"Value"`
+ }
+
bigint := new(big.Int)
bigintExpected := big.NewInt(1000000)
bigintExpected2 := big.NewInt(2218516807680)
+ bigintExpected3 := big.NewInt(1000001)
addr := common.HexToAddress("0x00Ce0d46d924CC8437c806721496599FC3FFA268")
var testCases = []struct {
data string
@@ -158,6 +232,34 @@ func TestEventTupleUnpack(t *testing.T) {
jsonEventTransfer,
"",
"Can unpack ERC20 Transfer event into slice",
+ }, {
+ transferData1,
+ &EventTransferWithTag{},
+ &EventTransferWithTag{Value1: bigintExpected},
+ jsonEventTransfer,
+ "",
+ "Can unpack ERC20 Transfer event into structure with abi: tag",
+ }, {
+ transferData1,
+ &BadEventTransferWithDuplicatedTag{},
+ &BadEventTransferWithDuplicatedTag{},
+ jsonEventTransfer,
+ "struct: abi tag in 'Value2' already mapped",
+ "Can not unpack ERC20 Transfer event with duplicated abi tag",
+ }, {
+ transferData1,
+ &BadEventTransferWithSameFieldAndTag{},
+ &BadEventTransferWithSameFieldAndTag{},
+ jsonEventTransfer,
+ "abi: multiple variables maps to the same abi field 'value'",
+ "Can not unpack ERC20 Transfer event with a field and a tag mapping to the same abi variable",
+ }, {
+ transferData1,
+ &BadEventTransferWithEmptyTag{},
+ &BadEventTransferWithEmptyTag{},
+ jsonEventTransfer,
+ "struct: abi tag in 'Value' is empty",
+ "Can not unpack ERC20 Transfer event with an empty tag",
}, {
pledgeData1,
&EventPledge{},
@@ -207,15 +309,22 @@ func TestEventTupleUnpack(t *testing.T) {
&[]interface{}{common.Address{}, new(big.Int)},
&[]interface{}{},
jsonEventPledge,
- "abi: insufficient number of elements in the list/array for unpack, want 3, got 2",
+ "abi: insufficient number of arguments for unpack, want 3, got 2",
"Can not unpack Pledge event into too short slice",
}, {
pledgeData1,
new(map[string]interface{}),
&[]interface{}{},
jsonEventPledge,
- "abi: cannot unmarshal tuple into map[string]interface {}",
+ "abi:[2] cannot unmarshal tuple in to map[string]interface {}",
"Can not unpack Pledge event into map",
+ }, {
+ mixedCaseData1,
+ &EventMixedCase{},
+ &EventMixedCase{Value1: bigintExpected, Value2: bigintExpected2, Value3: bigintExpected3},
+ jsonEventMixedCase,
+ "",
+ "Can unpack abi variables with mixed case",
}}
for _, tc := range testCases {
@@ -227,7 +336,7 @@ func TestEventTupleUnpack(t *testing.T) {
assert.Nil(err, "Should be able to unpack event data.")
assert.Equal(tc.expected, tc.dest, tc.name)
} else {
- assert.EqualError(err, tc.error)
+ assert.EqualError(err, tc.error, tc.name)
}
})
}
@@ -239,48 +348,14 @@ func unpackTestEventData(dest interface{}, hexData string, jsonEvent []byte, ass
var e Event
assert.NoError(json.Unmarshal(jsonEvent, &e), "Should be able to unmarshal event ABI")
a := ABI{Events: map[string]Event{"e": e}}
- return a.Unpack(dest, "e", data)
-}
-
-/*
-Taken from
-https://github.com/tomochain/tomochain/pull/15568
-*/
-
-type testResult struct {
- Values [2]*big.Int
- Value1 *big.Int
- Value2 *big.Int
-}
-
-type testCase struct {
- definition string
- want testResult
-}
-
-func (tc testCase) encoded(intType, arrayType Type) []byte {
- var b bytes.Buffer
- if tc.want.Value1 != nil {
- val, _ := intType.pack(reflect.ValueOf(tc.want.Value1))
- b.Write(val)
- }
-
- if !reflect.DeepEqual(tc.want.Values, [2]*big.Int{nil, nil}) {
- val, _ := arrayType.pack(reflect.ValueOf(tc.want.Values))
- b.Write(val)
- }
- if tc.want.Value2 != nil {
- val, _ := intType.pack(reflect.ValueOf(tc.want.Value2))
- b.Write(val)
- }
- return b.Bytes()
+ return a.UnpackIntoInterface(dest, "e", data)
}
// TestEventUnpackIndexed verifies that indexed field will be skipped by event decoder.
func TestEventUnpackIndexed(t *testing.T) {
definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": true, "name":"value1", "type":"uint8"},{"indexed": false, "name":"value2", "type":"uint8"}]}]`
type testStruct struct {
- Value1 uint8
+ Value1 uint8 // indexed
Value2 uint8
}
abi, err := JSON(strings.NewReader(definition))
@@ -288,16 +363,16 @@ func TestEventUnpackIndexed(t *testing.T) {
var b bytes.Buffer
b.Write(packNum(reflect.ValueOf(uint8(8))))
var rst testStruct
- require.NoError(t, abi.Unpack(&rst, "test", b.Bytes()))
+ require.NoError(t, abi.UnpackIntoInterface(&rst, "test", b.Bytes()))
require.Equal(t, uint8(0), rst.Value1)
require.Equal(t, uint8(8), rst.Value2)
}
-// TestEventIndexedWithArrayUnpack verifies that decoder will not overlow when static array is indexed input.
+// TestEventIndexedWithArrayUnpack verifies that decoder will not overflow when static array is indexed input.
func TestEventIndexedWithArrayUnpack(t *testing.T) {
definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": true, "name":"value1", "type":"uint8[2]"},{"indexed": false, "name":"value2", "type":"string"}]}]`
type testStruct struct {
- Value1 [2]uint8
+ Value1 [2]uint8 // indexed
Value2 string
}
abi, err := JSON(strings.NewReader(definition))
@@ -310,7 +385,7 @@ func TestEventIndexedWithArrayUnpack(t *testing.T) {
b.Write(common.RightPadBytes([]byte(stringOut), 32))
var rst testStruct
- require.NoError(t, abi.Unpack(&rst, "test", b.Bytes()))
+ require.NoError(t, abi.UnpackIntoInterface(&rst, "test", b.Bytes()))
require.Equal(t, [2]uint8{0, 0}, rst.Value1)
require.Equal(t, stringOut, rst.Value2)
}
diff --git a/accounts/abi/method.go b/accounts/abi/method.go
index 57a2f0e0a..e2ca38420 100644
--- a/accounts/abi/method.go
+++ b/accounts/abi/method.go
@@ -23,57 +23,146 @@ import (
"github.com/tomochain/tomochain/crypto"
)
+// FunctionType represents different types of functions a contract might have.
+type FunctionType int
+
+const (
+ // Constructor represents the constructor of the contract.
+ // The constructor function is called while deploying a contract.
+ Constructor FunctionType = iota
+ // Fallback represents the fallback function.
+ // This function is executed if no other function matches the given function
+ // signature and no receive function is specified.
+ Fallback
+ // Receive represents the receive function.
+ // This function is executed on plain Ether transfers.
+ Receive
+ // Function represents a normal function.
+ Function
+)
+
// Method represents a callable given a `Name` and whether the method is a constant.
// If the method is `Const` no transaction needs to be created for this
// particular Method call. It can easily be simulated using a local VM.
// For example a `Balance()` method only needs to retrieve something
-// from the storage and therefor requires no Tx to be send to the
+// from the storage and therefore requires no Tx to be sent to the
// network. A method such as `Transact` does require a Tx and thus will
-// be flagged `true`.
+// be flagged `false`.
// Input specifies the required input parameters for this gives method.
type Method struct {
+ // Name is the method name used for internal representation. It's derived from
+ // the raw name and a suffix will be added in the case of a function overload.
+ //
+ // e.g.
+ // These are two functions that have the same name:
+ // * foo(int,int)
+ // * foo(uint,uint)
+ // The method name of the first one will be resolved as foo while the second one
+ // will be resolved as foo0.
Name string
- Const bool
+ RawName string // RawName is the raw method name parsed from ABI
+
+ // Type indicates whether the method is a
+ // special fallback introduced in solidity v0.6.0
+ Type FunctionType
+
+ // StateMutability indicates the mutability state of method,
+ // the default value is nonpayable. It can be empty if the abi
+ // is generated by legacy compiler.
+ StateMutability string
+
+ // Legacy indicators generated by compiler before v0.6.0
+ Constant bool
+ Payable bool
+
Inputs Arguments
Outputs Arguments
+ str string
+ // Sig returns the methods string signature according to the ABI spec.
+ // e.g. function foo(uint32 a, int b) = "foo(uint32,int256)"
+ // Please note that "int" is substitute for its canonical representation "int256"
+ Sig string
+ // ID returns the canonical representation of the method's signature used by the
+ // abi definition to identify method names and types.
+ ID []byte
}
-// Sig returns the methods string signature according to the ABI spec.
-//
-// Example
-//
-// function foo(uint32 a, int b) = "foo(uint32,int256)"
-//
-// Please note that "int" is substitute for its canonical representation "int256"
-func (method Method) Sig() string {
- types := make([]string, len(method.Inputs))
- i := 0
- for _, input := range method.Inputs {
+// NewMethod creates a new Method.
+// A method should always be created using NewMethod.
+// It also precomputes the sig representation and the string representation
+// of the method.
+func NewMethod(name string, rawName string, funType FunctionType, mutability string, isConst, isPayable bool, inputs Arguments, outputs Arguments) Method {
+ var (
+ types = make([]string, len(inputs))
+ inputNames = make([]string, len(inputs))
+ outputNames = make([]string, len(outputs))
+ )
+ for i, input := range inputs {
+ inputNames[i] = fmt.Sprintf("%v %v", input.Type, input.Name)
types[i] = input.Type.String()
- i++
- }
- return fmt.Sprintf("%v(%v)", method.Name, strings.Join(types, ","))
-}
-
-func (method Method) String() string {
- inputs := make([]string, len(method.Inputs))
- for i, input := range method.Inputs {
- inputs[i] = fmt.Sprintf("%v %v", input.Name, input.Type)
}
- outputs := make([]string, len(method.Outputs))
- for i, output := range method.Outputs {
+ for i, output := range outputs {
+ outputNames[i] = output.Type.String()
if len(output.Name) > 0 {
- outputs[i] = fmt.Sprintf("%v ", output.Name)
+ outputNames[i] += fmt.Sprintf(" %v", output.Name)
}
- outputs[i] += output.Type.String()
}
- constant := ""
- if method.Const {
- constant = "constant "
+ // calculate the signature and method id. Note only function
+ // has meaningful signature and id.
+ var (
+ sig string
+ id []byte
+ )
+ if funType == Function {
+ sig = fmt.Sprintf("%v(%v)", rawName, strings.Join(types, ","))
+ id = crypto.Keccak256([]byte(sig))[:4]
+ }
+ // Extract meaningful state mutability of solidity method.
+ // If it's default value, never print it.
+ state := mutability
+ if state == "nonpayable" {
+ state = ""
+ }
+ if state != "" {
+ state = state + " "
+ }
+ identity := fmt.Sprintf("function %v", rawName)
+ switch funType {
+ case Fallback:
+ identity = "fallback"
+ case Receive:
+ identity = "receive"
+ case Constructor:
+ identity = "constructor"
+ }
+ str := fmt.Sprintf("%v(%v) %sreturns(%v)", identity, strings.Join(inputNames, ", "), state, strings.Join(outputNames, ", "))
+
+ return Method{
+ Name: name,
+ RawName: rawName,
+ Type: funType,
+ StateMutability: mutability,
+ Constant: isConst,
+ Payable: isPayable,
+ Inputs: inputs,
+ Outputs: outputs,
+ str: str,
+ Sig: sig,
+ ID: id,
}
- return fmt.Sprintf("function %v(%v) %sreturns(%v)", method.Name, strings.Join(inputs, ", "), constant, strings.Join(outputs, ", "))
}
-func (method Method) Id() []byte {
- return crypto.Keccak256([]byte(method.Sig()))[:4]
+func (method Method) String() string {
+ return method.str
+}
+
+// IsConstant returns the indicator whether the method is read-only.
+func (method Method) IsConstant() bool {
+ return method.StateMutability == "view" || method.StateMutability == "pure" || method.Constant
+}
+
+// IsPayable returns the indicator whether the method can process
+// plain ether transfers.
+func (method Method) IsPayable() bool {
+ return method.StateMutability == "payable" || method.Payable
}
diff --git a/accounts/abi/numbers.go b/accounts/abi/numbers.go
index 3d541ee9a..491b94d34 100644
--- a/accounts/abi/numbers.go
+++ b/accounts/abi/numbers.go
@@ -25,35 +25,20 @@ import (
)
var (
- big_t = reflect.TypeOf(&big.Int{})
- derefbig_t = reflect.TypeOf(big.Int{})
- uint8_t = reflect.TypeOf(uint8(0))
- uint16_t = reflect.TypeOf(uint16(0))
- uint32_t = reflect.TypeOf(uint32(0))
- uint64_t = reflect.TypeOf(uint64(0))
- int_t = reflect.TypeOf(int(0))
- int8_t = reflect.TypeOf(int8(0))
- int16_t = reflect.TypeOf(int16(0))
- int32_t = reflect.TypeOf(int32(0))
- int64_t = reflect.TypeOf(int64(0))
- address_t = reflect.TypeOf(common.Address{})
- int_ts = reflect.TypeOf([]int(nil))
- int8_ts = reflect.TypeOf([]int8(nil))
- int16_ts = reflect.TypeOf([]int16(nil))
- int32_ts = reflect.TypeOf([]int32(nil))
- int64_ts = reflect.TypeOf([]int64(nil))
+ bigT = reflect.TypeOf(&big.Int{})
+ derefbigT = reflect.TypeOf(big.Int{})
+ uint8T = reflect.TypeOf(uint8(0))
+ uint16T = reflect.TypeOf(uint16(0))
+ uint32T = reflect.TypeOf(uint32(0))
+ uint64T = reflect.TypeOf(uint64(0))
+ int8T = reflect.TypeOf(int8(0))
+ int16T = reflect.TypeOf(int16(0))
+ int32T = reflect.TypeOf(int32(0))
+ int64T = reflect.TypeOf(int64(0))
+ addressT = reflect.TypeOf(common.Address{})
)
// U256 converts a big Int into a 256bit EVM number.
func U256(n *big.Int) []byte {
return math.PaddedBigBytes(math.U256(n), 32)
}
-
-// checks whether the given reflect value is signed. This also works for slices with a number type
-func isSigned(v reflect.Value) bool {
- switch v.Type() {
- case int_ts, int8_ts, int16_ts, int32_ts, int64_ts, int_t, int8_t, int16_t, int32_t, int64_t:
- return true
- }
- return false
-}
diff --git a/accounts/abi/numbers_test.go b/accounts/abi/numbers_test.go
index b9ff5aef1..d25a5abcb 100644
--- a/accounts/abi/numbers_test.go
+++ b/accounts/abi/numbers_test.go
@@ -19,7 +19,6 @@ package abi
import (
"bytes"
"math/big"
- "reflect"
"testing"
)
@@ -32,13 +31,3 @@ func TestNumberTypes(t *testing.T) {
t.Errorf("expected %x got %x", ubytes, unsigned)
}
}
-
-func TestSigned(t *testing.T) {
- if isSigned(reflect.ValueOf(uint(10))) {
- t.Error("signed")
- }
-
- if !isSigned(reflect.ValueOf(int(10))) {
- t.Error("not signed")
- }
-}
diff --git a/accounts/abi/pack.go b/accounts/abi/pack.go
index 7d422f579..5d8b86edb 100644
--- a/accounts/abi/pack.go
+++ b/accounts/abi/pack.go
@@ -17,6 +17,8 @@
package abi
import (
+ "errors"
+ "fmt"
"math/big"
"reflect"
@@ -25,7 +27,7 @@ import (
)
// packBytesSlice packs the given bytes as [L, V] as the canonical representation
-// bytes slice
+// bytes slice.
func packBytesSlice(bytes []byte, l int) []byte {
len := packNum(reflect.ValueOf(l))
return append(len, common.RightPadBytes(bytes, (l+31)/32*32)...)
@@ -33,49 +35,51 @@ func packBytesSlice(bytes []byte, l int) []byte {
// packElement packs the given reflect value according to the abi specification in
// t.
-func packElement(t Type, reflectValue reflect.Value) []byte {
+func packElement(t Type, reflectValue reflect.Value) ([]byte, error) {
switch t.T {
case IntTy, UintTy:
- return packNum(reflectValue)
+ return packNum(reflectValue), nil
case StringTy:
- return packBytesSlice([]byte(reflectValue.String()), reflectValue.Len())
+ return packBytesSlice([]byte(reflectValue.String()), reflectValue.Len()), nil
case AddressTy:
if reflectValue.Kind() == reflect.Array {
reflectValue = mustArrayToByteSlice(reflectValue)
}
- return common.LeftPadBytes(reflectValue.Bytes(), 32)
+ return common.LeftPadBytes(reflectValue.Bytes(), 32), nil
case BoolTy:
if reflectValue.Bool() {
- return math.PaddedBigBytes(common.Big1, 32)
+ return math.PaddedBigBytes(common.Big1, 32), nil
}
- return math.PaddedBigBytes(common.Big0, 32)
+ return math.PaddedBigBytes(common.Big0, 32), nil
case BytesTy:
if reflectValue.Kind() == reflect.Array {
reflectValue = mustArrayToByteSlice(reflectValue)
}
- return packBytesSlice(reflectValue.Bytes(), reflectValue.Len())
+ if reflectValue.Type() != reflect.TypeOf([]byte{}) {
+ return []byte{}, errors.New("Bytes type is neither slice nor array")
+ }
+ return packBytesSlice(reflectValue.Bytes(), reflectValue.Len()), nil
case FixedBytesTy, FunctionTy:
if reflectValue.Kind() == reflect.Array {
reflectValue = mustArrayToByteSlice(reflectValue)
}
- return common.RightPadBytes(reflectValue.Bytes(), 32)
+ return common.RightPadBytes(reflectValue.Bytes(), 32), nil
default:
- panic("abi: fatal error")
+ return []byte{}, fmt.Errorf("Could not pack element, unknown type: %v", t.T)
}
}
-// packNum packs the given number (using the reflect value) and will cast it to appropriate number representation
+// packNum packs the given number (using the reflected value) and will cast it to appropriate number representation.
func packNum(value reflect.Value) []byte {
switch kind := value.Kind(); kind {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- return U256(new(big.Int).SetUint64(value.Uint()))
+ return math.U256Bytes(new(big.Int).SetUint64(value.Uint()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return U256(big.NewInt(value.Int()))
+ return math.U256Bytes(big.NewInt(value.Int()))
case reflect.Ptr:
- return U256(value.Interface().(*big.Int))
+ return math.U256Bytes(new(big.Int).Set(value.Interface().(*big.Int)))
default:
panic("abi: fatal error")
}
-
}
diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go
index be48cb5b1..ed5585b11 100644
--- a/accounts/abi/pack_test.go
+++ b/accounts/abi/pack_test.go
@@ -18,336 +18,51 @@ package abi
import (
"bytes"
+ "encoding/hex"
+ "fmt"
"math"
"math/big"
"reflect"
+ "strconv"
"strings"
"testing"
"github.com/tomochain/tomochain/common"
)
+// TestPack tests the general pack/unpack tests in packing_test.go
func TestPack(t *testing.T) {
- for i, test := range []struct {
- typ string
-
- input interface{}
- output []byte
- }{
- {
- "uint8",
- uint8(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint8[]",
- []uint8{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint16",
- uint16(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint16[]",
- []uint16{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint32",
- uint32(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint32[]",
- []uint32{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint64",
- uint64(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint64[]",
- []uint64{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint256",
- big.NewInt(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "uint256[]",
- []*big.Int{big.NewInt(1), big.NewInt(2)},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int8",
- int8(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int8[]",
- []int8{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int16",
- int16(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int16[]",
- []int16{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int32",
- int32(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int32[]",
- []int32{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int64",
- int64(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int64[]",
- []int64{1, 2},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int256",
- big.NewInt(2),
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "int256[]",
- []*big.Int{big.NewInt(1), big.NewInt(2)},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"),
- },
- {
- "bytes1",
- [1]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes2",
- [2]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes3",
- [3]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes4",
- [4]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes5",
- [5]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes6",
- [6]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes7",
- [7]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes8",
- [8]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes9",
- [9]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes10",
- [10]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes11",
- [11]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes12",
- [12]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes13",
- [13]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes14",
- [14]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes15",
- [15]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes16",
- [16]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes17",
- [17]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes18",
- [18]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes19",
- [19]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes20",
- [20]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes21",
- [21]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes22",
- [22]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes23",
- [23]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes24",
- [24]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes24",
- [24]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes25",
- [25]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes26",
- [26]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes27",
- [27]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes28",
- [28]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes29",
- [29]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes30",
- [30]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes31",
- [31]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "bytes32",
- [32]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "uint32[2][3][4]",
- [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000700000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000000e000000000000000000000000000000000000000000000000000000000000000f000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000001300000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000015000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000170000000000000000000000000000000000000000000000000000000000000018"),
- },
- {
- "address[]",
- []common.Address{{1}, {2}},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000"),
- },
- {
- "bytes32[]",
- []common.Hash{{1}, {2}},
- common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000201000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "function",
- [24]byte{1},
- common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
- {
- "string",
- "foobar",
- common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000006666f6f6261720000000000000000000000000000000000000000000000000000"),
- },
- } {
- typ, err := NewType(test.typ)
- if err != nil {
- t.Fatalf("%v failed. Unexpected parse error: %v", i, err)
- }
-
- output, err := typ.pack(reflect.ValueOf(test.input))
- if err != nil {
- t.Fatalf("%v failed. Unexpected pack error: %v", i, err)
- }
-
- if !bytes.Equal(output, test.output) {
- t.Errorf("%d failed. Expected bytes: '%x' Got: '%x'", i, test.output, output)
- }
+ for i, test := range packUnpackTests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ encb, err := hex.DecodeString(test.packed)
+ if err != nil {
+ t.Fatalf("invalid hex %s: %v", test.packed, err)
+ }
+ inDef := fmt.Sprintf(`[{ "name" : "method", "type": "function", "inputs": %s}]`, test.def)
+ inAbi, err := JSON(strings.NewReader(inDef))
+ if err != nil {
+ t.Fatalf("invalid ABI definition %s, %v", inDef, err)
+ }
+ var packed []byte
+ packed, err = inAbi.Pack("method", test.unpacked)
+
+ if err != nil {
+ t.Fatalf("test %d (%v) failed: %v", i, test.def, err)
+ }
+ if !reflect.DeepEqual(packed[4:], encb) {
+ t.Errorf("test %d (%v) failed: expected %v, got %v", i, test.def, encb, packed[4:])
+ }
+ })
}
}
func TestMethodPack(t *testing.T) {
- abi, err := JSON(strings.NewReader(jsondata2))
+ abi, err := JSON(strings.NewReader(jsondata))
if err != nil {
t.Fatal(err)
}
- sig := abi.Methods["slice"].Id()
+ sig := abi.Methods["slice"].ID
sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
@@ -361,7 +76,7 @@ func TestMethodPack(t *testing.T) {
}
var addrA, addrB = common.Address{1}, common.Address{2}
- sig = abi.Methods["sliceAddress"].Id()
+ sig = abi.Methods["sliceAddress"].ID
sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...)
sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
sig = append(sig, common.LeftPadBytes(addrA[:], 32)...)
@@ -376,7 +91,7 @@ func TestMethodPack(t *testing.T) {
}
var addrC, addrD = common.Address{3}, common.Address{4}
- sig = abi.Methods["sliceMultiAddress"].Id()
+ sig = abi.Methods["sliceMultiAddress"].ID
sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...)
sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...)
sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
@@ -394,7 +109,7 @@ func TestMethodPack(t *testing.T) {
t.Errorf("expected %x got %x", sig, packed)
}
- sig = abi.Methods["slice256"].Id()
+ sig = abi.Methods["slice256"].ID
sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
@@ -406,6 +121,59 @@ func TestMethodPack(t *testing.T) {
if !bytes.Equal(packed, sig) {
t.Errorf("expected %x got %x", sig, packed)
}
+
+ a := [2][2]*big.Int{{big.NewInt(1), big.NewInt(1)}, {big.NewInt(2), big.NewInt(0)}}
+ sig = abi.Methods["nestedArray"].ID
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{0}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{0xa0}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes(addrC[:], 32)...)
+ sig = append(sig, common.LeftPadBytes(addrD[:], 32)...)
+ packed, err = abi.Pack("nestedArray", a, []common.Address{addrC, addrD})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(packed, sig) {
+ t.Errorf("expected %x got %x", sig, packed)
+ }
+
+ sig = abi.Methods["nestedArray2"].ID
+ sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{0x80}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ packed, err = abi.Pack("nestedArray2", [2][]uint8{{1}, {1}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(packed, sig) {
+ t.Errorf("expected %x got %x", sig, packed)
+ }
+
+ sig = abi.Methods["nestedSlice"].ID
+ sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{0x02}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{0xa0}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...)
+ sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...)
+ packed, err = abi.Pack("nestedSlice", [][]uint8{{1, 2}, {1, 2}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(packed, sig) {
+ t.Errorf("expected %x got %x", sig, packed)
+ }
}
func TestPackNumber(t *testing.T) {
diff --git a/accounts/abi/packing_test.go b/accounts/abi/packing_test.go
new file mode 100644
index 000000000..bdf00273a
--- /dev/null
+++ b/accounts/abi/packing_test.go
@@ -0,0 +1,990 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package abi
+
+import (
+ "math/big"
+
+ "github.com/tomochain/tomochain/common"
+)
+
+type packUnpackTest struct {
+ def string
+ unpacked interface{}
+ packed string
+}
+
+var packUnpackTests = []packUnpackTest{
+ // Booleans
+ {
+ def: `[{ "type": "bool" }]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: true,
+ },
+ {
+ def: `[{ "type": "bool" }]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000000",
+ unpacked: false,
+ },
+ // Integers
+ {
+ def: `[{ "type": "uint8" }]`,
+ unpacked: uint8(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{ "type": "uint8[]" }]`,
+ unpacked: []uint8{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{ "type": "uint16" }]`,
+ unpacked: uint16(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{ "type": "uint16[]" }]`,
+ unpacked: []uint16{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "uint17"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: big.NewInt(1),
+ },
+ {
+ def: `[{"type": "uint32"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: uint32(1),
+ },
+ {
+ def: `[{"type": "uint32[]"}]`,
+ unpacked: []uint32{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "uint64"}]`,
+ unpacked: uint64(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "uint64[]"}]`,
+ unpacked: []uint64{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "uint256"}]`,
+ unpacked: big.NewInt(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "uint256[]"}]`,
+ unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int8"}]`,
+ unpacked: int8(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int8[]"}]`,
+ unpacked: []int8{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int16"}]`,
+ unpacked: int16(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int16[]"}]`,
+ unpacked: []int16{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int17"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: big.NewInt(1),
+ },
+ {
+ def: `[{"type": "int32"}]`,
+ unpacked: int32(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int32"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: int32(1),
+ },
+ {
+ def: `[{"type": "int32[]"}]`,
+ unpacked: []int32{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int64"}]`,
+ unpacked: int64(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int64[]"}]`,
+ unpacked: []int64{1, 2},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int256"}]`,
+ unpacked: big.NewInt(2),
+ packed: "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ {
+ def: `[{"type": "int256"}]`,
+ packed: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+ unpacked: big.NewInt(-1),
+ },
+ {
+ def: `[{"type": "int256[]"}]`,
+ unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ },
+ // Address
+ {
+ def: `[{"type": "address"}]`,
+ packed: "0000000000000000000000000100000000000000000000000000000000000000",
+ unpacked: common.Address{1},
+ },
+ {
+ def: `[{"type": "address[]"}]`,
+ unpacked: []common.Address{{1}, {2}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000100000000000000000000000000000000000000" +
+ "0000000000000000000000000200000000000000000000000000000000000000",
+ },
+ // Bytes
+ {
+ def: `[{"type": "bytes1"}]`,
+ unpacked: [1]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes2"}]`,
+ unpacked: [2]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes3"}]`,
+ unpacked: [3]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes4"}]`,
+ unpacked: [4]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes5"}]`,
+ unpacked: [5]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes6"}]`,
+ unpacked: [6]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes7"}]`,
+ unpacked: [7]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes8"}]`,
+ unpacked: [8]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes9"}]`,
+ unpacked: [9]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes10"}]`,
+ unpacked: [10]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes11"}]`,
+ unpacked: [11]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes12"}]`,
+ unpacked: [12]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes13"}]`,
+ unpacked: [13]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes14"}]`,
+ unpacked: [14]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes15"}]`,
+ unpacked: [15]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes16"}]`,
+ unpacked: [16]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes17"}]`,
+ unpacked: [17]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes18"}]`,
+ unpacked: [18]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes19"}]`,
+ unpacked: [19]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes20"}]`,
+ unpacked: [20]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes21"}]`,
+ unpacked: [21]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes22"}]`,
+ unpacked: [22]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes23"}]`,
+ unpacked: [23]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes24"}]`,
+ unpacked: [24]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes25"}]`,
+ unpacked: [25]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes26"}]`,
+ unpacked: [26]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes27"}]`,
+ unpacked: [27]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes28"}]`,
+ unpacked: [28]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes29"}]`,
+ unpacked: [29]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes30"}]`,
+ unpacked: [30]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes31"}]`,
+ unpacked: [31]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes32"}]`,
+ unpacked: [32]byte{1},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "bytes32"}]`,
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ unpacked: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ },
+ {
+ def: `[{"type": "bytes"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0100000000000000000000000000000000000000000000000000000000000000",
+ unpacked: common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
+ },
+ {
+ def: `[{"type": "bytes32"}]`,
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ unpacked: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ },
+ // Functions
+ {
+ def: `[{"type": "function"}]`,
+ packed: "0100000000000000000000000000000000000000000000000000000000000000",
+ unpacked: [24]byte{1},
+ },
+ // Slice and Array
+ {
+ def: `[{"type": "uint8[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []uint8{1, 2},
+ },
+ {
+ def: `[{"type": "uint8[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ unpacked: []uint8{},
+ },
+ {
+ def: `[{"type": "uint256[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ unpacked: []*big.Int{},
+ },
+ {
+ def: `[{"type": "uint8[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]uint8{1, 2},
+ },
+ {
+ def: `[{"type": "int8[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]int8{1, 2},
+ },
+ {
+ def: `[{"type": "int16[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []int16{1, 2},
+ },
+ {
+ def: `[{"type": "int16[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]int16{1, 2},
+ },
+ {
+ def: `[{"type": "int32[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []int32{1, 2},
+ },
+ {
+ def: `[{"type": "int32[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]int32{1, 2},
+ },
+ {
+ def: `[{"type": "int64[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []int64{1, 2},
+ },
+ {
+ def: `[{"type": "int64[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]int64{1, 2},
+ },
+ {
+ def: `[{"type": "int256[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)},
+ },
+ {
+ def: `[{"type": "int256[3]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000003",
+ unpacked: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)},
+ },
+ // multi dimensional, if these pass, all types that don't require length prefix should pass
+ {
+ def: `[{"type": "uint8[][]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ unpacked: [][]uint8{},
+ },
+ {
+ def: `[{"type": "uint8[][]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000040" +
+ "00000000000000000000000000000000000000000000000000000000000000a0" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [][]uint8{{1, 2}, {1, 2}},
+ },
+ {
+ def: `[{"type": "uint8[][]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000040" +
+ "00000000000000000000000000000000000000000000000000000000000000a0" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000003" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000003",
+ unpacked: [][]uint8{{1, 2}, {1, 2, 3}},
+ },
+ {
+ def: `[{"type": "uint8[2][2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2][2]uint8{{1, 2}, {1, 2}},
+ },
+ {
+ def: `[{"type": "uint8[][2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000040" +
+ "0000000000000000000000000000000000000000000000000000000000000060" +
+ "0000000000000000000000000000000000000000000000000000000000000000" +
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ unpacked: [2][]uint8{{}, {}},
+ },
+ {
+ def: `[{"type": "uint8[][2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000040" +
+ "0000000000000000000000000000000000000000000000000000000000000080" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: [2][]uint8{{1}, {1}},
+ },
+ {
+ def: `[{"type": "uint8[2][]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ unpacked: [][2]uint8{},
+ },
+ {
+ def: `[{"type": "uint8[2][]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [][2]uint8{{1, 2}},
+ },
+ {
+ def: `[{"type": "uint8[2][]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [][2]uint8{{1, 2}, {1, 2}},
+ },
+ {
+ def: `[{"type": "uint16[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []uint16{1, 2},
+ },
+ {
+ def: `[{"type": "uint16[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]uint16{1, 2},
+ },
+ {
+ def: `[{"type": "uint32[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []uint32{1, 2},
+ },
+ {
+ def: `[{"type": "uint32[2][3][4]"}]`,
+ unpacked: [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000003" +
+ "0000000000000000000000000000000000000000000000000000000000000004" +
+ "0000000000000000000000000000000000000000000000000000000000000005" +
+ "0000000000000000000000000000000000000000000000000000000000000006" +
+ "0000000000000000000000000000000000000000000000000000000000000007" +
+ "0000000000000000000000000000000000000000000000000000000000000008" +
+ "0000000000000000000000000000000000000000000000000000000000000009" +
+ "000000000000000000000000000000000000000000000000000000000000000a" +
+ "000000000000000000000000000000000000000000000000000000000000000b" +
+ "000000000000000000000000000000000000000000000000000000000000000c" +
+ "000000000000000000000000000000000000000000000000000000000000000d" +
+ "000000000000000000000000000000000000000000000000000000000000000e" +
+ "000000000000000000000000000000000000000000000000000000000000000f" +
+ "0000000000000000000000000000000000000000000000000000000000000010" +
+ "0000000000000000000000000000000000000000000000000000000000000011" +
+ "0000000000000000000000000000000000000000000000000000000000000012" +
+ "0000000000000000000000000000000000000000000000000000000000000013" +
+ "0000000000000000000000000000000000000000000000000000000000000014" +
+ "0000000000000000000000000000000000000000000000000000000000000015" +
+ "0000000000000000000000000000000000000000000000000000000000000016" +
+ "0000000000000000000000000000000000000000000000000000000000000017" +
+ "0000000000000000000000000000000000000000000000000000000000000018",
+ },
+
+ {
+ def: `[{"type": "bytes32[]"}]`,
+ unpacked: [][32]byte{{1}, {2}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0100000000000000000000000000000000000000000000000000000000000000" +
+ "0200000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "uint32[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]uint32{1, 2},
+ },
+ {
+ def: `[{"type": "uint64[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []uint64{1, 2},
+ },
+ {
+ def: `[{"type": "uint64[2]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: [2]uint64{1, 2},
+ },
+ {
+ def: `[{"type": "uint256[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)},
+ },
+ {
+ def: `[{"type": "uint256[3]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000003",
+ unpacked: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)},
+ },
+ {
+ def: `[{"type": "string[4]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000080" +
+ "00000000000000000000000000000000000000000000000000000000000000c0" +
+ "0000000000000000000000000000000000000000000000000000000000000100" +
+ "0000000000000000000000000000000000000000000000000000000000000140" +
+ "0000000000000000000000000000000000000000000000000000000000000005" +
+ "48656c6c6f000000000000000000000000000000000000000000000000000000" +
+ "0000000000000000000000000000000000000000000000000000000000000005" +
+ "576f726c64000000000000000000000000000000000000000000000000000000" +
+ "000000000000000000000000000000000000000000000000000000000000000b" +
+ "476f2d657468657265756d000000000000000000000000000000000000000000" +
+ "0000000000000000000000000000000000000000000000000000000000000008" +
+ "457468657265756d000000000000000000000000000000000000000000000000",
+ unpacked: [4]string{"Hello", "World", "Go-ethereum", "Ethereum"},
+ },
+ {
+ def: `[{"type": "string[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000040" +
+ "0000000000000000000000000000000000000000000000000000000000000080" +
+ "0000000000000000000000000000000000000000000000000000000000000008" +
+ "457468657265756d000000000000000000000000000000000000000000000000" +
+ "000000000000000000000000000000000000000000000000000000000000000b" +
+ "676f2d657468657265756d000000000000000000000000000000000000000000",
+ unpacked: []string{"Ethereum", "go-ethereum"},
+ },
+ {
+ def: `[{"type": "bytes[]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000040" +
+ "0000000000000000000000000000000000000000000000000000000000000080" +
+ "0000000000000000000000000000000000000000000000000000000000000003" +
+ "f0f0f00000000000000000000000000000000000000000000000000000000000" +
+ "0000000000000000000000000000000000000000000000000000000000000003" +
+ "f0f0f00000000000000000000000000000000000000000000000000000000000",
+ unpacked: [][]byte{{0xf0, 0xf0, 0xf0}, {0xf0, 0xf0, 0xf0}},
+ },
+ {
+ def: `[{"type": "uint256[2][][]"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000040" +
+ "00000000000000000000000000000000000000000000000000000000000000e0" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "00000000000000000000000000000000000000000000000000000000000000c8" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "00000000000000000000000000000000000000000000000000000000000003e8" +
+ "0000000000000000000000000000000000000000000000000000000000000002" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "00000000000000000000000000000000000000000000000000000000000000c8" +
+ "0000000000000000000000000000000000000000000000000000000000000001" +
+ "00000000000000000000000000000000000000000000000000000000000003e8",
+ unpacked: [][][2]*big.Int{{{big.NewInt(1), big.NewInt(200)}, {big.NewInt(1), big.NewInt(1000)}}, {{big.NewInt(1), big.NewInt(200)}, {big.NewInt(1), big.NewInt(1000)}}},
+ },
+ // struct outputs
+ {
+ def: `[{"components": [{"name":"int1","type":"int256"},{"name":"int2","type":"int256"}], "type":"tuple"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: struct {
+ Int1 *big.Int
+ Int2 *big.Int
+ }{big.NewInt(1), big.NewInt(2)},
+ },
+ {
+ def: `[{"components": [{"name":"int_one","type":"int256"}], "type":"tuple"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: struct {
+ IntOne *big.Int
+ }{big.NewInt(1)},
+ },
+ {
+ def: `[{"components": [{"name":"int__one","type":"int256"}], "type":"tuple"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: struct {
+ IntOne *big.Int
+ }{big.NewInt(1)},
+ },
+ {
+ def: `[{"components": [{"name":"int_one_","type":"int256"}], "type":"tuple"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001",
+ unpacked: struct {
+ IntOne *big.Int
+ }{big.NewInt(1)},
+ },
+ {
+ def: `[{"components": [{"name":"int_one","type":"int256"}, {"name":"intone","type":"int256"}], "type":"tuple"}]`,
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" +
+ "0000000000000000000000000000000000000000000000000000000000000002",
+ unpacked: struct {
+ IntOne *big.Int
+ Intone *big.Int
+ }{big.NewInt(1), big.NewInt(2)},
+ },
+ {
+ def: `[{"type": "string"}]`,
+ unpacked: "foobar",
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000006" +
+ "666f6f6261720000000000000000000000000000000000000000000000000000",
+ },
+ {
+ def: `[{"type": "string[]"}]`,
+ unpacked: []string{"hello", "foobar"},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2
+ "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0
+ "0000000000000000000000000000000000000000000000000000000000000080" + // offset 128 to i = 1
+ "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5
+ "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0]
+ "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6
+ "666f6f6261720000000000000000000000000000000000000000000000000000", // str[1]
+ },
+ {
+ def: `[{"type": "string[2]"}]`,
+ unpacked: [2]string{"hello", "foobar"},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000040" + // offset to i = 0
+ "0000000000000000000000000000000000000000000000000000000000000080" + // offset to i = 1
+ "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5
+ "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0]
+ "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6
+ "666f6f6261720000000000000000000000000000000000000000000000000000", // str[1]
+ },
+ {
+ def: `[{"type": "bytes32[][]"}]`,
+ unpacked: [][][32]byte{{{1}, {2}}, {{3}, {4}, {5}}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2
+ "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0
+ "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1
+ "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2
+ "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0]
+ "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1]
+ "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3
+ "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0]
+ "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1]
+ "0500000000000000000000000000000000000000000000000000000000000000", // array[1][2]
+ },
+ {
+ def: `[{"type": "bytes32[][2]"}]`,
+ unpacked: [2][][32]byte{{{1}, {2}}, {{3}, {4}, {5}}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0
+ "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1
+ "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2
+ "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0]
+ "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1]
+ "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3
+ "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0]
+ "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1]
+ "0500000000000000000000000000000000000000000000000000000000000000", // array[1][2]
+ },
+ {
+ def: `[{"type": "bytes32[3][2]"}]`,
+ unpacked: [2][3][32]byte{{{1}, {2}, {3}}, {{3}, {4}, {5}}},
+ packed: "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0]
+ "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1]
+ "0300000000000000000000000000000000000000000000000000000000000000" + // array[0][2]
+ "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0]
+ "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1]
+ "0500000000000000000000000000000000000000000000000000000000000000", // array[1][2]
+ },
+ {
+ // static tuple
+ def: `[{"components": [{"name":"a","type":"int64"},
+ {"name":"b","type":"int256"},
+ {"name":"c","type":"int256"},
+ {"name":"d","type":"bool"},
+ {"name":"e","type":"bytes32[3][2]"}], "type":"tuple"}]`,
+ unpacked: struct {
+ A int64
+ B *big.Int
+ C *big.Int
+ D bool
+ E [2][3][32]byte
+ }{1, big.NewInt(1), big.NewInt(-1), true, [2][3][32]byte{{{1}, {2}, {3}}, {{3}, {4}, {5}}}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000001" + // struct[a]
+ "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b]
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // struct[c]
+ "0000000000000000000000000000000000000000000000000000000000000001" + // struct[d]
+ "0100000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][0]
+ "0200000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][1]
+ "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][2]
+ "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][0]
+ "0400000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][1]
+ "0500000000000000000000000000000000000000000000000000000000000000", // struct[e] array[1][2]
+ },
+ {
+ def: `[{"components": [{"name":"a","type":"string"},
+ {"name":"b","type":"int64"},
+ {"name":"c","type":"bytes"},
+ {"name":"d","type":"string[]"},
+ {"name":"e","type":"int256[]"},
+ {"name":"f","type":"address[]"}], "type":"tuple"}]`,
+ unpacked: struct {
+ A string
+ B int64
+ C []byte
+ D []string
+ E []*big.Int
+ F []common.Address
+ }{"foobar", 1, []byte{1}, []string{"foo", "bar"}, []*big.Int{big.NewInt(1), big.NewInt(-1)}, []common.Address{{1}, {2}}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" + // struct a
+ "00000000000000000000000000000000000000000000000000000000000000c0" + // struct[a] offset
+ "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b]
+ "0000000000000000000000000000000000000000000000000000000000000100" + // struct[c] offset
+ "0000000000000000000000000000000000000000000000000000000000000140" + // struct[d] offset
+ "0000000000000000000000000000000000000000000000000000000000000220" + // struct[e] offset
+ "0000000000000000000000000000000000000000000000000000000000000280" + // struct[f] offset
+ "0000000000000000000000000000000000000000000000000000000000000006" + // struct[a] length
+ "666f6f6261720000000000000000000000000000000000000000000000000000" + // struct[a] "foobar"
+ "0000000000000000000000000000000000000000000000000000000000000001" + // struct[c] length
+ "0100000000000000000000000000000000000000000000000000000000000000" + // []byte{1}
+ "0000000000000000000000000000000000000000000000000000000000000002" + // struct[d] length
+ "0000000000000000000000000000000000000000000000000000000000000040" + // foo offset
+ "0000000000000000000000000000000000000000000000000000000000000080" + // bar offset
+ "0000000000000000000000000000000000000000000000000000000000000003" + // foo length
+ "666f6f0000000000000000000000000000000000000000000000000000000000" + // foo
+ "0000000000000000000000000000000000000000000000000000000000000003" + // bar offset
+ "6261720000000000000000000000000000000000000000000000000000000000" + // bar
+ "0000000000000000000000000000000000000000000000000000000000000002" + // struct[e] length
+ "0000000000000000000000000000000000000000000000000000000000000001" + // 1
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // -1
+ "0000000000000000000000000000000000000000000000000000000000000002" + // struct[f] length
+ "0000000000000000000000000100000000000000000000000000000000000000" + // common.Address{1}
+ "0000000000000000000000000200000000000000000000000000000000000000", // common.Address{2}
+ },
+ {
+ def: `[{"components": [{ "type": "tuple","components": [{"name": "a","type": "uint256"},
+ {"name": "b","type": "uint256[]"}],
+ "name": "a","type": "tuple"},
+ {"name": "b","type": "uint256[]"}], "type": "tuple"}]`,
+ unpacked: struct {
+ A struct {
+ A *big.Int
+ B []*big.Int
+ }
+ B []*big.Int
+ }{
+ A: struct {
+ A *big.Int
+ B []*big.Int
+ }{big.NewInt(1), []*big.Int{big.NewInt(1), big.NewInt(2)}},
+ B: []*big.Int{big.NewInt(1), big.NewInt(2)}},
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" + // struct a
+ "0000000000000000000000000000000000000000000000000000000000000040" + // a offset
+ "00000000000000000000000000000000000000000000000000000000000000e0" + // b offset
+ "0000000000000000000000000000000000000000000000000000000000000001" + // a.a value
+ "0000000000000000000000000000000000000000000000000000000000000040" + // a.b offset
+ "0000000000000000000000000000000000000000000000000000000000000002" + // a.b length
+ "0000000000000000000000000000000000000000000000000000000000000001" + // a.b[0] value
+ "0000000000000000000000000000000000000000000000000000000000000002" + // a.b[1] value
+ "0000000000000000000000000000000000000000000000000000000000000002" + // b length
+ "0000000000000000000000000000000000000000000000000000000000000001" + // b[0] value
+ "0000000000000000000000000000000000000000000000000000000000000002", // b[1] value
+ },
+
+ {
+ def: `[{"components": [{"name": "a","type": "int256"},
+ {"name": "b","type": "int256[]"}],
+ "name": "a","type": "tuple[]"}]`,
+ unpacked: []struct {
+ A *big.Int
+ B []*big.Int
+ }{
+ {big.NewInt(-1), []*big.Int{big.NewInt(1), big.NewInt(3)}},
+ {big.NewInt(1), []*big.Int{big.NewInt(2), big.NewInt(-1)}},
+ },
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000002" + // tuple length
+ "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset
+ "00000000000000000000000000000000000000000000000000000000000000e0" + // tuple[1] offset
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A
+ "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0].B offset
+ "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].B length
+ "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].B[0] value
+ "0000000000000000000000000000000000000000000000000000000000000003" + // tuple[0].B[1] value
+ "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A
+ "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[1].B offset
+ "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B length
+ "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B[0] value
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", // tuple[1].B[1] value
+ },
+ {
+ def: `[{"components": [{"name": "a","type": "int256"},
+ {"name": "b","type": "int256"}],
+ "name": "a","type": "tuple[2]"}]`,
+ unpacked: [2]struct {
+ A *big.Int
+ B *big.Int
+ }{
+ {big.NewInt(-1), big.NewInt(1)},
+ {big.NewInt(1), big.NewInt(-1)},
+ },
+ packed: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].a
+ "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].b
+ "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].a
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", // tuple[1].b
+ },
+ {
+ def: `[{"components": [{"name": "a","type": "int256[]"}],
+ "name": "a","type": "tuple[2]"}]`,
+ unpacked: [2]struct {
+ A []*big.Int
+ }{
+ {[]*big.Int{big.NewInt(-1), big.NewInt(1)}},
+ {[]*big.Int{big.NewInt(1), big.NewInt(-1)}},
+ },
+ packed: "0000000000000000000000000000000000000000000000000000000000000020" +
+ "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset
+ "00000000000000000000000000000000000000000000000000000000000000c0" + // tuple[1] offset
+ "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[0].A offset
+ "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].A length
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A[0]
+ "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].A[1]
+ "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[1].A offset
+ "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].A length
+ "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A[0]
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", // tuple[1].A[1]
+ },
+}
diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go
index 2e6bf7098..0f4948ac8 100644
--- a/accounts/abi/reflect.go
+++ b/accounts/abi/reflect.go
@@ -17,48 +17,76 @@
package abi
import (
+ "errors"
"fmt"
+ "math/big"
"reflect"
+ "strings"
)
+// ConvertType converts an interface of a runtime type into a interface of the
+// given type
+// e.g. turn
+// var fields []reflect.StructField
+//
+// fields = append(fields, reflect.StructField{
+// Name: "X",
+// Type: reflect.TypeOf(new(big.Int)),
+// Tag: reflect.StructTag("json:\"" + "x" + "\""),
+// }
+//
+// into
+// type TupleT struct { X *big.Int }
+func ConvertType(in interface{}, proto interface{}) interface{} {
+ protoType := reflect.TypeOf(proto)
+ if reflect.TypeOf(in).ConvertibleTo(protoType) {
+ return reflect.ValueOf(in).Convert(protoType).Interface()
+ }
+ // Use set as a last ditch effort
+ if err := set(reflect.ValueOf(proto), reflect.ValueOf(in)); err != nil {
+ panic(err)
+ }
+ return proto
+}
+
// indirect recursively dereferences the value until it either gets the value
// or finds a big.Int
func indirect(v reflect.Value) reflect.Value {
- if v.Kind() == reflect.Ptr && v.Elem().Type() != derefbig_t {
+ if v.Kind() == reflect.Ptr && v.Elem().Type() != reflect.TypeOf(big.Int{}) {
return indirect(v.Elem())
}
return v
}
-// reflectIntKind returns the reflect using the given size and
+// reflectIntType returns the reflect using the given size and
// unsignedness.
-func reflectIntKindAndType(unsigned bool, size int) (reflect.Kind, reflect.Type) {
+func reflectIntType(unsigned bool, size int) reflect.Type {
+ if unsigned {
+ switch size {
+ case 8:
+ return reflect.TypeOf(uint8(0))
+ case 16:
+ return reflect.TypeOf(uint16(0))
+ case 32:
+ return reflect.TypeOf(uint32(0))
+ case 64:
+ return reflect.TypeOf(uint64(0))
+ }
+ }
switch size {
case 8:
- if unsigned {
- return reflect.Uint8, uint8_t
- }
- return reflect.Int8, int8_t
+ return reflect.TypeOf(int8(0))
case 16:
- if unsigned {
- return reflect.Uint16, uint16_t
- }
- return reflect.Int16, int16_t
+ return reflect.TypeOf(int16(0))
case 32:
- if unsigned {
- return reflect.Uint32, uint32_t
- }
- return reflect.Int32, int32_t
+ return reflect.TypeOf(int32(0))
case 64:
- if unsigned {
- return reflect.Uint64, uint64_t
- }
- return reflect.Int64, int64_t
+ return reflect.TypeOf(int64(0))
}
- return reflect.Ptr, big_t
+ return reflect.TypeOf(&big.Int{})
}
-// mustArrayToBytesSlice creates a new byte slice with the exact same size as value
+// mustArrayToByteSlice creates a new byte slice with the exact same size as value
// and copies the bytes in value to the new slice.
func mustArrayToByteSlice(value reflect.Value) reflect.Value {
slice := reflect.MakeSlice(reflect.TypeOf([]byte{}), value.Len(), value.Len())
@@ -70,59 +98,176 @@ func mustArrayToByteSlice(value reflect.Value) reflect.Value {
//
// set is a bit more lenient when it comes to assignment and doesn't force an as
// strict ruleset as bare `reflect` does.
-func set(dst, src reflect.Value, output Argument) error {
- dstType := dst.Type()
- srcType := src.Type()
+func set(dst, src reflect.Value) error {
+ dstType, srcType := dst.Type(), src.Type()
switch {
- case dstType.AssignableTo(srcType):
+ case dstType.Kind() == reflect.Interface && dst.Elem().IsValid():
+ return set(dst.Elem(), src)
+ case dstType.Kind() == reflect.Ptr && dstType.Elem() != reflect.TypeOf(big.Int{}):
+ return set(dst.Elem(), src)
+ case srcType.AssignableTo(dstType) && dst.CanSet():
dst.Set(src)
- case dstType.Kind() == reflect.Interface:
- dst.Set(src)
- case dstType.Kind() == reflect.Ptr:
- return set(dst.Elem(), src, output)
+ case dstType.Kind() == reflect.Slice && srcType.Kind() == reflect.Slice && dst.CanSet():
+ return setSlice(dst, src)
+ case dstType.Kind() == reflect.Array:
+ return setArray(dst, src)
+ case dstType.Kind() == reflect.Struct:
+ return setStruct(dst, src)
default:
return fmt.Errorf("abi: cannot unmarshal %v in to %v", src.Type(), dst.Type())
}
return nil
}
-// requireAssignable assures that `dest` is a pointer and it's not an interface.
-func requireAssignable(dst, src reflect.Value) error {
- if dst.Kind() != reflect.Ptr && dst.Kind() != reflect.Interface {
- return fmt.Errorf("abi: cannot unmarshal %v into %v", src.Type(), dst.Type())
+// setSlice attempts to assign src to dst when slices are not assignable by default
+// e.g. src: [][]byte -> dst: [][15]byte
+// setSlice ignores if we cannot copy all of src' elements.
+func setSlice(dst, src reflect.Value) error {
+ slice := reflect.MakeSlice(dst.Type(), src.Len(), src.Len())
+ for i := 0; i < src.Len(); i++ {
+ if src.Index(i).Kind() == reflect.Struct {
+ if err := set(slice.Index(i), src.Index(i)); err != nil {
+ return err
+ }
+ } else {
+ // e.g. [][32]uint8 to []common.Hash
+ if err := set(slice.Index(i), src.Index(i)); err != nil {
+ return err
+ }
+ }
}
- return nil
+ if dst.CanSet() {
+ dst.Set(slice)
+ return nil
+ }
+ return errors.New("Cannot set slice, destination not settable")
}
-// requireUnpackKind verifies preconditions for unpacking `args` into `kind`
-func requireUnpackKind(v reflect.Value, t reflect.Type, k reflect.Kind,
- args Arguments) error {
+func setArray(dst, src reflect.Value) error {
+ if src.Kind() == reflect.Ptr {
+ return set(dst, indirect(src))
+ }
+ array := reflect.New(dst.Type()).Elem()
+ min := src.Len()
+ if src.Len() > dst.Len() {
+ min = dst.Len()
+ }
+ for i := 0; i < min; i++ {
+ if err := set(array.Index(i), src.Index(i)); err != nil {
+ return err
+ }
+ }
+ if dst.CanSet() {
+ dst.Set(array)
+ return nil
+ }
+ return errors.New("Cannot set array, destination not settable")
+}
- switch k {
- case reflect.Struct:
- case reflect.Slice, reflect.Array:
- if minLen := args.LengthNonIndexed(); v.Len() < minLen {
- return fmt.Errorf("abi: insufficient number of elements in the list/array for unpack, want %d, got %d",
- minLen, v.Len())
+func setStruct(dst, src reflect.Value) error {
+ for i := 0; i < src.NumField(); i++ {
+ srcField := src.Field(i)
+ dstField := dst.Field(i)
+ if !dstField.IsValid() || !srcField.IsValid() {
+ return fmt.Errorf("Could not find src field: %v value: %v in destination", srcField.Type().Name(), srcField)
+ }
+ if err := set(dstField, srcField); err != nil {
+ return err
}
- default:
- return fmt.Errorf("abi: cannot unmarshal tuple into %v", t)
}
return nil
}
-// requireUniqueStructFieldNames makes sure field names don't collide
-func requireUniqueStructFieldNames(args Arguments) error {
- exists := make(map[string]bool)
- for _, arg := range args {
- field := capitalise(arg.Name)
- if field == "" {
- return fmt.Errorf("abi: purely underscored output cannot unpack to struct")
+// mapArgNamesToStructFields maps a slice of argument names to struct fields.
+// first round: for each Exportable field that contains a `abi:""` tag
+//
+// and this field name exists in the given argument name list, pair them together.
+//
+// second round: for each argument name that has not been already linked,
+//
+// find what variable is expected to be mapped into, if it exists and has not been
+// used, pair them.
+//
+// Note this function assumes the given value is a struct value.
+func mapArgNamesToStructFields(argNames []string, value reflect.Value) (map[string]string, error) {
+ typ := value.Type()
+
+ abi2struct := make(map[string]string)
+ struct2abi := make(map[string]string)
+
+ // first round ~~~
+ for i := 0; i < typ.NumField(); i++ {
+ structFieldName := typ.Field(i).Name
+
+ // skip private struct fields.
+ if structFieldName[:1] != strings.ToUpper(structFieldName[:1]) {
+ continue
+ }
+ // skip fields that have no abi:"" tag.
+ tagName, ok := typ.Field(i).Tag.Lookup("abi")
+ if !ok {
+ continue
}
- if exists[field] {
- return fmt.Errorf("abi: multiple outputs mapping to the same struct field '%s'", field)
+ // check if tag is empty.
+ if tagName == "" {
+ return nil, fmt.Errorf("struct: abi tag in '%s' is empty", structFieldName)
+ }
+ // check which argument field matches with the abi tag.
+ found := false
+ for _, arg := range argNames {
+ if arg == tagName {
+ if abi2struct[arg] != "" {
+ return nil, fmt.Errorf("struct: abi tag in '%s' already mapped", structFieldName)
+ }
+ // pair them
+ abi2struct[arg] = structFieldName
+ struct2abi[structFieldName] = arg
+ found = true
+ }
+ }
+ // check if this tag has been mapped.
+ if !found {
+ return nil, fmt.Errorf("struct: abi tag '%s' defined but not found in abi", tagName)
}
- exists[field] = true
}
- return nil
+
+ // second round ~~~
+ for _, argName := range argNames {
+
+ structFieldName := ToCamelCase(argName)
+
+ if structFieldName == "" {
+ return nil, fmt.Errorf("abi: purely underscored output cannot unpack to struct")
+ }
+
+ // this abi has already been paired, skip it... unless there exists another, yet unassigned
+ // struct field with the same field name. If so, raise an error:
+ // abi: [ { "name": "value" } ]
+ // struct { Value *big.Int , Value1 *big.Int `abi:"value"`}
+ if abi2struct[argName] != "" {
+ if abi2struct[argName] != structFieldName &&
+ struct2abi[structFieldName] == "" &&
+ value.FieldByName(structFieldName).IsValid() {
+ return nil, fmt.Errorf("abi: multiple variables maps to the same abi field '%s'", argName)
+ }
+ continue
+ }
+
+ // return an error if this struct field has already been paired.
+ if struct2abi[structFieldName] != "" {
+ return nil, fmt.Errorf("abi: multiple outputs mapping to the same struct field '%s'", structFieldName)
+ }
+
+ if value.FieldByName(structFieldName).IsValid() {
+ // pair them
+ abi2struct[argName] = structFieldName
+ struct2abi[structFieldName] = argName
+ } else {
+ // not paired, but annotate as used, to detect cases like
+ // abi : [ { "name": "value" }, { "name": "_value" } ]
+ // struct { Value *big.Int }
+ struct2abi[structFieldName] = argName
+ }
+ }
+ return abi2struct, nil
}
diff --git a/accounts/abi/type.go b/accounts/abi/type.go
index a1f13ffa2..d24387796 100644
--- a/accounts/abi/type.go
+++ b/accounts/abi/type.go
@@ -17,11 +17,16 @@
package abi
import (
+ "errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
+ "unicode"
+ "unicode/utf8"
+
+ "github.com/tomochain/tomochain/common"
)
// Type enumerator
@@ -32,6 +37,7 @@ const (
StringTy
SliceTy
ArrayTy
+ TupleTy
AddressTy
FixedBytesTy
BytesTy
@@ -40,16 +46,19 @@ const (
FunctionTy
)
-// Type is the reflection of the supported argument type
+// Type is the reflection of the supported argument type.
type Type struct {
Elem *Type
-
- Kind reflect.Kind
- Type reflect.Type
Size int
T byte // Our own type checking
stringKind string // holds the unparsed string for deriving signatures
+
+ // Tuple relative fields
+ TupleRawName string // Raw struct name defined in source code, may be empty.
+ TupleElems []*Type // Type information of all tuple fields
+ TupleRawNames []string // Raw field name of all tuple fields
+ TupleType reflect.Type // Underlying struct of the tuple
}
var (
@@ -58,20 +67,24 @@ var (
)
// NewType creates a new reflection type of abi type given in t.
-func NewType(t string) (typ Type, err error) {
+func NewType(t string, internalType string, components []ArgumentMarshaling) (typ Type, err error) {
// check that array brackets are equal if they exist
if strings.Count(t, "[") != strings.Count(t, "]") {
- return Type{}, fmt.Errorf("invalid arg type in abi")
+ return Type{}, errors.New("invalid arg type in abi")
}
-
typ.stringKind = t
// if there are brackets, get ready to go into slice/array mode and
// recursively create the type
if strings.Count(t, "[") != 0 {
- i := strings.LastIndex(t, "[")
+ // Note internalType can be empty here.
+ subInternal := internalType
+ if i := strings.LastIndex(internalType, "["); i != -1 {
+ subInternal = subInternal[:i]
+ }
// recursively embed the type
- embeddedType, err := NewType(t[:i])
+ i := strings.LastIndex(t, "[")
+ embeddedType, err := NewType(t[:i], subInternal, components)
if err != nil {
return Type{}, err
}
@@ -84,26 +97,29 @@ func NewType(t string) (typ Type, err error) {
if len(intz) == 0 {
// is a slice
typ.T = SliceTy
- typ.Kind = reflect.Slice
typ.Elem = &embeddedType
- typ.Type = reflect.SliceOf(embeddedType.Type)
+ typ.stringKind = embeddedType.stringKind + sliced
} else if len(intz) == 1 {
- // is a array
+ // is an array
typ.T = ArrayTy
- typ.Kind = reflect.Array
typ.Elem = &embeddedType
typ.Size, err = strconv.Atoi(intz[0])
if err != nil {
return Type{}, fmt.Errorf("abi: error parsing variable size: %v", err)
}
- typ.Type = reflect.ArrayOf(typ.Size, embeddedType.Type)
+ typ.stringKind = embeddedType.stringKind + sliced
} else {
- return Type{}, fmt.Errorf("invalid formatting of array type")
+ return Type{}, errors.New("invalid formatting of array type")
}
return typ, err
}
// parse the type and size of the abi-type.
- parsedType := typeRegex.FindAllStringSubmatch(t, -1)[0]
+ matches := typeRegex.FindAllStringSubmatch(t, -1)
+ if len(matches) == 0 {
+ return Type{}, fmt.Errorf("invalid type '%v'", t)
+ }
+ parsedType := matches[0]
+
// varSize is the size of the variable
var varSize int
if len(parsedType[3]) > 0 {
@@ -122,42 +138,87 @@ func NewType(t string) (typ Type, err error) {
// varType is the parsed abi type
switch varType := parsedType[1]; varType {
case "int":
- typ.Kind, typ.Type = reflectIntKindAndType(false, varSize)
typ.Size = varSize
typ.T = IntTy
case "uint":
- typ.Kind, typ.Type = reflectIntKindAndType(true, varSize)
typ.Size = varSize
typ.T = UintTy
case "bool":
- typ.Kind = reflect.Bool
typ.T = BoolTy
- typ.Type = reflect.TypeOf(bool(false))
case "address":
- typ.Kind = reflect.Array
- typ.Type = address_t
typ.Size = 20
typ.T = AddressTy
case "string":
- typ.Kind = reflect.String
- typ.Type = reflect.TypeOf("")
typ.T = StringTy
case "bytes":
if varSize == 0 {
typ.T = BytesTy
- typ.Kind = reflect.Slice
- typ.Type = reflect.SliceOf(reflect.TypeOf(byte(0)))
} else {
+ if varSize > 32 {
+ return Type{}, fmt.Errorf("unsupported arg type: %s", t)
+ }
typ.T = FixedBytesTy
- typ.Kind = reflect.Array
typ.Size = varSize
- typ.Type = reflect.ArrayOf(varSize, reflect.TypeOf(byte(0)))
}
+ case "tuple":
+ var (
+ fields []reflect.StructField
+ elems []*Type
+ names []string
+ expression string // canonical parameter expression
+ used = make(map[string]bool)
+ )
+ expression += "("
+ for idx, c := range components {
+ cType, err := NewType(c.Type, c.InternalType, c.Components)
+ if err != nil {
+ return Type{}, err
+ }
+ name := ToCamelCase(c.Name)
+ if name == "" {
+ return Type{}, errors.New("abi: purely anonymous or underscored field is not supported")
+ }
+ fieldName := ResolveNameConflict(name, func(s string) bool { return used[s] })
+ if err != nil {
+ return Type{}, err
+ }
+ used[fieldName] = true
+ if !isValidFieldName(fieldName) {
+ return Type{}, fmt.Errorf("field %d has invalid name", idx)
+ }
+ fields = append(fields, reflect.StructField{
+ Name: fieldName, // reflect.StructOf will panic for any exported field.
+ Type: cType.GetType(),
+ Tag: reflect.StructTag("json:\"" + c.Name + "\""),
+ })
+ elems = append(elems, &cType)
+ names = append(names, c.Name)
+ expression += cType.stringKind
+ if idx != len(components)-1 {
+ expression += ","
+ }
+ }
+ expression += ")"
+
+ typ.TupleType = reflect.StructOf(fields)
+ typ.TupleElems = elems
+ typ.TupleRawNames = names
+ typ.T = TupleTy
+ typ.stringKind = expression
+
+ const structPrefix = "struct "
+ // After solidity 0.5.10, a new field of abi "internalType"
+ // is introduced. From that we can obtain the struct name
+ // user defined in the source code.
+ if internalType != "" && strings.HasPrefix(internalType, structPrefix) {
+ // Foo.Bar type definition is not allowed in golang,
+ // convert the format to FooBar
+ typ.TupleRawName = strings.ReplaceAll(internalType[len(structPrefix):], ".", "")
+ }
+
case "function":
- typ.Kind = reflect.Array
typ.T = FunctionTy
typ.Size = 24
- typ.Type = reflect.ArrayOf(24, reflect.TypeOf(byte(0)))
default:
return Type{}, fmt.Errorf("unsupported arg type: %s", t)
}
@@ -165,7 +226,43 @@ func NewType(t string) (typ Type, err error) {
return
}
-// String implements Stringer
+// GetType returns the reflection type of the ABI type.
+func (t Type) GetType() reflect.Type {
+ switch t.T {
+ case IntTy:
+ return reflectIntType(false, t.Size)
+ case UintTy:
+ return reflectIntType(true, t.Size)
+ case BoolTy:
+ return reflect.TypeOf(false)
+ case StringTy:
+ return reflect.TypeOf("")
+ case SliceTy:
+ return reflect.SliceOf(t.Elem.GetType())
+ case ArrayTy:
+ return reflect.ArrayOf(t.Size, t.Elem.GetType())
+ case TupleTy:
+ return t.TupleType
+ case AddressTy:
+ return reflect.TypeOf(common.Address{})
+ case FixedBytesTy:
+ return reflect.ArrayOf(t.Size, reflect.TypeOf(byte(0)))
+ case BytesTy:
+ return reflect.SliceOf(reflect.TypeOf(byte(0)))
+ case HashTy:
+ // hashtype currently not used
+ return reflect.ArrayOf(32, reflect.TypeOf(byte(0)))
+ case FixedPointTy:
+ // fixedpoint type currently not used
+ return reflect.ArrayOf(32, reflect.TypeOf(byte(0)))
+ case FunctionTy:
+ return reflect.ArrayOf(24, reflect.TypeOf(byte(0)))
+ default:
+ panic("Invalid type")
+ }
+}
+
+// String implements Stringer.
func (t Type) String() (out string) {
return t.stringKind
}
@@ -173,32 +270,157 @@ func (t Type) String() (out string) {
func (t Type) pack(v reflect.Value) ([]byte, error) {
// dereference pointer first if it's a pointer
v = indirect(v)
-
if err := typeCheck(t, v); err != nil {
return nil, err
}
- if t.T == SliceTy || t.T == ArrayTy {
- var packed []byte
+ switch t.T {
+ case SliceTy, ArrayTy:
+ var ret []byte
+ if t.requiresLengthPrefix() {
+ // append length
+ ret = append(ret, packNum(reflect.ValueOf(v.Len()))...)
+ }
+
+ // calculate offset if any
+ offset := 0
+ offsetReq := isDynamicType(*t.Elem)
+ if offsetReq {
+ offset = getTypeSize(*t.Elem) * v.Len()
+ }
+ var tail []byte
for i := 0; i < v.Len(); i++ {
val, err := t.Elem.pack(v.Index(i))
if err != nil {
return nil, err
}
- packed = append(packed, val...)
+ if !offsetReq {
+ ret = append(ret, val...)
+ continue
+ }
+ ret = append(ret, packNum(reflect.ValueOf(offset))...)
+ offset += len(val)
+ tail = append(tail, val...)
}
- if t.T == SliceTy {
- return packBytesSlice(packed, v.Len()), nil
- } else if t.T == ArrayTy {
- return packed, nil
+ return append(ret, tail...), nil
+ case TupleTy:
+ // (T1,...,Tk) for k >= 0 and any types T1, …, Tk
+ // enc(X) = head(X(1)) ... head(X(k)) tail(X(1)) ... tail(X(k))
+ // where X = (X(1), ..., X(k)) and head and tail are defined for Ti being a static
+ // type as
+ // head(X(i)) = enc(X(i)) and tail(X(i)) = "" (the empty string)
+ // and as
+ // head(X(i)) = enc(len(head(X(1)) ... head(X(k)) tail(X(1)) ... tail(X(i-1))))
+ // tail(X(i)) = enc(X(i))
+ // otherwise, i.e. if Ti is a dynamic type.
+ fieldmap, err := mapArgNamesToStructFields(t.TupleRawNames, v)
+ if err != nil {
+ return nil, err
+ }
+ // Calculate prefix occupied size.
+ offset := 0
+ for _, elem := range t.TupleElems {
+ offset += getTypeSize(*elem)
+ }
+ var ret, tail []byte
+ for i, elem := range t.TupleElems {
+ field := v.FieldByName(fieldmap[t.TupleRawNames[i]])
+ if !field.IsValid() {
+ return nil, fmt.Errorf("field %s for tuple not found in the given struct", t.TupleRawNames[i])
+ }
+ val, err := elem.pack(field)
+ if err != nil {
+ return nil, err
+ }
+ if isDynamicType(*elem) {
+ ret = append(ret, packNum(reflect.ValueOf(offset))...)
+ tail = append(tail, val...)
+ offset += len(val)
+ } else {
+ ret = append(ret, val...)
+ }
}
+ return append(ret, tail...), nil
+
+ default:
+ return packElement(t, v)
}
- return packElement(t, v), nil
}
-// requireLengthPrefix returns whether the type requires any sort of length
+// requiresLengthPrefix returns whether the type requires any sort of length
// prefixing.
func (t Type) requiresLengthPrefix() bool {
return t.T == StringTy || t.T == BytesTy || t.T == SliceTy
}
+
+// isDynamicType returns true if the type is dynamic.
+// The following types are called “dynamic”:
+// * bytes
+// * string
+// * T[] for any T
+// * T[k] for any dynamic T and any k >= 0
+// * (T1,...,Tk) if Ti is dynamic for some 1 <= i <= k
+func isDynamicType(t Type) bool {
+ if t.T == TupleTy {
+ for _, elem := range t.TupleElems {
+ if isDynamicType(*elem) {
+ return true
+ }
+ }
+ return false
+ }
+ return t.T == StringTy || t.T == BytesTy || t.T == SliceTy || (t.T == ArrayTy && isDynamicType(*t.Elem))
+}
+
+// getTypeSize returns the size that this type needs to occupy.
+// We distinguish static and dynamic types. Static types are encoded in-place
+// and dynamic types are encoded at a separately allocated location after the
+// current block.
+// So for a static variable, the size returned represents the size that the
+// variable actually occupies.
+// For a dynamic variable, the returned size is fixed 32 bytes, which is used
+// to store the location reference for actual value storage.
+func getTypeSize(t Type) int {
+ if t.T == ArrayTy && !isDynamicType(*t.Elem) {
+ // Recursively calculate type size if it is a nested array
+ if t.Elem.T == ArrayTy || t.Elem.T == TupleTy {
+ return t.Size * getTypeSize(*t.Elem)
+ }
+ return t.Size * 32
+ } else if t.T == TupleTy && !isDynamicType(t) {
+ total := 0
+ for _, elem := range t.TupleElems {
+ total += getTypeSize(*elem)
+ }
+ return total
+ }
+ return 32
+}
+
+// isLetter reports whether a given 'rune' is classified as a Letter.
+// This method is copied from reflect/type.go
+func isLetter(ch rune) bool {
+ return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch >= utf8.RuneSelf && unicode.IsLetter(ch)
+}
+
+// isValidFieldName checks if a string is a valid (struct) field name or not.
+//
+// According to the language spec, a field name should be an identifier.
+//
+// identifier = letter { letter | unicode_digit } .
+// letter = unicode_letter | "_" .
+// This method is copied from reflect/type.go
+func isValidFieldName(fieldName string) bool {
+ for i, c := range fieldName {
+ if i == 0 && !isLetter(c) {
+ return false
+ }
+
+ if !(isLetter(c) || unicode.IsDigit(c)) {
+ return false
+ }
+ }
+
+ return len(fieldName) > 0
+}
diff --git a/accounts/abi/type_test.go b/accounts/abi/type_test.go
index fc23b0752..3b8902941 100644
--- a/accounts/abi/type_test.go
+++ b/accounts/abi/type_test.go
@@ -22,82 +22,92 @@ import (
"testing"
"github.com/davecgh/go-spew/spew"
+
"github.com/tomochain/tomochain/common"
)
-// typeWithoutStringer is a alias for the Type type which simply doesn't implement
+// typeWithoutStringer is an alias for the Type type which simply doesn't implement
// the stringer interface to allow printing type details in the tests below.
type typeWithoutStringer Type
// Tests that all allowed types get recognized by the type parser.
func TestTypeRegexp(t *testing.T) {
tests := []struct {
- blob string
- kind Type
+ blob string
+ components []ArgumentMarshaling
+ kind Type
}{
- {"bool", Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}},
- {"bool[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool(nil)), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}},
- {"bool[2]", Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}},
- {"bool[2][]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}},
- {"bool[][]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}},
- {"bool[][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}},
- {"bool[2][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}},
- {"bool[2][][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][][2]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}, stringKind: "bool[2][][2]"}},
- {"bool[2][2][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}, stringKind: "bool[2][2][2]"}},
- {"bool[][][]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}, stringKind: "bool[][][]"}},
- {"bool[][2][]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][2][]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}, stringKind: "bool[][2][]"}},
- {"int8", Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}},
- {"int16", Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}},
- {"int32", Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}},
- {"int64", Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}},
- {"int256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}},
- {"int8[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}},
- {"int8[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}},
- {"int16[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}},
- {"int16[2]", Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}},
- {"int32[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}},
- {"int32[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}},
- {"int64[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}},
- {"int64[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}},
- {"int256[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}},
- {"int256[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}},
- {"uint8", Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}},
- {"uint16", Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}},
- {"uint32", Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}},
- {"uint64", Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}},
- {"uint256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}},
- {"uint8[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}},
- {"uint8[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}},
- {"uint16[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}},
- {"uint16[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}},
- {"uint32[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}},
- {"uint32[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}},
- {"uint64[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}},
- {"uint64[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}},
- {"uint256[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}},
- {"uint256[2]", Type{Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]*big.Int{}), Size: 2, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}},
- {"bytes32", Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}},
- {"bytes[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]byte{}), Elem: &Type{Kind: reflect.Slice, Type: reflect.TypeOf([]byte{}), T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}},
- {"bytes[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]byte{}), Elem: &Type{T: BytesTy, Type: reflect.TypeOf([]byte{}), Kind: reflect.Slice, stringKind: "bytes"}, stringKind: "bytes[2]"}},
- {"bytes32[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][32]byte{}), Elem: &Type{Kind: reflect.Array, Type: reflect.TypeOf([32]byte{}), T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[]"}},
- {"bytes32[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][32]byte{}), Elem: &Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}, stringKind: "bytes32[2]"}},
- {"string", Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}},
- {"string[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]string{}), Elem: &Type{Kind: reflect.String, Type: reflect.TypeOf(""), T: StringTy, stringKind: "string"}, stringKind: "string[]"}},
- {"string[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]string{}), Elem: &Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}, stringKind: "string[2]"}},
- {"address", Type{Kind: reflect.Array, Type: address_t, Size: 20, T: AddressTy, stringKind: "address"}},
- {"address[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: address_t, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[]"}},
- {"address[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: address_t, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[2]"}},
+ {"bool", nil, Type{T: BoolTy, stringKind: "bool"}},
+ {"bool[]", nil, Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}},
+ {"bool[2]", nil, Type{Size: 2, T: ArrayTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}},
+ {"bool[2][]", nil, Type{T: SliceTy, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}},
+ {"bool[][]", nil, Type{T: SliceTy, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}},
+ {"bool[][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}},
+ {"bool[2][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}},
+ {"bool[2][][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: SliceTy, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}, stringKind: "bool[2][][2]"}},
+ {"bool[2][2][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}, stringKind: "bool[2][2][2]"}},
+ {"bool[][][]", nil, Type{T: SliceTy, Elem: &Type{T: SliceTy, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}, stringKind: "bool[][][]"}},
+ {"bool[][2][]", nil, Type{T: SliceTy, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}, stringKind: "bool[][2][]"}},
+ {"int8", nil, Type{Size: 8, T: IntTy, stringKind: "int8"}},
+ {"int16", nil, Type{Size: 16, T: IntTy, stringKind: "int16"}},
+ {"int32", nil, Type{Size: 32, T: IntTy, stringKind: "int32"}},
+ {"int64", nil, Type{Size: 64, T: IntTy, stringKind: "int64"}},
+ {"int256", nil, Type{Size: 256, T: IntTy, stringKind: "int256"}},
+ {"int8[]", nil, Type{T: SliceTy, Elem: &Type{Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}},
+ {"int8[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}},
+ {"int16[]", nil, Type{T: SliceTy, Elem: &Type{Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}},
+ {"int16[2]", nil, Type{Size: 2, T: ArrayTy, Elem: &Type{Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}},
+ {"int32[]", nil, Type{T: SliceTy, Elem: &Type{Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}},
+ {"int32[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}},
+ {"int64[]", nil, Type{T: SliceTy, Elem: &Type{Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}},
+ {"int64[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}},
+ {"int256[]", nil, Type{T: SliceTy, Elem: &Type{Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}},
+ {"int256[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}},
+ {"uint8", nil, Type{Size: 8, T: UintTy, stringKind: "uint8"}},
+ {"uint16", nil, Type{Size: 16, T: UintTy, stringKind: "uint16"}},
+ {"uint32", nil, Type{Size: 32, T: UintTy, stringKind: "uint32"}},
+ {"uint64", nil, Type{Size: 64, T: UintTy, stringKind: "uint64"}},
+ {"uint256", nil, Type{Size: 256, T: UintTy, stringKind: "uint256"}},
+ {"uint8[]", nil, Type{T: SliceTy, Elem: &Type{Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}},
+ {"uint8[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}},
+ {"uint16[]", nil, Type{T: SliceTy, Elem: &Type{Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}},
+ {"uint16[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}},
+ {"uint32[]", nil, Type{T: SliceTy, Elem: &Type{Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}},
+ {"uint32[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}},
+ {"uint64[]", nil, Type{T: SliceTy, Elem: &Type{Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}},
+ {"uint64[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}},
+ {"uint256[]", nil, Type{T: SliceTy, Elem: &Type{Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}},
+ {"uint256[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}},
+ {"bytes32", nil, Type{T: FixedBytesTy, Size: 32, stringKind: "bytes32"}},
+ {"bytes[]", nil, Type{T: SliceTy, Elem: &Type{T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}},
+ {"bytes[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[2]"}},
+ {"bytes32[]", nil, Type{T: SliceTy, Elem: &Type{T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[]"}},
+ {"bytes32[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[2]"}},
+ {"string", nil, Type{T: StringTy, stringKind: "string"}},
+ {"string[]", nil, Type{T: SliceTy, Elem: &Type{T: StringTy, stringKind: "string"}, stringKind: "string[]"}},
+ {"string[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: StringTy, stringKind: "string"}, stringKind: "string[2]"}},
+ {"address", nil, Type{Size: 20, T: AddressTy, stringKind: "address"}},
+ {"address[]", nil, Type{T: SliceTy, Elem: &Type{Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[]"}},
+ {"address[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[2]"}},
// TODO when fixed types are implemented properly
- // {"fixed", Type{}},
- // {"fixed128x128", Type{}},
- // {"fixed[]", Type{}},
- // {"fixed[2]", Type{}},
- // {"fixed128x128[]", Type{}},
- // {"fixed128x128[2]", Type{}},
+ // {"fixed", nil, Type{}},
+ // {"fixed128x128", nil, Type{}},
+ // {"fixed[]", nil, Type{}},
+ // {"fixed[2]", nil, Type{}},
+ // {"fixed128x128[]", nil, Type{}},
+ // {"fixed128x128[2]", nil, Type{}},
+ {"tuple", []ArgumentMarshaling{{Name: "a", Type: "int64"}}, Type{T: TupleTy, TupleType: reflect.TypeOf(struct {
+ A int64 `json:"a"`
+ }{}), stringKind: "(int64)",
+ TupleElems: []*Type{{T: IntTy, Size: 64, stringKind: "int64"}}, TupleRawNames: []string{"a"}}},
+ {"tuple with long name", []ArgumentMarshaling{{Name: "aTypicalParamName", Type: "int64"}}, Type{T: TupleTy, TupleType: reflect.TypeOf(struct {
+ ATypicalParamName int64 `json:"aTypicalParamName"`
+ }{}), stringKind: "(int64)",
+ TupleElems: []*Type{{T: IntTy, Size: 64, stringKind: "int64"}}, TupleRawNames: []string{"aTypicalParamName"}}},
}
for _, tt := range tests {
- typ, err := NewType(tt.blob)
+ typ, err := NewType(tt.blob, "", tt.components)
if err != nil {
t.Errorf("type %q: failed to parse type string: %v", tt.blob, err)
}
@@ -109,151 +119,170 @@ func TestTypeRegexp(t *testing.T) {
func TestTypeCheck(t *testing.T) {
for i, test := range []struct {
- typ string
- input interface{}
- err string
+ typ string
+ components []ArgumentMarshaling
+ input interface{}
+ err string
}{
- {"uint", big.NewInt(1), "unsupported arg type: uint"},
- {"int", big.NewInt(1), "unsupported arg type: int"},
- {"uint256", big.NewInt(1), ""},
- {"uint256[][3][]", [][3][]*big.Int{{{}}}, ""},
- {"uint256[][][3]", [3][][]*big.Int{{{}}}, ""},
- {"uint256[3][][]", [][][3]*big.Int{{{}}}, ""},
- {"uint256[3][3][3]", [3][3][3]*big.Int{{{}}}, ""},
- {"uint8[][]", [][]uint8{}, ""},
- {"int256", big.NewInt(1), ""},
- {"uint8", uint8(1), ""},
- {"uint16", uint16(1), ""},
- {"uint32", uint32(1), ""},
- {"uint64", uint64(1), ""},
- {"int8", int8(1), ""},
- {"int16", int16(1), ""},
- {"int32", int32(1), ""},
- {"int64", int64(1), ""},
- {"uint24", big.NewInt(1), ""},
- {"uint40", big.NewInt(1), ""},
- {"uint48", big.NewInt(1), ""},
- {"uint56", big.NewInt(1), ""},
- {"uint72", big.NewInt(1), ""},
- {"uint80", big.NewInt(1), ""},
- {"uint88", big.NewInt(1), ""},
- {"uint96", big.NewInt(1), ""},
- {"uint104", big.NewInt(1), ""},
- {"uint112", big.NewInt(1), ""},
- {"uint120", big.NewInt(1), ""},
- {"uint128", big.NewInt(1), ""},
- {"uint136", big.NewInt(1), ""},
- {"uint144", big.NewInt(1), ""},
- {"uint152", big.NewInt(1), ""},
- {"uint160", big.NewInt(1), ""},
- {"uint168", big.NewInt(1), ""},
- {"uint176", big.NewInt(1), ""},
- {"uint184", big.NewInt(1), ""},
- {"uint192", big.NewInt(1), ""},
- {"uint200", big.NewInt(1), ""},
- {"uint208", big.NewInt(1), ""},
- {"uint216", big.NewInt(1), ""},
- {"uint224", big.NewInt(1), ""},
- {"uint232", big.NewInt(1), ""},
- {"uint240", big.NewInt(1), ""},
- {"uint248", big.NewInt(1), ""},
- {"int24", big.NewInt(1), ""},
- {"int40", big.NewInt(1), ""},
- {"int48", big.NewInt(1), ""},
- {"int56", big.NewInt(1), ""},
- {"int72", big.NewInt(1), ""},
- {"int80", big.NewInt(1), ""},
- {"int88", big.NewInt(1), ""},
- {"int96", big.NewInt(1), ""},
- {"int104", big.NewInt(1), ""},
- {"int112", big.NewInt(1), ""},
- {"int120", big.NewInt(1), ""},
- {"int128", big.NewInt(1), ""},
- {"int136", big.NewInt(1), ""},
- {"int144", big.NewInt(1), ""},
- {"int152", big.NewInt(1), ""},
- {"int160", big.NewInt(1), ""},
- {"int168", big.NewInt(1), ""},
- {"int176", big.NewInt(1), ""},
- {"int184", big.NewInt(1), ""},
- {"int192", big.NewInt(1), ""},
- {"int200", big.NewInt(1), ""},
- {"int208", big.NewInt(1), ""},
- {"int216", big.NewInt(1), ""},
- {"int224", big.NewInt(1), ""},
- {"int232", big.NewInt(1), ""},
- {"int240", big.NewInt(1), ""},
- {"int248", big.NewInt(1), ""},
- {"uint30", uint8(1), "abi: cannot use uint8 as type ptr as argument"},
- {"uint8", uint16(1), "abi: cannot use uint16 as type uint8 as argument"},
- {"uint8", uint32(1), "abi: cannot use uint32 as type uint8 as argument"},
- {"uint8", uint64(1), "abi: cannot use uint64 as type uint8 as argument"},
- {"uint8", int8(1), "abi: cannot use int8 as type uint8 as argument"},
- {"uint8", int16(1), "abi: cannot use int16 as type uint8 as argument"},
- {"uint8", int32(1), "abi: cannot use int32 as type uint8 as argument"},
- {"uint8", int64(1), "abi: cannot use int64 as type uint8 as argument"},
- {"uint16", uint16(1), ""},
- {"uint16", uint8(1), "abi: cannot use uint8 as type uint16 as argument"},
- {"uint16[]", []uint16{1, 2, 3}, ""},
- {"uint16[]", [3]uint16{1, 2, 3}, ""},
- {"uint16[]", []uint32{1, 2, 3}, "abi: cannot use []uint32 as type [0]uint16 as argument"},
- {"uint16[3]", [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"},
- {"uint16[3]", [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
- {"uint16[3]", []uint16{1, 2, 3}, ""},
- {"uint16[3]", []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
- {"address[]", []common.Address{{1}}, ""},
- {"address[1]", []common.Address{{1}}, ""},
- {"address[1]", [1]common.Address{{1}}, ""},
- {"address[2]", [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"},
- {"bytes32", [32]byte{}, ""},
- {"bytes31", [31]byte{}, ""},
- {"bytes30", [30]byte{}, ""},
- {"bytes29", [29]byte{}, ""},
- {"bytes28", [28]byte{}, ""},
- {"bytes27", [27]byte{}, ""},
- {"bytes26", [26]byte{}, ""},
- {"bytes25", [25]byte{}, ""},
- {"bytes24", [24]byte{}, ""},
- {"bytes23", [23]byte{}, ""},
- {"bytes22", [22]byte{}, ""},
- {"bytes21", [21]byte{}, ""},
- {"bytes20", [20]byte{}, ""},
- {"bytes19", [19]byte{}, ""},
- {"bytes18", [18]byte{}, ""},
- {"bytes17", [17]byte{}, ""},
- {"bytes16", [16]byte{}, ""},
- {"bytes15", [15]byte{}, ""},
- {"bytes14", [14]byte{}, ""},
- {"bytes13", [13]byte{}, ""},
- {"bytes12", [12]byte{}, ""},
- {"bytes11", [11]byte{}, ""},
- {"bytes10", [10]byte{}, ""},
- {"bytes9", [9]byte{}, ""},
- {"bytes8", [8]byte{}, ""},
- {"bytes7", [7]byte{}, ""},
- {"bytes6", [6]byte{}, ""},
- {"bytes5", [5]byte{}, ""},
- {"bytes4", [4]byte{}, ""},
- {"bytes3", [3]byte{}, ""},
- {"bytes2", [2]byte{}, ""},
- {"bytes1", [1]byte{}, ""},
- {"bytes32", [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"},
- {"bytes32", common.Hash{1}, ""},
- {"bytes31", common.Hash{1}, "abi: cannot use common.Hash as type [31]uint8 as argument"},
- {"bytes31", [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"},
- {"bytes", []byte{0, 1}, ""},
- {"bytes", [2]byte{0, 1}, "abi: cannot use array as type slice as argument"},
- {"bytes", common.Hash{1}, "abi: cannot use array as type slice as argument"},
- {"string", "hello world", ""},
- {"string", string(""), ""},
- {"string", []byte{}, "abi: cannot use slice as type string as argument"},
- {"bytes32[]", [][32]byte{{}}, ""},
- {"function", [24]byte{}, ""},
- {"bytes20", common.Address{}, ""},
- {"address", [20]byte{}, ""},
- {"address", common.Address{}, ""},
+ {"uint", nil, big.NewInt(1), "unsupported arg type: uint"},
+ {"int", nil, big.NewInt(1), "unsupported arg type: int"},
+ {"uint256", nil, big.NewInt(1), ""},
+ {"uint256[][3][]", nil, [][3][]*big.Int{{{}}}, ""},
+ {"uint256[][][3]", nil, [3][][]*big.Int{{{}}}, ""},
+ {"uint256[3][][]", nil, [][][3]*big.Int{{{}}}, ""},
+ {"uint256[3][3][3]", nil, [3][3][3]*big.Int{{{}}}, ""},
+ {"uint8[][]", nil, [][]uint8{}, ""},
+ {"int256", nil, big.NewInt(1), ""},
+ {"uint8", nil, uint8(1), ""},
+ {"uint16", nil, uint16(1), ""},
+ {"uint32", nil, uint32(1), ""},
+ {"uint64", nil, uint64(1), ""},
+ {"int8", nil, int8(1), ""},
+ {"int16", nil, int16(1), ""},
+ {"int32", nil, int32(1), ""},
+ {"int64", nil, int64(1), ""},
+ {"uint24", nil, big.NewInt(1), ""},
+ {"uint40", nil, big.NewInt(1), ""},
+ {"uint48", nil, big.NewInt(1), ""},
+ {"uint56", nil, big.NewInt(1), ""},
+ {"uint72", nil, big.NewInt(1), ""},
+ {"uint80", nil, big.NewInt(1), ""},
+ {"uint88", nil, big.NewInt(1), ""},
+ {"uint96", nil, big.NewInt(1), ""},
+ {"uint104", nil, big.NewInt(1), ""},
+ {"uint112", nil, big.NewInt(1), ""},
+ {"uint120", nil, big.NewInt(1), ""},
+ {"uint128", nil, big.NewInt(1), ""},
+ {"uint136", nil, big.NewInt(1), ""},
+ {"uint144", nil, big.NewInt(1), ""},
+ {"uint152", nil, big.NewInt(1), ""},
+ {"uint160", nil, big.NewInt(1), ""},
+ {"uint168", nil, big.NewInt(1), ""},
+ {"uint176", nil, big.NewInt(1), ""},
+ {"uint184", nil, big.NewInt(1), ""},
+ {"uint192", nil, big.NewInt(1), ""},
+ {"uint200", nil, big.NewInt(1), ""},
+ {"uint208", nil, big.NewInt(1), ""},
+ {"uint216", nil, big.NewInt(1), ""},
+ {"uint224", nil, big.NewInt(1), ""},
+ {"uint232", nil, big.NewInt(1), ""},
+ {"uint240", nil, big.NewInt(1), ""},
+ {"uint248", nil, big.NewInt(1), ""},
+ {"int24", nil, big.NewInt(1), ""},
+ {"int40", nil, big.NewInt(1), ""},
+ {"int48", nil, big.NewInt(1), ""},
+ {"int56", nil, big.NewInt(1), ""},
+ {"int72", nil, big.NewInt(1), ""},
+ {"int80", nil, big.NewInt(1), ""},
+ {"int88", nil, big.NewInt(1), ""},
+ {"int96", nil, big.NewInt(1), ""},
+ {"int104", nil, big.NewInt(1), ""},
+ {"int112", nil, big.NewInt(1), ""},
+ {"int120", nil, big.NewInt(1), ""},
+ {"int128", nil, big.NewInt(1), ""},
+ {"int136", nil, big.NewInt(1), ""},
+ {"int144", nil, big.NewInt(1), ""},
+ {"int152", nil, big.NewInt(1), ""},
+ {"int160", nil, big.NewInt(1), ""},
+ {"int168", nil, big.NewInt(1), ""},
+ {"int176", nil, big.NewInt(1), ""},
+ {"int184", nil, big.NewInt(1), ""},
+ {"int192", nil, big.NewInt(1), ""},
+ {"int200", nil, big.NewInt(1), ""},
+ {"int208", nil, big.NewInt(1), ""},
+ {"int216", nil, big.NewInt(1), ""},
+ {"int224", nil, big.NewInt(1), ""},
+ {"int232", nil, big.NewInt(1), ""},
+ {"int240", nil, big.NewInt(1), ""},
+ {"int248", nil, big.NewInt(1), ""},
+ {"uint30", nil, uint8(1), "abi: cannot use uint8 as type ptr as argument"},
+ {"uint8", nil, uint16(1), "abi: cannot use uint16 as type uint8 as argument"},
+ {"uint8", nil, uint32(1), "abi: cannot use uint32 as type uint8 as argument"},
+ {"uint8", nil, uint64(1), "abi: cannot use uint64 as type uint8 as argument"},
+ {"uint8", nil, int8(1), "abi: cannot use int8 as type uint8 as argument"},
+ {"uint8", nil, int16(1), "abi: cannot use int16 as type uint8 as argument"},
+ {"uint8", nil, int32(1), "abi: cannot use int32 as type uint8 as argument"},
+ {"uint8", nil, int64(1), "abi: cannot use int64 as type uint8 as argument"},
+ {"uint16", nil, uint16(1), ""},
+ {"uint16", nil, uint8(1), "abi: cannot use uint8 as type uint16 as argument"},
+ {"uint16[]", nil, []uint16{1, 2, 3}, ""},
+ {"uint16[]", nil, [3]uint16{1, 2, 3}, ""},
+ {"uint16[]", nil, []uint32{1, 2, 3}, "abi: cannot use []uint32 as type [0]uint16 as argument"},
+ {"uint16[3]", nil, [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"},
+ {"uint16[3]", nil, [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
+ {"uint16[3]", nil, []uint16{1, 2, 3}, ""},
+ {"uint16[3]", nil, []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"},
+ {"address[]", nil, []common.Address{{1}}, ""},
+ {"address[1]", nil, []common.Address{{1}}, ""},
+ {"address[1]", nil, [1]common.Address{{1}}, ""},
+ {"address[2]", nil, [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"},
+ {"bytes32", nil, [32]byte{}, ""},
+ {"bytes31", nil, [31]byte{}, ""},
+ {"bytes30", nil, [30]byte{}, ""},
+ {"bytes29", nil, [29]byte{}, ""},
+ {"bytes28", nil, [28]byte{}, ""},
+ {"bytes27", nil, [27]byte{}, ""},
+ {"bytes26", nil, [26]byte{}, ""},
+ {"bytes25", nil, [25]byte{}, ""},
+ {"bytes24", nil, [24]byte{}, ""},
+ {"bytes23", nil, [23]byte{}, ""},
+ {"bytes22", nil, [22]byte{}, ""},
+ {"bytes21", nil, [21]byte{}, ""},
+ {"bytes20", nil, [20]byte{}, ""},
+ {"bytes19", nil, [19]byte{}, ""},
+ {"bytes18", nil, [18]byte{}, ""},
+ {"bytes17", nil, [17]byte{}, ""},
+ {"bytes16", nil, [16]byte{}, ""},
+ {"bytes15", nil, [15]byte{}, ""},
+ {"bytes14", nil, [14]byte{}, ""},
+ {"bytes13", nil, [13]byte{}, ""},
+ {"bytes12", nil, [12]byte{}, ""},
+ {"bytes11", nil, [11]byte{}, ""},
+ {"bytes10", nil, [10]byte{}, ""},
+ {"bytes9", nil, [9]byte{}, ""},
+ {"bytes8", nil, [8]byte{}, ""},
+ {"bytes7", nil, [7]byte{}, ""},
+ {"bytes6", nil, [6]byte{}, ""},
+ {"bytes5", nil, [5]byte{}, ""},
+ {"bytes4", nil, [4]byte{}, ""},
+ {"bytes3", nil, [3]byte{}, ""},
+ {"bytes2", nil, [2]byte{}, ""},
+ {"bytes1", nil, [1]byte{}, ""},
+ {"bytes32", nil, [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"},
+ {"bytes32", nil, common.Hash{1}, ""},
+ {"bytes31", nil, common.Hash{1}, "abi: cannot use common.Hash as type [31]uint8 as argument"},
+ {"bytes31", nil, [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"},
+ {"bytes", nil, []byte{0, 1}, ""},
+ {"bytes", nil, [2]byte{0, 1}, "abi: cannot use array as type slice as argument"},
+ {"bytes", nil, common.Hash{1}, "abi: cannot use array as type slice as argument"},
+ {"string", nil, "hello world", ""},
+ {"string", nil, string(""), ""},
+ {"string", nil, []byte{}, "abi: cannot use slice as type string as argument"},
+ {"bytes32[]", nil, [][32]byte{{}}, ""},
+ {"function", nil, [24]byte{}, ""},
+ {"bytes20", nil, common.Address{}, ""},
+ {"address", nil, [20]byte{}, ""},
+ {"address", nil, common.Address{}, ""},
+ {"bytes32[]]", nil, "", "invalid arg type in abi"},
+ {"invalidType", nil, "", "unsupported arg type: invalidType"},
+ {"invalidSlice[]", nil, "", "unsupported arg type: invalidSlice"},
+ // simple tuple
+ {"tuple", []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256"}}, struct {
+ A *big.Int
+ B *big.Int
+ }{}, ""},
+ // tuple slice
+ {"tuple[]", []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256"}}, []struct {
+ A *big.Int
+ B *big.Int
+ }{}, ""},
+ // tuple array
+ {"tuple[2]", []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256"}}, []struct {
+ A *big.Int
+ B *big.Int
+ }{{big.NewInt(0), big.NewInt(0)}, {big.NewInt(0), big.NewInt(0)}}, ""},
} {
- typ, err := NewType(test.typ)
+ typ, err := NewType(test.typ, "", test.components)
if err != nil && len(test.err) == 0 {
t.Fatal("unexpected parse error:", err)
} else if err != nil && len(test.err) != 0 {
@@ -278,3 +307,63 @@ func TestTypeCheck(t *testing.T) {
}
}
}
+
+func TestInternalType(t *testing.T) {
+ components := []ArgumentMarshaling{{Name: "a", Type: "int64"}}
+ internalType := "struct a.b[]"
+ kind := Type{
+ T: TupleTy,
+ TupleType: reflect.TypeOf(struct {
+ A int64 `json:"a"`
+ }{}),
+ stringKind: "(int64)",
+ TupleRawName: "ab[]",
+ TupleElems: []*Type{{T: IntTy, Size: 64, stringKind: "int64"}},
+ TupleRawNames: []string{"a"},
+ }
+
+ blob := "tuple"
+ typ, err := NewType(blob, internalType, components)
+ if err != nil {
+ t.Errorf("type %q: failed to parse type string: %v", blob, err)
+ }
+ if !reflect.DeepEqual(typ, kind) {
+ t.Errorf("type %q: parsed type mismatch:\nGOT %s\nWANT %s ", blob, spew.Sdump(typeWithoutStringer(typ)), spew.Sdump(typeWithoutStringer(kind)))
+ }
+}
+
+func TestGetTypeSize(t *testing.T) {
+ var testCases = []struct {
+ typ string
+ components []ArgumentMarshaling
+ typSize int
+ }{
+ // simple array
+ {"uint256[2]", nil, 32 * 2},
+ {"address[3]", nil, 32 * 3},
+ {"bytes32[4]", nil, 32 * 4},
+ // array array
+ {"uint256[2][3][4]", nil, 32 * (2 * 3 * 4)},
+ // array tuple
+ {"tuple[2]", []ArgumentMarshaling{{Name: "x", Type: "bytes32"}, {Name: "y", Type: "bytes32"}}, (32 * 2) * 2},
+ // simple tuple
+ {"tuple", []ArgumentMarshaling{{Name: "x", Type: "uint256"}, {Name: "y", Type: "uint256"}}, 32 * 2},
+ // tuple array
+ {"tuple", []ArgumentMarshaling{{Name: "x", Type: "bytes32[2]"}}, 32 * 2},
+ // tuple tuple
+ {"tuple", []ArgumentMarshaling{{Name: "x", Type: "tuple", Components: []ArgumentMarshaling{{Name: "x", Type: "bytes32"}}}}, 32},
+ {"tuple", []ArgumentMarshaling{{Name: "x", Type: "tuple", Components: []ArgumentMarshaling{{Name: "x", Type: "bytes32[2]"}, {Name: "y", Type: "uint256"}}}}, 32 * (2 + 1)},
+ }
+
+ for i, data := range testCases {
+ typ, err := NewType(data.typ, "", data.components)
+ if err != nil {
+ t.Errorf("type %q: failed to parse type string: %v", data.typ, err)
+ }
+
+ result := getTypeSize(typ)
+ if result != data.typSize {
+ t.Errorf("case %d type %q: get type size error: actual: %d expected: %d", i, data.typ, result, data.typSize)
+ }
+ }
+}
diff --git a/accounts/abi/unpack.go b/accounts/abi/unpack.go
index 208486349..927f43afe 100644
--- a/accounts/abi/unpack.go
+++ b/accounts/abi/unpack.go
@@ -25,31 +25,55 @@ import (
"github.com/tomochain/tomochain/common"
)
-// reads the integer based on its kind
-func readInteger(kind reflect.Kind, b []byte) interface{} {
- switch kind {
- case reflect.Uint8:
- return b[len(b)-1]
- case reflect.Uint16:
- return binary.BigEndian.Uint16(b[len(b)-2:])
- case reflect.Uint32:
- return binary.BigEndian.Uint32(b[len(b)-4:])
- case reflect.Uint64:
- return binary.BigEndian.Uint64(b[len(b)-8:])
- case reflect.Int8:
+var (
+ // MaxUint256 is the maximum value that can be represented by a uint256.
+ MaxUint256 = new(big.Int).Sub(new(big.Int).Lsh(common.Big1, 256), common.Big1)
+ // MaxInt256 is the maximum value that can be represented by a int256.
+ MaxInt256 = new(big.Int).Sub(new(big.Int).Lsh(common.Big1, 255), common.Big1)
+)
+
+// ReadInteger reads the integer based on its kind and returns the appropriate value.
+func ReadInteger(typ Type, b []byte) interface{} {
+ if typ.T == UintTy {
+ switch typ.Size {
+ case 8:
+ return b[len(b)-1]
+ case 16:
+ return binary.BigEndian.Uint16(b[len(b)-2:])
+ case 32:
+ return binary.BigEndian.Uint32(b[len(b)-4:])
+ case 64:
+ return binary.BigEndian.Uint64(b[len(b)-8:])
+ default:
+ // the only case left for unsigned integer is uint256.
+ return new(big.Int).SetBytes(b)
+ }
+ }
+ switch typ.Size {
+ case 8:
return int8(b[len(b)-1])
- case reflect.Int16:
+ case 16:
return int16(binary.BigEndian.Uint16(b[len(b)-2:]))
- case reflect.Int32:
+ case 32:
return int32(binary.BigEndian.Uint32(b[len(b)-4:]))
- case reflect.Int64:
+ case 64:
return int64(binary.BigEndian.Uint64(b[len(b)-8:]))
default:
- return new(big.Int).SetBytes(b)
+ // the only case left for integer is int256
+ // big.SetBytes can't tell if a number is negative or positive in itself.
+ // On EVM, if the returned number > max int256, it is negative.
+ // A number is > max int256 if the bit at position 255 is set.
+ ret := new(big.Int).SetBytes(b)
+ if ret.Bit(255) == 1 {
+ ret.Add(MaxUint256, new(big.Int).Neg(ret))
+ ret.Add(ret, common.Big1)
+ ret.Neg(ret)
+ }
+ return ret
}
}
-// reads a bool
+// readBool reads a bool.
func readBool(word []byte) (bool, error) {
for _, b := range word[:31] {
if b != 0 {
@@ -67,7 +91,8 @@ func readBool(word []byte) (bool, error) {
}
// A function type is simply the address with the function selection signature at the end.
-// This enforces that standard by always presenting it as a 24-array (address + sig = 24 bytes)
+//
+// readFunctionType enforces that standard by always presenting it as a 24-array (address + sig = 24 bytes)
func readFunctionType(t Type, word []byte) (funcTy [24]byte, err error) {
if t.T != FunctionTy {
return [24]byte{}, fmt.Errorf("abi: invalid type in call to make function type byte array")
@@ -80,31 +105,20 @@ func readFunctionType(t Type, word []byte) (funcTy [24]byte, err error) {
return
}
-// through reflection, creates a fixed array to be read from
-func readFixedBytes(t Type, word []byte) (interface{}, error) {
+// ReadFixedBytes uses reflection to create a fixed array to be read from.
+func ReadFixedBytes(t Type, word []byte) (interface{}, error) {
if t.T != FixedBytesTy {
return nil, fmt.Errorf("abi: invalid type in call to make fixed byte array")
}
// convert
- array := reflect.New(t.Type).Elem()
+ array := reflect.New(t.GetType()).Elem()
reflect.Copy(array, reflect.ValueOf(word[0:t.Size]))
return array.Interface(), nil
}
-func getFullElemSize(elem *Type) int {
- //all other should be counted as 32 (slices have pointers to respective elements)
- size := 32
- //arrays wrap it, each element being the same size
- for elem.T == ArrayTy {
- size *= elem.Size
- elem = elem.Elem
- }
- return size
-}
-
-// iteratively unpack elements
+// forEachUnpack iteratively unpack elements.
func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) {
if size < 0 {
return nil, fmt.Errorf("cannot marshal input to array, size is negative (%d)", size)
@@ -118,23 +132,19 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error)
if t.T == SliceTy {
// declare our slice
- refSlice = reflect.MakeSlice(t.Type, size, size)
+ refSlice = reflect.MakeSlice(t.GetType(), size, size)
} else if t.T == ArrayTy {
// declare our array
- refSlice = reflect.New(t.Type).Elem()
+ refSlice = reflect.New(t.GetType()).Elem()
} else {
return nil, fmt.Errorf("abi: invalid type in array/slice unpacking stage")
}
// Arrays have packed elements, resulting in longer unpack steps.
// Slices have just 32 bytes per element (pointing to the contents).
- elemSize := 32
- if t.T == ArrayTy {
- elemSize = getFullElemSize(t.Elem)
- }
+ elemSize := getTypeSize(*t.Elem)
for i, j := start, 0; j < size; i, j = i+elemSize, j+1 {
-
inter, err := toGoType(i, *t.Elem, output)
if err != nil {
return nil, err
@@ -148,6 +158,36 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error)
return refSlice.Interface(), nil
}
+func forTupleUnpack(t Type, output []byte) (interface{}, error) {
+ retval := reflect.New(t.GetType()).Elem()
+ virtualArgs := 0
+ for index, elem := range t.TupleElems {
+ marshalledValue, err := toGoType((index+virtualArgs)*32, *elem, output)
+ if elem.T == ArrayTy && !isDynamicType(*elem) {
+ // If we have a static array, like [3]uint256, these are coded as
+ // just like uint256,uint256,uint256.
+ // This means that we need to add two 'virtual' arguments when
+ // we count the index from now on.
+ //
+ // Array values nested multiple levels deep are also encoded inline:
+ // [2][3]uint256: uint256,uint256,uint256,uint256,uint256,uint256
+ //
+ // Calculate the full array size to get the correct offset for the next argument.
+ // Decrement it by 1, as the normal index increment is still applied.
+ virtualArgs += getTypeSize(*elem)/32 - 1
+ } else if elem.T == TupleTy && !isDynamicType(*elem) {
+ // If we have a static tuple, like (uint256, bool, uint256), these are
+ // coded as just like uint256,bool,uint256
+ virtualArgs += getTypeSize(*elem)/32 - 1
+ }
+ if err != nil {
+ return nil, err
+ }
+ retval.Field(index).Set(reflect.ValueOf(marshalledValue))
+ }
+ return retval.Interface(), nil
+}
+
// toGoType parses the output bytes and recursively assigns the value of these bytes
// into a go type with accordance with the ABI spec.
func toGoType(index int, t Type, output []byte) (interface{}, error) {
@@ -156,14 +196,14 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) {
}
var (
- returnOutput []byte
- begin, end int
- err error
+ returnOutput []byte
+ begin, length int
+ err error
)
// if we require a length prefix, find the beginning word and size returned.
if t.requiresLengthPrefix() {
- begin, end, err = lengthPrefixPointsTo(index, output)
+ begin, length, err = lengthPrefixPointsTo(index, output)
if err != nil {
return nil, err
}
@@ -172,14 +212,30 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) {
}
switch t.T {
+ case TupleTy:
+ if isDynamicType(t) {
+ begin, err := tuplePointsTo(index, output)
+ if err != nil {
+ return nil, err
+ }
+ return forTupleUnpack(t, output[begin:])
+ }
+ return forTupleUnpack(t, output[index:])
case SliceTy:
- return forEachUnpack(t, output, begin, end)
+ return forEachUnpack(t, output[begin:], 0, length)
case ArrayTy:
- return forEachUnpack(t, output, index, t.Size)
+ if isDynamicType(*t.Elem) {
+ offset := binary.BigEndian.Uint64(returnOutput[len(returnOutput)-8:])
+ if offset > uint64(len(output)) {
+ return nil, fmt.Errorf("abi: toGoType offset greater than output length: offset: %d, len(output): %d", offset, len(output))
+ }
+ return forEachUnpack(t, output[offset:], 0, t.Size)
+ }
+ return forEachUnpack(t, output[index:], 0, t.Size)
case StringTy: // variable arrays are written at the end of the return bytes
- return string(output[begin : begin+end]), nil
+ return string(output[begin : begin+length]), nil
case IntTy, UintTy:
- return readInteger(t.Kind, returnOutput), nil
+ return ReadInteger(t, returnOutput), nil
case BoolTy:
return readBool(returnOutput)
case AddressTy:
@@ -187,9 +243,9 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) {
case HashTy:
return common.BytesToHash(returnOutput), nil
case BytesTy:
- return output[begin : begin+end], nil
+ return output[begin : begin+length], nil
case FixedBytesTy:
- return readFixedBytes(t, returnOutput)
+ return ReadFixedBytes(t, returnOutput)
case FunctionTy:
return readFunctionType(t, returnOutput)
default:
@@ -197,7 +253,7 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) {
}
}
-// interprets a 32 byte slice as an offset and then determines which indice to look to decode the type.
+// lengthPrefixPointsTo interprets a 32 byte slice as an offset and then determines which indices to look to decode the type.
func lengthPrefixPointsTo(index int, output []byte) (start int, length int, err error) {
bigOffsetEnd := big.NewInt(0).SetBytes(output[index : index+32])
bigOffsetEnd.Add(bigOffsetEnd, common.Big32)
@@ -218,7 +274,7 @@ func lengthPrefixPointsTo(index int, output []byte) (start int, length int, err
totalSize.Add(totalSize, bigOffsetEnd)
totalSize.Add(totalSize, lengthBig)
if totalSize.BitLen() > 63 {
- return 0, 0, fmt.Errorf("abi length larger than int64: %v", totalSize)
+ return 0, 0, fmt.Errorf("abi: length larger than int64: %v", totalSize)
}
if totalSize.Cmp(outputLength) > 0 {
@@ -228,3 +284,17 @@ func lengthPrefixPointsTo(index int, output []byte) (start int, length int, err
length = int(lengthBig.Uint64())
return
}
+
+// tuplePointsTo resolves the location reference for dynamic tuple.
+func tuplePointsTo(index int, output []byte) (start int, err error) {
+ offset := big.NewInt(0).SetBytes(output[index : index+32])
+ outputLen := big.NewInt(int64(len(output)))
+
+ if offset.Cmp(big.NewInt(int64(len(output)))) > 0 {
+ return 0, fmt.Errorf("abi: cannot marshal in to go slice: offset %v would go over slice boundary (len=%v)", offset, outputLen)
+ }
+ if offset.BitLen() > 63 {
+ return 0, fmt.Errorf("abi offset larger than int64: %v", offset)
+ }
+ return int(offset.Uint64()), nil
+}
diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go
index 24b5dc0f1..9af0a666f 100644
--- a/accounts/abi/unpack_test.go
+++ b/accounts/abi/unpack_test.go
@@ -27,9 +27,36 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
"github.com/tomochain/tomochain/common"
)
+// TestUnpack tests the general pack/unpack tests in packing_test.go
+func TestUnpack(t *testing.T) {
+ for i, test := range packUnpackTests {
+ t.Run(strconv.Itoa(i)+" "+test.def, func(t *testing.T) {
+ //Unpack
+ def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, test.def)
+ abi, err := JSON(strings.NewReader(def))
+ if err != nil {
+ t.Fatalf("invalid ABI definition %s: %v", def, err)
+ }
+ encb, err := hex.DecodeString(test.packed)
+ if err != nil {
+ t.Fatalf("invalid hex %s: %v", test.packed, err)
+ }
+ out, err := abi.Unpack("method", encb)
+ if err != nil {
+ t.Errorf("test %d (%v) failed: %v", i, test.def, err)
+ return
+ }
+ if !reflect.DeepEqual(test.unpacked, ConvertType(out[0], test.unpacked)) {
+ t.Errorf("test %d (%v) failed: expected %v, got %v", i, test.def, test.unpacked, out[0])
+ }
+ })
+ }
+}
+
type unpackTest struct {
def string // ABI definition JSON
enc string // evm return data
@@ -51,16 +78,20 @@ func (test unpackTest) checkError(err error) error {
}
var unpackTests = []unpackTest{
+ // Bools
{
def: `[{ "type": "bool" }]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000001",
- want: true,
+ enc: "0000000000000000000000000000000000000000000000000001000000000001",
+ want: false,
+ err: "abi: improperly encoded boolean value",
},
{
- def: `[{"type": "uint32"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000001",
- want: uint32(1),
+ def: `[{ "type": "bool" }]`,
+ enc: "0000000000000000000000000000000000000000000000000000000000000003",
+ want: false,
+ err: "abi: improperly encoded boolean value",
},
+ // Integers
{
def: `[{"type": "uint32"}]`,
enc: "0000000000000000000000000000000000000000000000000000000000000001",
@@ -73,16 +104,6 @@ var unpackTests = []unpackTest{
want: uint16(0),
err: "abi: cannot unmarshal *big.Int in to uint16",
},
- {
- def: `[{"type": "uint17"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000001",
- want: big.NewInt(1),
- },
- {
- def: `[{"type": "int32"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000001",
- want: int32(1),
- },
{
def: `[{"type": "int32"}]`,
enc: "0000000000000000000000000000000000000000000000000000000000000001",
@@ -95,31 +116,10 @@ var unpackTests = []unpackTest{
want: int16(0),
err: "abi: cannot unmarshal *big.Int in to int16",
},
- {
- def: `[{"type": "int17"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000001",
- want: big.NewInt(1),
- },
- {
- def: `[{"type": "address"}]`,
- enc: "0000000000000000000000000100000000000000000000000000000000000000",
- want: common.Address{1},
- },
- {
- def: `[{"type": "bytes32"}]`,
- enc: "0100000000000000000000000000000000000000000000000000000000000000",
- want: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
- },
- {
- def: `[{"type": "bytes"}]`,
- enc: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200100000000000000000000000000000000000000000000000000000000000000",
- want: common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"),
- },
{
def: `[{"type": "bytes"}]`,
enc: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200100000000000000000000000000000000000000000000000000000000000000",
- want: [32]byte{},
- err: "abi: cannot unmarshal []uint8 in to [32]uint8",
+ want: [32]byte{1},
},
{
def: `[{"type": "bytes32"}]`,
@@ -128,150 +128,21 @@ var unpackTests = []unpackTest{
err: "abi: cannot unmarshal [32]uint8 in to []uint8",
},
{
- def: `[{"type": "bytes32"}]`,
- enc: "0100000000000000000000000000000000000000000000000000000000000000",
- want: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
- },
- {
- def: `[{"type": "function"}]`,
- enc: "0100000000000000000000000000000000000000000000000000000000000000",
- want: [24]byte{1},
- },
- // slices
- {
- def: `[{"type": "uint8[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []uint8{1, 2},
- },
- {
- def: `[{"type": "uint8[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]uint8{1, 2},
- },
- // multi dimensional, if these pass, all types that don't require length prefix should pass
- {
- def: `[{"type": "uint8[][]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000000000000000E0000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [][]uint8{{1, 2}, {1, 2}},
- },
- {
- def: `[{"type": "uint8[2][2]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2][2]uint8{{1, 2}, {1, 2}},
- },
- {
- def: `[{"type": "uint8[][2]"}]`,
- enc: "000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001",
- want: [2][]uint8{{1}, {1}},
- },
- {
- def: `[{"type": "uint8[2][]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [][2]uint8{{1, 2}},
- },
- {
- def: `[{"type": "uint16[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []uint16{1, 2},
- },
- {
- def: `[{"type": "uint16[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]uint16{1, 2},
- },
- {
- def: `[{"type": "uint32[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []uint32{1, 2},
- },
- {
- def: `[{"type": "uint32[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]uint32{1, 2},
- },
- {
- def: `[{"type": "uint32[2][3][4]"}]`,
- enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000700000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000000e000000000000000000000000000000000000000000000000000000000000000f000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000001300000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000015000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000170000000000000000000000000000000000000000000000000000000000000018",
- want: [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}},
- },
- {
- def: `[{"type": "uint64[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []uint64{1, 2},
- },
- {
- def: `[{"type": "uint64[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]uint64{1, 2},
- },
- {
- def: `[{"type": "uint256[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []*big.Int{big.NewInt(1), big.NewInt(2)},
- },
- {
- def: `[{"type": "uint256[3]"}]`,
- enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003",
- want: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)},
- },
- {
- def: `[{"type": "int8[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []int8{1, 2},
- },
- {
- def: `[{"type": "int8[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]int8{1, 2},
- },
- {
- def: `[{"type": "int16[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []int16{1, 2},
- },
- {
- def: `[{"type": "int16[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]int16{1, 2},
- },
- {
- def: `[{"type": "int32[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []int32{1, 2},
- },
- {
- def: `[{"type": "int32[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]int32{1, 2},
- },
- {
- def: `[{"type": "int64[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []int64{1, 2},
- },
- {
- def: `[{"type": "int64[2]"}]`,
- enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: [2]int64{1, 2},
- },
- {
- def: `[{"type": "int256[]"}]`,
- enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
- want: []*big.Int{big.NewInt(1), big.NewInt(2)},
- },
- {
- def: `[{"type": "int256[3]"}]`,
- enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003",
- want: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)},
+ def: `[{"name":"___","type":"int256"}]`,
+ enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
+ want: struct {
+ IntOne *big.Int
+ Intone *big.Int
+ }{IntOne: big.NewInt(1)},
},
- // struct outputs
{
- def: `[{"name":"int1","type":"int256"},{"name":"int2","type":"int256"}]`,
+ def: `[{"name":"int_one","type":"int256"},{"name":"IntOne","type":"int256"}]`,
enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
want: struct {
Int1 *big.Int
Int2 *big.Int
- }{big.NewInt(1), big.NewInt(2)},
+ }{},
+ err: "abi: multiple outputs mapping to the same struct field 'IntOne'",
},
{
def: `[{"name":"int","type":"int256"},{"name":"Int","type":"int256"}]`,
@@ -309,22 +180,47 @@ var unpackTests = []unpackTest{
}{},
err: "abi: purely underscored output cannot unpack to struct",
},
+ // Make sure only the first argument is consumed
+ {
+ def: `[{"name":"int_one","type":"int256"}]`,
+ enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
+ want: struct {
+ IntOne *big.Int
+ }{big.NewInt(1)},
+ },
+ {
+ def: `[{"name":"int__one","type":"int256"}]`,
+ enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
+ want: struct {
+ IntOne *big.Int
+ }{big.NewInt(1)},
+ },
+ {
+ def: `[{"name":"int_one_","type":"int256"}]`,
+ enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002",
+ want: struct {
+ IntOne *big.Int
+ }{big.NewInt(1)},
+ },
}
-func TestUnpack(t *testing.T) {
+// TestLocalUnpackTests runs test specially designed only for unpacking.
+// All test cases that can be used to test packing and unpacking should move to packing_test.go
+func TestLocalUnpackTests(t *testing.T) {
for i, test := range unpackTests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
- def := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def)
+ //Unpack
+ def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, test.def)
abi, err := JSON(strings.NewReader(def))
if err != nil {
t.Fatalf("invalid ABI definition %s: %v", def, err)
}
encb, err := hex.DecodeString(test.enc)
if err != nil {
- t.Fatalf("invalid hex: %s" + test.enc)
+ t.Fatalf("invalid hex %s: %v", test.enc, err)
}
outptr := reflect.New(reflect.TypeOf(test.want))
- err = abi.Unpack(outptr.Interface(), "method", encb)
+ err = abi.UnpackIntoInterface(outptr.Interface(), "method", encb)
if err := test.checkError(err); err != nil {
t.Errorf("test %d (%v) failed: %v", i, test.def, err)
return
@@ -337,6 +233,55 @@ func TestUnpack(t *testing.T) {
}
}
+func TestUnpackIntoInterfaceSetDynamicArrayOutput(t *testing.T) {
+ abi, err := JSON(strings.NewReader(`[{"constant":true,"inputs":[],"name":"testDynamicFixedBytes15","outputs":[{"name":"","type":"bytes15[]"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":true,"inputs":[],"name":"testDynamicFixedBytes32","outputs":[{"name":"","type":"bytes32[]"}],"payable":false,"stateMutability":"view","type":"function"}]`))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ marshalledReturn32 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000230783132333435363738393000000000000000000000000000000000000000003078303938373635343332310000000000000000000000000000000000000000")
+ marshalledReturn15 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000230783031323334350000000000000000000000000000000000000000000000003078393837363534000000000000000000000000000000000000000000000000")
+
+ out32 [][32]byte
+ out15 [][15]byte
+ )
+
+ // test 32
+ err = abi.UnpackIntoInterface(&out32, "testDynamicFixedBytes32", marshalledReturn32)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(out32) != 2 {
+ t.Fatalf("expected array with 2 values, got %d", len(out32))
+ }
+ expected := common.Hex2Bytes("3078313233343536373839300000000000000000000000000000000000000000")
+ if !bytes.Equal(out32[0][:], expected) {
+ t.Errorf("expected %x, got %x\n", expected, out32[0])
+ }
+ expected = common.Hex2Bytes("3078303938373635343332310000000000000000000000000000000000000000")
+ if !bytes.Equal(out32[1][:], expected) {
+ t.Errorf("expected %x, got %x\n", expected, out32[1])
+ }
+
+ // test 15
+ err = abi.UnpackIntoInterface(&out15, "testDynamicFixedBytes32", marshalledReturn15)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(out15) != 2 {
+ t.Fatalf("expected array with 2 values, got %d", len(out15))
+ }
+ expected = common.Hex2Bytes("307830313233343500000000000000")
+ if !bytes.Equal(out15[0][:], expected) {
+ t.Errorf("expected %x, got %x\n", expected, out15[0])
+ }
+ expected = common.Hex2Bytes("307839383736353400000000000000")
+ if !bytes.Equal(out15[1][:], expected) {
+ t.Errorf("expected %x, got %x\n", expected, out15[1])
+ }
+}
+
type methodMultiOutput struct {
Int *big.Int
String string
@@ -344,7 +289,7 @@ type methodMultiOutput struct {
func methodMultiReturn(require *require.Assertions) (ABI, []byte, methodMultiOutput) {
const definition = `[
- { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]`
+ { "name" : "multi", "type": "function", "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]`
var expected = methodMultiOutput{big.NewInt(1), "hello"}
abi, err := JSON(strings.NewReader(definition))
@@ -364,6 +309,11 @@ func TestMethodMultiReturn(t *testing.T) {
Int *big.Int
}
+ newInterfaceSlice := func(len int) interface{} {
+ slice := make([]interface{}, len)
+ return &slice
+ }
+
abi, data, expected := methodMultiReturn(require.New(t))
bigint := new(big.Int)
var testCases = []struct {
@@ -391,6 +341,16 @@ func TestMethodMultiReturn(t *testing.T) {
&[2]interface{}{&expected.Int, &expected.String},
"",
"Can unpack into an array",
+ }, {
+ &[2]interface{}{},
+ &[2]interface{}{expected.Int, expected.String},
+ "",
+ "Can unpack into interface array",
+ }, {
+ newInterfaceSlice(2),
+ &[]interface{}{expected.Int, expected.String},
+ "",
+ "Can unpack into interface slice",
}, {
&[]interface{}{new(int), new(int)},
&[]interface{}{&expected.Int, &expected.String},
@@ -399,14 +359,14 @@ func TestMethodMultiReturn(t *testing.T) {
}, {
&[]interface{}{new(int)},
&[]interface{}{},
- "abi: insufficient number of elements in the list/array for unpack, want 2, got 1",
+ "abi: insufficient number of arguments for unpack, want 2, got 1",
"Can not unpack into a slice with wrong types",
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
- err := abi.Unpack(tc.dest, "multi", data)
+ err := abi.UnpackIntoInterface(tc.dest, "multi", data)
if tc.error == "" {
require.Nil(err, "Should be able to unpack method outputs.")
require.Equal(tc.expected, tc.dest)
@@ -418,7 +378,7 @@ func TestMethodMultiReturn(t *testing.T) {
}
func TestMultiReturnWithArray(t *testing.T) {
- const definition = `[{"name" : "multi", "outputs": [{"type": "uint64[3]"}, {"type": "uint64"}]}]`
+ const definition = `[{"name" : "multi", "type": "function", "outputs": [{"type": "uint64[3]"}, {"type": "uint64"}]}]`
abi, err := JSON(strings.NewReader(definition))
if err != nil {
t.Fatal(err)
@@ -429,7 +389,7 @@ func TestMultiReturnWithArray(t *testing.T) {
ret1, ret1Exp := new([3]uint64), [3]uint64{9, 9, 9}
ret2, ret2Exp := new(uint64), uint64(8)
- if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil {
+ if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*ret1, ret1Exp) {
@@ -440,12 +400,74 @@ func TestMultiReturnWithArray(t *testing.T) {
}
}
+func TestMultiReturnWithStringArray(t *testing.T) {
+ const definition = `[{"name" : "multi", "type": "function", "outputs": [{"name": "","type": "uint256[3]"},{"name": "","type": "address"},{"name": "","type": "string[2]"},{"name": "","type": "bool"}]}]`
+ abi, err := JSON(strings.NewReader(definition))
+ if err != nil {
+ t.Fatal(err)
+ }
+ buff := new(bytes.Buffer)
+ buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000005c1b78ea0000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000001a055690d9db80000000000000000000000000000ab1257528b3782fb40d7ed5f72e624b744dffb2f00000000000000000000000000000000000000000000000000000000000000c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000008457468657265756d000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001048656c6c6f2c20457468657265756d2100000000000000000000000000000000"))
+ temp, _ := big.NewInt(0).SetString("30000000000000000000", 10)
+ ret1, ret1Exp := new([3]*big.Int), [3]*big.Int{big.NewInt(1545304298), big.NewInt(6), temp}
+ ret2, ret2Exp := new(common.Address), common.HexToAddress("ab1257528b3782fb40d7ed5f72e624b744dffb2f")
+ ret3, ret3Exp := new([2]string), [2]string{"Ethereum", "Hello, Ethereum!"}
+ ret4, ret4Exp := new(bool), false
+ if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2, ret3, ret4}, "multi", buff.Bytes()); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(*ret1, ret1Exp) {
+ t.Error("big.Int array result", *ret1, "!= Expected", ret1Exp)
+ }
+ if !reflect.DeepEqual(*ret2, ret2Exp) {
+ t.Error("address result", *ret2, "!= Expected", ret2Exp)
+ }
+ if !reflect.DeepEqual(*ret3, ret3Exp) {
+ t.Error("string array result", *ret3, "!= Expected", ret3Exp)
+ }
+ if !reflect.DeepEqual(*ret4, ret4Exp) {
+ t.Error("bool result", *ret4, "!= Expected", ret4Exp)
+ }
+}
+
+func TestMultiReturnWithStringSlice(t *testing.T) {
+ const definition = `[{"name" : "multi", "type": "function", "outputs": [{"name": "","type": "string[]"},{"name": "","type": "uint256[]"}]}]`
+ abi, err := JSON(strings.NewReader(definition))
+ if err != nil {
+ t.Fatal(err)
+ }
+ buff := new(bytes.Buffer)
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // output[0] offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000120")) // output[1] offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // output[0] length
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // output[0][0] offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // output[0][1] offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000008")) // output[0][0] length
+ buff.Write(common.Hex2Bytes("657468657265756d000000000000000000000000000000000000000000000000")) // output[0][0] value
+ buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000b")) // output[0][1] length
+ buff.Write(common.Hex2Bytes("676f2d657468657265756d000000000000000000000000000000000000000000")) // output[0][1] value
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // output[1] length
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000064")) // output[1][0] value
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000065")) // output[1][1] value
+ ret1, ret1Exp := new([]string), []string{"ethereum", "go-ethereum"}
+ ret2, ret2Exp := new([]*big.Int), []*big.Int{big.NewInt(100), big.NewInt(101)}
+ if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(*ret1, ret1Exp) {
+ t.Error("string slice result", *ret1, "!= Expected", ret1Exp)
+ }
+ if !reflect.DeepEqual(*ret2, ret2Exp) {
+ t.Error("uint256 slice result", *ret2, "!= Expected", ret2Exp)
+ }
+}
+
func TestMultiReturnWithDeeplyNestedArray(t *testing.T) {
// Similar to TestMultiReturnWithArray, but with a special case in mind:
// values of nested static arrays count towards the size as well, and any element following
// after such nested array argument should be read with the correct offset,
// so that it does not read content from the previous array argument.
- const definition = `[{"name" : "multi", "outputs": [{"type": "uint64[3][2][4]"}, {"type": "uint64"}]}]`
+ const definition = `[{"name" : "multi", "type": "function", "outputs": [{"type": "uint64[3][2][4]"}, {"type": "uint64"}]}]`
abi, err := JSON(strings.NewReader(definition))
if err != nil {
t.Fatal(err)
@@ -469,7 +491,7 @@ func TestMultiReturnWithDeeplyNestedArray(t *testing.T) {
{{0x411, 0x412, 0x413}, {0x421, 0x422, 0x423}},
}
ret2, ret2Exp := new(uint64), uint64(0x9876)
- if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil {
+ if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*ret1, ret1Exp) {
@@ -482,15 +504,15 @@ func TestMultiReturnWithDeeplyNestedArray(t *testing.T) {
func TestUnmarshal(t *testing.T) {
const definition = `[
- { "name" : "int", "constant" : false, "outputs": [ { "type": "uint256" } ] },
- { "name" : "bool", "constant" : false, "outputs": [ { "type": "bool" } ] },
- { "name" : "bytes", "constant" : false, "outputs": [ { "type": "bytes" } ] },
- { "name" : "fixed", "constant" : false, "outputs": [ { "type": "bytes32" } ] },
- { "name" : "multi", "constant" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] },
- { "name" : "intArraySingle", "constant" : false, "outputs": [ { "type": "uint256[3]" } ] },
- { "name" : "addressSliceSingle", "constant" : false, "outputs": [ { "type": "address[]" } ] },
- { "name" : "addressSliceDouble", "constant" : false, "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] },
- { "name" : "mixedBytes", "constant" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]`
+ { "name" : "int", "type": "function", "outputs": [ { "type": "uint256" } ] },
+ { "name" : "bool", "type": "function", "outputs": [ { "type": "bool" } ] },
+ { "name" : "bytes", "type": "function", "outputs": [ { "type": "bytes" } ] },
+ { "name" : "fixed", "type": "function", "outputs": [ { "type": "bytes32" } ] },
+ { "name" : "multi", "type": "function", "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] },
+ { "name" : "intArraySingle", "type": "function", "outputs": [ { "type": "uint256[3]" } ] },
+ { "name" : "addressSliceSingle", "type": "function", "outputs": [ { "type": "address[]" } ] },
+ { "name" : "addressSliceDouble", "type": "function", "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] },
+ { "name" : "mixedBytes", "type": "function", "stateMutability" : "view", "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]`
abi, err := JSON(strings.NewReader(definition))
if err != nil {
@@ -508,7 +530,7 @@ func TestUnmarshal(t *testing.T) {
buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000a"))
buff.Write(common.Hex2Bytes("0102000000000000000000000000000000000000000000000000000000000000"))
- err = abi.Unpack(&mixedBytes, "mixedBytes", buff.Bytes())
+ err = abi.UnpackIntoInterface(&mixedBytes, "mixedBytes", buff.Bytes())
if err != nil {
t.Error(err)
} else {
@@ -523,7 +545,7 @@ func TestUnmarshal(t *testing.T) {
// marshal int
var Int *big.Int
- err = abi.Unpack(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
+ err = abi.UnpackIntoInterface(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
if err != nil {
t.Error(err)
}
@@ -534,7 +556,7 @@ func TestUnmarshal(t *testing.T) {
// marshal bool
var Bool bool
- err = abi.Unpack(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
+ err = abi.UnpackIntoInterface(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001"))
if err != nil {
t.Error(err)
}
@@ -551,7 +573,7 @@ func TestUnmarshal(t *testing.T) {
buff.Write(bytesOut)
var Bytes []byte
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes())
if err != nil {
t.Error(err)
}
@@ -567,7 +589,7 @@ func TestUnmarshal(t *testing.T) {
bytesOut = common.RightPadBytes([]byte("hello"), 64)
buff.Write(bytesOut)
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes())
if err != nil {
t.Error(err)
}
@@ -583,7 +605,7 @@ func TestUnmarshal(t *testing.T) {
bytesOut = common.RightPadBytes([]byte("hello"), 64)
buff.Write(bytesOut)
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes())
if err != nil {
t.Error(err)
}
@@ -593,7 +615,7 @@ func TestUnmarshal(t *testing.T) {
}
// marshal dynamic bytes output empty
- err = abi.Unpack(&Bytes, "bytes", nil)
+ err = abi.UnpackIntoInterface(&Bytes, "bytes", nil)
if err == nil {
t.Error("expected error")
}
@@ -604,7 +626,7 @@ func TestUnmarshal(t *testing.T) {
buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005"))
buff.Write(common.RightPadBytes([]byte("hello"), 32))
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes())
if err != nil {
t.Error(err)
}
@@ -618,7 +640,7 @@ func TestUnmarshal(t *testing.T) {
buff.Write(common.RightPadBytes([]byte("hello"), 32))
var hash common.Hash
- err = abi.Unpack(&hash, "fixed", buff.Bytes())
+ err = abi.UnpackIntoInterface(&hash, "fixed", buff.Bytes())
if err != nil {
t.Error(err)
}
@@ -631,12 +653,12 @@ func TestUnmarshal(t *testing.T) {
// marshal error
buff.Reset()
buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020"))
- err = abi.Unpack(&Bytes, "bytes", buff.Bytes())
+ err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes())
if err == nil {
t.Error("expected error")
}
- err = abi.Unpack(&Bytes, "multi", make([]byte, 64))
+ err = abi.UnpackIntoInterface(&Bytes, "multi", make([]byte, 64))
if err == nil {
t.Error("expected error")
}
@@ -647,7 +669,7 @@ func TestUnmarshal(t *testing.T) {
buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003"))
// marshal int array
var intArray [3]*big.Int
- err = abi.Unpack(&intArray, "intArraySingle", buff.Bytes())
+ err = abi.UnpackIntoInterface(&intArray, "intArraySingle", buff.Bytes())
if err != nil {
t.Error(err)
}
@@ -668,7 +690,7 @@ func TestUnmarshal(t *testing.T) {
buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000"))
var outAddr []common.Address
- err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes())
+ err = abi.UnpackIntoInterface(&outAddr, "addressSliceSingle", buff.Bytes())
if err != nil {
t.Fatal("didn't expect error:", err)
}
@@ -695,7 +717,7 @@ func TestUnmarshal(t *testing.T) {
A []common.Address
B []common.Address
}
- err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes())
+ err = abi.UnpackIntoInterface(&outAddrStruct, "addressSliceDouble", buff.Bytes())
if err != nil {
t.Fatal("didn't expect error:", err)
}
@@ -723,12 +745,114 @@ func TestUnmarshal(t *testing.T) {
buff.Reset()
buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100"))
- err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes())
+ err = abi.UnpackIntoInterface(&outAddr, "addressSliceSingle", buff.Bytes())
if err == nil {
t.Fatal("expected error:", err)
}
}
+func TestUnpackTuple(t *testing.T) {
+ const simpleTuple = `[{"name":"tuple","type":"function","outputs":[{"type":"tuple","name":"ret","components":[{"type":"int256","name":"a"},{"type":"int256","name":"b"}]}]}]`
+ abi, err := JSON(strings.NewReader(simpleTuple))
+ if err != nil {
+ t.Fatal(err)
+ }
+ buff := new(bytes.Buffer)
+
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // ret[a] = 1
+ buff.Write(common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) // ret[b] = -1
+
+ // If the result is single tuple, use struct as return value container directly.
+ type v struct {
+ A *big.Int
+ B *big.Int
+ }
+ type r struct {
+ Result v
+ }
+ var ret0 = new(r)
+ err = abi.UnpackIntoInterface(ret0, "tuple", buff.Bytes())
+
+ if err != nil {
+ t.Error(err)
+ } else {
+ if ret0.Result.A.Cmp(big.NewInt(1)) != 0 {
+ t.Errorf("unexpected value unpacked: want %x, got %x", 1, ret0.Result.A)
+ }
+ if ret0.Result.B.Cmp(big.NewInt(-1)) != 0 {
+ t.Errorf("unexpected value unpacked: want %x, got %x", -1, ret0.Result.B)
+ }
+ }
+
+ // Test nested tuple
+ const nestedTuple = `[{"name":"tuple","type":"function","outputs":[
+ {"type":"tuple","name":"s","components":[{"type":"uint256","name":"a"},{"type":"uint256[]","name":"b"},{"type":"tuple[]","name":"c","components":[{"name":"x", "type":"uint256"},{"name":"y","type":"uint256"}]}]},
+ {"type":"tuple","name":"t","components":[{"name":"x", "type":"uint256"},{"name":"y","type":"uint256"}]},
+ {"type":"uint256","name":"a"}
+ ]}]`
+
+ abi, err = JSON(strings.NewReader(nestedTuple))
+ if err != nil {
+ t.Fatal(err)
+ }
+ buff.Reset()
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // s offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")) // t.X = 0
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // t.Y = 1
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // a = 1
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.A = 1
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000060")) // s.B offset
+ buff.Write(common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000c0")) // s.C offset
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.B length
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.B[0] = 1
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.B[0] = 2
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.C length
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.C[0].X
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.C[0].Y
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.C[1].X
+ buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.C[1].Y
+
+ type T struct {
+ X *big.Int `abi:"x"`
+ Z *big.Int `abi:"y"` // Test whether the abi tag works.
+ }
+
+ type S struct {
+ A *big.Int
+ B []*big.Int
+ C []T
+ }
+
+ type Ret struct {
+ FieldS S `abi:"s"`
+ FieldT T `abi:"t"`
+ A *big.Int
+ }
+ var ret Ret
+ var expected = Ret{
+ FieldS: S{
+ A: big.NewInt(1),
+ B: []*big.Int{big.NewInt(1), big.NewInt(2)},
+ C: []T{
+ {big.NewInt(1), big.NewInt(2)},
+ {big.NewInt(2), big.NewInt(1)},
+ },
+ },
+ FieldT: T{
+ big.NewInt(0), big.NewInt(1),
+ },
+ A: big.NewInt(1),
+ }
+
+ err = abi.UnpackIntoInterface(&ret, "tuple", buff.Bytes())
+ if err != nil {
+ t.Error(err)
+ }
+ if reflect.DeepEqual(ret, expected) {
+ t.Error("unexpected unpack value")
+ }
+}
+
func TestOOMMaliciousInput(t *testing.T) {
oomTests := []unpackTest{
{
@@ -783,7 +907,7 @@ func TestOOMMaliciousInput(t *testing.T) {
},
}
for i, test := range oomTests {
- def := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def)
+ def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, test.def)
abi, err := JSON(strings.NewReader(def))
if err != nil {
t.Fatalf("invalid ABI definition %s: %v", def, err)
diff --git a/accounts/abi/utils.go b/accounts/abi/utils.go
new file mode 100644
index 000000000..f88d2ee2d
--- /dev/null
+++ b/accounts/abi/utils.go
@@ -0,0 +1,39 @@
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package abi
+
+import "fmt"
+
+// ResolveNameConflict returns the next available name for a given thing.
+// This helper can be used for lots of purposes:
+//
+// - In solidity function overloading is supported, this function can fix
+// the name conflicts of overloaded functions.
+// - In golang binding generation, the parameter(in function, event, error,
+// and struct definition) name will be converted to camelcase style which
+// may eventually lead to name conflicts.
+//
+// Name conflicts are mostly resolved by adding number suffix. e.g. if the abi contains
+// Methods "send" and "send1", ResolveNameConflict would return "send2" for input "send".
+func ResolveNameConflict(rawName string, used func(string) bool) string {
+ name := rawName
+ ok := used(name)
+ for idx := 0; ok; idx++ {
+ name = fmt.Sprintf("%s%d", rawName, idx)
+ ok = used(name)
+ }
+ return name
+}
diff --git a/accounts/keystore/keystore_wallet.go b/accounts/keystore/keystore_wallet.go
index 01ffd75a8..91ac13878 100644
--- a/accounts/keystore/keystore_wallet.go
+++ b/accounts/keystore/keystore_wallet.go
@@ -90,7 +90,7 @@ func (w *keystoreWallet) SignHash(account accounts.Account, hash []byte) ([]byte
if account.URL != (accounts.URL{}) && account.URL != w.account.URL {
return nil, accounts.ErrUnknownAccount
}
- // Account seems valid, request the keystore to sign
+ // StateAccount seems valid, request the keystore to sign
return w.keystore.SignHash(account, hash)
}
@@ -106,7 +106,7 @@ func (w *keystoreWallet) SignTx(account accounts.Account, tx *types.Transaction,
if account.URL != (accounts.URL{}) && account.URL != w.account.URL {
return nil, accounts.ErrUnknownAccount
}
- // Account seems valid, request the keystore to sign
+ // StateAccount seems valid, request the keystore to sign
return w.keystore.SignTx(account, tx, chainID)
}
@@ -120,7 +120,7 @@ func (w *keystoreWallet) SignHashWithPassphrase(account accounts.Account, passph
if account.URL != (accounts.URL{}) && account.URL != w.account.URL {
return nil, accounts.ErrUnknownAccount
}
- // Account seems valid, request the keystore to sign
+ // StateAccount seems valid, request the keystore to sign
return w.keystore.SignHashWithPassphrase(account, passphrase, hash)
}
@@ -134,6 +134,6 @@ func (w *keystoreWallet) SignTxWithPassphrase(account accounts.Account, passphra
if account.URL != (accounts.URL{}) && account.URL != w.account.URL {
return nil, accounts.ErrUnknownAccount
}
- // Account seems valid, request the keystore to sign
+ // StateAccount seems valid, request the keystore to sign
return w.keystore.SignTxWithPassphrase(account, passphrase, tx, chainID)
}
diff --git a/accounts/usbwallet/wallet.go b/accounts/usbwallet/wallet.go
index d3cda1f21..2cb2ca2ae 100644
--- a/accounts/usbwallet/wallet.go
+++ b/accounts/usbwallet/wallet.go
@@ -319,7 +319,7 @@ func (w *wallet) selfDerive() {
// Termination requested
continue
case reqc = <-w.deriveReq:
- // Account discovery requested
+ // StateAccount discovery requested
}
// Derivation needs a chain and device access, skip if either unavailable
w.stateLock.RLock()
diff --git a/build/ci.go b/build/ci.go
index ea4481704..6af2b18af 100644
--- a/build/ci.go
+++ b/build/ci.go
@@ -14,6 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
+//go:build none
// +build none
/*
@@ -23,14 +24,13 @@ Usage: go run build/ci.go
Available commands are:
- install [ -arch architecture ] [ -cc compiler ] [ packages... ] -- builds packages and executables
- test [ -coverage ] [ packages... ] -- runs the tests
- lint -- runs certain pre-selected linters
- importkeys -- imports signing keys from env
- xgo [ -alltools ] [ options ] -- cross builds according to options
+ install [ -arch architecture ] [ -cc compiler ] [ packages... ] -- builds packages and executables
+ test [ -coverage ] [ packages... ] -- runs the tests
+ lint -- runs certain pre-selected linters
+ importkeys -- imports signing keys from env
+ xgo [ -alltools ] [ options ] -- cross builds according to options
For all commands, -n prevents execution of external programs (dry run mode).
-
*/
package main
@@ -62,6 +62,7 @@ var (
executablePath("rlpdump"),
executablePath("swarm"),
executablePath("wnode"),
+ executablePath("rlp/rlpgen"),
}
)
diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go
index c582847c5..f3a827887 100644
--- a/cmd/bootnode/main.go
+++ b/cmd/bootnode/main.go
@@ -29,6 +29,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/p2p/discv5"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
"github.com/tomochain/tomochain/p2p/netutil"
)
@@ -85,7 +86,7 @@ func main() {
}
if *writeAddr {
- fmt.Printf("%v\n", discover.PubkeyID(&nodeKey.PublicKey))
+ fmt.Printf("%v\n", enode.PubkeyToIDV4(&nodeKey.PublicKey))
os.Exit(0)
}
diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go
index 5d3b24289..75abb768c 100644
--- a/cmd/evm/runner.go
+++ b/cmd/evm/runner.go
@@ -20,24 +20,25 @@ import (
"bytes"
"encoding/json"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"io/ioutil"
"os"
+ goruntime "runtime"
"runtime/pprof"
"time"
- goruntime "runtime"
+ cli "gopkg.in/urfave/cli.v1"
"github.com/tomochain/tomochain/cmd/evm/internal/compiler"
"github.com/tomochain/tomochain/cmd/utils"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/core/vm/runtime"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/params"
- cli "gopkg.in/urfave/cli.v1"
+ "github.com/tomochain/tomochain/trie"
)
var runCommand = cli.Command{
@@ -83,6 +84,7 @@ func runCmd(ctx *cli.Context) error {
debugLogger *vm.StructLogger
statedb *state.StateDB
chainConfig *params.ChainConfig
+ preimages = ctx.Bool(DumpFlag.Name)
sender = common.StringToAddress("sender")
receiver = common.StringToAddress("receiver")
)
@@ -98,11 +100,11 @@ func runCmd(ctx *cli.Context) error {
gen := readGenesis(ctx.GlobalString(GenesisFlag.Name))
db := rawdb.NewMemoryDatabase()
genesis := gen.ToBlock(db)
- statedb, _ = state.New(genesis.Root(), state.NewDatabase(db))
+ statedb, _ = state.New(genesis.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages}), nil)
chainConfig = gen.Config
} else {
db := rawdb.NewMemoryDatabase()
- statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ = state.New(common.Hash{}, state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages}), nil)
}
if ctx.GlobalString(SenderFlag.Name) != "" {
sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name))
diff --git a/cmd/evm/staterunner.go b/cmd/evm/staterunner.go
index 5499be696..018a7c526 100644
--- a/cmd/evm/staterunner.go
+++ b/cmd/evm/staterunner.go
@@ -94,7 +94,7 @@ func stateTestCmd(ctx *cli.Context) error {
for _, st := range test.Subtests() {
// Run the test and aggregate the result
result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true}
- state, err := test.Run(st, cfg)
+ state, err := test.Run(st, cfg, false)
if err != nil {
// Test failed, mark as so and dump any state to aid debugging
result.Pass, result.Error = false, err.Error()
diff --git a/cmd/faucet/faucet.go b/cmd/faucet/faucet.go
index 6014f3c5a..33a17f35b 100644
--- a/cmd/faucet/faucet.go
+++ b/cmd/faucet/faucet.go
@@ -54,8 +54,8 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/p2p/discv5"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
"github.com/tomochain/tomochain/params"
"golang.org/x/net/websocket"
@@ -200,7 +200,7 @@ type faucet struct {
index []byte // Index page to serve up on the web
keystore *keystore.KeyStore // Keystore containing the single signer
- account accounts.Account // Account funding user faucet requests
+ account accounts.Account // StateAccount funding user faucet requests
nonce uint64 // Current pending nonce of the faucet
price *big.Int // Current gas price to issue funds with
@@ -255,7 +255,7 @@ func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network u
return nil, err
}
for _, boot := range enodes {
- old, _ := discover.ParseNode(boot.String())
+ old, _ := enode.ParseV4(boot.String())
stack.Server().AddPeer(old)
}
// Attach to the client and retrieve and interesting metadatas
diff --git a/cmd/gc/main.go b/cmd/gc/main.go
index 567349ee4..8b3552dca 100644
--- a/cmd/gc/main.go
+++ b/cmd/gc/main.go
@@ -3,9 +3,6 @@ package main
import (
"flag"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
- "github.com/tomochain/tomochain/ethdb"
- "github.com/tomochain/tomochain/ethdb/leveldb"
"os"
"os/signal"
"runtime"
@@ -13,13 +10,14 @@ import (
"sync/atomic"
"time"
- "github.com/hashicorp/golang-lru"
+ lru "github.com/hashicorp/golang-lru"
+
"github.com/tomochain/tomochain/cmd/utils"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/core"
- "github.com/tomochain/tomochain/core/state"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/eth"
- "github.com/tomochain/tomochain/rlp"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/ethdb/leveldb"
"github.com/tomochain/tomochain/trie"
)
@@ -54,15 +52,15 @@ func main() {
flag.Parse()
db, _ := leveldb.New(*dir, eth.DefaultConfig.DatabaseCache, utils.MakeDatabaseHandles(), "")
lddb := rawdb.NewDatabase(db)
- head := core.GetHeadBlockHash(lddb)
- currentHeader := core.GetHeader(lddb, head, core.GetBlockNumber(lddb, head))
+ head := rawdb.GetHeadBlockHash(lddb)
+ currentHeader := rawdb.GetHeader(lddb, head, rawdb.GetBlockNumber(lddb, head))
tridb := trie.NewDatabase(lddb)
catchEventInterupt(db)
cache, _ = lru.New(*cacheSize)
go func() {
for i := uint64(1); i <= currentHeader.Number.Uint64(); i++ {
- hash := core.GetCanonicalHash(lddb, i)
- root := core.GetHeader(lddb, hash, i).Root
+ hash := rawdb.GetCanonicalHash(lddb, i)
+ root := rawdb.GetHeader(lddb, hash, i).Root
trieRoot, err := trie.NewSecure(root, tridb)
if err != nil {
continue
@@ -81,9 +79,7 @@ func main() {
atomic.StoreInt32(&finish, 1)
if running {
for _, address := range cleanAddress {
- enc := trieRoot.trie.Get(address.Bytes())
- var data state.Account
- rlp.DecodeBytes(enc, &data)
+ data, _ := trieRoot.trie.GetAccount(address)
fmt.Println(time.Now().Format(time.RFC3339), "Start clean state address ", address.Hex(), " at block ", trieRoot.number)
signerRoot, err := resolveHash(data.Root[:], db)
if err != nil {
diff --git a/cmd/p2psim/main.go b/cmd/p2psim/main.go
index 7ae0b8b56..a39c5da3a 100644
--- a/cmd/p2psim/main.go
+++ b/cmd/p2psim/main.go
@@ -19,21 +19,20 @@
// Here is an example of creating a 2 node network with the first node
// connected to the second:
//
-// $ p2psim node create
-// Created node01
+// $ p2psim node create
+// Created node01
//
-// $ p2psim node start node01
-// Started node01
+// $ p2psim node start node01
+// Started node01
//
-// $ p2psim node create
-// Created node02
+// $ p2psim node create
+// Created node02
//
-// $ p2psim node start node02
-// Started node02
-//
-// $ p2psim node connect node01 node02
-// Connected node01 to node02
+// $ p2psim node start node02
+// Started node02
//
+// $ p2psim node connect node01 node02
+// Connected node01 to node02
package main
import (
@@ -47,7 +46,7 @@ import (
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
"github.com/tomochain/tomochain/rpc"
@@ -283,7 +282,7 @@ func createNode(ctx *cli.Context) error {
if err != nil {
return err
}
- config.ID = discover.PubkeyID(&privKey.PublicKey)
+ config.ID = enode.PubkeyToIDV4(&privKey.PublicKey)
config.PrivateKey = privKey
}
if services := ctx.String("services"); services != "" {
diff --git a/cmd/swarm/config_test.go b/cmd/swarm/config_test.go
index 05b5eeb90..4a1e30db1 100644
--- a/cmd/swarm/config_test.go
+++ b/cmd/swarm/config_test.go
@@ -20,6 +20,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "net"
"os"
"os/exec"
"testing"
@@ -552,3 +553,16 @@ func TestValidateConfig(t *testing.T) {
}
}
}
+
+func assignTCPPort() (string, error) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return "", err
+ }
+ l.Close()
+ _, port, err := net.SplitHostPort(l.Addr().String())
+ if err != nil {
+ return "", err
+ }
+ return port, nil
+}
diff --git a/cmd/swarm/main.go b/cmd/swarm/main.go
index ecd6aae79..83a2609df 100644
--- a/cmd/swarm/main.go
+++ b/cmd/swarm/main.go
@@ -39,7 +39,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/swarm"
bzzapi "github.com/tomochain/tomochain/swarm/api"
@@ -153,7 +153,7 @@ var (
}
)
-//declare a few constant error messages, useful for later error check comparisons in test
+// declare a few constant error messages, useful for later error check comparisons in test
var (
SWARM_ERR_NO_BZZACCOUNT = "bzzaccount option is required but not set; check your config file, command line or environment variables"
SWARM_ERR_SWAP_SET_NO_API = "SWAP is enabled but --swap-api is not set"
@@ -543,7 +543,7 @@ func getPassPhrase(prompt string, i int, passwords []string) string {
func injectBootnodes(srv *p2p.Server, nodes []string) {
for _, url := range nodes {
- n, err := discover.ParseNode(url)
+ n, err := enode.Parse(enode.ValidSchemes, url)
if err != nil {
log.Error("Invalid swarm bootnode", "err", err)
continue
diff --git a/cmd/swarm/run_test.go b/cmd/swarm/run_test.go
index 6c6d3d66e..3a53ef009 100644
--- a/cmd/swarm/run_test.go
+++ b/cmd/swarm/run_test.go
@@ -17,16 +17,22 @@
package main
import (
+ "context"
+ "crypto/ecdsa"
"fmt"
"io/ioutil"
"net"
"os"
+ "path"
"path/filepath"
"runtime"
+ "sync"
+ "syscall"
"testing"
"time"
"github.com/docker/docker/pkg/reexec"
+
"github.com/tomochain/tomochain/accounts"
"github.com/tomochain/tomochain/accounts/keystore"
"github.com/tomochain/tomochain/internal/cmdtest"
@@ -81,6 +87,7 @@ type testCluster struct {
//
// When starting more than one node, they are connected together using the
// admin SetPeer RPC method.
+
func newTestCluster(t *testing.T, size int) *testCluster {
cluster := &testCluster{}
defer func() {
@@ -96,27 +103,22 @@ func newTestCluster(t *testing.T, size int) *testCluster {
cluster.TmpDir = tmpdir
// start the nodes
- cluster.Nodes = make([]*testNode, 0, size)
- for i := 0; i < size; i++ {
- dir := filepath.Join(cluster.TmpDir, fmt.Sprintf("swarm%02d", i))
- if err := os.Mkdir(dir, 0700); err != nil {
- t.Fatal(err)
- }
-
- node := newTestNode(t, dir)
- node.Name = fmt.Sprintf("swarm%02d", i)
-
- cluster.Nodes = append(cluster.Nodes, node)
- }
+ cluster.StartNewNodes(t, size)
if size == 1 {
return cluster
}
// connect the nodes together
- for _, node := range cluster.Nodes {
- if err := node.Client.Call(nil, "admin_addPeer", cluster.Nodes[0].Enode); err != nil {
- t.Fatal(err)
+ for i, node := range cluster.Nodes {
+ // TODO(trinhdn2): only need to peer with cluster.Nodes[0], fix this later
+ for j := 0; j < size; j++ {
+ if i == j {
+ continue
+ }
+ if err := node.Client.Call(nil, "admin_addPeer", cluster.Nodes[j].Enode); err != nil {
+ t.Fatal(err)
+ }
}
}
@@ -145,14 +147,52 @@ func (c *testCluster) Shutdown() {
os.RemoveAll(c.TmpDir)
}
+func (c *testCluster) Stop() {
+ for _, node := range c.Nodes {
+ node.Shutdown()
+ }
+}
+
+func (c *testCluster) StartNewNodes(t *testing.T, size int) {
+ c.Nodes = make([]*testNode, 0, size)
+ for i := 0; i < size; i++ {
+ dir := filepath.Join(c.TmpDir, fmt.Sprintf("swarm%02d", i))
+ if err := os.Mkdir(dir, 0700); err != nil {
+ t.Fatal(err)
+ }
+
+ node := newTestNode(t, dir)
+ node.Name = fmt.Sprintf("swarm%02d", i)
+
+ c.Nodes = append(c.Nodes, node)
+ }
+}
+
+func (c *testCluster) StartExistingNodes(t *testing.T, size int, bzzaccount string) {
+ c.Nodes = make([]*testNode, 0, size)
+ for i := 0; i < size; i++ {
+ dir := filepath.Join(c.TmpDir, fmt.Sprintf("swarm%02d", i))
+ node := existingTestNode(t, dir, bzzaccount)
+ node.Name = fmt.Sprintf("swarm%02d", i)
+
+ c.Nodes = append(c.Nodes, node)
+ }
+}
+
+func (c *testCluster) Cleanup() {
+ os.RemoveAll(c.TmpDir)
+}
+
type testNode struct {
- Name string
- Addr string
- URL string
- Enode string
- Dir string
- Client *rpc.Client
- Cmd *cmdtest.TestCmd
+ Name string
+ Addr string
+ URL string
+ Enode string
+ Dir string
+ IpcPath string
+ PrivateKey *ecdsa.PrivateKey
+ Client *rpc.Client
+ Cmd *cmdtest.TestCmd
}
const testPassphrase = "swarm-test-passphrase"
@@ -181,24 +221,103 @@ func getTestAccount(t *testing.T, dir string) (conf *node.Config, account accoun
return conf, account
}
-func newTestNode(t *testing.T, dir string) *testNode {
-
- conf, account := getTestAccount(t, dir)
+func existingTestNode(t *testing.T, dir string, bzzaccount string) *testNode {
+ conf, _ := getTestAccount(t, dir)
node := &testNode{Dir: dir}
+ // use a unique IPCPath when running tests on Windows
+ if runtime.GOOS == "windows" {
+ conf.IPCPath = fmt.Sprintf("bzzd-%s.ipc", bzzaccount)
+ }
+
// assign ports
- httpPort, err := assignTCPPort()
+ ports, err := getAvailableTCPPorts(2)
if err != nil {
t.Fatal(err)
}
- p2pPort, err := assignTCPPort()
+ p2pPort := ports[0]
+ httpPort := ports[1]
+
+ // start the node
+ node.Cmd = runSwarm(t,
+ "--port", p2pPort,
+ "--nat", "extip:127.0.0.1",
+ "--nodiscover",
+ "--datadir", dir,
+ "--ipcpath", conf.IPCPath,
+ "--ens-api", "",
+ "--bzzaccount", bzzaccount,
+ "--bzznetworkid", "321",
+ "--bzzport", httpPort,
+ "--verbosity", "3",
+ )
+ node.Cmd.InputLine(testPassphrase)
+ defer func() {
+ if t.Failed() {
+ node.Shutdown()
+ }
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // ensure that all ports have active listeners
+ // so that the next node will not get the same
+ // when calling getAvailableTCPPorts
+ err = waitTCPPorts(ctx, ports...)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // wait for the node to start
+ for start := time.Now(); time.Since(start) < 10*time.Second; time.Sleep(50 * time.Millisecond) {
+ node.Client, err = rpc.Dial(conf.IPCEndpoint())
+ if err == nil {
+ break
+ }
+ }
+ if node.Client == nil {
+ t.Fatal(err)
+ }
+
+ // load info
+ var info swarm.Info
+ if err := node.Client.Call(&info, "bzz_info"); err != nil {
+ t.Fatal(err)
+ }
+ node.Addr = net.JoinHostPort("127.0.0.1", info.Port)
+ node.URL = "http://" + node.Addr
+
+ var nodeInfo p2p.NodeInfo
+ if err := node.Client.Call(&nodeInfo, "admin_nodeInfo"); err != nil {
+ t.Fatal(err)
+ }
+ node.Enode = nodeInfo.Enode
+ node.IpcPath = conf.IPCPath
+ return node
+}
+
+func newTestNode(t *testing.T, dir string) *testNode {
+
+ conf, account := getTestAccount(t, dir)
+ ks := keystore.NewKeyStore(path.Join(dir, "keystore"), 1<<18, 1)
+
+ pk := decryptStoreAccount(ks, account.Address.Hex(), []string{testPassphrase})
+
+ node := &testNode{Dir: dir, PrivateKey: pk}
+
+ // assign ports
+ ports, err := getAvailableTCPPorts(2)
if err != nil {
t.Fatal(err)
}
+ p2pPort := ports[0]
+ httpPort := ports[1]
// start the node
node.Cmd = runSwarm(t,
"--port", p2pPort,
+ "--nat", "extip:127.0.0.1",
"--nodiscover",
"--datadir", dir,
"--ipcpath", conf.IPCPath,
@@ -206,7 +325,7 @@ func newTestNode(t *testing.T, dir string) *testNode {
"--bzzaccount", account.Address.String(),
"--bzznetworkid", "321",
"--bzzport", httpPort,
- "--verbosity", "6",
+ "--verbosity", "3",
)
node.Cmd.InputLine(testPassphrase)
defer func() {
@@ -215,6 +334,17 @@ func newTestNode(t *testing.T, dir string) *testNode {
}
}()
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // ensure that all ports have active listeners
+ // so that the next node will not get the same
+ // when calling getAvailableTCPPorts
+ err = waitTCPPorts(ctx, ports...)
+ if err != nil {
+ t.Fatal(err)
+ }
+
// wait for the node to start
for start := time.Now(); time.Since(start) < 10*time.Second; time.Sleep(50 * time.Millisecond) {
node.Client, err = rpc.Dial(conf.IPCEndpoint())
@@ -238,8 +368,8 @@ func newTestNode(t *testing.T, dir string) *testNode {
if err := node.Client.Call(&nodeInfo, "admin_nodeInfo"); err != nil {
t.Fatal(err)
}
- node.Enode = fmt.Sprintf("enode://%s@127.0.0.1:%s", nodeInfo.ID, p2pPort)
-
+ node.Enode = nodeInfo.Enode
+ node.IpcPath = conf.IPCPath
return node
}
@@ -249,15 +379,92 @@ func (n *testNode) Shutdown() {
}
}
-func assignTCPPort() (string, error) {
- l, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return "", err
+// getAvailableTCPPorts returns a set of ports that
+// nothing is listening on at the time.
+//
+// Function assignTCPPort cannot be called in sequence
+// and guardantee that the same port will be returned in
+// different calls as the listener is closed within the function,
+// not after all listeners are started and selected unique
+// available ports.
+func getAvailableTCPPorts(count int) (ports []string, err error) {
+ for i := 0; i < count; i++ {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, err
+ }
+ // defer close in the loop to be sure the same port will not
+ // be selected in the next iteration
+ defer l.Close()
+
+ _, port, err := net.SplitHostPort(l.Addr().String())
+ if err != nil {
+ return nil, err
+ }
+ ports = append(ports, port)
}
- l.Close()
- _, port, err := net.SplitHostPort(l.Addr().String())
- if err != nil {
- return "", err
+ return ports, nil
+}
+
+// waitTCPPorts blocks until tcp connections can be
+// established on all provided ports. It runs all
+// ports dialers in parallel, and returns the first
+// encountered error.
+// See waitTCPPort also.
+func waitTCPPorts(ctx context.Context, ports ...string) error {
+ var err error
+ // mu locks err variable that is assigned in
+ // other goroutines
+ var mu sync.Mutex
+
+ // cancel is canceling all goroutines
+ // when the firs error is returned
+ // to prevent unnecessary waiting
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ var wg sync.WaitGroup
+ for _, port := range ports {
+ wg.Add(1)
+ go func(port string) {
+ defer wg.Done()
+
+ e := waitTCPPort(ctx, port)
+
+ mu.Lock()
+ defer mu.Unlock()
+ if e != nil && err == nil {
+ err = e
+ cancel()
+ }
+ }(port)
+ }
+ wg.Wait()
+
+ return err
+}
+
+// waitTCPPort blocks until tcp connection can be established
+// ona provided port. It has a 3 minute timeout as maximum,
+// to prevent long waiting, but it can be shortened with
+// a provided context instance. Dialer has a 10 second timeout
+// in every iteration, and connection refused error will be
+// retried in 100 milliseconds periods.
+func waitTCPPort(ctx context.Context, port string) error {
+ ctx, cancel := context.WithTimeout(ctx, 3*time.Minute)
+ defer cancel()
+
+ for {
+ c, err := (&net.Dialer{Timeout: 10 * time.Second}).DialContext(ctx, "tcp", "127.0.0.1:"+port)
+ if err != nil {
+ if operr, ok := err.(*net.OpError); ok {
+ if syserr, ok := operr.Err.(*os.SyscallError); ok && syserr.Err == syscall.ECONNREFUSED {
+ time.Sleep(100 * time.Millisecond)
+ continue
+ }
+ }
+ return err
+ }
+ return c.Close()
}
- return port, nil
}
diff --git a/cmd/tomo/bugcmd.go b/cmd/tomo/bugcmd.go
index 3174f7388..5cec10ad4 100644
--- a/cmd/tomo/bugcmd.go
+++ b/cmd/tomo/bugcmd.go
@@ -105,5 +105,4 @@ const header = `Please answer these questions before submitting your issue. Than
#### What did you see instead?
-#### System details
-`
+#### System details`
diff --git a/cmd/tomo/chaincmd.go b/cmd/tomo/chaincmd.go
index dc0a274ba..e1c23c0cd 100644
--- a/cmd/tomo/chaincmd.go
+++ b/cmd/tomo/chaincmd.go
@@ -66,6 +66,7 @@ It expects the genesis file as argument.`,
utils.CacheFlag,
utils.LightModeFlag,
utils.GCModeFlag,
+ utils.SnapshotFlag,
utils.CacheDatabaseFlag,
utils.CacheGCFlag,
},
@@ -450,7 +451,7 @@ func dump(ctx *cli.Context) error {
fmt.Println("{}")
utils.Fatalf("block not found")
} else {
- state, err := state.New(block.Root(), state.NewDatabase(chainDb))
+ state, err := state.New(block.Root(), state.NewDatabase(chainDb), nil)
if err != nil {
utils.Fatalf("could not create new state: %v", err)
}
diff --git a/cmd/tomo/config.go b/cmd/tomo/config.go
index d8dffb4f5..55d44f8be 100644
--- a/cmd/tomo/config.go
+++ b/cmd/tomo/config.go
@@ -20,7 +20,6 @@ import (
"bufio"
"errors"
"fmt"
- "gopkg.in/urfave/cli.v1"
"io"
"math/big"
"os"
@@ -29,6 +28,8 @@ import (
"unicode"
"github.com/naoina/toml"
+ "gopkg.in/urfave/cli.v1"
+
"github.com/tomochain/tomochain/cmd/utils"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/eth"
@@ -166,6 +167,21 @@ func makeConfigNode(ctx *cli.Context) (*node.Node, tomoConfig) {
common.BlackListHFNumber = uint64(0)
}
+ // Check if devnet is enable
+ if ctx.GlobalUint64(utils.NetworkIdFlag.Name) == 989898 {
+ cfg.Eth.NetworkId = 989898
+ common.IsDevnet = true
+ common.TIP2019Block = big.NewInt(300)
+ common.TIPSigning = big.NewInt(600)
+ common.TIPRandomize = big.NewInt(900)
+ common.TIPTomoX = big.NewInt(1200)
+ common.TIPTomoXLending = big.NewInt(1500)
+ common.TIPTomoXCancellationFee = big.NewInt(1800)
+ common.EpocBlockSecret = uint64(100)
+ common.EpocBlockOpening = uint64(125)
+ common.EpocBlockRandomize = uint64(150)
+ }
+
// Rewound
if rewound := ctx.GlobalInt(utils.RewoundFlag.Name); rewound != 0 {
common.Rewound = uint64(rewound)
diff --git a/cmd/tomo/consolecmd_test.go b/cmd/tomo/consolecmd_test.go
index 241373f52..894f55c69 100644
--- a/cmd/tomo/consolecmd_test.go
+++ b/cmd/tomo/consolecmd_test.go
@@ -52,7 +52,7 @@ func TestConsoleWelcome(t *testing.T) {
tomo.SetTemplateFunc("goarch", func() string { return runtime.GOARCH })
tomo.SetTemplateFunc("gover", runtime.Version)
tomo.SetTemplateFunc("tomover", func() string { return params.Version })
- tomo.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format(time.RFC1123) })
+ tomo.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (MST)") })
tomo.SetTemplateFunc("apis", func() string { return ipcAPIs })
// Verify the actual welcome message to the required template
@@ -137,7 +137,7 @@ func testAttachWelcome(t *testing.T, tomo *testtomo, endpoint, apis string) {
attach.SetTemplateFunc("gover", runtime.Version)
attach.SetTemplateFunc("tomover", func() string { return params.Version })
attach.SetTemplateFunc("etherbase", func() string { return tomo.Etherbase })
- attach.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format(time.RFC1123) })
+ attach.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (MST)") })
attach.SetTemplateFunc("ipc", func() bool { return strings.HasPrefix(endpoint, "ipc") })
attach.SetTemplateFunc("datadir", func() string { return tomo.Datadir })
attach.SetTemplateFunc("apis", func() string { return apis })
diff --git a/cmd/tomo/dao_test.go b/cmd/tomo/dao_test.go
index 773f1ed15..768a7bb76 100644
--- a/cmd/tomo/dao_test.go
+++ b/cmd/tomo/dao_test.go
@@ -17,7 +17,6 @@
package main
import (
- "github.com/tomochain/tomochain/core/rawdb"
"io/ioutil"
"math/big"
"os"
@@ -25,7 +24,7 @@ import (
"testing"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
)
// Genesis block for nodes which don't care about the DAO fork (i.e. not configured)
@@ -130,7 +129,7 @@ func testDAOForkBlockNewChain(t *testing.T, test int, genesis string, expectBloc
if genesis != "" {
genesisHash = daoGenesisHash
}
- config, err := core.GetChainConfig(db, genesisHash)
+ config, err := rawdb.GetChainConfig(db, genesisHash)
if err != nil {
t.Errorf("test %d: failed to retrieve chain config: %v", test, err)
return // we want to return here, the other checks can't make it past this point (nil panic).
diff --git a/cmd/tomo/main.go b/cmd/tomo/main.go
index 2a606fbb7..99851e86a 100644
--- a/cmd/tomo/main.go
+++ b/cmd/tomo/main.go
@@ -86,6 +86,7 @@ var (
utils.LightModeFlag,
utils.SyncModeFlag,
utils.GCModeFlag,
+ utils.SnapshotFlag,
//utils.LightServFlag,
//utils.LightPeersFlag,
//utils.LightKDFFlag,
@@ -93,6 +94,7 @@ var (
//utils.CacheDatabaseFlag,
//utils.CacheGCFlag,
//utils.TrieCacheGenFlag,
+ utils.CacheSnapshotFlag,
utils.ListenPortFlag,
utils.MaxPeersFlag,
utils.MaxPendingPeersFlag,
@@ -224,6 +226,10 @@ func tomo(ctx *cli.Context) error {
// it unlocks any requested accounts, and starts the RPC/IPC interfaces and the
// miner.
func startNode(ctx *cli.Context, stack *node.Node, cfg tomoConfig) {
+ if common.IsDevnet {
+ log.Info("DEVNET configuration applied")
+ }
+
// Start up the node itself
utils.StartNode(stack)
diff --git a/cmd/tomo/usage.go b/cmd/tomo/usage.go
index f166d9aae..8840af839 100644
--- a/cmd/tomo/usage.go
+++ b/cmd/tomo/usage.go
@@ -123,15 +123,15 @@ var AppHelpFlagGroups = []flagGroup{
// utils.TxPoolLifetimeFlag,
// },
//},
- //{
- // Name: "PERFORMANCE TUNING",
- // Flags: []cli.Flag{
- // utils.CacheFlag,
- // utils.CacheDatabaseFlag,
- // utils.CacheGCFlag,
- // utils.TrieCacheGenFlag,
- // },
- //},
+ {
+ Name: "PERFORMANCE TUNING",
+ Flags: []cli.Flag{
+ utils.CacheFlag,
+ utils.CacheDatabaseFlag,
+ utils.CacheGCFlag,
+ utils.CacheSnapshotFlag,
+ },
+ },
{
Name: "ACCOUNT",
Flags: []cli.Flag{
diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go
index a3787f731..667098e90 100644
--- a/cmd/utils/cmd.go
+++ b/cmd/utils/cmd.go
@@ -29,6 +29,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/ethdb"
@@ -271,7 +272,7 @@ func ImportPreimages(db ethdb.Database, fn string) error {
// Accumulate the preimages and flush when enough ws gathered
preimages[crypto.Keccak256Hash(blob)] = common.CopyBytes(blob)
if len(preimages) > 1024 {
- if err := core.WritePreimages(db, 0, preimages); err != nil {
+ if err := rawdb.WritePreimages(db, 0, preimages); err != nil {
return err
}
preimages = make(map[common.Hash][]byte)
@@ -279,7 +280,7 @@ func ImportPreimages(db ethdb.Database, fn string) error {
}
// Flush the last batch preimage data
if len(preimages) > 0 {
- return core.WritePreimages(db, 0, preimages)
+ return rawdb.WritePreimages(db, 0, preimages)
}
return nil
}
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index 59a0cdeaf..5f8542b29 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -46,8 +46,8 @@ import (
"github.com/tomochain/tomochain/metrics"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/p2p/discv5"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
"github.com/tomochain/tomochain/p2p/netutil"
"github.com/tomochain/tomochain/params"
@@ -190,6 +190,10 @@ var (
Usage: `Blockchain garbage collection mode ("full", "archive")`,
Value: "full",
}
+ SnapshotFlag = cli.BoolFlag{
+ Name: "snapshot",
+ Usage: `Enables snapshot-database mode -- experimental work in progress feature`,
+ }
LightServFlag = cli.IntFlag{
Name: "lightserv",
Usage: "Maximum percentage of time allowed for serving LES requests (0-90)",
@@ -305,6 +309,11 @@ var (
Usage: "Percentage of cache memory allowance to use for trie pruning",
Value: 25,
}
+ CacheSnapshotFlag = cli.IntFlag{
+ Name: "cache.snapshot",
+ Usage: "Percentage of cache memory allowance to use for snapshot caching (default = 10% full mode, 20% archive mode)",
+ Value: 10,
+ }
// Miner settings
StakingEnabledFlag = cli.BoolFlag{
Name: "mine",
@@ -634,9 +643,9 @@ func setBootstrapNodes(ctx *cli.Context, cfg *p2p.Config) {
case ctx.GlobalBool(TomoTestnetFlag.Name):
urls = params.TestnetBootnodes
}
- cfg.BootstrapNodes = make([]*discover.Node, 0, len(urls))
+ cfg.BootstrapNodes = make([]*enode.Node, 0, len(urls))
for _, url := range urls {
- node, err := discover.ParseNode(url)
+ node, err := enode.ParseV4(url)
if err != nil {
log.Error("Bootstrap URL invalid", "enode", url, "err", err)
continue
diff --git a/cmd/wnode/main.go b/cmd/wnode/main.go
index 78c558bc1..b5795e715 100644
--- a/cmd/wnode/main.go
+++ b/cmd/wnode/main.go
@@ -41,7 +41,7 @@ import (
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
"github.com/tomochain/tomochain/whisper/mailserver"
whisper "github.com/tomochain/tomochain/whisper/whisperv6"
@@ -175,7 +175,7 @@ func initialize() {
log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*argVerbosity), log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
done = make(chan struct{})
- var peers []*discover.Node
+ var peers []*enode.Node
var err error
if *generateKey {
@@ -203,7 +203,7 @@ func initialize() {
if len(*argEnode) == 0 {
argEnode = scanLineA("Please enter the peer's enode: ")
}
- peer := discover.MustParseNode(*argEnode)
+ peer := enode.MustParseV4(*argEnode)
peers = append(peers, peer)
}
@@ -750,11 +750,11 @@ func requestExpiredMessagesLoop() {
}
func extractIDFromEnode(s string) []byte {
- n, err := discover.ParseNode(s)
+ n, err := enode.ParseV4(s)
if err != nil {
utils.Fatalf("Failed to parse enode: %s", err)
}
- return n.ID[:]
+ return n.ID().Bytes()
}
// obfuscateBloom adds 16 random bits to the the bloom
diff --git a/common/bytes.go b/common/bytes.go
index ba00e8a4b..1801cb1ca 100644
--- a/common/bytes.go
+++ b/common/bytes.go
@@ -119,3 +119,25 @@ func LeftPadBytes(slice []byte, l int) []byte {
return padded
}
+
+// TrimLeftZeroes returns a subslice of s without leading zeroes
+func TrimLeftZeroes(s []byte) []byte {
+ idx := 0
+ for ; idx < len(s); idx++ {
+ if s[idx] != 0 {
+ break
+ }
+ }
+ return s[idx:]
+}
+
+// TrimRightZeroes returns a subslice of s without trailing zeroes
+func TrimRightZeroes(s []byte) []byte {
+ idx := len(s)
+ for ; idx > 0; idx-- {
+ if s[idx-1] != 0 {
+ break
+ }
+ }
+ return s[:idx]
+}
diff --git a/common/constants.go b/common/constants.go
index af75a82e7..627880052 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -11,9 +11,6 @@ const (
HexSignMethod = "e341eaa4"
HexSetSecret = "34d38600"
HexSetOpening = "e11f5ba2"
- EpocBlockSecret = 800
- EpocBlockOpening = 850
- EpocBlockRandomize = 900
MaxMasternodes = 150
LimitPenaltyEpoch = 4
BlocksPerYear = uint64(15768000)
@@ -29,6 +26,11 @@ const (
var Rewound = uint64(0)
+// dynamic configs
+var EpocBlockSecret = uint64(800)
+var EpocBlockOpening = uint64(850)
+var EpocBlockRandomize = uint64(900)
+
// hardforks
var TIP2019Block = big.NewInt(1050000)
var TIPSigning = big.NewInt(3000000)
@@ -38,6 +40,7 @@ var TIPTomoX = big.NewInt(20581700)
var TIPTomoXLending = big.NewInt(21430200)
var TIPTomoXCancellationFee = big.NewInt(30915660)
var TIPTomoXTestnet = big.NewInt(0)
+var IsDevnet bool = false
var IsTestnet bool = false
var StoreRewardFolder string
var RollbackHash Hash
diff --git a/common/math/big.go b/common/math/big.go
index 787278650..27068b228 100644
--- a/common/math/big.go
+++ b/common/math/big.go
@@ -176,13 +176,19 @@ func U256(x *big.Int) *big.Int {
return x.And(x, tt256m1)
}
+// U256Bytes converts a big Int into a 256bit EVM number.
+// This operation is destructive.
+func U256Bytes(n *big.Int) []byte {
+ return PaddedBigBytes(U256(n), 32)
+}
+
// S256 interprets x as a two's complement number.
// x must not exceed 256 bits (the result is undefined if it does) and is not modified.
//
-// S256(0) = 0
-// S256(1) = 1
-// S256(2**255) = -2**255
-// S256(2**256-1) = -1
+// S256(0) = 0
+// S256(1) = 1
+// S256(2**255) = -2**255
+// S256(2**256-1) = -1
func S256(x *big.Int) *big.Int {
if x.Cmp(tt255) < 0 {
return x
diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go
index f63373e17..5c03e332c 100644
--- a/consensus/clique/clique.go
+++ b/consensus/clique/clique.go
@@ -40,6 +40,7 @@ import (
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/rpc"
+ "github.com/tomochain/tomochain/trie"
)
const (
@@ -575,7 +576,7 @@ func (c *Clique) Finalize(chain consensus.ChainReader, header *types.Header, sta
header.UncleHash = types.CalcUncleHash(nil)
// Assemble and return the final block for sealing
- return types.NewBlock(header, txs, nil, receipts), nil
+ return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil
}
// Authorize injects a private key into the consensus engine to mint new blocks
diff --git a/consensus/clique/snapshot.go b/consensus/clique/snapshot.go
index 3c2bf703d..9a1e9e884 100644
--- a/consensus/clique/snapshot.go
+++ b/consensus/clique/snapshot.go
@@ -32,7 +32,7 @@ import (
type Vote struct {
Signer common.Address `json:"signer"` // Authorized signer that cast this vote
Block uint64 `json:"block"` // Block number the vote was cast in (expire old votes)
- Address common.Address `json:"address"` // Account being voted on to change its authorization
+ Address common.Address `json:"address"` // StateAccount being voted on to change its authorization
Authorize bool `json:"authorize"` // Whether to authorize or deauthorize the voted account
}
diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go
index 12f63cfde..706456992 100644
--- a/consensus/ethash/consensus.go
+++ b/consensus/ethash/consensus.go
@@ -32,6 +32,7 @@ import (
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/trie"
)
// Ethash proof-of-work protocol constants.
@@ -519,7 +520,7 @@ func (ethash *Ethash) Finalize(chain consensus.ChainReader, header *types.Header
header.Root = state.IntermediateRoot(chain.Config().IsEIP158(header.Number))
// Header seems complete, assemble into a block and return
- return types.NewBlock(header, txs, uncles, receipts), nil
+ return types.NewBlock(header, txs, uncles, receipts, new(trie.StackTrie)), nil
}
// Some weird constants to avoid constant memory allocs for them.
diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go
index 002710497..71f9e52f7 100644
--- a/consensus/posv/posv.go
+++ b/consensus/posv/posv.go
@@ -21,9 +21,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/tomochain/tomochain/tomox/tradingstate"
- "github.com/tomochain/tomochain/tomoxlending/lendingstate"
- "gopkg.in/karalabe/cookiejar.v2/collections/prque"
"io/ioutil"
"math/big"
"math/rand"
@@ -35,6 +32,7 @@ import (
"time"
lru "github.com/hashicorp/golang-lru"
+
"github.com/tomochain/tomochain/accounts"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
@@ -50,6 +48,10 @@ import (
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/rpc"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
+ "github.com/tomochain/tomochain/tomoxlending/lendingstate"
+ "github.com/tomochain/tomochain/trie"
+ "gopkg.in/karalabe/cookiejar.v2/collections/prque"
)
const (
@@ -66,6 +68,7 @@ type Masternode struct {
type TradingService interface {
GetTradingStateRoot(block *types.Block, author common.Address) (common.Hash, error)
GetTradingState(block *types.Block, author common.Address) (*tradingstate.TradingStateDB, error)
+ GetEmptyTradingState() (*tradingstate.TradingStateDB, error)
HasTradingState(block *types.Block, author common.Address) bool
GetStateCache() tradingstate.Database
GetTriegc() *prque.Prque
@@ -181,7 +184,7 @@ var (
// SignerFn is a signer callback function to request a hash to be signed by a
// backing account.
-//type SignerFn func(accounts.Account, []byte) ([]byte, error)
+//type SignerFn func(accounts.StateAccount, []byte) ([]byte, error)
// sigHash returns the hash which is used as input for the proof-of-stake-voting
// signing. It is the hash of the entire header apart from the 65 byte signature
@@ -985,7 +988,7 @@ func (c *Posv) Finalize(chain consensus.ChainReader, header *types.Header, state
header.UncleHash = types.CalcUncleHash(nil)
// Assemble and return the final block for sealing
- return types.NewBlock(header, txs, nil, receipts), nil
+ return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil
}
// Authorize injects a private key into the consensus engine to mint new blocks
@@ -1146,7 +1149,7 @@ func (c *Posv) CacheData(header *types.Header, txs []*types.Transaction, receipt
signTxs := []*types.Transaction{}
for _, tx := range txs {
if tx.IsSigningTransaction() {
- var b uint
+ var b uint64
for _, r := range receipts {
if r.TxHash == tx.Hash() {
if len(r.PostState) > 0 {
diff --git a/consensus/posv/snapshot.go b/consensus/posv/snapshot.go
index aef9e2a39..01f9d50e4 100644
--- a/consensus/posv/snapshot.go
+++ b/consensus/posv/snapshot.go
@@ -32,7 +32,7 @@ import (
//type Vote struct {
// Signer common.Address `json:"signer"` // Authorized signer that cast this vote
// Block uint64 `json:"block"` // Block number the vote was cast in (expire old votes)
-// Address common.Address `json:"address"` // Account being voted on to change its authorization
+// Address common.Address `json:"address"` // StateAccount being voted on to change its authorization
// Authorize bool `json:"authorize"` // Whether to authorize or deauthorize the voted account
//}
diff --git a/console/console_test.go b/console/console_test.go
index 22527f4dd..98f85c4b4 100644
--- a/console/console_test.go
+++ b/console/console_test.go
@@ -19,8 +19,6 @@ package console
import (
"bytes"
"errors"
- "github.com/tomochain/tomochain/tomox"
- "github.com/tomochain/tomochain/tomoxlending"
"io/ioutil"
"os"
"strings"
@@ -29,10 +27,13 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
+ "github.com/tomochain/tomochain/console/prompt"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/eth"
"github.com/tomochain/tomochain/internal/jsre"
"github.com/tomochain/tomochain/node"
+ "github.com/tomochain/tomochain/tomox"
+ "github.com/tomochain/tomochain/tomoxlending"
)
const (
@@ -67,10 +68,10 @@ func (p *hookedPrompter) PromptPassword(prompt string) (string, error) {
func (p *hookedPrompter) PromptConfirm(prompt string) (bool, error) {
return false, errors.New("not implemented")
}
-func (p *hookedPrompter) SetHistory(history []string) {}
-func (p *hookedPrompter) AppendHistory(command string) {}
-func (p *hookedPrompter) ClearHistory() {}
-func (p *hookedPrompter) SetWordCompleter(completer WordCompleter) {}
+func (p *hookedPrompter) SetHistory(history []string) {}
+func (p *hookedPrompter) AppendHistory(command string) {}
+func (p *hookedPrompter) ClearHistory() {}
+func (p *hookedPrompter) SetWordCompleter(completer prompt.WordCompleter) {}
// tester is a console test environment for the console tests to operate on.
type tester struct {
@@ -262,7 +263,7 @@ func TestPrettyError(t *testing.T) {
defer tester.Close(t)
tester.console.Evaluate("throw 'hello'")
- want := jsre.ErrorColor("hello") + "\n"
+ want := jsre.ErrorColor("hello") + "\n\tat :1:1(1)\n\n"
if output := tester.output.String(); output != want {
t.Fatalf("pretty error mismatch: have %s, want %s", output, want)
}
diff --git a/contracts/utils.go b/contracts/utils.go
index 4468b5de9..eede7c43f 100644
--- a/contracts/utils.go
+++ b/contracts/utils.go
@@ -39,6 +39,7 @@ import (
"github.com/tomochain/tomochain/contracts/blocksigner/contract"
randomizeContract "github.com/tomochain/tomochain/contracts/randomize/contract"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
stateDatabase "github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
@@ -336,7 +337,7 @@ func GetRewardForCheckpoint(c *posv.Posv, chain consensus.ChainReader, header *t
block := chain.GetBlock(header.Hash(), i)
txs := block.Transactions()
if !chain.Config().IsTIPSigning(header.Number) {
- receipts := core.GetBlockReceipts(c.GetDb(), header.Hash(), i)
+ receipts := rawdb.GetBlockReceipts(c.GetDb(), header.Hash(), i, chain.Config())
signData = c.CacheData(header, txs, receipts)
} else {
signData = c.CacheSigner(header.Hash(), txs)
diff --git a/contracts/validator/validator_test.go b/contracts/validator/validator_test.go
index c7a452d75..9cdb8bec8 100644
--- a/contracts/validator/validator_test.go
+++ b/contracts/validator/validator_test.go
@@ -60,10 +60,7 @@ func TestValidator(t *testing.T) {
d := time.Now().Add(1000 * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
- code, _ := contractBackend.CodeAt(ctx, validatorAddress, nil)
- t.Log("contract code", common.ToHex(code))
f := func(key, val common.Hash) bool {
- t.Log(key.Hex(), val.Hex())
return true
}
contractBackend.ForEachStorageAt(ctx, validatorAddress, nil, f)
diff --git a/core/bench_test.go b/core/bench_test.go
index 137b57f03..5380398f9 100644
--- a/core/bench_test.go
+++ b/core/bench_test.go
@@ -18,7 +18,6 @@ package core
import (
"crypto/ecdsa"
- "github.com/tomochain/tomochain/core/rawdb"
"io/ioutil"
"math/big"
"os"
@@ -27,6 +26,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/consensus/ethash"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/crypto"
@@ -235,13 +235,13 @@ func makeChainForBench(db ethdb.Database, full bool, count uint64) {
ReceiptHash: types.EmptyRootHash,
}
hash = header.Hash()
- WriteHeader(db, header)
- WriteCanonicalHash(db, hash, n)
- WriteTd(db, hash, n, big.NewInt(int64(n+1)))
+ rawdb.WriteHeader(db, header)
+ rawdb.WriteCanonicalHash(db, hash, n)
+ rawdb.WriteTd(db, hash, n, big.NewInt(int64(n+1)))
if full || n == 0 {
block := types.NewBlockWithHeader(header)
- WriteBody(db, hash, n, block.Body())
- WriteBlockReceipts(db, hash, n, nil)
+ rawdb.WriteBody(db, hash, n, block.Body())
+ rawdb.WriteBlockReceipts(db, hash, n, nil)
}
}
}
@@ -275,6 +275,8 @@ func benchReadChain(b *testing.B, full bool, count uint64) {
}
makeChainForBench(db, full, count)
db.Close()
+ cacheConfig := defaultCacheConfig
+ cacheConfig.Disabled = true
b.ReportAllocs()
b.ResetTimer()
@@ -284,7 +286,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) {
if err != nil {
b.Fatalf("error opening database at %v: %v", dir, err)
}
- chain, err := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
+ chain, err := NewBlockChain(db, cacheConfig, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
if err != nil {
b.Fatalf("error creating chain: %v", err)
}
@@ -293,8 +295,8 @@ func benchReadChain(b *testing.B, full bool, count uint64) {
header := chain.GetHeaderByNumber(n)
if full {
hash := header.Hash()
- GetBody(db, hash, n)
- GetBlockReceipts(db, hash, n)
+ rawdb.GetBody(db, hash, n)
+ rawdb.GetBlockReceipts(db, hash, n, params.TestChainConfig)
}
}
diff --git a/core/block_validator.go b/core/block_validator.go
index 34fde4ced..63e3f5438 100644
--- a/core/block_validator.go
+++ b/core/block_validator.go
@@ -18,6 +18,7 @@ package core
import (
"fmt"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/consensus/posv"
@@ -27,6 +28,7 @@ import (
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/tomox/tradingstate"
"github.com/tomochain/tomochain/tomoxlending/lendingstate"
+ "github.com/tomochain/tomochain/trie"
)
// BlockValidator is responsible for validating block headers, uncles and
@@ -71,7 +73,7 @@ func (v *BlockValidator) ValidateBody(block *types.Block) error {
if hash := types.CalcUncleHash(block.Uncles()); hash != header.UncleHash {
return fmt.Errorf("uncle root hash mismatch: have %x, want %x", hash, header.UncleHash)
}
- if hash := types.DeriveSha(block.Transactions()); hash != header.TxHash {
+ if hash := types.DeriveSha(block.Transactions(), new(trie.StackTrie)); hash != header.TxHash {
return fmt.Errorf("transaction root hash mismatch: have %x, want %x", hash, header.TxHash)
}
return nil
@@ -93,7 +95,7 @@ func (v *BlockValidator) ValidateState(block, parent *types.Block, statedb *stat
return fmt.Errorf("invalid bloom (remote: %x local: %x)", header.Bloom, rbloom)
}
// Tre receipt Trie's root (R = (Tr [[H1, R1], ... [Hn, R1]]))
- receiptSha := types.DeriveSha(receipts)
+ receiptSha := types.DeriveSha(receipts, new(trie.StackTrie))
if receiptSha != header.ReceiptHash {
return fmt.Errorf("invalid receipt root hash (remote: %x local: %x)", header.ReceiptHash, receiptSha)
}
diff --git a/core/blockchain.go b/core/blockchain.go
index f763189be..2a912a4ce 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -28,18 +28,18 @@ import (
"sync/atomic"
"time"
- "github.com/tomochain/tomochain/tomoxlending/lendingstate"
+ lru "github.com/hashicorp/golang-lru"
+ "gopkg.in/karalabe/cookiejar.v2/collections/prque"
"github.com/tomochain/tomochain/accounts/abi/bind"
- "github.com/tomochain/tomochain/tomox/tradingstate"
-
- lru "github.com/hashicorp/golang-lru"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/mclock"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/consensus/posv"
contractValidator "github.com/tomochain/tomochain/contracts/validator/contract"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
+ "github.com/tomochain/tomochain/core/state/snapshot"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/crypto"
@@ -50,14 +50,40 @@ import (
"github.com/tomochain/tomochain/metrics"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
+ "github.com/tomochain/tomochain/tomoxlending/lendingstate"
"github.com/tomochain/tomochain/trie"
- "gopkg.in/karalabe/cookiejar.v2/collections/prque"
)
var (
- blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil)
- CheckpointCh = make(chan int)
- ErrNoGenesis = errors.New("Genesis not found in chain")
+ accountReadTimer = metrics.NewRegisteredTimer("chain/account/reads", nil)
+ accountHashTimer = metrics.NewRegisteredTimer("chain/account/hashes", nil)
+ accountUpdateTimer = metrics.NewRegisteredTimer("chain/account/updates", nil)
+ accountCommitTimer = metrics.NewRegisteredTimer("chain/account/commits", nil)
+
+ storageReadTimer = metrics.NewRegisteredTimer("chain/storage/reads", nil)
+ storageHashTimer = metrics.NewRegisteredTimer("chain/storage/hashes", nil)
+ storageUpdateTimer = metrics.NewRegisteredTimer("chain/storage/updates", nil)
+ storageCommitTimer = metrics.NewRegisteredTimer("chain/storage/commits", nil)
+
+ snapshotAccountReadTimer = metrics.NewRegisteredTimer("chain/snapshot/account/reads", nil)
+ snapshotStorageReadTimer = metrics.NewRegisteredTimer("chain/snapshot/storage/reads", nil)
+ snapshotCommitTimer = metrics.NewRegisteredTimer("chain/snapshot/commits", nil)
+
+ blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil)
+ blockValidationTimer = metrics.NewRegisteredTimer("chain/validation", nil)
+ blockExecutionTimer = metrics.NewRegisteredTimer("chain/execution", nil)
+ blockWriteTimer = metrics.NewRegisteredTimer("chain/write", nil)
+ blockReorgAddMeter = metrics.NewRegisteredMeter("chain/reorg/drop", nil)
+ blockReorgDropMeter = metrics.NewRegisteredMeter("chain/reorg/add", nil)
+
+ blockPrefetchExecuteTimer = metrics.NewRegisteredTimer("chain/prefetch/executes", nil)
+ blockPrefetchInterruptMeter = metrics.NewRegisteredMeter("chain/prefetch/interrupts", nil)
+
+ errInsertionInterrupted = errors.New("insertion is interrupted")
+
+ CheckpointCh = make(chan int)
+ ErrNoGenesis = errors.New("Genesis not found in chain")
)
const (
@@ -81,7 +107,18 @@ type CacheConfig struct {
Disabled bool // Whether to disable trie write caching (archive node)
TrieNodeLimit int // Memory limit (MB) at which to flush the current in-memory trie to disk
TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk
+ SnapshotLimit int // Memory allowance (MB) to use for caching snapshot entries in memory
+
+ SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it
+}
+
+// defaultCacheConfig are the default caching values if none are specified by the
+// user (also used during testing).
+var defaultCacheConfig = &CacheConfig{
+ TrieNodeLimit: 256,
+ TrieTimeLimit: 5 * time.Minute,
}
+
type ResultProcessBlock struct {
logs []*types.Log
receipts []*types.Receipt
@@ -112,8 +149,9 @@ type BlockChain struct {
db ethdb.Database // Low level persistent database to store final content in
tomoxDb ethdb.TomoxDatabase
- triegc *prque.Prque // Priority queue mapping block numbers to tries to gc
- gcproc time.Duration // Accumulates canonical block processing for trie dumping
+ snaps *snapshot.Tree // Snapshot tree for fast trie leaf access
+ triegc *prque.Prque // Priority queue mapping block numbers to tries to gc
+ gcproc time.Duration // Accumulates canonical block processing for trie dumping
hc *HeaderChain
rmLogsFeed event.Feed
@@ -175,6 +213,8 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
cacheConfig = &CacheConfig{
TrieNodeLimit: 256 * 1024 * 1024,
TrieTimeLimit: 5 * time.Minute,
+ SnapshotLimit: 256,
+ SnapshotWait: true,
}
}
bodyCache, _ := lru.New(bodyCacheLimit)
@@ -247,6 +287,10 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
}
}
}
+ // Load any existing snapshot, regenerating it if loading failed
+ if bc.cacheConfig.SnapshotLimit > 0 {
+ bc.snaps = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, bc.CurrentBlock().Root(), !bc.cacheConfig.SnapshotWait)
+ }
// Take ownership of this particular state
go bc.update()
return bc, nil
@@ -276,7 +320,7 @@ func (bc *BlockChain) addTomoxDb(tomoxDb ethdb.TomoxDatabase) {
// assumes that the chain manager mutex is held.
func (bc *BlockChain) loadLastState() error {
// Restore the last known head block
- head := GetHeadBlockHash(bc.db)
+ head := rawdb.GetHeadBlockHash(bc.db)
if head == (common.Hash{}) {
// Corrupt or empty database, init from scratch
log.Warn("Empty database, resetting chain")
@@ -289,13 +333,25 @@ func (bc *BlockChain) loadLastState() error {
log.Warn("Head block missing, resetting chain", "hash", head)
return bc.Reset()
}
+ // Make sure the state associated with the block is available
+ if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil {
+ // Dangling block without a state associated, init from scratch
+ log.Warn("Head state missing, repairing chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
+ if err := bc.repair(¤tBlock); err != nil {
+ return err
+ }
+ rawdb.WriteHeadBlockHash(bc.db, currentBlock.Hash())
+ }
+
+ // Everything seems to be fine, set as the head block
+ bc.currentBlock.Store(currentBlock)
+
repair := false
if common.Rewound != uint64(0) {
repair = true
}
// Make sure the state associated with the block is available
- _, err := state.New(currentBlock.Root(), bc.stateCache)
- if err != nil {
+ if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil {
repair = true
} else {
engine, ok := bc.Engine().(*posv.Posv)
@@ -344,7 +400,7 @@ func (bc *BlockChain) loadLastState() error {
// Restore the last known head header
currentHeader := currentBlock.Header()
- if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) {
+ if head := rawdb.GetHeadHeaderHash(bc.db); head != (common.Hash{}) {
if header := bc.GetHeaderByHash(head); header != nil {
currentHeader = header
}
@@ -353,7 +409,7 @@ func (bc *BlockChain) loadLastState() error {
// Restore the last known head fast block
bc.currentFastBlock.Store(currentBlock)
- if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) {
+ if head := rawdb.GetHeadFastBlockHash(bc.db); head != (common.Hash{}) {
if block := bc.GetBlockByHash(head); block != nil {
bc.currentFastBlock.Store(block)
}
@@ -385,7 +441,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
// Rewind the header chain, deleting all block bodies until then
delFn := func(hash common.Hash, num uint64) {
- DeleteBody(bc.db, hash, num)
+ rawdb.DeleteBody(bc.db, hash, num)
}
bc.hc.SetHead(head, delFn)
currentHeader := bc.hc.CurrentHeader()
@@ -402,7 +458,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
bc.currentBlock.Store(bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()))
}
if currentBlock := bc.CurrentBlock(); currentBlock != nil {
- if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil {
+ if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil {
// Rewound state missing, rolled back to before pivot, reset to genesis
bc.currentBlock.Store(bc.genesisBlock)
}
@@ -420,10 +476,10 @@ func (bc *BlockChain) SetHead(head uint64) error {
}
currentBlock := bc.CurrentBlock()
currentFastBlock := bc.CurrentFastBlock()
- if err := WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil {
+ if err := rawdb.WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil {
log.Crit("Failed to reset head full block", "err", err)
}
- if err := WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil {
+ if err := rawdb.WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil {
log.Crit("Failed to reset head fast block", "err", err)
}
return bc.loadLastState()
@@ -445,6 +501,11 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
bc.currentBlock.Store(block)
bc.mu.Unlock()
+ // Destroy any existing state snapshot and regenerate it in the background
+ if bc.snaps != nil {
+ log.Info("Destroy any existing state snapshot and regenerate it in the background", "Snapshot", bc.snaps)
+ bc.snaps.Rebuild(block.Root())
+ }
log.Info("Committed new head block", "number", block.Number(), "hash", hash)
return nil
}
@@ -501,7 +562,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) {
// StateAt returns a new mutable state based on a particular point in time.
func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) {
- return state.New(root, bc.stateCache)
+ return state.New(root, bc.stateCache, bc.snaps)
}
// OrderStateAt returns a new mutable state based on a particular point in time.
@@ -518,6 +579,13 @@ func (bc *BlockChain) OrderStateAt(block *types.Block) (*tradingstate.TradingSta
} else {
return nil, err
}
+ } else {
+ tomoxState, err := tomoXService.GetEmptyTradingState()
+ if err == nil {
+ return tomoxState, nil
+ } else {
+ return nil, err
+ }
}
}
return nil, errors.New("Get tomox state fail")
@@ -562,7 +630,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
log.Crit("Failed to write genesis block TD", "err", err)
}
- if err := WriteBlock(bc.db, genesis); err != nil {
+ if err := rawdb.WriteBlock(bc.db, genesis); err != nil {
log.Crit("Failed to write genesis block", "err", err)
}
bc.genesisBlock = genesis
@@ -585,7 +653,7 @@ func (bc *BlockChain) repair(head **types.Block) error {
for {
// Abort if we've rewound to a head block that does have associated state
if (common.Rewound == uint64(0)) || ((*head).Number().Uint64() < common.Rewound) {
- if _, err := state.New((*head).Root(), bc.stateCache); err == nil {
+ if _, err := state.New((*head).Root(), bc.stateCache, bc.snaps); err == nil {
log.Info("Rewound blockchain to past state", "number", (*head).Number(), "hash", (*head).Hash())
engine, ok := bc.Engine().(*posv.Posv)
if ok {
@@ -658,13 +726,13 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
// Note, this function assumes that the `mu` mutex is held!
func (bc *BlockChain) insert(block *types.Block) {
// If the block is on a side chain or an unknown one, force other heads onto it too
- updateHeads := GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash()
+ updateHeads := rawdb.GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash()
// Add the block to the canonical chain number scheme and mark as the head
- if err := WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil {
+ if err := rawdb.WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil {
log.Crit("Failed to insert block number", "err", err)
}
- if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil {
+ if err := rawdb.WriteHeadBlockHash(bc.db, block.Hash()); err != nil {
log.Crit("Failed to insert head block hash", "err", err)
}
bc.currentBlock.Store(block)
@@ -681,7 +749,7 @@ func (bc *BlockChain) insert(block *types.Block) {
if updateHeads {
bc.hc.SetCurrentHeader(block.Header())
- if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil {
+ if err := rawdb.WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil {
log.Crit("Failed to insert head fast block hash", "err", err)
}
bc.currentFastBlock.Store(block)
@@ -701,7 +769,7 @@ func (bc *BlockChain) GetBody(hash common.Hash) *types.Body {
body := cached.(*types.Body)
return body
}
- body := GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash))
+ body := rawdb.GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash))
if body == nil {
return nil
}
@@ -717,7 +785,7 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue {
if cached, ok := bc.bodyRLPCache.Get(hash); ok {
return cached.(rlp.RawValue)
}
- body := GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash))
+ body := rawdb.GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash))
if len(body) == 0 {
return nil
}
@@ -731,7 +799,7 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool {
if bc.blockCache.Contains(hash) {
return true
}
- ok, _ := bc.db.Has(blockBodyKey(hash, number))
+ ok, _ := bc.db.Has(rawdb.BlockBodyKey(number, hash))
return ok
}
@@ -774,7 +842,7 @@ func (bc *BlockChain) GetBlock(hash common.Hash, number uint64) *types.Block {
if block, ok := bc.blockCache.Get(hash); ok {
return block.(*types.Block)
}
- block := GetBlock(bc.db, hash, number)
+ block := rawdb.GetBlock(bc.db, hash, number)
if block == nil {
return nil
}
@@ -791,7 +859,7 @@ func (bc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block {
// GetBlockByNumber retrieves a block from the database by number, caching it
// (associated with its hash) if found.
func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block {
- hash := GetCanonicalHash(bc.db, number)
+ hash := rawdb.GetCanonicalHash(bc.db, number)
if hash == (common.Hash{}) {
return nil
}
@@ -800,7 +868,7 @@ func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block {
// GetReceiptsByHash retrieves the receipts for all transactions in a given block.
func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts {
- return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash))
+ return rawdb.GetBlockReceipts(bc.db, hash, rawdb.GetBlockNumber(bc.db, hash), bc.chainConfig)
}
// GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors.
@@ -867,18 +935,28 @@ func (bc *BlockChain) SaveData() {
// Make sure no inconsistent state is leaked during insertion
bc.mu.Lock()
defer bc.mu.Unlock()
+ // Ensure that the entirety of the state snapshot is journalled to disk.
+ var snapBase common.Hash
+ if bc.snaps != nil {
+ var err error
+ if snapBase, err = bc.snaps.Journal(bc.CurrentBlock().Root()); err != nil {
+ log.Error("Failed to journal state snapshot", "err", err)
+ }
+ }
// Ensure the state of a recent block is also stored to disk before exiting.
// We're writing three different states to catch different restart scenarios:
// - HEAD: So we don't need to reprocess any blocks in the general case
// - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle
// - HEAD-127: So we have a hard limit on the number of blocks reexecuted
if !bc.cacheConfig.Disabled {
- var tradingTriedb *trie.Database
- var lendingTriedb *trie.Database
+ var (
+ tradingTriedb *trie.Database
+ lendingTriedb *trie.Database
+ tradingService posv.TradingService
+ lendingService posv.LendingService
+ )
engine, _ := bc.Engine().(*posv.Posv)
triedb := bc.stateCache.TrieDB()
- var tradingService posv.TradingService
- var lendingService posv.LendingService
if bc.Config().IsTIPTomoX(bc.CurrentBlock().Number()) && bc.chainConfig.Posv != nil && bc.CurrentBlock().NumberU64() > bc.chainConfig.Posv.Epoch && engine != nil {
tradingService = engine.GetTomoXService()
if tradingService != nil && tradingService.GetStateCache() != nil {
@@ -918,6 +996,12 @@ func (bc *BlockChain) SaveData() {
}
}
}
+ if snapBase != (common.Hash{}) {
+ log.Info("Writing snapshot state to disk", "root", snapBase)
+ if err := triedb.Commit(snapBase, true); err != nil {
+ log.Error("Failed to commit recent state trie", "err", err)
+ }
+ }
for !bc.triegc.Empty() {
triedb.Dereference(bc.triegc.PopItem().(common.Hash))
}
@@ -996,12 +1080,12 @@ func (bc *BlockChain) Rollback(chain []common.Hash) {
if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock.Hash() == hash {
newFastBlock := bc.GetBlock(currentFastBlock.ParentHash(), currentFastBlock.NumberU64()-1)
bc.currentFastBlock.Store(newFastBlock)
- WriteHeadFastBlockHash(bc.db, newFastBlock.Hash())
+ rawdb.WriteHeadFastBlockHash(bc.db, newFastBlock.Hash())
}
if currentBlock := bc.CurrentBlock(); currentBlock.Hash() == hash {
newBlock := bc.GetBlock(currentBlock.ParentHash(), currentBlock.NumberU64()-1)
bc.currentBlock.Store(newBlock)
- WriteHeadBlockHash(bc.db, newBlock.Hash())
+ rawdb.WriteHeadBlockHash(bc.db, newBlock.Hash())
}
}
}
@@ -1086,13 +1170,13 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
return i, fmt.Errorf("failed to set receipts data: %v", err)
}
// Write all the data out into the database
- if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil {
+ if err := rawdb.WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil {
return i, fmt.Errorf("failed to write block body: %v", err)
}
- if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
+ if err := rawdb.WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
return i, fmt.Errorf("failed to write block receipts: %v", err)
}
- if err := WriteTxLookupEntries(batch, block); err != nil {
+ if err := rawdb.WriteTxLookupEntries(batch, block); err != nil {
return i, fmt.Errorf("failed to write lookup metadata: %v", err)
}
stats.processed++
@@ -1118,7 +1202,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case
currentFastBlock := bc.CurrentFastBlock()
if bc.GetTd(currentFastBlock.Hash(), currentFastBlock.NumberU64()).Cmp(td) < 0 {
- if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil {
+ if err := rawdb.WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil {
log.Crit("Failed to update head fast block hash", "err", err)
}
bc.currentFastBlock.Store(head)
@@ -1148,7 +1232,7 @@ func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (e
if err := bc.hc.WriteTd(block.Hash(), block.NumberU64(), td); err != nil {
return err
}
- if err := WriteBlock(bc.db, block); err != nil {
+ if err := rawdb.WriteBlock(bc.db, block); err != nil {
return err
}
return nil
@@ -1178,7 +1262,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
}
// Write other block data using a batch.
batch := bc.db.NewBatch()
- if err := WriteBlock(batch, block); err != nil {
+ if err := rawdb.WriteBlock(batch, block); err != nil {
return NonStatTy, err
}
root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number()))
@@ -1324,7 +1408,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
}
}
}
- if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
+ if err := rawdb.WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
return NonStatTy, err
}
// If the total difficulty is higher than our known, add it to the canonical chain
@@ -1344,11 +1428,11 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
}
}
// Write the positional metadata for transaction and receipt lookups
- if err := WriteTxLookupEntries(batch, block); err != nil {
+ if err := rawdb.WriteTxLookupEntries(batch, block); err != nil {
return NonStatTy, err
}
// Write hash preimages
- if err := WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil {
+ if err := rawdb.WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil {
return NonStatTy, err
}
status = CanonStatTy
@@ -1521,7 +1605,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
} else {
parent = chain[i-1]
}
- statedb, err := state.New(parent.Root(), bc.stateCache)
+ statedb, err := state.New(parent.Root(), bc.stateCache, bc.snaps)
if err != nil {
return i, events, coalescedLogs, err
}
@@ -1532,11 +1616,13 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
}
parentAuthor, _ := bc.Engine().Author(parent.Header())
// clear the previous dry-run cache
- var tradingState *tradingstate.TradingStateDB
- var lendingState *lendingstate.LendingStateDB
- var tradingService posv.TradingService
- var lendingService posv.LendingService
- isSDKNode := false
+ var (
+ tradingState *tradingstate.TradingStateDB
+ lendingState *lendingstate.LendingStateDB
+ tradingService posv.TradingService
+ lendingService posv.LendingService
+ isSDKNode = false
+ )
if bc.Config().IsTIPTomoX(block.Number()) && bc.chainConfig.Posv != nil && engine != nil && block.NumberU64() > bc.chainConfig.Posv.Epoch {
tradingService = engine.GetTomoXService()
lendingService = engine.GetLendingService()
@@ -1627,6 +1713,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
}
feeCapacity := state.GetTRC21FeeCapacityFromStateWithCache(parent.Root(), statedb)
// Process block using the parent state as reference point.
+ substart := time.Now()
receipts, logs, usedGas, err := bc.processor.Process(block, statedb, tradingState, bc.vmConfig, feeCapacity)
if err != nil {
bc.reportBlock(block, receipts, err)
@@ -1638,12 +1725,34 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
bc.reportBlock(block, receipts, err)
return i, events, coalescedLogs, err
}
+ // Update the metrics touched during block processing
+ accountReadTimer.Update(statedb.AccountReads) // Account reads are complete, we can mark them
+ storageReadTimer.Update(statedb.StorageReads) // Storage reads are complete, we can mark them
+ accountUpdateTimer.Update(statedb.AccountUpdates) // Account updates are complete, we can mark them
+ storageUpdateTimer.Update(statedb.StorageUpdates) // Storage updates are complete, we can mark them
+ snapshotAccountReadTimer.Update(statedb.SnapshotAccountReads) // Account reads are complete, we can mark them
+ snapshotStorageReadTimer.Update(statedb.SnapshotStorageReads) // Storage reads are complete, we can mark them
+
+ triehash := statedb.AccountHashes + statedb.StorageHashes // Save to not double count in validation
+ trieproc := statedb.SnapshotAccountReads + statedb.AccountReads + statedb.AccountUpdates
+ trieproc += statedb.SnapshotStorageReads + statedb.StorageReads + statedb.StorageUpdates
+
+ blockExecutionTimer.Update(time.Since(substart) - trieproc - triehash)
+
proctime := time.Since(bstart)
// Write the block to the chain and get the status.
status, err := bc.WriteBlockWithState(block, receipts, statedb, tradingState, lendingState)
if err != nil {
return i, events, coalescedLogs, err
}
+
+ // Update the metrics touched during block commit
+ accountCommitTimer.Update(statedb.AccountCommits) // Account commits are complete, we can mark them
+ storageCommitTimer.Update(statedb.StorageCommits) // Storage commits are complete, we can mark them
+ snapshotCommitTimer.Update(statedb.SnapshotCommits) // Snapshot commits are complete, we can mark them
+
+ blockWriteTimer.Update(time.Since(substart) - statedb.AccountCommits - statedb.StorageCommits - statedb.SnapshotCommits)
+
if bc.chainConfig.Posv != nil {
c := bc.engine.(*posv.Posv)
coinbase := c.Signer()
@@ -1813,7 +1922,7 @@ func (bc *BlockChain) getResultBlock(block *types.Block, verifiedM2 bool) (*Resu
// Create a new statedb using the parent block and report an
// error if it fails.
var parent = bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
- statedb, err := state.New(parent.Root(), bc.stateCache)
+ statedb, err := state.New(parent.Root(), bc.stateCache, bc.snaps)
if err != nil {
return nil, err
}
@@ -2120,7 +2229,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// These logs are later announced as deleted.
collectLogs = func(h common.Hash) {
// Coalesce logs and set 'Removed'.
- receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h))
+ receipts := rawdb.GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h), bc.chainConfig)
for _, receipt := range receipts {
for _, log := range receipt.Logs {
del := *log
@@ -2189,7 +2298,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// insert the block in the canonical way, re-writing history
bc.insert(newChain[i])
// write lookup entries for hash based transaction/receipt searches
- if err := WriteTxLookupEntries(bc.db, newChain[i]); err != nil {
+ if err := rawdb.WriteTxLookupEntries(bc.db, newChain[i]); err != nil {
return err
}
addedTxs = append(addedTxs, newChain[i].Transactions()...)
@@ -2199,7 +2308,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// When transactions get deleted from the database that means the
// receipts that were created in the fork must also be deleted
for _, tx := range diff {
- DeleteTxLookupEntry(bc.db, tx.Hash())
+ rawdb.DeleteTxLookupEntry(bc.db, tx.Hash())
}
if len(deletedLogs) > 0 {
go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs})
diff --git a/core/blockchain_test.go b/core/blockchain_test.go
index 686092411..ac911ca3f 100644
--- a/core/blockchain_test.go
+++ b/core/blockchain_test.go
@@ -18,7 +18,6 @@ package core
import (
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"math/rand"
"sync"
@@ -27,11 +26,13 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/trie"
)
// Test fork of length N starting from block i
@@ -113,7 +114,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
}
return err
}
- statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache)
+ statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache, nil)
if err != nil {
return err
}
@@ -128,8 +129,8 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
return err
}
blockchain.mu.Lock()
- WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
- WriteBlock(blockchain.db, block)
+ rawdb.WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
+ rawdb.WriteBlock(blockchain.db, block)
statedb.Commit(true)
blockchain.mu.Unlock()
}
@@ -146,8 +147,8 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error
}
// Manually insert the header into the database, but don't reorganise (allows subsequent testing)
blockchain.mu.Lock()
- WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
- WriteHeader(blockchain.db, header)
+ rawdb.WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
+ rawdb.WriteHeader(blockchain.db, header)
blockchain.mu.Unlock()
}
return nil
@@ -173,7 +174,7 @@ func TestLastBlock(t *testing.T) {
if _, err := blockchain.InsertChain(blocks); err != nil {
t.Fatalf("Failed to insert block: %v", err)
}
- if blocks[len(blocks)-1].Hash() != GetHeadBlockHash(blockchain.db) {
+ if blocks[len(blocks)-1].Hash() != rawdb.GetHeadBlockHash(blockchain.db) {
t.Fatalf("Write/Get HeadBlockHash failed")
}
}
@@ -617,18 +618,18 @@ func TestFastVsFullChains(t *testing.T) {
}
if fblock, ablock := fast.GetBlockByHash(hash), archive.GetBlockByHash(hash); fblock.Hash() != ablock.Hash() {
t.Errorf("block #%d [%x]: block mismatch: have %v, want %v", num, hash, fblock, ablock)
- } else if types.DeriveSha(fblock.Transactions()) != types.DeriveSha(ablock.Transactions()) {
+ } else if types.DeriveSha(fblock.Transactions(), new(trie.StackTrie)) != types.DeriveSha(ablock.Transactions(), new(trie.StackTrie)) {
t.Errorf("block #%d [%x]: transactions mismatch: have %v, want %v", num, hash, fblock.Transactions(), ablock.Transactions())
} else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) {
t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles())
}
- if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) {
+ if freceipts, areceipts := rawdb.GetBlockReceipts(fastDb, hash, rawdb.GetBlockNumber(fastDb, hash), fast.Config()), rawdb.GetBlockReceipts(archiveDb, hash, rawdb.GetBlockNumber(archiveDb, hash), fast.Config()); types.DeriveSha(freceipts, trie.NewStackTrie(nil)) != types.DeriveSha(areceipts, trie.NewStackTrie(nil)) {
t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts)
}
}
// Check that the canonical chains are the same between the databases
for i := 0; i < len(blocks)+1; i++ {
- if fhash, ahash := GetCanonicalHash(fastDb, uint64(i)), GetCanonicalHash(archiveDb, uint64(i)); fhash != ahash {
+ if fhash, ahash := rawdb.GetCanonicalHash(fastDb, uint64(i)), rawdb.GetCanonicalHash(archiveDb, uint64(i)); fhash != ahash {
t.Errorf("block #%d: canonical hash mismatch: have %v, want %v", i, fhash, ahash)
}
}
@@ -804,28 +805,28 @@ func TestChainTxReorgs(t *testing.T) {
// removed tx
for i, tx := range (types.Transactions{pastDrop, freshDrop}) {
- if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn != nil {
+ if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn != nil {
t.Errorf("drop %d: tx %v found while shouldn't have been", i, txn)
}
- if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt != nil {
+ if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt != nil {
t.Errorf("drop %d: receipt %v found while shouldn't have been", i, rcpt)
}
}
// added tx
for i, tx := range (types.Transactions{pastAdd, freshAdd, futureAdd}) {
- if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil {
+ if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn == nil {
t.Errorf("add %d: expected tx to be found", i)
}
- if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil {
+ if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt == nil {
t.Errorf("add %d: expected receipt to be found", i)
}
}
// shared tx
for i, tx := range (types.Transactions{postponed, swapped}) {
- if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil {
+ if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn == nil {
t.Errorf("share %d: expected tx to be found", i)
}
- if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil {
+ if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt == nil {
t.Errorf("share %d: expected receipt to be found", i)
}
}
@@ -980,14 +981,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) {
// try to retrieve a block by its canonical hash and see if the block data can be retrieved.
for {
- ch := GetCanonicalHash(blockchain.db, block.NumberU64())
+ ch := rawdb.GetCanonicalHash(blockchain.db, block.NumberU64())
if ch == (common.Hash{}) {
continue // busy wait for canonical hash to be written
}
if ch != block.Hash() {
t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex())
}
- fb := GetBlock(blockchain.db, ch, block.NumberU64())
+ fb := rawdb.GetBlock(blockchain.db, ch, block.NumberU64())
if fb == nil {
t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex())
}
diff --git a/core/chain_indexer.go b/core/chain_indexer.go
index 95190eea9..41f391990 100644
--- a/core/chain_indexer.go
+++ b/core/chain_indexer.go
@@ -24,6 +24,7 @@ import (
"time"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/event"
@@ -206,7 +207,7 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainE
// TODO(karalabe): This operation is expensive and might block, causing the event system to
// potentially also lock up. We need to do with on a different thread somehow.
- if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil {
+ if h := rawdb.FindCommonAncestor(c.chainDb, prevHeader, header); h != nil {
c.newHead(h.Number.Uint64(), true)
}
}
@@ -349,11 +350,11 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com
}
for number := section * c.sectionSize; number < (section+1)*c.sectionSize; number++ {
- hash := GetCanonicalHash(c.chainDb, number)
+ hash := rawdb.GetCanonicalHash(c.chainDb, number)
if hash == (common.Hash{}) {
return common.Hash{}, fmt.Errorf("canonical block #%d unknown", number)
}
- header := GetHeader(c.chainDb, hash, number)
+ header := rawdb.GetHeader(c.chainDb, hash, number)
if header == nil {
return common.Hash{}, fmt.Errorf("block #%d [%x…] not found", number, hash[:4])
} else if header.ParentHash != lastHead {
diff --git a/core/chain_indexer_test.go b/core/chain_indexer_test.go
index a954c062d..3a50819b9 100644
--- a/core/chain_indexer_test.go
+++ b/core/chain_indexer_test.go
@@ -18,13 +18,13 @@ package core
import (
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"math/rand"
"testing"
"time"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
)
@@ -92,10 +92,10 @@ func testChainIndexer(t *testing.T, count int) {
inject := func(number uint64) {
header := &types.Header{Number: big.NewInt(int64(number)), Extra: big.NewInt(rand.Int63()).Bytes()}
if number > 0 {
- header.ParentHash = GetCanonicalHash(db, number-1)
+ header.ParentHash = rawdb.GetCanonicalHash(db, number-1)
}
- WriteHeader(db, header)
- WriteCanonicalHash(db, header.Hash(), number)
+ rawdb.WriteHeader(db, header)
+ rawdb.WriteCanonicalHash(db, header.Hash(), number)
}
// Start indexer with an already existing chain
for i := uint64(0); i <= 100; i++ {
diff --git a/core/chain_makers.go b/core/chain_makers.go
index ac7c311fd..1e2aeb88b 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -18,12 +18,12 @@ package core
import (
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/consensus/misc"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
@@ -115,6 +115,15 @@ func (b *BlockGen) AddTxWithChain(bc *BlockChain, tx *types.Transaction) {
}
}
+// AddUncheckedTx forcefully adds a transaction to the block without any
+// validation.
+//
+// AddUncheckedTx will cause consensus failures when used during real
+// chain processing. This is best used in conjunction with raw block insertion.
+func (b *BlockGen) AddUncheckedTx(tx *types.Transaction) {
+ b.txs = append(b.txs, tx)
+}
+
// Number returns the block number of the block being generated.
func (b *BlockGen) Number() *big.Int {
return new(big.Int).Set(b.header.Number)
@@ -225,7 +234,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
return nil, nil
}
for i := 0; i < n; i++ {
- statedb, err := state.New(parent.Root(), state.NewDatabase(db))
+ statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil)
if err != nil {
panic(err)
}
diff --git a/core/error.go b/core/error.go
index 63be6ab83..14177d13f 100644
--- a/core/error.go
+++ b/core/error.go
@@ -33,9 +33,23 @@ var (
// next one expected based on the local chain.
ErrNonceTooHigh = errors.New("nonce too high")
+ // ErrNonceMax is returned if the nonce of a transaction sender account has
+ // maximum allowed value and would become invalid if incremented.
+ ErrNonceMax = errors.New("nonce has max value")
+
ErrNotPoSV = errors.New("Posv not found in config")
ErrNotFoundM1 = errors.New("list M1 not found ")
ErrStopPreparingBlock = errors.New("stop calculating a block not verified by M2")
+
+ // ErrSenderNoEOA is returned if the sender of a transaction is a contract.
+ ErrSenderNoEOA = errors.New("sender not an eoa")
+
+ // ErrGasUintOverflow is returned when calculating gas usage.
+ ErrGasUintOverflow = errors.New("gas uint64 overflow")
+
+ // ErrInsufficientFundsForTransfer is returned if the transaction sender doesn't
+ // have enough funds for transfer(topmost call only).
+ ErrInsufficientFundsForTransfer = errors.New("insufficient funds for transfer")
)
diff --git a/core/evm.go b/core/evm.go
index 04636999b..f3ac62a73 100644
--- a/core/evm.go
+++ b/core/evm.go
@@ -26,7 +26,7 @@ import (
)
// NewEVMContext creates a new context for use in the EVM.
-func NewEVMContext(msg Message, header *types.Header, chain consensus.ChainContext, author *common.Address) vm.Context {
+func NewEVMContext(msg *Message, header *types.Header, chain consensus.ChainContext, author *common.Address) vm.Context {
// If we don't have an explicit author (i.e. not mining), extract from the header
var beneficiary common.Address
if author == nil {
@@ -38,13 +38,13 @@ func NewEVMContext(msg Message, header *types.Header, chain consensus.ChainConte
CanTransfer: CanTransfer,
Transfer: Transfer,
GetHash: GetHashFn(header, chain),
- Origin: msg.From(),
+ Origin: msg.From,
Coinbase: beneficiary,
BlockNumber: new(big.Int).Set(header.Number),
Time: new(big.Int).Set(header.Time),
Difficulty: new(big.Int).Set(header.Difficulty),
GasLimit: header.GasLimit,
- GasPrice: new(big.Int).Set(msg.GasPrice()),
+ GasPrice: new(big.Int).Set(msg.GasPrice),
}
}
diff --git a/core/genesis.go b/core/genesis.go
index e1b7185a4..77970085f 100644
--- a/core/genesis.go
+++ b/core/genesis.go
@@ -22,10 +22,11 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"strings"
+ "github.com/tomochain/tomochain/core/rawdb"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/common/math"
@@ -35,6 +36,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
+ "github.com/tomochain/tomochain/trie"
)
//go:generate gencodec -type Genesis -field-override genesisSpecMarshaling -out gen_genesis.go
@@ -140,10 +142,10 @@ func (e *GenesisMismatchError) Error() string {
// SetupGenesisBlock writes or updates the genesis block in db.
// The block that will be used is:
//
-// genesis == nil genesis != nil
-// +------------------------------------------
-// db has no genesis | main-net default | genesis
-// db has genesis | from DB | genesis (if compatible)
+// genesis == nil genesis != nil
+// +------------------------------------------
+// db has no genesis | main-net default | genesis
+// db has genesis | from DB | genesis (if compatible)
//
// The stored chain configuration will be updated if it is compatible (i.e. does not
// specify a fork block below the local head block). In case of a conflict, the
@@ -156,7 +158,7 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig
}
// Just commit the new block if there is no stored genesis block.
- stored := GetCanonicalHash(db, 0)
+ stored := rawdb.GetCanonicalHash(db, 0)
if (stored == common.Hash{}) {
if genesis == nil {
log.Info("Writing default main-net genesis block")
@@ -178,12 +180,12 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig
// Get the existing chain configuration.
newcfg := genesis.configOrDefault(stored)
- storedcfg, err := GetChainConfig(db, stored)
+ storedcfg, err := rawdb.GetChainConfig(db, stored)
if err != nil {
- if err == ErrChainConfigNotFound {
+ if err == rawdb.ErrChainConfigNotFound {
// This case happens if a genesis write was interrupted.
log.Warn("Found genesis block without chain config")
- err = WriteChainConfig(db, stored, newcfg)
+ err = rawdb.WriteChainConfig(db, stored, newcfg)
}
return newcfg, stored, err
}
@@ -196,15 +198,15 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig
// Check config compatibility and write the config. Compatibility errors
// are returned to the caller unless we're already at block zero.
- height := GetBlockNumber(db, GetHeadHeaderHash(db))
- if height == missingNumber {
+ height := rawdb.GetBlockNumber(db, rawdb.GetHeadHeaderHash(db))
+ if height == rawdb.MissingNumber {
return newcfg, stored, fmt.Errorf("missing block number for head header hash")
}
compatErr := storedcfg.CheckCompatible(newcfg, height)
if compatErr != nil && height != 0 && compatErr.RewindTo != 0 {
return newcfg, stored, compatErr
}
- return newcfg, stored, WriteChainConfig(db, stored, newcfg)
+ return newcfg, stored, rawdb.WriteChainConfig(db, stored, newcfg)
}
func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig {
@@ -226,7 +228,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
if db == nil {
db = rawdb.NewMemoryDatabase()
}
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
for addr, account := range g.Alloc {
statedb.AddBalance(addr, account.Balance)
statedb.SetCode(addr, account.Code)
@@ -258,7 +260,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
statedb.Commit(false)
statedb.Database().TrieDB().Commit(root, true)
- return types.NewBlock(head, nil, nil, nil)
+ return types.NewBlock(head, nil, nil, nil, new(trie.StackTrie))
}
// Commit writes the block and state of a genesis specification to the database.
@@ -268,29 +270,29 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) {
if block.Number().Sign() != 0 {
return nil, fmt.Errorf("can't commit genesis block with number > 0")
}
- if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil {
+ if err := rawdb.WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil {
return nil, err
}
- if err := WriteBlock(db, block); err != nil {
+ if err := rawdb.WriteBlock(db, block); err != nil {
return nil, err
}
- if err := WriteBlockReceipts(db, block.Hash(), block.NumberU64(), nil); err != nil {
+ if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), nil); err != nil {
return nil, err
}
- if err := WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
+ if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
return nil, err
}
- if err := WriteHeadBlockHash(db, block.Hash()); err != nil {
+ if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil {
return nil, err
}
- if err := WriteHeadHeaderHash(db, block.Hash()); err != nil {
+ if err := rawdb.WriteHeadHeaderHash(db, block.Hash()); err != nil {
return nil, err
}
config := g.Config
if config == nil {
config = params.AllEthashProtocolChanges
}
- return block, WriteChainConfig(db, block.Hash(), config)
+ return block, rawdb.WriteChainConfig(db, block.Hash(), config)
}
// MustCommit writes the genesis block and state to db, panicking on error.
diff --git a/core/genesis_test.go b/core/genesis_test.go
index 177798a5d..ee32b6705 100644
--- a/core/genesis_test.go
+++ b/core/genesis_test.go
@@ -17,7 +17,6 @@
package core
import (
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"reflect"
"testing"
@@ -25,6 +24,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/params"
@@ -155,7 +155,7 @@ func TestSetupGenesis(t *testing.T) {
t.Errorf("%s: returned hash %s, want %s", test.name, hash.Hex(), test.wantHash.Hex())
} else if err == nil {
// Check database content.
- stored := GetBlock(db, test.wantHash, 0)
+ stored := rawdb.GetBlock(db, test.wantHash, 0)
if stored.Hash() != test.wantHash {
t.Errorf("%s: block in DB has hash %s, want %s", test.name, stored.Hash(), test.wantHash)
}
diff --git a/core/headerchain.go b/core/headerchain.go
index 8365f2127..feed409ca 100644
--- a/core/headerchain.go
+++ b/core/headerchain.go
@@ -26,9 +26,11 @@ import (
"sync/atomic"
"time"
- "github.com/hashicorp/golang-lru"
+ lru "github.com/hashicorp/golang-lru"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/log"
@@ -66,9 +68,9 @@ type HeaderChain struct {
}
// NewHeaderChain creates a new HeaderChain structure.
-// getValidator should return the parent's validator
-// procInterrupt points to the parent's interrupt semaphore
-// wg points to the parent's shutdown wait group
+// getValidator should return the parent's validator
+// procInterrupt points to the parent's interrupt semaphore
+// wg points to the parent's shutdown wait group
func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, procInterrupt func() bool) (*HeaderChain, error) {
headerCache, _ := lru.New(headerCacheLimit)
tdCache, _ := lru.New(tdCacheLimit)
@@ -97,7 +99,7 @@ func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine c
}
hc.currentHeader.Store(hc.genesisHeader)
- if head := GetHeadBlockHash(chainDb); head != (common.Hash{}) {
+ if head := rawdb.GetHeadBlockHash(chainDb); head != (common.Hash{}) {
if chead := hc.GetHeaderByHash(head); chead != nil {
hc.currentHeader.Store(chead)
}
@@ -113,8 +115,8 @@ func (hc *HeaderChain) GetBlockNumber(hash common.Hash) uint64 {
if cached, ok := hc.numberCache.Get(hash); ok {
return cached.(uint64)
}
- number := GetBlockNumber(hc.chainDb, hash)
- if number != missingNumber {
+ number := rawdb.GetBlockNumber(hc.chainDb, hash)
+ if number != rawdb.MissingNumber {
hc.numberCache.Add(hash, number)
}
return number
@@ -147,7 +149,7 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er
if err := hc.WriteTd(hash, number, externTd); err != nil {
log.Crit("Failed to write header total difficulty", "err", err)
}
- if err := WriteHeader(hc.chainDb, header); err != nil {
+ if err := rawdb.WriteHeader(hc.chainDb, header); err != nil {
log.Crit("Failed to write header content", "err", err)
}
// If the total difficulty is higher than our known, add it to the canonical chain
@@ -156,11 +158,11 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er
if externTd.Cmp(localTd) > 0 || (externTd.Cmp(localTd) == 0 && mrand.Float64() < 0.5) {
// Delete any canonical number assignments above the new head
for i := number + 1; ; i++ {
- hash := GetCanonicalHash(hc.chainDb, i)
+ hash := rawdb.GetCanonicalHash(hc.chainDb, i)
if hash == (common.Hash{}) {
break
}
- DeleteCanonicalHash(hc.chainDb, i)
+ rawdb.DeleteCanonicalHash(hc.chainDb, i)
}
// Overwrite any stale canonical number assignments
var (
@@ -168,18 +170,18 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er
headNumber = header.Number.Uint64() - 1
headHeader = hc.GetHeader(headHash, headNumber)
)
- for GetCanonicalHash(hc.chainDb, headNumber) != headHash {
- WriteCanonicalHash(hc.chainDb, headHash, headNumber)
+ for rawdb.GetCanonicalHash(hc.chainDb, headNumber) != headHash {
+ rawdb.WriteCanonicalHash(hc.chainDb, headHash, headNumber)
headHash = headHeader.ParentHash
headNumber = headHeader.Number.Uint64() - 1
headHeader = hc.GetHeader(headHash, headNumber)
}
// Extend the canonical chain with the new header
- if err := WriteCanonicalHash(hc.chainDb, hash, number); err != nil {
+ if err := rawdb.WriteCanonicalHash(hc.chainDb, hash, number); err != nil {
log.Crit("Failed to insert header number", "err", err)
}
- if err := WriteHeadHeaderHash(hc.chainDb, hash); err != nil {
+ if err := rawdb.WriteHeadHeaderHash(hc.chainDb, hash); err != nil {
log.Crit("Failed to insert head header hash", "err", err)
}
hc.currentHeaderHash = hash
@@ -316,7 +318,7 @@ func (hc *HeaderChain) GetTd(hash common.Hash, number uint64) *big.Int {
if cached, ok := hc.tdCache.Get(hash); ok {
return cached.(*big.Int)
}
- td := GetTd(hc.chainDb, hash, number)
+ td := rawdb.GetTd(hc.chainDb, hash, number)
if td == nil {
return nil
}
@@ -334,7 +336,7 @@ func (hc *HeaderChain) GetTdByHash(hash common.Hash) *big.Int {
// WriteTd stores a block's total difficulty into the database, also caching it
// along the way.
func (hc *HeaderChain) WriteTd(hash common.Hash, number uint64, td *big.Int) error {
- if err := WriteTd(hc.chainDb, hash, number, td); err != nil {
+ if err := rawdb.WriteTd(hc.chainDb, hash, number, td); err != nil {
return err
}
hc.tdCache.Add(hash, new(big.Int).Set(td))
@@ -348,7 +350,7 @@ func (hc *HeaderChain) GetHeader(hash common.Hash, number uint64) *types.Header
if header, ok := hc.headerCache.Get(hash); ok {
return header.(*types.Header)
}
- header := GetHeader(hc.chainDb, hash, number)
+ header := rawdb.GetHeader(hc.chainDb, hash, number)
if header == nil {
return nil
}
@@ -368,14 +370,14 @@ func (hc *HeaderChain) HasHeader(hash common.Hash, number uint64) bool {
if hc.numberCache.Contains(hash) || hc.headerCache.Contains(hash) {
return true
}
- ok, _ := hc.chainDb.Has(headerKey(hash, number))
+ ok, _ := hc.chainDb.Has(rawdb.HeaderKey(number, hash))
return ok
}
// GetHeaderByNumber retrieves a block header from the database by number,
// caching it (associated with its hash) if found.
func (hc *HeaderChain) GetHeaderByNumber(number uint64) *types.Header {
- hash := GetCanonicalHash(hc.chainDb, number)
+ hash := rawdb.GetCanonicalHash(hc.chainDb, number)
if hash == (common.Hash{}) {
return nil
}
@@ -390,7 +392,7 @@ func (hc *HeaderChain) CurrentHeader() *types.Header {
// SetCurrentHeader sets the current head header of the canonical chain.
func (hc *HeaderChain) SetCurrentHeader(head *types.Header) {
- if err := WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil {
+ if err := rawdb.WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil {
log.Crit("Failed to insert head header hash", "err", err)
}
hc.currentHeader.Store(head)
@@ -416,13 +418,13 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) {
if delFn != nil {
delFn(hash, num)
}
- DeleteHeader(hc.chainDb, hash, num)
- DeleteTd(hc.chainDb, hash, num)
+ rawdb.DeleteHeader(hc.chainDb, hash, num)
+ rawdb.DeleteTd(hc.chainDb, hash, num)
hc.currentHeader.Store(hc.GetHeader(hdr.ParentHash, hdr.Number.Uint64()-1))
}
// Roll back the canonical chain numbering
for i := height; i > head; i-- {
- DeleteCanonicalHash(hc.chainDb, i)
+ rawdb.DeleteCanonicalHash(hc.chainDb, i)
}
// Clear out any stale content from the caches
hc.headerCache.Purge()
@@ -434,7 +436,7 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) {
}
hc.currentHeaderHash = hc.CurrentHeader().Hash()
- if err := WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil {
+ if err := rawdb.WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil {
log.Crit("Failed to reset head header hash", "err", err)
}
}
diff --git a/core/database_util.go b/core/rawdb/accessors_chain.go
similarity index 51%
rename from core/database_util.go
rename to core/rawdb/accessors_chain.go
index a5ab18687..40a5b1d3e 100644
--- a/core/database_util.go
+++ b/core/rawdb/accessors_chain.go
@@ -14,22 +14,17 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package core
+package rawdb
import (
"bytes"
"encoding/binary"
- "encoding/json"
- "errors"
- "fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/metrics"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
)
@@ -44,46 +39,6 @@ type DatabaseDeleter interface {
Delete(key []byte) error
}
-var (
- headHeaderKey = []byte("LastHeader")
- headBlockKey = []byte("LastBlock")
- headFastKey = []byte("LastFast")
- trieSyncKey = []byte("TrieSync")
-
- // Data item prefixes (use single byte to avoid mixing data types, avoid `i`).
- headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header
- tdSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + tdSuffix -> td
- numSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + numSuffix -> hash
- blockHashPrefix = []byte("H") // blockHashPrefix + hash -> num (uint64 big endian)
- bodyPrefix = []byte("b") // bodyPrefix + num (uint64 big endian) + hash -> block body
- blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts
- lookupPrefix = []byte("l") // lookupPrefix + hash -> transaction/receipt lookup metadata
- bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits
-
- preimagePrefix = "secure-key-" // preimagePrefix + hash -> preimage
- configPrefix = []byte("ethereum-config-") // config prefix for the db
-
- // Chain index prefixes (use `i` + single byte to avoid mixing data types).
- BloomBitsIndexPrefix = []byte("iB") // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress
-
- // used by old db, now only used for conversion
- oldReceiptsPrefix = []byte("receipts-")
- oldTxMetaSuffix = []byte{0x01}
-
- ErrChainConfigNotFound = errors.New("ChainConfig not found") // general config not found error
-
- preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil)
- preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil)
-)
-
-// TxLookupEntry is a positional metadata to help looking up the data content of
-// a transaction or receipt given only its hash.
-type TxLookupEntry struct {
- BlockHash common.Hash
- BlockIndex uint64
- Index uint64
-}
-
// encodeBlockNumber encodes a block number as big endian uint64
func encodeBlockNumber(number uint64) []byte {
enc := make([]byte, 8)
@@ -93,23 +48,23 @@ func encodeBlockNumber(number uint64) []byte {
// GetCanonicalHash retrieves a hash assigned to a canonical block number.
func GetCanonicalHash(db DatabaseReader, number uint64) common.Hash {
- data, _ := db.Get(append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...))
+ data, _ := db.Get(headerHashKey(number))
if len(data) == 0 {
return common.Hash{}
}
return common.BytesToHash(data)
}
-// missingNumber is returned by GetBlockNumber if no header with the
+// MissingNumber is returned by GetBlockNumber if no header with the
// given block hash has been stored in the database
-const missingNumber = uint64(0xffffffffffffffff)
+const MissingNumber = uint64(0xffffffffffffffff)
// GetBlockNumber returns the block number assigned to a block hash
// if the corresponding header is present in the database
func GetBlockNumber(db DatabaseReader, hash common.Hash) uint64 {
- data, _ := db.Get(append(blockHashPrefix, hash.Bytes()...))
+ data, _ := db.Get(headerNumberKey(hash))
if len(data) != 8 {
- return missingNumber
+ return MissingNumber
}
return binary.BigEndian.Uint64(data)
}
@@ -149,7 +104,7 @@ func GetHeadFastBlockHash(db DatabaseReader) common.Hash {
}
// GetTrieSyncProgress retrieves the number of tries nodes fast synced to allow
-// reportinc correct numbers across restarts.
+// reporting correct numbers across restarts.
func GetTrieSyncProgress(db DatabaseReader) uint64 {
data, _ := db.Get(trieSyncKey)
if len(data) == 0 {
@@ -161,7 +116,7 @@ func GetTrieSyncProgress(db DatabaseReader) uint64 {
// GetHeaderRLP retrieves a block header in its raw RLP database encoding, or nil
// if the header's not found.
func GetHeaderRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue {
- data, _ := db.Get(headerKey(hash, number))
+ data, _ := db.Get(HeaderKey(number, hash))
return data
}
@@ -182,19 +137,11 @@ func GetHeader(db DatabaseReader, hash common.Hash, number uint64) *types.Header
// GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding.
func GetBodyRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue {
- data, _ := db.Get(blockBodyKey(hash, number))
+ data, _ := db.Get(BlockBodyKey(number, hash))
return data
}
-func headerKey(hash common.Hash, number uint64) []byte {
- return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
-}
-
-func blockBodyKey(hash common.Hash, number uint64) []byte {
- return append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
-}
-
-// GetBody retrieves the block body (transactons, uncles) corresponding to the
+// GetBody retrieves the block body (transactions, uncles) corresponding to the
// hash, nil if none found.
func GetBody(db DatabaseReader, hash common.Hash, number uint64) *types.Body {
data := GetBodyRLP(db, hash, number)
@@ -212,7 +159,7 @@ func GetBody(db DatabaseReader, hash common.Hash, number uint64) *types.Body {
// GetTd retrieves a block's total difficulty corresponding to the hash, nil if
// none found.
func GetTd(db DatabaseReader, hash common.Hash, number uint64) *big.Int {
- data, _ := db.Get(append(append(append(headerPrefix, encodeBlockNumber(number)...), hash[:]...), tdSuffix...))
+ data, _ := db.Get(headerTDKey(number, hash))
if len(data) == 0 {
return nil
}
@@ -244,14 +191,25 @@ func GetBlock(db DatabaseReader, hash common.Hash, number uint64) *types.Block {
return types.NewBlockWithHeader(header).WithBody(body.Transactions, body.Uncles)
}
-// GetBlockReceipts retrieves the receipts generated by the transactions included
-// in a block given by its hash.
-func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts {
+// ReadReceiptsRLP retrieves all the transaction receipts belonging to a block in RLP encoding.
+func ReadReceiptsRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue {
data, _ := db.Get(append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash[:]...))
if len(data) == 0 {
return nil
}
- storageReceipts := []*types.ReceiptForStorage{}
+ return data
+}
+
+// ReadRawReceipts retrieves all the transaction receipts belonging to a block.
+// The receipt metadata fields are not guaranteed to be populated, so they
+// should not be used. Use ReadReceipts instead if the metadata is needed.
+func ReadRawReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts {
+ // Retrieve the flattened receipt slice
+ data := ReadReceiptsRLP(db, hash, number)
+ if len(data) == 0 {
+ return nil
+ }
+ var storageReceipts []*types.ReceiptForStorage
if err := rlp.DecodeBytes(data, &storageReceipts); err != nil {
log.Error("Invalid receipt array RLP", "hash", hash, "err", err)
return nil
@@ -263,100 +221,30 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types.
return receipts
}
-// GetTxLookupEntry retrieves the positional metadata associated with a transaction
-// hash to allow retrieving the transaction or receipt by hash.
-func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) {
- // Load the positional metadata from disk and bail if it fails
- data, _ := db.Get(append(lookupPrefix, hash.Bytes()...))
- if len(data) == 0 {
- return common.Hash{}, 0, 0
- }
- // Parse and return the contents of the lookup entry
- var entry TxLookupEntry
- if err := rlp.DecodeBytes(data, &entry); err != nil {
- log.Error("Invalid lookup entry RLP", "hash", hash, "err", err)
- return common.Hash{}, 0, 0
- }
- return entry.BlockHash, entry.BlockIndex, entry.Index
-}
-
-// GetTransaction retrieves a specific transaction from the database, along with
-// its added positional metadata.
-func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, common.Hash, uint64, uint64) {
- // Retrieve the lookup metadata and resolve the transaction from the body
- blockHash, blockNumber, txIndex := GetTxLookupEntry(db, hash)
-
- if blockHash != (common.Hash{}) {
- body := GetBody(db, blockHash, blockNumber)
- if body == nil || len(body.Transactions) <= int(txIndex) {
- log.Error("Transaction referenced missing", "number", blockNumber, "hash", blockHash, "index", txIndex)
- return nil, common.Hash{}, 0, 0
- }
- return body.Transactions[txIndex], blockHash, blockNumber, txIndex
- }
- // Old transaction representation, load the transaction and it's metadata separately
- data, _ := db.Get(hash.Bytes())
- if len(data) == 0 {
- return nil, common.Hash{}, 0, 0
- }
- var tx types.Transaction
- if err := rlp.DecodeBytes(data, &tx); err != nil {
- return nil, common.Hash{}, 0, 0
- }
- // Retrieve the blockchain positional metadata
- data, _ = db.Get(append(hash.Bytes(), oldTxMetaSuffix...))
- if len(data) == 0 {
- return nil, common.Hash{}, 0, 0
- }
- var entry TxLookupEntry
- if err := rlp.DecodeBytes(data, &entry); err != nil {
- return nil, common.Hash{}, 0, 0
- }
- return &tx, entry.BlockHash, entry.BlockIndex, entry.Index
-}
-
-// GetReceipt retrieves a specific transaction receipt from the database, along with
-// its added positional metadata.
-func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Hash, uint64, uint64) {
- // Retrieve the lookup metadata and resolve the receipt from the receipts
- blockHash, blockNumber, receiptIndex := GetTxLookupEntry(db, hash)
-
- if blockHash != (common.Hash{}) {
- receipts := GetBlockReceipts(db, blockHash, blockNumber)
- if len(receipts) <= int(receiptIndex) {
- log.Error("Receipt refereced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex)
- return nil, common.Hash{}, 0, 0
- }
- return receipts[receiptIndex], blockHash, blockNumber, receiptIndex
+// GetBlockReceipts retrieves the receipts generated by the transactions included
+// in a block given by its hash.
+func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64, config *params.ChainConfig) types.Receipts {
+ // We're deriving many fields from the block body, retrieve beside the receipt
+ receipts := ReadRawReceipts(db, hash, number)
+ if receipts == nil {
+ return nil
}
- // Old receipt representation, load the receipt and set an unknown metadata
- data, _ := db.Get(append(oldReceiptsPrefix, hash[:]...))
- if len(data) == 0 {
- return nil, common.Hash{}, 0, 0
+ body := GetBody(db, hash, number)
+ if body == nil {
+ log.Error("Missing body but have receipt", "hash", hash, "number", number)
+ return nil
}
- var receipt types.ReceiptForStorage
- err := rlp.DecodeBytes(data, &receipt)
- if err != nil {
- log.Error("Invalid receipt RLP", "hash", hash, "err", err)
+ if err := receipts.DeriveFields(config, hash, number, body.Transactions); err != nil {
+ log.Error("Failed to derive block receipts fields", "hash", hash, "number", number, "err", err)
+ return nil
}
- return (*types.Receipt)(&receipt), common.Hash{}, 0, 0
-}
-
-// GetBloomBits retrieves the compressed bloom bit vector belonging to the given
-// section and bit index from the.
-func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) {
- key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...)
- binary.BigEndian.PutUint16(key[1:], uint16(bit))
- binary.BigEndian.PutUint64(key[3:], section)
-
- return db.Get(key)
+ return receipts
}
// WriteCanonicalHash stores the canonical hash for the given block number.
func WriteCanonicalHash(db ethdb.KeyValueWriter, hash common.Hash, number uint64) error {
- key := append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...)
- if err := db.Put(key, hash.Bytes()); err != nil {
+ if err := db.Put(headerHashKey(number), hash.Bytes()); err != nil {
log.Crit("Failed to store number to hash mapping", "err", err)
}
return nil
@@ -401,15 +289,13 @@ func WriteHeader(db ethdb.KeyValueWriter, header *types.Header) error {
if err != nil {
return err
}
- hash := header.Hash().Bytes()
+ hash := header.Hash()
num := header.Number.Uint64()
encNum := encodeBlockNumber(num)
- key := append(blockHashPrefix, hash...)
- if err := db.Put(key, encNum); err != nil {
+ if err := db.Put(headerNumberKey(hash), encNum); err != nil {
log.Crit("Failed to store hash to number mapping", "err", err)
}
- key = append(append(headerPrefix, encNum...), hash...)
- if err := db.Put(key, data); err != nil {
+ if err := db.Put(headerKey(num, hash), data); err != nil {
log.Crit("Failed to store header", "err", err)
}
return nil
@@ -426,8 +312,7 @@ func WriteBody(db ethdb.KeyValueWriter, hash common.Hash, number uint64, body *t
// WriteBodyRLP writes a serialized body of a block into the database.
func WriteBodyRLP(db ethdb.KeyValueWriter, hash common.Hash, number uint64, rlp rlp.RawValue) error {
- key := append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
- if err := db.Put(key, rlp); err != nil {
+ if err := db.Put(BlockBodyKey(number, hash), rlp); err != nil {
log.Crit("Failed to store block body", "err", err)
}
return nil
@@ -439,8 +324,7 @@ func WriteTd(db ethdb.KeyValueWriter, hash common.Hash, number uint64, td *big.I
if err != nil {
return err
}
- key := append(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...), tdSuffix...)
- if err := db.Put(key, data); err != nil {
+ if err := db.Put(headerTDKey(number, hash), data); err != nil {
log.Crit("Failed to store block total difficulty", "err", err)
}
return nil
@@ -473,66 +357,31 @@ func WriteBlockReceipts(db ethdb.KeyValueWriter, hash common.Hash, number uint64
return err
}
// Store the flattened receipt slice
- key := append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
- if err := db.Put(key, bytes); err != nil {
+ if err := db.Put(blockReceiptsKey(number, hash), bytes); err != nil {
log.Crit("Failed to store block receipts", "err", err)
}
return nil
}
-// WriteTxLookupEntries stores a positional metadata for every transaction from
-// a block, enabling hash based transaction and receipt lookups.
-func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error {
- // Iterate over each transaction and encode its metadata
- for i, tx := range block.Transactions() {
- entry := TxLookupEntry{
- BlockHash: block.Hash(),
- BlockIndex: block.NumberU64(),
- Index: uint64(i),
- }
- data, err := rlp.EncodeToBytes(entry)
- if err != nil {
- return err
- }
- if err := db.Put(append(lookupPrefix, tx.Hash().Bytes()...), data); err != nil {
- return err
- }
- }
- return nil
-}
-
-// WriteBloomBits writes the compressed bloom bits vector belonging to the given
-// section and bit index.
-func WriteBloomBits(db ethdb.KeyValueWriter, bit uint, section uint64, head common.Hash, bits []byte) {
- key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...)
-
- binary.BigEndian.PutUint16(key[1:], uint16(bit))
- binary.BigEndian.PutUint64(key[3:], section)
-
- if err := db.Put(key, bits); err != nil {
- log.Crit("Failed to store bloom bits", "err", err)
- }
-}
-
// DeleteCanonicalHash removes the number to hash canonical mapping.
func DeleteCanonicalHash(db DatabaseDeleter, number uint64) {
- db.Delete(append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...))
+ db.Delete(headerHashKey(number))
}
// DeleteHeader removes all block header data associated with a hash.
func DeleteHeader(db DatabaseDeleter, hash common.Hash, number uint64) {
- db.Delete(append(blockHashPrefix, hash.Bytes()...))
- db.Delete(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...))
+ db.Delete(headerNumberKey(hash))
+ db.Delete(headerKey(number, hash))
}
// DeleteBody removes all block body data associated with a hash.
func DeleteBody(db DatabaseDeleter, hash common.Hash, number uint64) {
- db.Delete(append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...))
+ db.Delete(BlockBodyKey(number, hash))
}
// DeleteTd removes all block total difficulty data associated with a hash.
func DeleteTd(db DatabaseDeleter, hash common.Hash, number uint64) {
- db.Delete(append(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...), tdSuffix...))
+ db.Delete(headerTDKey(number, hash))
}
// DeleteBlock removes all block data associated with a hash.
@@ -545,84 +394,7 @@ func DeleteBlock(db DatabaseDeleter, hash common.Hash, number uint64) {
// DeleteBlockReceipts removes all receipt data associated with a block hash.
func DeleteBlockReceipts(db DatabaseDeleter, hash common.Hash, number uint64) {
- db.Delete(append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...))
-}
-
-// DeleteTxLookupEntry removes all transaction data associated with a hash.
-func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) {
- db.Delete(append(lookupPrefix, hash.Bytes()...))
-}
-
-// PreimageTable returns a Database instance with the key prefix for preimage entries.
-func PreimageTable(db ethdb.Database) ethdb.Database {
- return rawdb.NewTable(db, preimagePrefix)
-}
-
-// WritePreimages writes the provided set of preimages to the database. `number` is the
-// current block number, and is used for debug messages only.
-func WritePreimages(db ethdb.Database, number uint64, preimages map[common.Hash][]byte) error {
- table := PreimageTable(db)
- batch := table.NewBatch()
- hitCount := 0
- for hash, preimage := range preimages {
- if _, err := table.Get(hash.Bytes()); err != nil {
- batch.Put(hash.Bytes(), preimage)
- hitCount++
- }
- }
- preimageCounter.Inc(int64(len(preimages)))
- preimageHitCounter.Inc(int64(hitCount))
- if hitCount > 0 {
- if err := batch.Write(); err != nil {
- return fmt.Errorf("preimage write fail for block %d: %v", number, err)
- }
- }
- return nil
-}
-
-// GetBlockChainVersion reads the version number from db.
-func GetBlockChainVersion(db DatabaseReader) int {
- var vsn uint
- enc, _ := db.Get([]byte("BlockchainVersion"))
- rlp.DecodeBytes(enc, &vsn)
- return int(vsn)
-}
-
-// WriteBlockChainVersion writes vsn as the version number to db.
-func WriteBlockChainVersion(db ethdb.KeyValueWriter, vsn int) {
- enc, _ := rlp.EncodeToBytes(uint(vsn))
- db.Put([]byte("BlockchainVersion"), enc)
-}
-
-// WriteChainConfig writes the chain config settings to the database.
-func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.ChainConfig) error {
- // short circuit and ignore if nil config. GetChainConfig
- // will return a default.
- if cfg == nil {
- return nil
- }
-
- jsonChainConfig, err := json.Marshal(cfg)
- if err != nil {
- return err
- }
-
- return db.Put(append(configPrefix, hash[:]...), jsonChainConfig)
-}
-
-// GetChainConfig will fetch the network settings based on the given hash.
-func GetChainConfig(db DatabaseReader, hash common.Hash) (*params.ChainConfig, error) {
- jsonChainConfig, _ := db.Get(append(configPrefix, hash[:]...))
- if len(jsonChainConfig) == 0 {
- return nil, ErrChainConfigNotFound
- }
-
- var config params.ChainConfig
- if err := json.Unmarshal(jsonChainConfig, &config); err != nil {
- return nil, err
- }
-
- return &config, nil
+ db.Delete(blockReceiptsKey(number, hash))
}
// FindCommonAncestor returns the last common ancestor of two block headers
diff --git a/core/database_util_test.go b/core/rawdb/accessors_chain_test.go
similarity index 83%
rename from core/database_util_test.go
rename to core/rawdb/accessors_chain_test.go
index f28ca160a..c85bae235 100644
--- a/core/database_util_test.go
+++ b/core/rawdb/accessors_chain_test.go
@@ -14,23 +14,27 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package core
+package rawdb
import (
"bytes"
- "github.com/tomochain/tomochain/core/rawdb"
+ "encoding/hex"
+ "fmt"
"math/big"
"testing"
+ "github.com/tomochain/tomochain/params"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto/sha3"
+ "github.com/tomochain/tomochain/internal/blocktest"
"github.com/tomochain/tomochain/rlp"
)
// Tests block header storage and retrieval operations.
func TestHeaderStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
// Create a test header to move around the database and make sure it's really new
header := &types.Header{Number: big.NewInt(42), Extra: []byte("test header")}
@@ -65,7 +69,7 @@ func TestHeaderStorage(t *testing.T) {
// Tests block body storage and retrieval operations.
func TestBodyStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
// Create a test body to move around the database and make sure it's really new
body := &types.Body{Uncles: []*types.Header{{Extra: []byte("test header")}}}
@@ -83,7 +87,7 @@ func TestBodyStorage(t *testing.T) {
}
if entry := GetBody(db, hash, 0); entry == nil {
t.Fatalf("Stored body not found")
- } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(types.Transactions(body.Transactions)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) {
+ } else if types.DeriveSha(types.Transactions(entry.Transactions), blocktest.NewHasher()) != types.DeriveSha(types.Transactions(body.Transactions), blocktest.NewHasher()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) {
t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, body)
}
if entry := GetBodyRLP(db, hash, 0); entry == nil {
@@ -105,7 +109,7 @@ func TestBodyStorage(t *testing.T) {
// Tests block storage and retrieval operations.
func TestBlockStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
// Create a test block to move around the database and make sure it's really new
block := types.NewBlockWithHeader(&types.Header{
@@ -139,7 +143,7 @@ func TestBlockStorage(t *testing.T) {
}
if entry := GetBody(db, block.Hash(), block.NumberU64()); entry == nil {
t.Fatalf("Stored body not found")
- } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(block.Transactions()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) {
+ } else if types.DeriveSha(types.Transactions(entry.Transactions), blocktest.NewHasher()) != types.DeriveSha(block.Transactions(), blocktest.NewHasher()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) {
t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, block.Body())
}
// Delete the block and verify the execution
@@ -157,7 +161,7 @@ func TestBlockStorage(t *testing.T) {
// Tests that partial block contents don't get reassembled into full blocks.
func TestPartialBlockStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
block := types.NewBlockWithHeader(&types.Header{
Extra: []byte("test block"),
UncleHash: types.EmptyUncleHash,
@@ -198,7 +202,7 @@ func TestPartialBlockStorage(t *testing.T) {
// Tests block total difficulty storage and retrieval operations.
func TestTdStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
// Create a test TD to move around the database and make sure it's really new
hash, td := common.Hash{}, big.NewInt(314)
@@ -223,7 +227,7 @@ func TestTdStorage(t *testing.T) {
// Tests that canonical numbers can be mapped to hashes and retrieved.
func TestCanonicalMappingStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
// Create a test canonical number and assinged hash to move around
hash, number := common.Hash{0: 0xff}, uint64(314)
@@ -248,7 +252,7 @@ func TestCanonicalMappingStorage(t *testing.T) {
// Tests that head headers and head blocks can be assigned, individually.
func TestHeadStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
blockHead := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block header")})
blockFull := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block full")})
@@ -288,14 +292,14 @@ func TestHeadStorage(t *testing.T) {
// Tests that positional lookup metadata can be stored and retrieved.
func TestLookupStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
tx1 := types.NewTransaction(1, common.BytesToAddress([]byte{0x11}), big.NewInt(111), 1111, big.NewInt(11111), []byte{0x11, 0x11, 0x11})
tx2 := types.NewTransaction(2, common.BytesToAddress([]byte{0x22}), big.NewInt(222), 2222, big.NewInt(22222), []byte{0x22, 0x22, 0x22})
tx3 := types.NewTransaction(3, common.BytesToAddress([]byte{0x33}), big.NewInt(333), 3333, big.NewInt(33333), []byte{0x33, 0x33, 0x33})
txs := []*types.Transaction{tx1, tx2, tx3}
- block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil)
+ block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil, blocktest.NewHasher())
// Check that no transactions entries are in a pristine database
for i, tx := range txs {
@@ -333,8 +337,15 @@ func TestLookupStorage(t *testing.T) {
// Tests that receipts associated with a single block can be stored and retrieved.
func TestBlockReceiptStorage(t *testing.T) {
- db := rawdb.NewMemoryDatabase()
+ db := NewMemoryDatabase()
+
+ // Create a live block since we need metadata to reconstruct the receipt
+ tx1 := types.NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil)
+ tx2 := types.NewTransaction(2, common.HexToAddress("0x2"), big.NewInt(2), 2, big.NewInt(2), nil)
+
+ body := &types.Body{Transactions: types.Transactions{tx1, tx2}}
+ // Create the two receipts to manage afterwards
receipt1 := &types.Receipt{
Status: types.ReceiptStatusFailed,
CumulativeGasUsed: 1,
@@ -342,10 +353,12 @@ func TestBlockReceiptStorage(t *testing.T) {
{Address: common.BytesToAddress([]byte{0x11})},
{Address: common.BytesToAddress([]byte{0x01, 0x11})},
},
- TxHash: common.BytesToHash([]byte{0x11, 0x11}),
+ TxHash: tx1.Hash(),
ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}),
GasUsed: 111111,
}
+ receipt1.Bloom = types.CreateBloom(types.Receipts{receipt1})
+
receipt2 := &types.Receipt{
PostState: common.Hash{2}.Bytes(),
CumulativeGasUsed: 2,
@@ -353,36 +366,64 @@ func TestBlockReceiptStorage(t *testing.T) {
{Address: common.BytesToAddress([]byte{0x22})},
{Address: common.BytesToAddress([]byte{0x02, 0x22})},
},
- TxHash: common.BytesToHash([]byte{0x22, 0x22}),
+ TxHash: tx2.Hash(),
ContractAddress: common.BytesToAddress([]byte{0x02, 0x22, 0x22}),
GasUsed: 222222,
}
+ receipt2.Bloom = types.CreateBloom(types.Receipts{receipt2})
receipts := []*types.Receipt{receipt1, receipt2}
// Check that no receipt entries are in a pristine database
hash := common.BytesToHash([]byte{0x03, 0x14})
- if rs := GetBlockReceipts(db, hash, 0); len(rs) != 0 {
+ if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 {
t.Fatalf("non existent receipts returned: %v", rs)
}
+ // Insert the body that corresponds to the receipts
+ WriteBody(db, hash, 0, body)
+
// Insert the receipt slice into the database and check presence
- if err := WriteBlockReceipts(db, hash, 0, receipts); err != nil {
- t.Fatalf("failed to write block receipts: %v", err)
- }
- if rs := GetBlockReceipts(db, hash, 0); len(rs) == 0 {
+ WriteBlockReceipts(db, hash, 0, receipts)
+ if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) == 0 {
t.Fatalf("no receipts returned")
} else {
- for i := 0; i < len(receipts); i++ {
- rlpHave, _ := rlp.EncodeToBytes(rs[i])
- rlpWant, _ := rlp.EncodeToBytes(receipts[i])
-
- if !bytes.Equal(rlpHave, rlpWant) {
- t.Fatalf("receipt #%d: receipt mismatch: have %v, want %v", i, rs[i], receipts[i])
- }
+ if err := checkReceiptsRLP(rs, receipts); err != nil {
+ t.Fatalf(err.Error())
}
}
- // Delete the receipt slice and check purge
+ // Delete the body and ensure that the receipts are no longer returned (metadata can't be recomputed)
+ DeleteBody(db, hash, 0)
+ if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); rs != nil {
+ t.Fatalf("receipts returned when body was deleted: %v", rs)
+ }
+ // Ensure that receipts without metadata can be returned without the block body too
+ if err := checkReceiptsRLP(ReadRawReceipts(db, hash, 0), receipts); err != nil {
+ t.Fatalf(err.Error())
+ }
+ // Sanity check that body alone without the receipt is a full purge
+ WriteBody(db, hash, 0, body)
+
DeleteBlockReceipts(db, hash, 0)
- if rs := GetBlockReceipts(db, hash, 0); len(rs) != 0 {
+ if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 {
t.Fatalf("deleted receipts returned: %v", rs)
}
}
+
+func checkReceiptsRLP(have, want types.Receipts) error {
+ if len(have) != len(want) {
+ return fmt.Errorf("receipts sizes mismatch: have %d, want %d", len(have), len(want))
+ }
+ for i := 0; i < len(want); i++ {
+ rlpHave, err := rlp.EncodeToBytes(have[i])
+ if err != nil {
+ return err
+ }
+ rlpWant, err := rlp.EncodeToBytes(want[i])
+ if err != nil {
+ return err
+ }
+ if !bytes.Equal(rlpHave, rlpWant) {
+ return fmt.Errorf("receipt #%d: receipt mismatch: have %s, want %s", i, hex.EncodeToString(rlpHave), hex.EncodeToString(rlpWant))
+ }
+ }
+ return nil
+}
diff --git a/core/rawdb/accessors_indexes.go b/core/rawdb/accessors_indexes.go
new file mode 100644
index 000000000..0bb54d65b
--- /dev/null
+++ b/core/rawdb/accessors_indexes.go
@@ -0,0 +1,145 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import (
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+// GetTxLookupEntry retrieves the positional metadata associated with a transaction
+// hash to allow retrieving the transaction or receipt by hash.
+func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) {
+ // Load the positional metadata from disk and bail if it fails
+ data, _ := db.Get(txLookupKey(hash))
+ if len(data) == 0 {
+ return common.Hash{}, 0, 0
+ }
+ // Parse and return the contents of the lookup entry
+ var entry TxLookupEntry
+ if err := rlp.DecodeBytes(data, &entry); err != nil {
+ log.Error("Invalid lookup entry RLP", "hash", hash, "err", err)
+ return common.Hash{}, 0, 0
+ }
+ return entry.BlockHash, entry.BlockIndex, entry.Index
+}
+
+// WriteTxLookupEntries stores a positional metadata for every transaction from
+// a block, enabling hash based transaction and receipt lookups.
+func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error {
+ // Iterate over each transaction and encode its metadata
+ for i, tx := range block.Transactions() {
+ entry := TxLookupEntry{
+ BlockHash: block.Hash(),
+ BlockIndex: block.NumberU64(),
+ Index: uint64(i),
+ }
+ data, err := rlp.EncodeToBytes(entry)
+ if err != nil {
+ return err
+ }
+ if err := db.Put(txLookupKey(tx.Hash()), data); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// DeleteTxLookupEntry removes all transaction data associated with a hash.
+func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) {
+ db.Delete(txLookupKey(hash))
+}
+
+// GetTransaction retrieves a specific transaction from the database, along with
+// its added positional metadata.
+func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, common.Hash, uint64, uint64) {
+ // Retrieve the lookup metadata and resolve the transaction from the body
+ blockHash, blockNumber, txIndex := GetTxLookupEntry(db, hash)
+
+ if blockHash != (common.Hash{}) {
+ body := GetBody(db, blockHash, blockNumber)
+ if body == nil || len(body.Transactions) <= int(txIndex) {
+ log.Error("Transaction referenced missing", "number", blockNumber, "hash", blockHash, "index", txIndex)
+ return nil, common.Hash{}, 0, 0
+ }
+ return body.Transactions[txIndex], blockHash, blockNumber, txIndex
+ }
+ // Old transaction representation, load the transaction and its metadata separately
+ data, _ := db.Get(hash.Bytes())
+ if len(data) == 0 {
+ return nil, common.Hash{}, 0, 0
+ }
+ var tx types.Transaction
+ if err := rlp.DecodeBytes(data, &tx); err != nil {
+ return nil, common.Hash{}, 0, 0
+ }
+ // Retrieve the blockchain positional metadata
+ data, _ = db.Get(oldTxMetaKey(hash))
+ if len(data) == 0 {
+ return nil, common.Hash{}, 0, 0
+ }
+ var entry TxLookupEntry
+ if err := rlp.DecodeBytes(data, &entry); err != nil {
+ return nil, common.Hash{}, 0, 0
+ }
+ return &tx, entry.BlockHash, entry.BlockIndex, entry.Index
+}
+
+// GetReceipt retrieves a specific transaction receipt from the database, along with
+// its added positional metadata.
+func GetReceipt(db DatabaseReader, hash common.Hash, config *params.ChainConfig) (*types.Receipt, common.Hash, uint64, uint64) {
+ // Retrieve the lookup metadata and resolve the receipt from the receipts
+ blockHash, blockNumber, receiptIndex := GetTxLookupEntry(db, hash)
+
+ if blockHash != (common.Hash{}) {
+ receipts := GetBlockReceipts(db, blockHash, blockNumber, config)
+ if len(receipts) <= int(receiptIndex) {
+ log.Error("Receipt refereced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex)
+ return nil, common.Hash{}, 0, 0
+ }
+ return receipts[receiptIndex], blockHash, blockNumber, receiptIndex
+ }
+ // Old receipt representation, load the receipt and set an unknown metadata
+ data, _ := db.Get(append(oldReceiptsPrefix, hash[:]...))
+ if len(data) == 0 {
+ return nil, common.Hash{}, 0, 0
+ }
+ var receipt types.ReceiptForStorage
+ err := rlp.DecodeBytes(data, &receipt)
+ if err != nil {
+ log.Error("Invalid receipt RLP", "hash", hash, "err", err)
+ }
+ return (*types.Receipt)(&receipt), common.Hash{}, 0, 0
+}
+
+// GetBloomBits retrieves the compressed bloom bit vector belonging to the given
+// bit index and section indexes.
+func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) {
+ return db.Get(bloomBitsKey(bit, section, head))
+}
+
+// WriteBloomBits writes the compressed bloom bits vector belonging to the given
+// section and bit index.
+func WriteBloomBits(db ethdb.KeyValueWriter, bit uint, section uint64, head common.Hash, bits []byte) {
+ if err := db.Put(bloomBitsKey(bit, section, head), bits); err != nil {
+ log.Crit("Failed to store bloom bits", "err", err)
+ }
+}
diff --git a/core/rawdb/accessors_metadata.go b/core/rawdb/accessors_metadata.go
new file mode 100644
index 000000000..16fbbd77b
--- /dev/null
+++ b/core/rawdb/accessors_metadata.go
@@ -0,0 +1,71 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import (
+ "encoding/json"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+// GetBlockChainVersion reads the version number from db.
+func GetBlockChainVersion(db DatabaseReader) int {
+ var vsn uint
+ enc, _ := db.Get([]byte("BlockchainVersion"))
+ rlp.DecodeBytes(enc, &vsn)
+ return int(vsn)
+}
+
+// WriteBlockChainVersion writes vsn as the version number to db.
+func WriteBlockChainVersion(db ethdb.KeyValueWriter, vsn int) {
+ enc, _ := rlp.EncodeToBytes(uint(vsn))
+ db.Put([]byte("BlockchainVersion"), enc)
+}
+
+// WriteChainConfig writes the chain config settings to the database.
+func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.ChainConfig) error {
+ // short circuit and ignore if nil config. GetChainConfig
+ // will return a default.
+ if cfg == nil {
+ return nil
+ }
+
+ jsonChainConfig, err := json.Marshal(cfg)
+ if err != nil {
+ return err
+ }
+
+ return db.Put(configKey(hash), jsonChainConfig)
+}
+
+// GetChainConfig will fetch the network settings based on the given hash.
+func GetChainConfig(db DatabaseReader, hash common.Hash) (*params.ChainConfig, error) {
+ jsonChainConfig, _ := db.Get(configKey(hash))
+ if len(jsonChainConfig) == 0 {
+ return nil, ErrChainConfigNotFound
+ }
+
+ var config params.ChainConfig
+ if err := json.Unmarshal(jsonChainConfig, &config); err != nil {
+ return nil, err
+ }
+
+ return &config, nil
+}
diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go
new file mode 100644
index 000000000..6ef285019
--- /dev/null
+++ b/core/rawdb/accessors_snapshot.go
@@ -0,0 +1,135 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import (
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/log"
+)
+
+// ReadSnapshotRoot retrieves the root of the block whose state is contained in
+// the persisted snapshot.
+func ReadSnapshotRoot(db ethdb.KeyValueReader) common.Hash {
+ data, _ := db.Get(snapshotRootKey)
+ if len(data) != common.HashLength {
+ return common.Hash{}
+ }
+ return common.BytesToHash(data)
+}
+
+// WriteSnapshotRoot stores the root of the block whose state is contained in
+// the persisted snapshot.
+func WriteSnapshotRoot(db ethdb.KeyValueWriter, root common.Hash) {
+ if err := db.Put(snapshotRootKey, root[:]); err != nil {
+ log.Crit("Failed to store snapshot root", "err", err)
+ }
+}
+
+// DeleteSnapshotRoot deletes the hash of the block whose state is contained in
+// the persisted snapshot. Since snapshots are not immutable, this method can
+// be used during updates, so a crash or failure will mark the entire snapshot
+// invalid.
+func DeleteSnapshotRoot(db ethdb.KeyValueWriter) {
+ if err := db.Delete(snapshotRootKey); err != nil {
+ log.Crit("Failed to remove snapshot root", "err", err)
+ }
+}
+
+// ReadAccountSnapshot retrieves the snapshot entry of an account trie leaf.
+func ReadAccountSnapshot(db ethdb.KeyValueReader, hash common.Hash) []byte {
+ data, _ := db.Get(accountSnapshotKey(hash))
+ return data
+}
+
+// WriteAccountSnapshot stores the snapshot entry of an account trie leaf.
+func WriteAccountSnapshot(db ethdb.KeyValueWriter, hash common.Hash, entry []byte) {
+ if err := db.Put(accountSnapshotKey(hash), entry); err != nil {
+ log.Crit("Failed to store account snapshot", "err", err)
+ }
+}
+
+// DeleteAccountSnapshot removes the snapshot entry of an account trie leaf.
+func DeleteAccountSnapshot(db ethdb.KeyValueWriter, hash common.Hash) {
+ if err := db.Delete(accountSnapshotKey(hash)); err != nil {
+ log.Crit("Failed to delete account snapshot", "err", err)
+ }
+}
+
+// ReadStorageSnapshot retrieves the snapshot entry of an storage trie leaf.
+func ReadStorageSnapshot(db ethdb.KeyValueReader, accountHash, storageHash common.Hash) []byte {
+ data, _ := db.Get(storageSnapshotKey(accountHash, storageHash))
+ return data
+}
+
+// WriteStorageSnapshot stores the snapshot entry of an storage trie leaf.
+func WriteStorageSnapshot(db ethdb.KeyValueWriter, accountHash, storageHash common.Hash, entry []byte) {
+ if err := db.Put(storageSnapshotKey(accountHash, storageHash), entry); err != nil {
+ log.Crit("Failed to store storage snapshot", "err", err)
+ }
+}
+
+// DeleteStorageSnapshot removes the snapshot entry of an storage trie leaf.
+func DeleteStorageSnapshot(db ethdb.KeyValueWriter, accountHash, storageHash common.Hash) {
+ if err := db.Delete(storageSnapshotKey(accountHash, storageHash)); err != nil {
+ log.Crit("Failed to delete storage snapshot", "err", err)
+ }
+}
+
+// IterateStorageSnapshots returns an iterator for walking the entire storage
+// space of a specific account.
+func IterateStorageSnapshots(db ethdb.Iteratee, accountHash common.Hash) ethdb.Iterator {
+ return NewKeyLengthIterator(db.NewIterator(storageSnapshotsKey(accountHash), nil), len(SnapshotStoragePrefix)+2*common.HashLength)
+}
+
+// ReadSnapshotJournal retrieves the serialized in-memory diff layers saved at
+// the last shutdown. The blob is expected to be max a few 10s of megabytes.
+func ReadSnapshotJournal(db ethdb.KeyValueReader) []byte {
+ data, _ := db.Get(snapshotJournalKey)
+ return data
+}
+
+// WriteSnapshotJournal stores the serialized in-memory diff layers to save at
+// shutdown. The blob is expected to be max a few 10s of megabytes.
+func WriteSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) {
+ if err := db.Put(snapshotJournalKey, journal); err != nil {
+ log.Crit("Failed to store snapshot journal", "err", err)
+ }
+}
+
+// DeleteSnapshotJournal deletes the serialized in-memory diff layers saved at
+// the last shutdown
+func DeleteSnapshotJournal(db ethdb.KeyValueWriter) {
+ if err := db.Delete(snapshotJournalKey); err != nil {
+ log.Crit("Failed to remove snapshot journal", "err", err)
+ }
+}
+
+// ReadSnapshotGenerator retrieves the serialized snapshot generator saved at
+// the last shutdown.
+func ReadSnapshotGenerator(db ethdb.KeyValueReader) []byte {
+ data, _ := db.Get(snapshotGeneratorKey)
+ return data
+}
+
+// WriteSnapshotGenerator stores the serialized snapshot generator to save at
+// shutdown.
+func WriteSnapshotGenerator(db ethdb.KeyValueWriter, generator []byte) {
+ if err := db.Put(snapshotGeneratorKey, generator); err != nil {
+ log.Crit("Failed to store snapshot generator", "err", err)
+ }
+}
diff --git a/core/rawdb/accessors_state.go b/core/rawdb/accessors_state.go
new file mode 100644
index 000000000..28bba40f3
--- /dev/null
+++ b/core/rawdb/accessors_state.go
@@ -0,0 +1,58 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import (
+ "fmt"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/ethdb"
+)
+
+// PreimageTable returns a Database instance with the key prefix for preimage entries.
+func PreimageTable(db ethdb.Database) ethdb.Database {
+ return NewTable(db, PreimagePrefix)
+}
+
+// ReadPreimage retrieves a single preimage of the provided hash.
+func ReadPreimage(db ethdb.Database, hash common.Hash) []byte {
+ table := PreimageTable(db)
+ data, _ := table.Get(hash.Bytes())
+ return data
+}
+
+// WritePreimages writes the provided set of preimages to the database. `number` is the
+// current block number, and is used for debug messages only.
+func WritePreimages(db ethdb.Database, number uint64, preimages map[common.Hash][]byte) error {
+ table := PreimageTable(db)
+ batch := table.NewBatch()
+ hitCount := 0
+ for hash, preimage := range preimages {
+ if _, err := table.Get(hash.Bytes()); err != nil {
+ batch.Put(hash.Bytes(), preimage)
+ hitCount++
+ }
+ }
+ preimageCounter.Inc(int64(len(preimages)))
+ preimageHitCounter.Inc(int64(hitCount))
+ if hitCount > 0 {
+ if err := batch.Write(); err != nil {
+ return fmt.Errorf("preimage write fail for block %d: %v", number, err)
+ }
+ }
+ return nil
+}
diff --git a/core/rawdb/accessors_trie.go b/core/rawdb/accessors_trie.go
new file mode 100644
index 000000000..7e1bbcaa2
--- /dev/null
+++ b/core/rawdb/accessors_trie.go
@@ -0,0 +1,64 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package rawdb
+
+import (
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/log"
+)
+
+// HashScheme is the legacy hash-based state scheme with which trie nodes are
+// stored in the disk with node hash as the database key. The advantage of this
+// scheme is that different versions of trie nodes can be stored in disk, which
+// is very beneficial for constructing archive nodes. The drawback is it will
+// store different trie nodes on the same path to different locations on the disk
+// with no data locality, and it's unfriendly for designing state pruning.
+//
+// Now this scheme is still kept for backward compatibility, and it will be used
+// for archive node and some other tries(e.g. light trie).
+const HashScheme = "hashScheme"
+
+// ReadLegacyTrieNode retrieves the legacy trie node with the given
+// associated node hash.
+func ReadLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash) []byte {
+ data, err := db.Get(hash.Bytes())
+ if err != nil {
+ return nil
+ }
+ return data
+}
+
+// HasLegacyTrieNode checks if the trie node with the provided hash is present in db.
+func HasLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash) bool {
+ ok, _ := db.Has(hash.Bytes())
+ return ok
+}
+
+// WriteLegacyTrieNode writes the provided legacy trie node to database.
+func WriteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash, node []byte) {
+ if err := db.Put(hash.Bytes(), node); err != nil {
+ log.Crit("Failed to store legacy trie node", "err", err)
+ }
+}
+
+// DeleteLegacyTrieNode deletes the specified legacy trie node from database.
+func DeleteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash) {
+ if err := db.Delete(hash.Bytes()); err != nil {
+ log.Crit("Failed to delete legacy trie node", "err", err)
+ }
+}
diff --git a/core/rawdb/database.go b/core/rawdb/database.go
index 1183a74f5..ea1dfe234 100644
--- a/core/rawdb/database.go
+++ b/core/rawdb/database.go
@@ -17,10 +17,17 @@
package rawdb
import (
+ "bytes"
"fmt"
+ "os"
+ "time"
+
+ "github.com/olekukonko/tablewriter"
+ "github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/ethdb/leveldb"
"github.com/tomochain/tomochain/ethdb/memorydb"
+ "github.com/tomochain/tomochain/log"
)
// freezerdb is a database wrapper that enabled freezer data retrievals.
@@ -108,3 +115,177 @@ func NewLevelDBDatabase(file string, cache int, handles int, namespace string) (
}
return NewDatabase(db), nil
}
+
+type counter uint64
+
+func (c counter) String() string {
+ return fmt.Sprintf("%d", c)
+}
+
+func (c counter) Percentage(current uint64) string {
+ return fmt.Sprintf("%d", current*100/uint64(c))
+}
+
+// stat stores sizes and count for a parameter
+type stat struct {
+ size common.StorageSize
+ count counter
+}
+
+// Add size to the stat and increase the counter by 1
+func (s *stat) Add(size common.StorageSize) {
+ s.size += size
+ s.count++
+}
+
+func (s *stat) Size() string {
+ return s.size.String()
+}
+
+func (s *stat) Count() string {
+ return s.count.String()
+}
+
+// InspectDatabase traverses the entire database and checks the size
+// of all different categories of data.
+func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error {
+ it := db.NewIterator(keyPrefix, keyStart)
+ defer it.Release()
+
+ var (
+ count int64
+ start = time.Now()
+ logged = time.Now()
+
+ // Key-value store statistics
+ headers stat
+ bodies stat
+ receipts stat
+ tds stat
+ numHashPairings stat
+ hashNumPairings stat
+ tries stat
+ codes stat
+ txLookups stat
+ accountSnaps stat
+ storageSnaps stat
+ preimages stat
+ bloomBits stat
+ cliqueSnaps stat
+
+ // Ancient store statistics
+ ancientHeadersSize common.StorageSize
+ ancientBodiesSize common.StorageSize
+ ancientReceiptsSize common.StorageSize
+ ancientTdsSize common.StorageSize
+ ancientHashesSize common.StorageSize
+
+ // Les statistic
+ chtTrieNodes stat
+ bloomTrieNodes stat
+
+ // Meta- and unaccounted data
+ metadata stat
+ unaccounted stat
+
+ // Totals
+ total common.StorageSize
+ )
+ // Inspect key-value database first.
+ for it.Next() {
+ var (
+ key = it.Key()
+ size = common.StorageSize(len(key) + len(it.Value()))
+ )
+ total += size
+ switch {
+ case bytes.HasPrefix(key, headerPrefix) && len(key) == (len(headerPrefix)+8+common.HashLength):
+ headers.Add(size)
+ case bytes.HasPrefix(key, blockBodyPrefix) && len(key) == (len(blockBodyPrefix)+8+common.HashLength):
+ bodies.Add(size)
+ case bytes.HasPrefix(key, blockReceiptsPrefix) && len(key) == (len(blockReceiptsPrefix)+8+common.HashLength):
+ receipts.Add(size)
+ case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerTDSuffix):
+ tds.Add(size)
+ case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerHashSuffix):
+ numHashPairings.Add(size)
+ case bytes.HasPrefix(key, headerNumberPrefix) && len(key) == (len(headerNumberPrefix)+common.HashLength):
+ hashNumPairings.Add(size)
+ case len(key) == common.HashLength:
+ tries.Add(size)
+ case bytes.HasPrefix(key, txLookupPrefix) && len(key) == (len(txLookupPrefix)+common.HashLength):
+ txLookups.Add(size)
+ case bytes.HasPrefix(key, SnapshotAccountPrefix) && len(key) == (len(SnapshotAccountPrefix)+common.HashLength):
+ accountSnaps.Add(size)
+ case bytes.HasPrefix(key, SnapshotStoragePrefix) && len(key) == (len(SnapshotStoragePrefix)+2*common.HashLength):
+ storageSnaps.Add(size)
+ case bytes.HasPrefix(key, []byte(PreimagePrefix)) && len(key) == (len(PreimagePrefix)+common.HashLength):
+ preimages.Add(size)
+ case bytes.HasPrefix(key, bloomBitsPrefix) && len(key) == (len(bloomBitsPrefix)+10+common.HashLength):
+ bloomBits.Add(size)
+ case bytes.HasPrefix(key, []byte("clique-")) && len(key) == 7+common.HashLength:
+ cliqueSnaps.Add(size)
+ case bytes.HasPrefix(key, []byte("cht-")) && len(key) == 4+common.HashLength:
+ chtTrieNodes.Add(size)
+ case bytes.HasPrefix(key, []byte("blt-")) && len(key) == 4+common.HashLength:
+ bloomTrieNodes.Add(size)
+ default:
+ var accounted bool
+ for _, meta := range [][]byte{databaseVersionKey, headHeaderKey, headBlockKey, headFastBlockKey, fastTrieProgressKey} {
+ if bytes.Equal(key, meta) {
+ metadata.Add(size)
+ accounted = true
+ break
+ }
+ }
+ if !accounted {
+ unaccounted.Add(size)
+ }
+ }
+ count += 1
+ if count%1000 == 0 && time.Since(logged) > 8*time.Second {
+ log.Info("Inspecting database", "count", count, "elapsed", common.PrettyDuration(time.Since(start)))
+ logged = time.Now()
+ }
+ }
+ // Get number of ancient rows inside the freezer
+ ancients := counter(0)
+ if count, err := db.Ancients(); err == nil {
+ ancients = counter(count)
+ }
+ // Display the database statistic.
+ stats := [][]string{
+ {"Key-Value store", "Headers", headers.Size(), headers.Count()},
+ {"Key-Value store", "Bodies", bodies.Size(), bodies.Count()},
+ {"Key-Value store", "Receipt lists", receipts.Size(), receipts.Count()},
+ {"Key-Value store", "Difficulties", tds.Size(), tds.Count()},
+ {"Key-Value store", "Block number->hash", numHashPairings.Size(), numHashPairings.Count()},
+ {"Key-Value store", "Block hash->number", hashNumPairings.Size(), hashNumPairings.Count()},
+ {"Key-Value store", "Transaction index", txLookups.Size(), txLookups.Count()},
+ {"Key-Value store", "Bloombit index", bloomBits.Size(), bloomBits.Count()},
+ {"Key-Value store", "Contract codes", codes.Size(), codes.Count()},
+ {"Key-Value store", "Trie nodes", tries.Size(), tries.Count()},
+ {"Key-Value store", "Trie preimages", preimages.Size(), preimages.Count()},
+ {"Key-Value store", "Account snapshot", accountSnaps.Size(), accountSnaps.Count()},
+ {"Key-Value store", "Storage snapshot", storageSnaps.Size(), storageSnaps.Count()},
+ {"Key-Value store", "Clique snapshots", cliqueSnaps.Size(), cliqueSnaps.Count()},
+ {"Key-Value store", "Singleton metadata", metadata.Size(), metadata.Count()},
+ {"Ancient store", "Headers", ancientHeadersSize.String(), ancients.String()},
+ {"Ancient store", "Bodies", ancientBodiesSize.String(), ancients.String()},
+ {"Ancient store", "Receipt lists", ancientReceiptsSize.String(), ancients.String()},
+ {"Ancient store", "Difficulties", ancientTdsSize.String(), ancients.String()},
+ {"Ancient store", "Block number->hash", ancientHashesSize.String(), ancients.String()},
+ {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()},
+ {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()},
+ }
+ table := tablewriter.NewWriter(os.Stdout)
+ table.SetHeader([]string{"Database", "Category", "Size", "Items"})
+ table.SetFooter([]string{"", "Total", total.String(), " "})
+ table.AppendBulk(stats)
+ table.Render()
+
+ if unaccounted.size > 0 {
+ log.Error("Database contains unaccounted data", "size", unaccounted.size, "count", unaccounted.count)
+ }
+ return nil
+}
diff --git a/core/rawdb/key_length_iterator.go b/core/rawdb/key_length_iterator.go
new file mode 100644
index 000000000..9e24f0ec3
--- /dev/null
+++ b/core/rawdb/key_length_iterator.go
@@ -0,0 +1,47 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import "github.com/tomochain/tomochain/ethdb"
+
+// KeyLengthIterator is a wrapper for a database iterator that ensures only key-value pairs
+// with a specific key length will be returned.
+type KeyLengthIterator struct {
+ requiredKeyLength int
+ ethdb.Iterator
+}
+
+// NewKeyLengthIterator returns a wrapped version of the iterator that will only return key-value
+// pairs where keys with a specific key length will be returned.
+func NewKeyLengthIterator(it ethdb.Iterator, keyLen int) ethdb.Iterator {
+ return &KeyLengthIterator{
+ Iterator: it,
+ requiredKeyLength: keyLen,
+ }
+}
+
+func (it *KeyLengthIterator) Next() bool {
+ // Return true as soon as a key with the required key length is discovered
+ for it.Iterator.Next() {
+ if len(it.Iterator.Key()) == it.requiredKeyLength {
+ return true
+ }
+ }
+
+ // Return false when we exhaust the keys in the underlying iterator.
+ return false
+}
diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go
new file mode 100644
index 000000000..b49e238ab
--- /dev/null
+++ b/core/rawdb/schema.go
@@ -0,0 +1,166 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package rawdb contains a collection of low level database accessors.
+package rawdb
+
+import (
+ "encoding/binary"
+ "errors"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/metrics"
+)
+
+var (
+ // databaseVersionKey tracks the current database version.
+ databaseVersionKey = []byte("DatabaseVersion")
+
+ // headFastBlockKey tracks the latest known incomplete block's hash during fast sync.
+ headFastBlockKey = []byte("LastFast")
+
+ // fastTrieProgressKey tracks the number of trie entries imported during fast sync.
+ fastTrieProgressKey = []byte("TrieSync")
+
+ // snapshotRootKey tracks the hash of the last snapshot.
+ snapshotRootKey = []byte("SnapshotRoot")
+
+ // snapshotJournalKey tracks the in-memory diff layers across restarts.
+ snapshotJournalKey = []byte("SnapshotJournal")
+
+ // snapshotGeneratorKey tracks the snapshot generation marker across restarts.
+ snapshotGeneratorKey = []byte("SnapshotGenerator")
+
+ headHeaderKey = []byte("LastHeader")
+ headBlockKey = []byte("LastBlock")
+ headFastKey = []byte("LastFast")
+ trieSyncKey = []byte("TrieSync")
+
+ // Data item prefixes (use single byte to avoid mixing data types, avoid `i`).
+ headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header
+ headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td
+ headerHashSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + headerHashSuffix -> hash
+ headerNumberPrefix = []byte("H") // headerNumberPrefix + hash -> num (uint64 big endian)
+ blockBodyPrefix = []byte("b") // blockBodyPrefix + num (uint64 big endian) + hash -> block body
+ blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts
+ txLookupPrefix = []byte("l") // txLookupPrefix + hash -> transaction/receipt lookup metadata
+ bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits
+ SnapshotAccountPrefix = []byte("a") // SnapshotAccountPrefix + account hash -> account trie value
+ SnapshotStoragePrefix = []byte("o") // SnapshotStoragePrefix + account hash + storage hash -> storage trie value
+
+ PreimagePrefix = "secure-key-" // PreimagePrefix + hash -> preimage
+ configPrefix = []byte("ethereum-config-") // config prefix for the db
+
+ // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress
+ BloomBitsIndexPrefix = []byte("iB") // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress
+
+ // used by old db, now only used for conversion
+ oldReceiptsPrefix = []byte("receipts-")
+ oldTxMetaSuffix = []byte{0x01}
+
+ ErrChainConfigNotFound = errors.New("ChainConfig not found") // general config not found error
+
+ preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil)
+ preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil)
+)
+
+// TxLookupEntry is a positional metadata to help looking up the data content of
+// a transaction or receipt given only its hash.
+type TxLookupEntry struct {
+ BlockHash common.Hash
+ BlockIndex uint64
+ Index uint64
+}
+
+// configKey = configPrefix + hash
+func configKey(hash common.Hash) []byte {
+ return append(configPrefix, hash.Bytes()...)
+}
+
+// headerKey = headerPrefix + num (uint64 big endian) + hash
+func headerKey(number uint64, hash common.Hash) []byte {
+ return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
+}
+
+// headerTDKey = headerPrefix + num (uint64 big endian) + hash + headerTDSuffix
+func headerTDKey(number uint64, hash common.Hash) []byte {
+ return append(HeaderKey(number, hash), headerTDSuffix...)
+}
+
+// headerHashKey = headerPrefix + num (uint64 big endian) + headerHashSuffix
+func headerHashKey(number uint64) []byte {
+ return append(append(headerPrefix, encodeBlockNumber(number)...), headerHashSuffix...)
+}
+
+// HeaderKey = headerPrefix + num (uint64 big endian) + hash
+func HeaderKey(number uint64, hash common.Hash) []byte {
+ return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
+}
+
+// headerNumberKey = headerNumberPrefix + hash
+func headerNumberKey(hash common.Hash) []byte {
+ return append(headerNumberPrefix, hash.Bytes()...)
+}
+
+// BlockBodyKey = blockBodyPrefix + num (uint64 big endian) + hash
+func BlockBodyKey(number uint64, hash common.Hash) []byte {
+ return append(append(blockBodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
+}
+
+// blockReceiptsKey = blockReceiptsPrefix + num (uint64 big endian) + hash
+func blockReceiptsKey(number uint64, hash common.Hash) []byte {
+ return append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
+}
+
+// txLookupKey = txLookupPrefix + hash
+func txLookupKey(hash common.Hash) []byte {
+ return append(txLookupPrefix, hash.Bytes()...)
+}
+
+// bloomBitsKey = bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash
+func bloomBitsKey(bit uint, section uint64, hash common.Hash) []byte {
+ key := append(append(bloomBitsPrefix, make([]byte, 10)...), hash.Bytes()...)
+
+ binary.BigEndian.PutUint16(key[1:], uint16(bit))
+ binary.BigEndian.PutUint64(key[3:], section)
+
+ return key
+}
+
+// oldTxMetaKey = hash + oldTxMetaSuffix
+func oldTxMetaKey(hash common.Hash) []byte {
+ return append(hash.Bytes(), oldTxMetaSuffix...)
+}
+
+// oldReceiptsKey = oldReceiptsPrefix + hash
+func oldReceiptsKey(hash common.Hash) []byte {
+ return append(oldReceiptsPrefix, hash[:]...)
+}
+
+// accountSnapshotKey = SnapshotAccountPrefix + hash
+func accountSnapshotKey(hash common.Hash) []byte {
+ return append(SnapshotAccountPrefix, hash.Bytes()...)
+}
+
+// storageSnapshotKey = SnapshotStoragePrefix + account hash + storage hash
+func storageSnapshotKey(accountHash, storageHash common.Hash) []byte {
+ return append(append(SnapshotStoragePrefix, accountHash.Bytes()...), storageHash.Bytes()...)
+}
+
+// storageSnapshotsKey = SnapshotStoragePrefix + account hash + storage hash
+func storageSnapshotsKey(accountHash common.Hash) []byte {
+ return append(SnapshotStoragePrefix, accountHash.Bytes()...)
+}
diff --git a/core/state/database.go b/core/state/database.go
index b57f134db..8f47c3396 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -21,6 +21,7 @@ import (
lru "github.com/hashicorp/golang-lru"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/trie"
)
@@ -59,20 +60,40 @@ type Trie interface {
// TODO(fjl): remove this when SecureTrie is removed
GetKey([]byte) []byte
- // TryGet returns the value for key stored in the trie. The value bytes must
- // not be modified by the caller. If a node was not found in the database, a
- // trie.MissingNodeError is returned.
- TryGet(key []byte) ([]byte, error)
-
- // TryUpdate associates key with value in the trie. If value has length zero, any
- // existing value is deleted from the trie. The value bytes must not be modified
+ // GetStorage returns the value for key stored in the trie. The value bytes
+ // must not be modified by the caller. If a node was not found in the database,
+ // a trie.MissingNodeError is returned.
+ GetStorage(addr common.Address, key []byte) ([]byte, error)
+
+ // GetAccount abstracts an account read from the trie. It retrieves the
+ // account blob from the trie with provided account address and decodes it
+ // with associated decoding algorithm. If the specified account is not in
+ // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes
+ // are missing or the account blob is incorrect for decoding), an error will
+ // be returned.
+ GetAccount(address common.Address) (*types.StateAccount, error)
+
+ // UpdateStorage associates key with value in the trie. If value has length zero,
+ // any existing value is deleted from the trie. The value bytes must not be modified
// by the caller while they are stored in the trie. If a node was not found in the
// database, a trie.MissingNodeError is returned.
- TryUpdate(key, value []byte) error
+ UpdateStorage(addr common.Address, key, value []byte) error
+
+ // UpdateAccount abstracts an account write to the trie. It encodes the
+ // provided account object with associated algorithm and then updates it
+ // in the trie with provided address.
+ UpdateAccount(address common.Address, account *types.StateAccount) error
+
+ // UpdateContractCode abstracts code write to the trie. It is expected
+ // to be moved to the stateWriter interface when the latter is ready.
+ UpdateContractCode(address common.Address, codeHash common.Hash, code []byte) error
+
+ // DeleteStorage removes any existing value for key from the trie. If a node
+ // was not found in the database, a trie.MissingNodeError is returned.
+ DeleteStorage(addr common.Address, key []byte) error
- // TryDelete removes any existing value for key from the trie. If a node was not
- // found in the database, a trie.MissingNodeError is returned.
- TryDelete(key []byte) error
+ // DeleteAccount abstracts an account deletion from the trie.
+ DeleteAccount(address common.Address) error
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
@@ -98,18 +119,18 @@ type Trie interface {
// NewDatabase creates a backing store for state. The returned database is safe for
// concurrent use, but does not retain any recent trie nodes in memory. To keep some
-// historical state in memory, use the NewDatabaseWithCache constructor.
+// historical state in memory, use the NewDatabaseWithConfig constructor.
func NewDatabase(db ethdb.Database) Database {
- return NewDatabaseWithCache(db, 0)
+ return NewDatabaseWithConfig(db, nil)
}
-// NewDatabaseWithCache creates a backing store for state. The returned database
+// NewDatabaseWithConfig creates a backing store for state. The returned database
// is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a
// large memory cache.
-func NewDatabaseWithCache(db ethdb.Database, cache int) Database {
+func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{
- db: trie.NewDatabaseWithCache(db, cache),
+ db: trie.NewDatabaseWithConfig(db, config),
codeSizeCache: csc,
}
}
diff --git a/core/state/dump.go b/core/state/dump.go
index f08c6e7df..6d8994462 100644
--- a/core/state/dump.go
+++ b/core/state/dump.go
@@ -21,6 +21,7 @@ import (
"fmt"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/trie"
)
@@ -39,40 +40,40 @@ type Dump struct {
Accounts map[string]DumpAccount `json:"accounts"`
}
-func (self *StateDB) RawDump() Dump {
+func (s *StateDB) RawDump() Dump {
dump := Dump{
- Root: fmt.Sprintf("%x", self.trie.Hash()),
+ Root: fmt.Sprintf("%x", s.trie.Hash()),
Accounts: make(map[string]DumpAccount),
}
- it := trie.NewIterator(self.trie.NodeIterator(nil))
+ it := trie.NewIterator(s.trie.NodeIterator(nil))
for it.Next() {
- addr := self.trie.GetKey(it.Key)
- var data Account
+ addr := s.trie.GetKey(it.Key)
+ var data types.StateAccount
if err := rlp.DecodeBytes(it.Value, &data); err != nil {
panic(err)
}
- obj := newObject(nil, common.BytesToAddress(addr), data, nil)
+ obj := newObject(nil, common.BytesToAddress(addr), &data)
account := DumpAccount{
Balance: data.Balance.String(),
Nonce: data.Nonce,
Root: common.Bytes2Hex(data.Root[:]),
CodeHash: common.Bytes2Hex(data.CodeHash),
- Code: common.Bytes2Hex(obj.Code(self.db)),
+ Code: common.Bytes2Hex(obj.Code(s.db)),
Storage: make(map[string]string),
}
- storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator(nil))
+ storageIt := trie.NewIterator(obj.getTrie(s.db).NodeIterator(nil))
for storageIt.Next() {
- account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value)
+ account.Storage[common.Bytes2Hex(s.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value)
}
dump.Accounts[common.Bytes2Hex(addr)] = account
}
return dump
}
-func (self *StateDB) Dump() []byte {
- json, err := json.MarshalIndent(self.RawDump(), "", " ")
+func (s *StateDB) Dump() []byte {
+ json, err := json.MarshalIndent(s.RawDump(), "", " ")
if err != nil {
fmt.Println("dump err", err)
}
diff --git a/core/state/iterator.go b/core/state/iterator.go
index 3cfc592ec..d69321f36 100644
--- a/core/state/iterator.go
+++ b/core/state/iterator.go
@@ -21,6 +21,7 @@ import (
"fmt"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/trie"
)
@@ -104,7 +105,7 @@ func (it *NodeIterator) step() error {
return nil
}
// Otherwise we've reached an account node, initiate data iteration
- var account Account
+ var account types.StateAccount
if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil {
return err
}
diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go
index 20864e076..b5b728702 100644
--- a/core/state/iterator_test.go
+++ b/core/state/iterator_test.go
@@ -29,7 +29,7 @@ func TestNodeIteratorCoverage(t *testing.T) {
// Create some arbitrary test state to iterate
db, root, _ := makeTestState()
- state, err := New(root, db)
+ state, err := New(root, db, nil)
if err != nil {
t.Fatalf("failed to create state trie at %x: %v", root, err)
}
diff --git a/core/state/journal.go b/core/state/journal.go
index 1ac5cdbf2..2c75c9dbe 100644
--- a/core/state/journal.go
+++ b/core/state/journal.go
@@ -22,11 +22,67 @@ import (
"github.com/tomochain/tomochain/common"
)
+// journalEntry is a modification entry in the state change journal that can be
+// reverted on demand.
type journalEntry interface {
- undo(*StateDB)
+ // revert undoes the changes introduced by this journal entry.
+ revert(*StateDB)
+
+ // dirtied returns the Ethereum address modified by this journal entry.
+ dirtied() *common.Address
+}
+
+// journal contains the list of state modifications applied since the last state
+// commit. These are tracked to be able to be reverted in case of an execution
+// exception or revertal request.
+type journal struct {
+ entries []journalEntry // Current changes tracked by the journal
+ dirties map[common.Address]int // Dirty accounts and the number of changes
}
-type journal []journalEntry
+// newJournal create a new initialized journal.
+func newJournal() *journal {
+ return &journal{
+ dirties: make(map[common.Address]int),
+ }
+}
+
+// append inserts a new modification entry to the end of the change journal.
+func (j *journal) append(entry journalEntry) {
+ j.entries = append(j.entries, entry)
+ if addr := entry.dirtied(); addr != nil {
+ j.dirties[*addr]++
+ }
+}
+
+// revert undoes a batch of journalled modifications along with any reverted
+// dirty handling too.
+func (j *journal) revert(statedb *StateDB, snapshot int) {
+ for i := len(j.entries) - 1; i >= snapshot; i-- {
+ // Undo the changes made by the operation
+ j.entries[i].revert(statedb)
+
+ // Drop any dirty tracking induced by the change
+ if addr := j.entries[i].dirtied(); addr != nil {
+ if j.dirties[*addr]--; j.dirties[*addr] == 0 {
+ delete(j.dirties, *addr)
+ }
+ }
+ }
+ j.entries = j.entries[:snapshot]
+}
+
+// dirty explicitly sets an address to dirty, even if the change entries would
+// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD
+// precompile consensus exception.
+func (j *journal) dirty(addr common.Address) {
+ j.dirties[addr]++
+}
+
+// length returns the current number of entries in the journal.
+func (j *journal) length() int {
+ return len(j.entries)
+}
type (
// Changes to the account trie.
@@ -34,7 +90,8 @@ type (
account *common.Address
}
resetObjectChange struct {
- prev *stateObject
+ prev *stateObject
+ prevdestruct bool
}
suicideChange struct {
account *common.Address
@@ -77,16 +134,27 @@ type (
}
)
-func (ch createObjectChange) undo(s *StateDB) {
+func (ch createObjectChange) revert(s *StateDB) {
delete(s.stateObjects, *ch.account)
delete(s.stateObjectsDirty, *ch.account)
}
-func (ch resetObjectChange) undo(s *StateDB) {
+func (ch createObjectChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch resetObjectChange) revert(s *StateDB) {
s.setStateObject(ch.prev)
+ if !ch.prevdestruct && s.snap != nil {
+ delete(s.snapDestructs, ch.prev.addrHash)
+ }
+}
+
+func (ch resetObjectChange) dirtied() *common.Address {
+ return nil
}
-func (ch suicideChange) undo(s *StateDB) {
+func (ch suicideChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
obj.suicided = ch.prev
@@ -94,38 +162,60 @@ func (ch suicideChange) undo(s *StateDB) {
}
}
+func (ch suicideChange) dirtied() *common.Address {
+ return ch.account
+}
+
var ripemd = common.HexToAddress("0000000000000000000000000000000000000003")
-func (ch touchChange) undo(s *StateDB) {
- if !ch.prev && *ch.account != ripemd {
- s.getStateObject(*ch.account).touched = ch.prev
- if !ch.prevDirty {
- delete(s.stateObjectsDirty, *ch.account)
- }
- }
+func (ch touchChange) revert(s *StateDB) {
+}
+
+func (ch touchChange) dirtied() *common.Address {
+ return ch.account
}
-func (ch balanceChange) undo(s *StateDB) {
+func (ch balanceChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setBalance(ch.prev)
}
-func (ch nonceChange) undo(s *StateDB) {
+func (ch balanceChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch nonceChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setNonce(ch.prev)
}
-func (ch codeChange) undo(s *StateDB) {
+func (ch nonceChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch codeChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
}
-func (ch storageChange) undo(s *StateDB) {
+func (ch codeChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch storageChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setState(ch.key, ch.prevalue)
}
-func (ch refundChange) undo(s *StateDB) {
+func (ch storageChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch refundChange) revert(s *StateDB) {
s.refund = ch.prev
}
-func (ch addLogChange) undo(s *StateDB) {
+func (ch refundChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch addLogChange) revert(s *StateDB) {
logs := s.logs[ch.txhash]
if len(logs) == 1 {
delete(s.logs, ch.txhash)
@@ -135,6 +225,14 @@ func (ch addLogChange) undo(s *StateDB) {
s.logSize--
}
-func (ch addPreimageChange) undo(s *StateDB) {
+func (ch addLogChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch addPreimageChange) revert(s *StateDB) {
delete(s.preimages, ch.hash)
}
+
+func (ch addPreimageChange) dirtied() *common.Address {
+ return nil
+}
diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go
index 79220dc07..c4fa4937a 100644
--- a/core/state/managed_state_test.go
+++ b/core/state/managed_state_test.go
@@ -17,17 +17,17 @@
package state
import (
- "github.com/tomochain/tomochain/core/rawdb"
"testing"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
)
var addr = common.BytesToAddress([]byte("test"))
func create() (*ManagedState, *account) {
db := rawdb.NewMemoryDatabase()
- statedb, _ := New(common.Hash{}, NewDatabase(db))
+ statedb, _ := New(common.Hash{}, NewDatabase(db), nil)
ms := ManageState(statedb)
ms.StateDB.SetNonce(addr, 100)
ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr))
diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go
new file mode 100644
index 000000000..98214497c
--- /dev/null
+++ b/core/state/snapshot/difflayer.go
@@ -0,0 +1,535 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/rand"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ bloomfilter "github.com/holiman/bloomfilter/v2"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+var (
+ // aggregatorMemoryLimit is the maximum size of the bottom-most diff layer
+ // that aggregates the writes from above until it's flushed into the disk
+ // layer.
+ //
+ // Note, bumping this up might drastically increase the size of the bloom
+ // filters that's stored in every diff layer. Don't do that without fully
+ // understanding all the implications.
+ aggregatorMemoryLimit = uint64(4 * 1024 * 1024)
+
+ // aggregatorItemLimit is an approximate number of items that will end up
+ // in the agregator layer before it's flushed out to disk. A plain account
+ // weighs around 14B (+hash), a storage slot 32B (+hash), a deleted slot
+ // 0B (+hash). Slots are mostly set/unset in lockstep, so thet average at
+ // 16B (+hash). All in all, the average entry seems to be 15+32=47B. Use a
+ // smaller number to be on the safe side.
+ aggregatorItemLimit = aggregatorMemoryLimit / 42
+
+ // bloomTargetError is the target false positive rate when the aggregator
+ // layer is at its fullest. The actual value will probably move around up
+ // and down from this number, it's mostly a ballpark figure.
+ //
+ // Note, dropping this down might drastically increase the size of the bloom
+ // filters that's stored in every diff layer. Don't do that without fully
+ // understanding all the implications.
+ bloomTargetError = 0.02
+
+ // bloomSize is the ideal bloom filter size given the maximum number of items
+ // it's expected to hold and the target false positive error rate.
+ bloomSize = math.Ceil(float64(aggregatorItemLimit) * math.Log(bloomTargetError) / math.Log(1/math.Pow(2, math.Log(2))))
+
+ // bloomFuncs is the ideal number of bits a single entry should set in the
+ // bloom filter to keep its size to a minimum (given it's size and maximum
+ // entry count).
+ bloomFuncs = math.Round((bloomSize / float64(aggregatorItemLimit)) * math.Log(2))
+
+ // the bloom offsets are runtime constants which determines which part of the
+ // the account/storage hash the hasher functions looks at, to determine the
+ // bloom key for an account/slot. This is randomized at init(), so that the
+ // global population of nodes do not all display the exact same behaviour with
+ // regards to bloom content
+ bloomDestructHasherOffset = 0
+ bloomAccountHasherOffset = 0
+ bloomStorageHasherOffset = 0
+)
+
+func init() {
+ // Init the bloom offsets in the range [0:24] (requires 8 bytes)
+ bloomDestructHasherOffset = rand.Intn(25)
+ bloomAccountHasherOffset = rand.Intn(25)
+ bloomStorageHasherOffset = rand.Intn(25)
+
+ // The destruct and account blooms must be different, as the storage slots
+ // will check for destruction too for every bloom miss. It should not collide
+ // with modified accounts.
+ for bloomAccountHasherOffset == bloomDestructHasherOffset {
+ bloomAccountHasherOffset = rand.Intn(25)
+ }
+}
+
+// diffLayer represents a collection of modifications made to a state snapshot
+// after running a block on top. It contains one sorted list for the account trie
+// and one-one list for each storage tries.
+//
+// The goal of a diff layer is to act as a journal, tracking recent modifications
+// made to the state, that have not yet graduated into a semi-immutable state.
+type diffLayer struct {
+ origin *diskLayer // Base disk layer to directly use on bloom misses
+ parent snapshot // Parent snapshot modified by this one, never nil
+ memory uint64 // Approximate guess as to how much memory we use
+
+ root common.Hash // Root hash to which this snapshot diff belongs to
+ stale uint32 // Signals that the layer became stale (state progressed)
+
+ destructSet map[common.Hash]struct{} // Keyed markers for deleted (and potentially) recreated accounts
+ accountList []common.Hash // List of account for iteration. If it exists, it's sorted, otherwise it's nil
+ accountData map[common.Hash][]byte // Keyed accounts for direct retrival (nil means deleted)
+ storageList map[common.Hash][]common.Hash // List of storage slots for iterated retrievals, one per account. Any existing lists are sorted if non-nil
+ storageData map[common.Hash]map[common.Hash][]byte // Keyed storage slots for direct retrival. one per account (nil means deleted)
+
+ diffed *bloomfilter.Filter // Bloom filter tracking all the diffed items up to the disk layer
+
+ lock sync.RWMutex
+}
+
+// destructBloomHasher is a wrapper around a common.Hash to satisfy the interface
+// API requirements of the bloom library used. It's used to convert a destruct
+// event into a 64 bit mini hash.
+type destructBloomHasher common.Hash
+
+func (h destructBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") }
+func (h destructBloomHasher) Sum(b []byte) []byte { panic("not implemented") }
+func (h destructBloomHasher) Reset() { panic("not implemented") }
+func (h destructBloomHasher) BlockSize() int { panic("not implemented") }
+func (h destructBloomHasher) Size() int { return 8 }
+func (h destructBloomHasher) Sum64() uint64 {
+ return binary.BigEndian.Uint64(h[bloomDestructHasherOffset : bloomDestructHasherOffset+8])
+}
+
+// accountBloomHasher is a wrapper around a common.Hash to satisfy the interface
+// API requirements of the bloom library used. It's used to convert an account
+// hash into a 64 bit mini hash.
+type accountBloomHasher common.Hash
+
+func (h accountBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") }
+func (h accountBloomHasher) Sum(b []byte) []byte { panic("not implemented") }
+func (h accountBloomHasher) Reset() { panic("not implemented") }
+func (h accountBloomHasher) BlockSize() int { panic("not implemented") }
+func (h accountBloomHasher) Size() int { return 8 }
+func (h accountBloomHasher) Sum64() uint64 {
+ return binary.BigEndian.Uint64(h[bloomAccountHasherOffset : bloomAccountHasherOffset+8])
+}
+
+// storageBloomHasher is a wrapper around a [2]common.Hash to satisfy the interface
+// API requirements of the bloom library used. It's used to convert an account
+// hash into a 64 bit mini hash.
+type storageBloomHasher [2]common.Hash
+
+func (h storageBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") }
+func (h storageBloomHasher) Sum(b []byte) []byte { panic("not implemented") }
+func (h storageBloomHasher) Reset() { panic("not implemented") }
+func (h storageBloomHasher) BlockSize() int { panic("not implemented") }
+func (h storageBloomHasher) Size() int { return 8 }
+func (h storageBloomHasher) Sum64() uint64 {
+ return binary.BigEndian.Uint64(h[0][bloomStorageHasherOffset:bloomStorageHasherOffset+8]) ^
+ binary.BigEndian.Uint64(h[1][bloomStorageHasherOffset:bloomStorageHasherOffset+8])
+}
+
+// newDiffLayer creates a new diff on top of an existing snapshot, whether that's a low
+// level persistent database or a hierarchical diff already.
+func newDiffLayer(parent snapshot, root common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer {
+ // Create the new layer with some pre-allocated data segments
+ dl := &diffLayer{
+ parent: parent,
+ root: root,
+ destructSet: destructs,
+ accountData: accounts,
+ storageData: storage,
+ }
+ switch parent := parent.(type) {
+ case *diskLayer:
+ dl.rebloom(parent)
+ case *diffLayer:
+ dl.rebloom(parent.origin)
+ default:
+ panic("unknown parent type")
+ }
+ // Sanity check that accounts or storage slots are never nil
+ for accountHash, blob := range accounts {
+ if blob == nil {
+ panic(fmt.Sprintf("account %#x nil", accountHash))
+ }
+ }
+ for accountHash, slots := range storage {
+ if slots == nil {
+ panic(fmt.Sprintf("storage %#x nil", accountHash))
+ }
+ }
+ // Determine memory size and track the dirty writes
+ for _, data := range accounts {
+ dl.memory += uint64(common.HashLength + len(data))
+ snapshotDirtyAccountWriteMeter.Mark(int64(len(data)))
+ }
+ // Fill the storage hashes and sort them for the iterator
+ dl.storageList = make(map[common.Hash][]common.Hash)
+ for accountHash := range destructs {
+ dl.storageList[accountHash] = nil
+ }
+ // Determine memory size and track the dirty writes
+ for _, slots := range storage {
+ for _, data := range slots {
+ dl.memory += uint64(common.HashLength + len(data))
+ snapshotDirtyStorageWriteMeter.Mark(int64(len(data)))
+ }
+ }
+ dl.memory += uint64(len(dl.storageList) * common.HashLength)
+ return dl
+}
+
+// rebloom discards the layer's current bloom and rebuilds it from scratch based
+// on the parent's and the local diffs.
+func (dl *diffLayer) rebloom(origin *diskLayer) {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ defer func(start time.Time) {
+ snapshotBloomIndexTimer.Update(time.Since(start))
+ }(time.Now())
+
+ // Inject the new origin that triggered the rebloom
+ dl.origin = origin
+
+ // Retrieve the parent bloom or create a fresh empty one
+ if parent, ok := dl.parent.(*diffLayer); ok {
+ parent.lock.RLock()
+ dl.diffed, _ = parent.diffed.Copy()
+ parent.lock.RUnlock()
+ } else {
+ dl.diffed, _ = bloomfilter.New(uint64(bloomSize), uint64(bloomFuncs))
+ }
+ // Iterate over all the accounts and storage slots and index them
+ for hash := range dl.destructSet {
+ dl.diffed.Add(destructBloomHasher(hash))
+ }
+ for hash := range dl.accountData {
+ dl.diffed.Add(accountBloomHasher(hash))
+ }
+ for accountHash, slots := range dl.storageData {
+ for storageHash := range slots {
+ dl.diffed.Add(storageBloomHasher{accountHash, storageHash})
+ }
+ }
+ // Calculate the current false positive rate and update the error rate meter.
+ // This is a bit cheating because subsequent layers will overwrite it, but it
+ // should be fine, we're only interested in ballpark figures.
+ k := float64(dl.diffed.K())
+ n := float64(dl.diffed.N())
+ m := float64(dl.diffed.M())
+ snapshotBloomErrorGauge.Update(math.Pow(1.0-math.Exp((-k)*(n+0.5)/(m-1)), k))
+}
+
+// Root returns the root hash for which this snapshot was made.
+func (dl *diffLayer) Root() common.Hash {
+ return dl.root
+}
+
+// Parent returns the subsequent layer of a diff layer.
+func (dl *diffLayer) Parent() snapshot {
+ return dl.parent
+}
+
+// Stale return whether this layer has become stale (was flattened across) or if
+// it's still live.
+func (dl *diffLayer) Stale() bool {
+ return atomic.LoadUint32(&dl.stale) != 0
+}
+
+// Account directly retrieves the account associated with a particular hash in
+// the snapshot slim data format.
+func (dl *diffLayer) Account(hash common.Hash) (*types.SlimAccount, error) {
+ data, err := dl.AccountRLP(hash)
+ if err != nil {
+ return nil, err
+ }
+ if len(data) == 0 { // can be both nil and []byte{}
+ return nil, nil
+ }
+ account := new(types.SlimAccount)
+ if err := rlp.DecodeBytes(data, account); err != nil {
+ panic(err)
+ }
+ return account, nil
+}
+
+// AccountRLP directly retrieves the account RLP associated with a particular
+// hash in the snapshot slim data format.
+func (dl *diffLayer) AccountRLP(hash common.Hash) ([]byte, error) {
+ // Check the bloom filter first whether there's even a point in reaching into
+ // all the maps in all the layers below
+ dl.lock.RLock()
+ hit := dl.diffed.Contains(accountBloomHasher(hash))
+ if !hit {
+ hit = dl.diffed.Contains(destructBloomHasher(hash))
+ }
+ dl.lock.RUnlock()
+
+ // If the bloom filter misses, don't even bother with traversing the memory
+ // diff layers, reach straight into the bottom persistent disk layer
+ if !hit {
+ snapshotBloomAccountMissMeter.Mark(1)
+ return dl.origin.AccountRLP(hash)
+ }
+ // The bloom filter hit, start poking in the internal maps
+ return dl.accountRLP(hash, 0)
+}
+
+// accountRLP is an internal version of AccountRLP that skips the bloom filter
+// checks and uses the internal maps to try and retrieve the data. It's meant
+// to be used if a higher layer's bloom filter hit already.
+func (dl *diffLayer) accountRLP(hash common.Hash, depth int) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.Stale() {
+ return nil, ErrSnapshotStale
+ }
+ // If the account is known locally, return it
+ if data, ok := dl.accountData[hash]; ok {
+ snapshotDirtyAccountHitMeter.Mark(1)
+ snapshotDirtyAccountHitDepthHist.Update(int64(depth))
+ snapshotDirtyAccountReadMeter.Mark(int64(len(data)))
+ snapshotBloomAccountTrueHitMeter.Mark(1)
+ return data, nil
+ }
+ // If the account is known locally, but deleted, return it
+ if _, ok := dl.destructSet[hash]; ok {
+ snapshotDirtyAccountHitMeter.Mark(1)
+ snapshotDirtyAccountHitDepthHist.Update(int64(depth))
+ snapshotDirtyAccountInexMeter.Mark(1)
+ snapshotBloomAccountTrueHitMeter.Mark(1)
+ return nil, nil
+ }
+ // Account unknown to this diff, resolve from parent
+ if diff, ok := dl.parent.(*diffLayer); ok {
+ return diff.accountRLP(hash, depth+1)
+ }
+ // Failed to resolve through diff layers, mark a bloom error and use the disk
+ snapshotBloomAccountFalseHitMeter.Mark(1)
+ return dl.parent.AccountRLP(hash)
+}
+
+// Storage directly retrieves the storage data associated with a particular hash,
+// within a particular account. If the slot is unknown to this diff, it's parent
+// is consulted.
+func (dl *diffLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) {
+ // Check the bloom filter first whether there's even a point in reaching into
+ // all the maps in all the layers below
+ dl.lock.RLock()
+ hit := dl.diffed.Contains(storageBloomHasher{accountHash, storageHash})
+ if !hit {
+ hit = dl.diffed.Contains(destructBloomHasher(accountHash))
+ }
+ dl.lock.RUnlock()
+
+ // If the bloom filter misses, don't even bother with traversing the memory
+ // diff layers, reach straight into the bottom persistent disk layer
+ if !hit {
+ snapshotBloomStorageMissMeter.Mark(1)
+ return dl.origin.Storage(accountHash, storageHash)
+ }
+ // The bloom filter hit, start poking in the internal maps
+ return dl.storage(accountHash, storageHash, 0)
+}
+
+// storage is an internal version of Storage that skips the bloom filter checks
+// and uses the internal maps to try and retrieve the data. It's meant to be
+// used if a higher layer's bloom filter hit already.
+func (dl *diffLayer) storage(accountHash, storageHash common.Hash, depth int) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.Stale() {
+ return nil, ErrSnapshotStale
+ }
+ // If the account is known locally, try to resolve the slot locally
+ if storage, ok := dl.storageData[accountHash]; ok {
+ if data, ok := storage[storageHash]; ok {
+ snapshotDirtyStorageHitMeter.Mark(1)
+ snapshotDirtyStorageHitDepthHist.Update(int64(depth))
+ if n := len(data); n > 0 {
+ snapshotDirtyStorageReadMeter.Mark(int64(n))
+ } else {
+ snapshotDirtyStorageInexMeter.Mark(1)
+ }
+ snapshotBloomStorageTrueHitMeter.Mark(1)
+ return data, nil
+ }
+ }
+ // If the account is known locally, but deleted, return an empty slot
+ if _, ok := dl.destructSet[accountHash]; ok {
+ snapshotDirtyStorageHitMeter.Mark(1)
+ snapshotDirtyStorageHitDepthHist.Update(int64(depth))
+ snapshotDirtyStorageInexMeter.Mark(1)
+ snapshotBloomStorageTrueHitMeter.Mark(1)
+ return nil, nil
+ }
+ // Storage slot unknown to this diff, resolve from parent
+ if diff, ok := dl.parent.(*diffLayer); ok {
+ return diff.storage(accountHash, storageHash, depth+1)
+ }
+ // Failed to resolve through diff layers, mark a bloom error and use the disk
+ snapshotBloomStorageFalseHitMeter.Mark(1)
+ return dl.parent.Storage(accountHash, storageHash)
+}
+
+// Update creates a new layer on top of the existing snapshot diff tree with
+// the specified data items.
+func (dl *diffLayer) Update(blockRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer {
+ return newDiffLayer(dl, blockRoot, destructs, accounts, storage)
+}
+
+// flatten pushes all data from this point downwards, flattening everything into
+// a single diff at the bottom. Since usually the lowermost diff is the largest,
+// the flattening bulds up from there in reverse.
+func (dl *diffLayer) flatten() snapshot {
+ // If the parent is not diff, we're the first in line, return unmodified
+ parent, ok := dl.parent.(*diffLayer)
+ if !ok {
+ return dl
+ }
+ // Parent is a diff, flatten it first (note, apart from weird corned cases,
+ // flatten will realistically only ever merge 1 layer, so there's no need to
+ // be smarter about grouping flattens together).
+ parent = parent.flatten().(*diffLayer)
+
+ parent.lock.Lock()
+ defer parent.lock.Unlock()
+
+ // Before actually writing all our data to the parent, first ensure that the
+ // parent hasn't been 'corrupted' by someone else already flattening into it
+ if atomic.SwapUint32(&parent.stale, 1) != 0 {
+ panic("parent diff layer is stale") // we've flattened into the same parent from two children, boo
+ }
+ // Overwrite all the updated accounts blindly, merge the sorted list
+ for hash := range dl.destructSet {
+ parent.destructSet[hash] = struct{}{}
+ delete(parent.accountData, hash)
+ delete(parent.storageData, hash)
+ }
+ for hash, data := range dl.accountData {
+ parent.accountData[hash] = data
+ }
+ // Overwrite all the updated storage slots (individually)
+ for accountHash, storage := range dl.storageData {
+ // If storage didn't exist (or was deleted) in the parent, overwrite blindly
+ if _, ok := parent.storageData[accountHash]; !ok {
+ parent.storageData[accountHash] = storage
+ continue
+ }
+ // Storage exists in both parent and child, merge the slots
+ comboData := parent.storageData[accountHash]
+ for storageHash, data := range storage {
+ comboData[storageHash] = data
+ }
+ parent.storageData[accountHash] = comboData
+ }
+ // Return the combo parent
+ return &diffLayer{
+ parent: parent.parent,
+ origin: parent.origin,
+ root: dl.root,
+ destructSet: parent.destructSet,
+ accountData: parent.accountData,
+ storageData: parent.storageData,
+ storageList: make(map[common.Hash][]common.Hash),
+ diffed: dl.diffed,
+ memory: parent.memory + dl.memory,
+ }
+}
+
+// AccountList returns a sorted list of all accounts in this difflayer, including
+// the deleted ones.
+//
+// Note, the returned slice is not a copy, so do not modify it.
+func (dl *diffLayer) AccountList() []common.Hash {
+ // If an old list already exists, return it
+ dl.lock.RLock()
+ list := dl.accountList
+ dl.lock.RUnlock()
+
+ if list != nil {
+ return list
+ }
+ // No old sorted account list exists, generate a new one
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ dl.accountList = make([]common.Hash, 0, len(dl.destructSet)+len(dl.accountData))
+ for hash := range dl.accountData {
+ dl.accountList = append(dl.accountList, hash)
+ }
+ for hash := range dl.destructSet {
+ if _, ok := dl.accountData[hash]; !ok {
+ dl.accountList = append(dl.accountList, hash)
+ }
+ }
+ sort.Sort(hashes(dl.accountList))
+ return dl.accountList
+}
+
+// StorageList returns a sorted list of all storage slot hashes in this difflayer
+// for the given account.
+//
+// Note, the returned slice is not a copy, so do not modify it.
+func (dl *diffLayer) StorageList(accountHash common.Hash) []common.Hash {
+ // If an old list already exists, return it
+ dl.lock.RLock()
+ list := dl.storageList[accountHash]
+ dl.lock.RUnlock()
+
+ if list != nil {
+ return list
+ }
+ // No old sorted account list exists, generate a new one
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ storageMap := dl.storageData[accountHash]
+ storageList := make([]common.Hash, 0, len(storageMap))
+ for k := range storageMap {
+ storageList = append(storageList, k)
+ }
+ sort.Sort(hashes(storageList))
+ dl.storageList[accountHash] = storageList
+ return storageList
+}
diff --git a/core/state/snapshot/difflayer_test.go b/core/state/snapshot/difflayer_test.go
new file mode 100644
index 000000000..89814432b
--- /dev/null
+++ b/core/state/snapshot/difflayer_test.go
@@ -0,0 +1,399 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/ethdb/memorydb"
+)
+
+func copyDestructs(destructs map[common.Hash]struct{}) map[common.Hash]struct{} {
+ copy := make(map[common.Hash]struct{})
+ for hash := range destructs {
+ copy[hash] = struct{}{}
+ }
+ return copy
+}
+
+func copyAccounts(accounts map[common.Hash][]byte) map[common.Hash][]byte {
+ copy := make(map[common.Hash][]byte)
+ for hash, blob := range accounts {
+ copy[hash] = blob
+ }
+ return copy
+}
+
+func copyStorage(storage map[common.Hash]map[common.Hash][]byte) map[common.Hash]map[common.Hash][]byte {
+ copy := make(map[common.Hash]map[common.Hash][]byte)
+ for accHash, slots := range storage {
+ copy[accHash] = make(map[common.Hash][]byte)
+ for slotHash, blob := range slots {
+ copy[accHash][slotHash] = blob
+ }
+ }
+ return copy
+}
+
+// TestMergeBasics tests some simple merges
+func TestMergeBasics(t *testing.T) {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ // Fill up a parent
+ for i := 0; i < 100; i++ {
+ h := randomHash()
+ data := randomAccount()
+
+ accounts[h] = data
+ if rand.Intn(4) == 0 {
+ destructs[h] = struct{}{}
+ }
+ if rand.Intn(2) == 0 {
+ accStorage := make(map[common.Hash][]byte)
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ storage[h] = accStorage
+ }
+ }
+ // Add some (identical) layers on top
+ parent := newDiffLayer(emptyLayer(), common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage))
+ child := newDiffLayer(parent, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage))
+ child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage))
+ child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage))
+ child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage))
+ // And flatten
+ merged := (child.flatten()).(*diffLayer)
+
+ { // Check account lists
+ if have, want := len(merged.accountList), 0; have != want {
+ t.Errorf("accountList wrong: have %v, want %v", have, want)
+ }
+ if have, want := len(merged.AccountList()), len(accounts); have != want {
+ t.Errorf("AccountList() wrong: have %v, want %v", have, want)
+ }
+ if have, want := len(merged.accountList), len(accounts); have != want {
+ t.Errorf("accountList [2] wrong: have %v, want %v", have, want)
+ }
+ }
+ { // Check account drops
+ if have, want := len(merged.destructSet), len(destructs); have != want {
+ t.Errorf("accountDrop wrong: have %v, want %v", have, want)
+ }
+ }
+ { // Check storage lists
+ i := 0
+ for aHash, sMap := range storage {
+ if have, want := len(merged.storageList), i; have != want {
+ t.Errorf("[1] storageList wrong: have %v, want %v", have, want)
+ }
+ if have, want := len(merged.StorageList(aHash)), len(sMap); have != want {
+ t.Errorf("[2] StorageList() wrong: have %v, want %v", have, want)
+ }
+ if have, want := len(merged.storageList[aHash]), len(sMap); have != want {
+ t.Errorf("storageList wrong: have %v, want %v", have, want)
+ }
+ i++
+ }
+ }
+}
+
+// TestMergeDelete tests some deletion
+func TestMergeDelete(t *testing.T) {
+ var (
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ // Fill up a parent
+ h1 := common.HexToHash("0x01")
+ h2 := common.HexToHash("0x02")
+
+ flipDrops := func() map[common.Hash]struct{} {
+ return map[common.Hash]struct{}{
+ h2: struct{}{},
+ }
+ }
+ flipAccs := func() map[common.Hash][]byte {
+ return map[common.Hash][]byte{
+ h1: randomAccount(),
+ }
+ }
+ flopDrops := func() map[common.Hash]struct{} {
+ return map[common.Hash]struct{}{
+ h1: struct{}{},
+ }
+ }
+ flopAccs := func() map[common.Hash][]byte {
+ return map[common.Hash][]byte{
+ h2: randomAccount(),
+ }
+ }
+ // Add some flipAccs-flopping layers on top
+ parent := newDiffLayer(emptyLayer(), common.Hash{}, flipDrops(), flipAccs(), storage)
+ child := parent.Update(common.Hash{}, flopDrops(), flopAccs(), storage)
+ child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage)
+ child = child.Update(common.Hash{}, flopDrops(), flopAccs(), storage)
+ child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage)
+ child = child.Update(common.Hash{}, flopDrops(), flopAccs(), storage)
+ child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage)
+
+ if data, _ := child.Account(h1); data == nil {
+ t.Errorf("last diff layer: expected %x account to be non-nil", h1)
+ }
+ if data, _ := child.Account(h2); data != nil {
+ t.Errorf("last diff layer: expected %x account to be nil", h2)
+ }
+ if _, ok := child.destructSet[h1]; ok {
+ t.Errorf("last diff layer: expected %x drop to be missing", h1)
+ }
+ if _, ok := child.destructSet[h2]; !ok {
+ t.Errorf("last diff layer: expected %x drop to be present", h1)
+ }
+ // And flatten
+ merged := (child.flatten()).(*diffLayer)
+
+ if data, _ := merged.Account(h1); data == nil {
+ t.Errorf("merged layer: expected %x account to be non-nil", h1)
+ }
+ if data, _ := merged.Account(h2); data != nil {
+ t.Errorf("merged layer: expected %x account to be nil", h2)
+ }
+ if _, ok := merged.destructSet[h1]; !ok { // Note, drops stay alive until persisted to disk!
+ t.Errorf("merged diff layer: expected %x drop to be present", h1)
+ }
+ if _, ok := merged.destructSet[h2]; !ok { // Note, drops stay alive until persisted to disk!
+ t.Errorf("merged diff layer: expected %x drop to be present", h1)
+ }
+ // If we add more granular metering of memory, we can enable this again,
+ // but it's not implemented for now
+ //if have, want := merged.memory, child.memory; have != want {
+ // t.Errorf("mem wrong: have %d, want %d", have, want)
+ //}
+}
+
+// This tests that if we create a new account, and set a slot, and then merge
+// it, the lists will be correct.
+func TestInsertAndMerge(t *testing.T) {
+ // Fill up a parent
+ var (
+ acc = common.HexToHash("0x01")
+ slot = common.HexToHash("0x02")
+ parent *diffLayer
+ child *diffLayer
+ )
+ {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ parent = newDiffLayer(emptyLayer(), common.Hash{}, destructs, accounts, storage)
+ }
+ {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ accounts[acc] = randomAccount()
+ storage[acc] = make(map[common.Hash][]byte)
+ storage[acc][slot] = []byte{0x01}
+ child = newDiffLayer(parent, common.Hash{}, destructs, accounts, storage)
+ }
+ // And flatten
+ merged := (child.flatten()).(*diffLayer)
+ { // Check that slot value is present
+ have, _ := merged.Storage(acc, slot)
+ if want := []byte{0x01}; !bytes.Equal(have, want) {
+ t.Errorf("merged slot value wrong: have %x, want %x", have, want)
+ }
+ }
+}
+
+func emptyLayer() *diskLayer {
+ return &diskLayer{
+ diskdb: memorydb.New(),
+ cache: fastcache.New(500 * 1024),
+ }
+}
+
+// BenchmarkSearch checks how long it takes to find a non-existing key
+// BenchmarkSearch-6 200000 10481 ns/op (1K per layer)
+// BenchmarkSearch-6 200000 10760 ns/op (10K per layer)
+// BenchmarkSearch-6 100000 17866 ns/op
+//
+// BenchmarkSearch-6 500000 3723 ns/op (10k per layer, only top-level RLock()
+func BenchmarkSearch(b *testing.B) {
+ // First, we set up 128 diff layers, with 1K items each
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ for i := 0; i < 10000; i++ {
+ accounts[randomHash()] = randomAccount()
+ }
+ return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage)
+ }
+ var layer snapshot
+ layer = emptyLayer()
+ for i := 0; i < 128; i++ {
+ layer = fill(layer)
+ }
+ key := crypto.Keccak256Hash([]byte{0x13, 0x38})
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ layer.AccountRLP(key)
+ }
+}
+
+// BenchmarkSearchSlot checks how long it takes to find a non-existing key
+// - Number of layers: 128
+// - Each layers contains the account, with a couple of storage slots
+// BenchmarkSearchSlot-6 100000 14554 ns/op
+// BenchmarkSearchSlot-6 100000 22254 ns/op (when checking parent root using mutex)
+// BenchmarkSearchSlot-6 100000 14551 ns/op (when checking parent number using atomic)
+// With bloom filter:
+// BenchmarkSearchSlot-6 3467835 351 ns/op
+func BenchmarkSearchSlot(b *testing.B) {
+ // First, we set up 128 diff layers, with 1K items each
+ accountKey := crypto.Keccak256Hash([]byte{0x13, 0x37})
+ storageKey := crypto.Keccak256Hash([]byte{0x13, 0x37})
+ accountRLP := randomAccount()
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ accounts[accountKey] = accountRLP
+
+ accStorage := make(map[common.Hash][]byte)
+ for i := 0; i < 5; i++ {
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ storage[accountKey] = accStorage
+ }
+ return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage)
+ }
+ var layer snapshot
+ layer = emptyLayer()
+ for i := 0; i < 128; i++ {
+ layer = fill(layer)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ layer.Storage(accountKey, storageKey)
+ }
+}
+
+// With accountList and sorting
+// BenchmarkFlatten-6 50 29890856 ns/op
+//
+// Without sorting and tracking accountlist
+// BenchmarkFlatten-6 300 5511511 ns/op
+func BenchmarkFlatten(b *testing.B) {
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ for i := 0; i < 100; i++ {
+ accountKey := randomHash()
+ accounts[accountKey] = randomAccount()
+
+ accStorage := make(map[common.Hash][]byte)
+ for i := 0; i < 20; i++ {
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+
+ }
+ storage[accountKey] = accStorage
+ }
+ return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ var layer snapshot
+ layer = emptyLayer()
+ for i := 1; i < 128; i++ {
+ layer = fill(layer)
+ }
+ b.StartTimer()
+
+ for i := 1; i < 128; i++ {
+ dl, ok := layer.(*diffLayer)
+ if !ok {
+ break
+ }
+ layer = dl.flatten()
+ }
+ b.StopTimer()
+ }
+}
+
+// This test writes ~324M of diff layers to disk, spread over
+// - 128 individual layers,
+// - each with 200 accounts
+// - containing 200 slots
+//
+// BenchmarkJournal-6 1 1471373923 ns/ops
+// BenchmarkJournal-6 1 1208083335 ns/op // bufio writer
+func BenchmarkJournal(b *testing.B) {
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ for i := 0; i < 200; i++ {
+ accountKey := randomHash()
+ accounts[accountKey] = randomAccount()
+
+ accStorage := make(map[common.Hash][]byte)
+ for i := 0; i < 200; i++ {
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+
+ }
+ storage[accountKey] = accStorage
+ }
+ return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage)
+ }
+ layer := snapshot(new(diskLayer))
+ for i := 1; i < 128; i++ {
+ layer = fill(layer)
+ }
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ layer.Journal(new(bytes.Buffer))
+ }
+}
diff --git a/core/state/snapshot/disklayer.go b/core/state/snapshot/disklayer.go
new file mode 100644
index 000000000..febb3e675
--- /dev/null
+++ b/core/state/snapshot/disklayer.go
@@ -0,0 +1,168 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "sync"
+
+ "github.com/VictoriaMetrics/fastcache"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/rlp"
+ "github.com/tomochain/tomochain/trie"
+)
+
+// diskLayer is a low level persistent snapshot built on top of a key-value store.
+type diskLayer struct {
+ diskdb ethdb.KeyValueStore // Key-value store containing the base snapshot
+ triedb *trie.Database // Trie node cache for reconstuction purposes
+ cache *fastcache.Cache // Cache to avoid hitting the disk for direct access
+
+ root common.Hash // Root hash of the base snapshot
+ stale bool // Signals that the layer became stale (state progressed)
+
+ genMarker []byte // Marker for the state that's indexed during initial layer generation
+ genPending chan struct{} // Notification channel when generation is done (test synchronicity)
+ genAbort chan chan *generatorStats // Notification channel to abort generating the snapshot in this layer
+
+ lock sync.RWMutex
+}
+
+// Root returns root hash for which this snapshot was made.
+func (dl *diskLayer) Root() common.Hash {
+ return dl.root
+}
+
+// Parent always returns nil as there's no layer below the disk.
+func (dl *diskLayer) Parent() snapshot {
+ return nil
+}
+
+// Stale return whether this layer has become stale (was flattened across) or if
+// it's still live.
+func (dl *diskLayer) Stale() bool {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ return dl.stale
+}
+
+// Account directly retrieves the account associated with a particular hash in
+// the snapshot slim data format.
+func (dl *diskLayer) Account(hash common.Hash) (*types.SlimAccount, error) {
+ data, err := dl.AccountRLP(hash)
+ if err != nil {
+ return nil, err
+ }
+ if len(data) == 0 { // can be both nil and []byte{}
+ return nil, nil
+ }
+ account := new(types.SlimAccount)
+ if err := rlp.DecodeBytes(data, account); err != nil {
+ panic(err)
+ }
+ return account, nil
+}
+
+// AccountRLP directly retrieves the account RLP associated with a particular
+// hash in the snapshot slim data format.
+func (dl *diskLayer) AccountRLP(hash common.Hash) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.stale {
+ return nil, ErrSnapshotStale
+ }
+ // If the layer is being generated, ensure the requested hash has already been
+ // covered by the generator.
+ if dl.genMarker != nil && bytes.Compare(hash[:], dl.genMarker) > 0 {
+ return nil, ErrNotCoveredYet
+ }
+ // If we're in the disk layer, all diff layers missed
+ snapshotDirtyAccountMissMeter.Mark(1)
+
+ // Try to retrieve the account from the memory cache
+ if blob, found := dl.cache.HasGet(nil, hash[:]); found {
+ snapshotCleanAccountHitMeter.Mark(1)
+ snapshotCleanAccountReadMeter.Mark(int64(len(blob)))
+ return blob, nil
+ }
+ // Cache doesn't contain account, pull from disk and cache for later
+ blob := rawdb.ReadAccountSnapshot(dl.diskdb, hash)
+ dl.cache.Set(hash[:], blob)
+
+ snapshotCleanAccountMissMeter.Mark(1)
+ if n := len(blob); n > 0 {
+ snapshotCleanAccountWriteMeter.Mark(int64(n))
+ } else {
+ snapshotCleanAccountInexMeter.Mark(1)
+ }
+ return blob, nil
+}
+
+// Storage directly retrieves the storage data associated with a particular hash,
+// within a particular account.
+func (dl *diskLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.stale {
+ return nil, ErrSnapshotStale
+ }
+ key := append(accountHash[:], storageHash[:]...)
+
+ // If the layer is being generated, ensure the requested hash has already been
+ // covered by the generator.
+ if dl.genMarker != nil && bytes.Compare(key, dl.genMarker) > 0 {
+ return nil, ErrNotCoveredYet
+ }
+ // If we're in the disk layer, all diff layers missed
+ snapshotDirtyStorageMissMeter.Mark(1)
+
+ // Try to retrieve the storage slot from the memory cache
+ if blob, found := dl.cache.HasGet(nil, key); found {
+ snapshotCleanStorageHitMeter.Mark(1)
+ snapshotCleanStorageReadMeter.Mark(int64(len(blob)))
+ return blob, nil
+ }
+ // Cache doesn't contain storage slot, pull from disk and cache for later
+ blob := rawdb.ReadStorageSnapshot(dl.diskdb, accountHash, storageHash)
+ dl.cache.Set(key, blob)
+
+ snapshotCleanStorageMissMeter.Mark(1)
+ if n := len(blob); n > 0 {
+ snapshotCleanStorageWriteMeter.Mark(int64(n))
+ } else {
+ snapshotCleanStorageInexMeter.Mark(1)
+ }
+ return blob, nil
+}
+
+// Update creates a new layer on top of the existing snapshot diff tree with
+// the specified data items. Note, the maps are retained by the method to avoid
+// copying everything.
+func (dl *diskLayer) Update(blockHash common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer {
+ return newDiffLayer(dl, blockHash, destructs, accounts, storage)
+}
diff --git a/core/state/snapshot/disklayer_test.go b/core/state/snapshot/disklayer_test.go
new file mode 100644
index 000000000..652e531b2
--- /dev/null
+++ b/core/state/snapshot/disklayer_test.go
@@ -0,0 +1,435 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/ethdb/memorydb"
+)
+
+// reverse reverses the contents of a byte slice. It's used to update random accs
+// with deterministic changes.
+func reverse(blob []byte) []byte {
+ res := make([]byte, len(blob))
+ for i, b := range blob {
+ res[len(blob)-1-i] = b
+ }
+ return res
+}
+
+// Tests that merging something into a disk layer persists it into the database
+// and invalidates any previously written and cached values.
+func TestDiskMerge(t *testing.T) {
+ // Create some accounts in the disk layer
+ db := memorydb.New()
+
+ var (
+ accNoModNoCache = common.Hash{0x1}
+ accNoModCache = common.Hash{0x2}
+ accModNoCache = common.Hash{0x3}
+ accModCache = common.Hash{0x4}
+ accDelNoCache = common.Hash{0x5}
+ accDelCache = common.Hash{0x6}
+ conNoModNoCache = common.Hash{0x7}
+ conNoModNoCacheSlot = common.Hash{0x70}
+ conNoModCache = common.Hash{0x8}
+ conNoModCacheSlot = common.Hash{0x80}
+ conModNoCache = common.Hash{0x9}
+ conModNoCacheSlot = common.Hash{0x90}
+ conModCache = common.Hash{0xa}
+ conModCacheSlot = common.Hash{0xa0}
+ conDelNoCache = common.Hash{0xb}
+ conDelNoCacheSlot = common.Hash{0xb0}
+ conDelCache = common.Hash{0xc}
+ conDelCacheSlot = common.Hash{0xc0}
+ conNukeNoCache = common.Hash{0xd}
+ conNukeNoCacheSlot = common.Hash{0xd0}
+ conNukeCache = common.Hash{0xe}
+ conNukeCacheSlot = common.Hash{0xe0}
+ baseRoot = randomHash()
+ diffRoot = randomHash()
+ )
+
+ rawdb.WriteAccountSnapshot(db, accNoModNoCache, accNoModNoCache[:])
+ rawdb.WriteAccountSnapshot(db, accNoModCache, accNoModCache[:])
+ rawdb.WriteAccountSnapshot(db, accModNoCache, accModNoCache[:])
+ rawdb.WriteAccountSnapshot(db, accModCache, accModCache[:])
+ rawdb.WriteAccountSnapshot(db, accDelNoCache, accDelNoCache[:])
+ rawdb.WriteAccountSnapshot(db, accDelCache, accDelCache[:])
+
+ rawdb.WriteAccountSnapshot(db, conNoModNoCache, conNoModNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conNoModCache, conNoModCache[:])
+ rawdb.WriteStorageSnapshot(db, conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conModNoCache, conModNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conModCache, conModCache[:])
+ rawdb.WriteStorageSnapshot(db, conModCache, conModCacheSlot, conModCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conDelNoCache, conDelNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conDelCache, conDelCache[:])
+ rawdb.WriteStorageSnapshot(db, conDelCache, conDelCacheSlot, conDelCacheSlot[:])
+
+ rawdb.WriteAccountSnapshot(db, conNukeNoCache, conNukeNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conNukeCache, conNukeCache[:])
+ rawdb.WriteStorageSnapshot(db, conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:])
+
+ rawdb.WriteSnapshotRoot(db, baseRoot)
+
+ // Create a disk layer based on the above and cache in some data
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ baseRoot: &diskLayer{
+ diskdb: db,
+ cache: fastcache.New(500 * 1024),
+ root: baseRoot,
+ },
+ },
+ }
+ base := snaps.Snapshot(baseRoot)
+ base.AccountRLP(accNoModCache)
+ base.AccountRLP(accModCache)
+ base.AccountRLP(accDelCache)
+ base.Storage(conNoModCache, conNoModCacheSlot)
+ base.Storage(conModCache, conModCacheSlot)
+ base.Storage(conDelCache, conDelCacheSlot)
+ base.Storage(conNukeCache, conNukeCacheSlot)
+
+ // Modify or delete some accounts, flatten everything onto disk
+ if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{
+ accDelNoCache: struct{}{},
+ accDelCache: struct{}{},
+ conNukeNoCache: struct{}{},
+ conNukeCache: struct{}{},
+ }, map[common.Hash][]byte{
+ accModNoCache: reverse(accModNoCache[:]),
+ accModCache: reverse(accModCache[:]),
+ }, map[common.Hash]map[common.Hash][]byte{
+ conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])},
+ conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])},
+ conDelNoCache: {conDelNoCacheSlot: nil},
+ conDelCache: {conDelCacheSlot: nil},
+ }); err != nil {
+ t.Fatalf("failed to update snapshot tree: %v", err)
+ }
+ if err := snaps.Cap(diffRoot, 0); err != nil {
+ t.Fatalf("failed to flatten snapshot tree: %v", err)
+ }
+ // Retrieve all the data through the disk layer and validate it
+ base = snaps.Snapshot(diffRoot)
+ if _, ok := base.(*diskLayer); !ok {
+ t.Fatalf("update not flattend into the disk layer")
+ }
+
+ // assertAccount ensures that an account matches the given blob.
+ assertAccount := func(account common.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.AccountRLP(account)
+ if err != nil {
+ t.Errorf("account access (%x) failed: %v", account, err)
+ } else if !bytes.Equal(blob, data) {
+ t.Errorf("account access (%x) mismatch: have %x, want %x", account, blob, data)
+ }
+ }
+ assertAccount(accNoModNoCache, accNoModNoCache[:])
+ assertAccount(accNoModCache, accNoModCache[:])
+ assertAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertAccount(accModCache, reverse(accModCache[:]))
+ assertAccount(accDelNoCache, nil)
+ assertAccount(accDelCache, nil)
+
+ // assertStorage ensures that a storage slot matches the given blob.
+ assertStorage := func(account common.Hash, slot common.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.Storage(account, slot)
+ if err != nil {
+ t.Errorf("storage access (%x:%x) failed: %v", account, slot, err)
+ } else if !bytes.Equal(blob, data) {
+ t.Errorf("storage access (%x:%x) mismatch: have %x, want %x", account, slot, blob, data)
+ }
+ }
+ assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertStorage(conDelCache, conDelCacheSlot, nil)
+ assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertStorage(conNukeCache, conNukeCacheSlot, nil)
+
+ // Retrieve all the data directly from the database and validate it
+
+ // assertDatabaseAccount ensures that an account from the database matches the given blob.
+ assertDatabaseAccount := func(account common.Hash, data []byte) {
+ t.Helper()
+ if blob := rawdb.ReadAccountSnapshot(db, account); !bytes.Equal(blob, data) {
+ t.Errorf("account database access (%x) mismatch: have %x, want %x", account, blob, data)
+ }
+ }
+ assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:])
+ assertDatabaseAccount(accNoModCache, accNoModCache[:])
+ assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertDatabaseAccount(accModCache, reverse(accModCache[:]))
+ assertDatabaseAccount(accDelNoCache, nil)
+ assertDatabaseAccount(accDelCache, nil)
+
+ // assertDatabaseStorage ensures that a storage slot from the database matches the given blob.
+ assertDatabaseStorage := func(account common.Hash, slot common.Hash, data []byte) {
+ t.Helper()
+ if blob := rawdb.ReadStorageSnapshot(db, account, slot); !bytes.Equal(blob, data) {
+ t.Errorf("storage database access (%x:%x) mismatch: have %x, want %x", account, slot, blob, data)
+ }
+ }
+ assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertDatabaseStorage(conDelCache, conDelCacheSlot, nil)
+ assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil)
+}
+
+// Tests that merging something into a disk layer persists it into the database
+// and invalidates any previously written and cached values, discarding anything
+// after the in-progress generation marker.
+func TestDiskPartialMerge(t *testing.T) {
+ // Iterate the test a few times to ensure we pick various internal orderings
+ // for the data slots as well as the progress marker.
+ for i := 0; i < 1024; i++ {
+ // Create some accounts in the disk layer
+ db := memorydb.New()
+
+ var (
+ accNoModNoCache = randomHash()
+ accNoModCache = randomHash()
+ accModNoCache = randomHash()
+ accModCache = randomHash()
+ accDelNoCache = randomHash()
+ accDelCache = randomHash()
+ conNoModNoCache = randomHash()
+ conNoModNoCacheSlot = randomHash()
+ conNoModCache = randomHash()
+ conNoModCacheSlot = randomHash()
+ conModNoCache = randomHash()
+ conModNoCacheSlot = randomHash()
+ conModCache = randomHash()
+ conModCacheSlot = randomHash()
+ conDelNoCache = randomHash()
+ conDelNoCacheSlot = randomHash()
+ conDelCache = randomHash()
+ conDelCacheSlot = randomHash()
+ conNukeNoCache = randomHash()
+ conNukeNoCacheSlot = randomHash()
+ conNukeCache = randomHash()
+ conNukeCacheSlot = randomHash()
+ baseRoot = randomHash()
+ diffRoot = randomHash()
+ genMarker = append(randomHash().Bytes(), randomHash().Bytes()...)
+ )
+
+ // insertAccount injects an account into the database if it's after the
+ // generator marker, drops the op otherwise. This is needed to seed the
+ // database with a valid starting snapshot.
+ insertAccount := func(account common.Hash, data []byte) {
+ if bytes.Compare(account[:], genMarker) <= 0 {
+ rawdb.WriteAccountSnapshot(db, account, data[:])
+ }
+ }
+ insertAccount(accNoModNoCache, accNoModNoCache[:])
+ insertAccount(accNoModCache, accNoModCache[:])
+ insertAccount(accModNoCache, accModNoCache[:])
+ insertAccount(accModCache, accModCache[:])
+ insertAccount(accDelNoCache, accDelNoCache[:])
+ insertAccount(accDelCache, accDelCache[:])
+
+ // insertStorage injects a storage slot into the database if it's after
+ // the generator marker, drops the op otherwise. This is needed to seed
+ // the database with a valid starting snapshot.
+ insertStorage := func(account common.Hash, slot common.Hash, data []byte) {
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 {
+ rawdb.WriteStorageSnapshot(db, account, slot, data[:])
+ }
+ }
+ insertAccount(conNoModNoCache, conNoModNoCache[:])
+ insertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ insertAccount(conNoModCache, conNoModCache[:])
+ insertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ insertAccount(conModNoCache, conModNoCache[:])
+ insertStorage(conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:])
+ insertAccount(conModCache, conModCache[:])
+ insertStorage(conModCache, conModCacheSlot, conModCacheSlot[:])
+ insertAccount(conDelNoCache, conDelNoCache[:])
+ insertStorage(conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:])
+ insertAccount(conDelCache, conDelCache[:])
+ insertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:])
+
+ insertAccount(conNukeNoCache, conNukeNoCache[:])
+ insertStorage(conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:])
+ insertAccount(conNukeCache, conNukeCache[:])
+ insertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:])
+
+ rawdb.WriteSnapshotRoot(db, baseRoot)
+
+ // Create a disk layer based on the above using a random progress marker
+ // and cache in some data.
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ baseRoot: &diskLayer{
+ diskdb: db,
+ cache: fastcache.New(500 * 1024),
+ root: baseRoot,
+ },
+ },
+ }
+ snaps.layers[baseRoot].(*diskLayer).genMarker = genMarker
+ base := snaps.Snapshot(baseRoot)
+
+ // assertAccount ensures that an account matches the given blob if it's
+ // already covered by the disk snapshot, and errors out otherwise.
+ assertAccount := func(account common.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.AccountRLP(account)
+ if bytes.Compare(account[:], genMarker) > 0 && err != ErrNotCoveredYet {
+ t.Fatalf("test %d: post-marker (%x) account access (%x) succeeded: %x", i, genMarker, account, blob)
+ }
+ if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%x) account access (%x) mismatch: have %x, want %x", i, genMarker, account, blob, data)
+ }
+ }
+ assertAccount(accNoModCache, accNoModCache[:])
+ assertAccount(accModCache, accModCache[:])
+ assertAccount(accDelCache, accDelCache[:])
+
+ // assertStorage ensures that a storage slot matches the given blob if
+ // it's already covered by the disk snapshot, and errors out otherwise.
+ assertStorage := func(account common.Hash, slot common.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.Storage(account, slot)
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && err != ErrNotCoveredYet {
+ t.Fatalf("test %d: post-marker (%x) storage access (%x:%x) succeeded: %x", i, genMarker, account, slot, blob)
+ }
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%x) storage access (%x:%x) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data)
+ }
+ }
+ assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertStorage(conModCache, conModCacheSlot, conModCacheSlot[:])
+ assertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:])
+ assertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:])
+
+ // Modify or delete some accounts, flatten everything onto disk
+ if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{
+ accDelNoCache: struct{}{},
+ accDelCache: struct{}{},
+ conNukeNoCache: struct{}{},
+ conNukeCache: struct{}{},
+ }, map[common.Hash][]byte{
+ accModNoCache: reverse(accModNoCache[:]),
+ accModCache: reverse(accModCache[:]),
+ }, map[common.Hash]map[common.Hash][]byte{
+ conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])},
+ conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])},
+ conDelNoCache: {conDelNoCacheSlot: nil},
+ conDelCache: {conDelCacheSlot: nil},
+ }); err != nil {
+ t.Fatalf("test %d: failed to update snapshot tree: %v", i, err)
+ }
+ if err := snaps.Cap(diffRoot, 0); err != nil {
+ t.Fatalf("test %d: failed to flatten snapshot tree: %v", i, err)
+ }
+ // Retrieve all the data through the disk layer and validate it
+ base = snaps.Snapshot(diffRoot)
+ if _, ok := base.(*diskLayer); !ok {
+ t.Fatalf("test %d: update not flattend into the disk layer", i)
+ }
+ assertAccount(accNoModNoCache, accNoModNoCache[:])
+ assertAccount(accNoModCache, accNoModCache[:])
+ assertAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertAccount(accModCache, reverse(accModCache[:]))
+ assertAccount(accDelNoCache, nil)
+ assertAccount(accDelCache, nil)
+
+ assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertStorage(conDelCache, conDelCacheSlot, nil)
+ assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertStorage(conNukeCache, conNukeCacheSlot, nil)
+
+ // Retrieve all the data directly from the database and validate it
+
+ // assertDatabaseAccount ensures that an account inside the database matches
+ // the given blob if it's already covered by the disk snapshot, and does not
+ // exist otherwise.
+ assertDatabaseAccount := func(account common.Hash, data []byte) {
+ t.Helper()
+ blob := rawdb.ReadAccountSnapshot(db, account)
+ if bytes.Compare(account[:], genMarker) > 0 && blob != nil {
+ t.Fatalf("test %d: post-marker (%x) account database access (%x) succeeded: %x", i, genMarker, account, blob)
+ }
+ if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%x) account database access (%x) mismatch: have %x, want %x", i, genMarker, account, blob, data)
+ }
+ }
+ assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:])
+ assertDatabaseAccount(accNoModCache, accNoModCache[:])
+ assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertDatabaseAccount(accModCache, reverse(accModCache[:]))
+ assertDatabaseAccount(accDelNoCache, nil)
+ assertDatabaseAccount(accDelCache, nil)
+
+ // assertDatabaseStorage ensures that a storage slot inside the database
+ // matches the given blob if it's already covered by the disk snapshot,
+ // and does not exist otherwise.
+ assertDatabaseStorage := func(account common.Hash, slot common.Hash, data []byte) {
+ t.Helper()
+ blob := rawdb.ReadStorageSnapshot(db, account, slot)
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && blob != nil {
+ t.Fatalf("test %d: post-marker (%x) storage database access (%x:%x) succeeded: %x", i, genMarker, account, slot, blob)
+ }
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%x) storage database access (%x:%x) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data)
+ }
+ }
+ assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertDatabaseStorage(conDelCache, conDelCacheSlot, nil)
+ assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil)
+ }
+}
+
+// Tests that merging something into a disk layer persists it into the database
+// and invalidates any previously written and cached values, discarding anything
+// after the in-progress generation marker.
+//
+// This test case is a tiny specialized case of TestDiskPartialMerge, which tests
+// some very specific cornercases that random tests won't ever trigger.
+func TestDiskMidAccountPartialMerge(t *testing.T) {
+}
diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go
new file mode 100644
index 000000000..ea5b59a72
--- /dev/null
+++ b/core/state/snapshot/generate.go
@@ -0,0 +1,286 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math/big"
+ "time"
+
+ "github.com/VictoriaMetrics/fastcache"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/common/math"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/rlp"
+ "github.com/tomochain/tomochain/trie"
+)
+
+var (
+ // emptyRoot is the known root hash of an empty trie.
+ emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+
+ // emptyCode is the known hash of the empty EVM bytecode.
+ emptyCode = crypto.Keccak256Hash(nil)
+)
+
+// generatorStats is a collection of statistics gathered by the snapshot generator
+// for logging purposes.
+type generatorStats struct {
+ origin uint64 // Origin prefix where generation started
+ start time.Time // Timestamp when generation started
+ accounts uint64 // Number of accounts indexed
+ slots uint64 // Number of storage slots indexed
+ storage common.StorageSize // Account and storage slot size
+}
+
+// Log creates an contextual log with the given message and the context pulled
+// from the internally maintained statistics.
+func (gs *generatorStats) Log(msg string, marker []byte) {
+ var ctx []interface{}
+
+ // Figure out whether we're after or within an account
+ switch len(marker) {
+ case common.HashLength:
+ ctx = append(ctx, []interface{}{"at", common.BytesToHash(marker)}...)
+ case 2 * common.HashLength:
+ ctx = append(ctx, []interface{}{
+ "in", common.BytesToHash(marker[:common.HashLength]),
+ "at", common.BytesToHash(marker[common.HashLength:]),
+ }...)
+ }
+ // Add the usual measurements
+ ctx = append(ctx, []interface{}{
+ "accounts", gs.accounts,
+ "slots", gs.slots,
+ "storage", gs.storage,
+ "elapsed", common.PrettyDuration(time.Since(gs.start)),
+ }...)
+ // Calculate the estimated indexing time based on current stats
+ if len(marker) > 0 {
+ if done := binary.BigEndian.Uint64(marker[:8]) - gs.origin; done > 0 {
+ left := math.MaxUint64 - binary.BigEndian.Uint64(marker[:8])
+
+ speed := done/uint64(time.Since(gs.start)/time.Millisecond+1) + 1 // +1s to avoid division by zero
+ ctx = append(ctx, []interface{}{
+ "eta", common.PrettyDuration(time.Duration(left/speed) * time.Millisecond),
+ }...)
+ }
+ }
+ log.Info(msg, ctx...)
+}
+
+// generateSnapshot regenerates a brand new snapshot based on an existing state
+// database and head block asynchronously. The snapshot is returned immediately
+// and generation is continued in the background until done.
+func generateSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash) *diskLayer {
+ // Create a new disk layer with an initialized state marker at zero
+ var (
+ stats = &generatorStats{start: time.Now()}
+ batch = diskdb.NewBatch()
+ genMarker = []byte{} // Initialized but empty!
+ )
+ // Create a new disk layer with an initialized state marker at zero
+ rawdb.WriteSnapshotRoot(diskdb, root)
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to write initialized state marker", "err", err)
+ }
+ base := &diskLayer{
+ diskdb: diskdb,
+ triedb: triedb,
+ root: root,
+ cache: fastcache.New(cache * 1024 * 1024),
+ genMarker: genMarker, // Initialized but empty!
+ genPending: make(chan struct{}),
+ genAbort: make(chan chan *generatorStats),
+ }
+ go base.generate(stats)
+ log.Debug("Start snapshot generation", "root", root)
+ return base
+}
+
+// journalProgress persists the generator stats into the database to resume later.
+func journalProgress(db ethdb.KeyValueWriter, marker []byte, stats *generatorStats) {
+ // Write out the generator marker. Note it's a standalone disk layer generator
+ // which is not mixed with journal. It's ok if the generator is persisted while
+ // journal is not.
+ entry := journalGenerator{
+ Done: marker == nil,
+ Marker: marker,
+ }
+ if stats != nil {
+ entry.Accounts = stats.accounts
+ entry.Slots = stats.slots
+ entry.Storage = uint64(stats.storage)
+ }
+ blob, err := rlp.EncodeToBytes(entry)
+ if err != nil {
+ panic(err) // Cannot happen, here to catch dev errors
+ }
+ var logstr string
+ switch {
+ case marker == nil:
+ logstr = "done"
+ case bytes.Equal(marker, []byte{}):
+ logstr = "empty"
+ case len(marker) == common.HashLength:
+ logstr = fmt.Sprintf("%#x", marker)
+ default:
+ logstr = fmt.Sprintf("%#x:%#x", marker[:common.HashLength], marker[common.HashLength:])
+ }
+ log.Debug("Journalled generator progress", "progress", logstr)
+ rawdb.WriteSnapshotGenerator(db, blob)
+}
+
+// generate is a background thread that iterates over the state and storage tries,
+// constructing the state snapshot. All the arguments are purely for statistics
+// gethering and logging, since the method surfs the blocks as they arrive, often
+// being restarted.
+func (dl *diskLayer) generate(stats *generatorStats) {
+ // Create an account and state iterator pointing to the current generator marker
+ accTrie, err := trie.NewSecure(dl.root, dl.triedb)
+ if err != nil {
+ // The account trie is missing (GC), surf the chain until one becomes available
+ stats.Log("Trie missing, state snapshotting paused", dl.genMarker)
+
+ abort := <-dl.genAbort
+ abort <- stats
+ return
+ }
+ stats.Log("Resuming state snapshot generation", dl.genMarker)
+
+ var accMarker []byte
+ if len(dl.genMarker) > 0 { // []byte{} is the start, use nil for that
+ accMarker = dl.genMarker[:common.HashLength]
+ }
+ accIt := trie.NewIterator(accTrie.NodeIterator(accMarker))
+ batch := dl.diskdb.NewBatch()
+
+ // Iterate from the previous marker and continue generating the state snapshot
+ logged := time.Now()
+ for accIt.Next() {
+ // Retrieve the current account and flatten it into the internal format
+ accountHash := common.BytesToHash(accIt.Key)
+
+ var acc struct {
+ Nonce uint64
+ Balance *big.Int
+ Root common.Hash
+ CodeHash []byte
+ }
+ if err := rlp.DecodeBytes(accIt.Value, &acc); err != nil {
+ log.Crit("Invalid account encountered during snapshot creation", "err", err)
+ }
+ data := types.SlimAccountRLP(acc)
+
+ // If the account is not yet in-progress, write it out
+ if accMarker == nil || !bytes.Equal(accountHash[:], accMarker) {
+ rawdb.WriteAccountSnapshot(batch, accountHash, data)
+ stats.storage += common.StorageSize(1 + common.HashLength + len(data))
+ stats.accounts++
+ }
+ // If we've exceeded our batch allowance or termination was requested, flush to disk
+ var abort chan *generatorStats
+ select {
+ case abort = <-dl.genAbort:
+ default:
+ }
+ if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil {
+ // Only write and set the marker if we actually did something useful
+ if batch.ValueSize() > 0 {
+ batch.Write()
+ batch.Reset()
+
+ dl.lock.Lock()
+ dl.genMarker = accountHash[:]
+ dl.lock.Unlock()
+ }
+ if abort != nil {
+ stats.Log("Aborting state snapshot generation", accountHash[:])
+ abort <- stats
+ return
+ }
+ }
+ // If the account is in-progress, continue where we left off (otherwise iterate all)
+ if acc.Root != emptyRoot {
+ storeTrie, err := trie.NewSecure(acc.Root, dl.triedb)
+ if err != nil {
+ log.Crit("Storage trie inaccessible for snapshot generation", "err", err)
+ }
+ var storeMarker []byte
+ if accMarker != nil && bytes.Equal(accountHash[:], accMarker) && len(dl.genMarker) > common.HashLength {
+ storeMarker = dl.genMarker[common.HashLength:]
+ }
+ storeIt := trie.NewIterator(storeTrie.NodeIterator(storeMarker))
+ for storeIt.Next() {
+ rawdb.WriteStorageSnapshot(batch, accountHash, common.BytesToHash(storeIt.Key), storeIt.Value)
+ stats.storage += common.StorageSize(1 + 2*common.HashLength + len(storeIt.Value))
+ stats.slots++
+
+ // If we've exceeded our batch allowance or termination was requested, flush to disk
+ var abort chan *generatorStats
+ select {
+ case abort = <-dl.genAbort:
+ default:
+ }
+ if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil {
+ // Only write and set the marker if we actually did something useful
+ if batch.ValueSize() > 0 {
+ batch.Write()
+ batch.Reset()
+
+ dl.lock.Lock()
+ dl.genMarker = append(accountHash[:], storeIt.Key...)
+ dl.lock.Unlock()
+ }
+ if abort != nil {
+ stats.Log("Aborting state snapshot generation", append(accountHash[:], storeIt.Key...))
+ abort <- stats
+ return
+ }
+ }
+ }
+ }
+ if time.Since(logged) > 8*time.Second {
+ stats.Log("Generating state snapshot", accIt.Key)
+ logged = time.Now()
+ }
+ // Some account processed, unmark the marker
+ accMarker = nil
+ }
+ // Snapshot fully generated, set the marker to nil
+ if batch.ValueSize() > 0 {
+ batch.Write()
+ }
+ log.Info("Generated state snapshot", "accounts", stats.accounts, "slots", stats.slots,
+ "storage", stats.storage, "elapsed", common.PrettyDuration(time.Since(stats.start)))
+
+ dl.lock.Lock()
+ dl.genMarker = nil
+ close(dl.genPending)
+ dl.lock.Unlock()
+
+ // Someone will be looking for us, wait it out
+ abort := <-dl.genAbort
+ abort <- nil
+}
diff --git a/core/state/snapshot/iterator.go b/core/state/snapshot/iterator.go
new file mode 100644
index 000000000..b62fb30e3
--- /dev/null
+++ b/core/state/snapshot/iterator.go
@@ -0,0 +1,221 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/ethdb"
+)
+
+// Iterator is an iterator to step over all the accounts or the specific
+// storage in a snapshot which may or may not be composed of multiple layers.
+type Iterator interface {
+ // Next steps the iterator forward one element, returning false if exhausted,
+ // or an error if iteration failed for some reason (e.g. root being iterated
+ // becomes stale and garbage collected).
+ Next() bool
+
+ // Error returns any failure that occurred during iteration, which might have
+ // caused a premature iteration exit (e.g. snapshot stack becoming stale).
+ Error() error
+
+ // Hash returns the hash of the account or storage slot the iterator is
+ // currently at.
+ Hash() common.Hash
+
+ // Release releases associated resources. Release should always succeed and
+ // can be called multiple times without causing error.
+ Release()
+}
+
+// AccountIterator is an iterator to step over all the accounts in a snapshot,
+// which may or may not be composed of multiple layers.
+type AccountIterator interface {
+ Iterator
+
+ // Account returns the RLP encoded slim account the iterator is currently at.
+ // An error will be returned if the iterator becomes invalid
+ Account() []byte
+}
+
+// diffAccountIterator is an account iterator that steps over the accounts (both
+// live and deleted) contained within a single diff layer. Higher order iterators
+// will use the deleted accounts to skip deeper iterators.
+type diffAccountIterator struct {
+ // curHash is the current hash the iterator is positioned on. The field is
+ // explicitly tracked since the referenced diff layer might go stale after
+ // the iterator was positioned and we don't want to fail accessing the old
+ // hash as long as the iterator is not touched any more.
+ curHash common.Hash
+
+ layer *diffLayer // Live layer to retrieve values from
+ keys []common.Hash // Keys left in the layer to iterate
+ fail error // Any failures encountered (stale)
+}
+
+// StorageIterator is an iterator to step over the specific storage in a snapshot,
+// which may or may not be composed of multiple layers.
+type StorageIterator interface {
+ Iterator
+
+ // Slot returns the storage slot the iterator is currently at. An error will
+ // be returned if the iterator becomes invalid
+ Slot() []byte
+}
+
+// AccountIterator creates an account iterator over a single diff layer.
+func (dl *diffLayer) AccountIterator(seek common.Hash) AccountIterator {
+ // Seek out the requested starting account
+ hashes := dl.AccountList()
+ index := sort.Search(len(hashes), func(i int) bool {
+ return bytes.Compare(seek[:], hashes[i][:]) < 0
+ })
+ // Assemble and returned the already seeked iterator
+ return &diffAccountIterator{
+ layer: dl,
+ keys: hashes[index:],
+ }
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (it *diffAccountIterator) Next() bool {
+ // If the iterator was already stale, consider it a programmer error. Although
+ // we could just return false here, triggering this path would probably mean
+ // somebody forgot to check for Error, so lets blow up instead of undefined
+ // behavior that's hard to debug.
+ if it.fail != nil {
+ panic(fmt.Sprintf("called Next of failed iterator: %v", it.fail))
+ }
+ // Stop iterating if all keys were exhausted
+ if len(it.keys) == 0 {
+ return false
+ }
+ if it.layer.Stale() {
+ it.fail, it.keys = ErrSnapshotStale, nil
+ return false
+ }
+ // Iterator seems to be still alive, retrieve and cache the live hash
+ it.curHash = it.keys[0]
+ // key cached, shift the iterator and notify the user of success
+ it.keys = it.keys[1:]
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+func (it *diffAccountIterator) Error() error {
+ return it.fail
+}
+
+// Hash returns the hash of the account the iterator is currently at.
+func (it *diffAccountIterator) Hash() common.Hash {
+ return it.curHash
+}
+
+// Account returns the RLP encoded slim account the iterator is currently at.
+// This method may _fail_, if the underlying layer has been flattened between
+// the call to Next and Acccount. That type of error will set it.Err.
+// This method assumes that flattening does not delete elements from
+// the accountdata mapping (writing nil into it is fine though), and will panic
+// if elements have been deleted.
+func (it *diffAccountIterator) Account() []byte {
+ it.layer.lock.RLock()
+ blob, ok := it.layer.accountData[it.curHash]
+ if !ok {
+ if _, ok := it.layer.destructSet[it.curHash]; ok {
+ return nil
+ }
+ panic(fmt.Sprintf("iterator referenced non-existent account: %x", it.curHash))
+ }
+ it.layer.lock.RUnlock()
+ if it.layer.Stale() {
+ it.fail, it.keys = ErrSnapshotStale, nil
+ }
+ return blob
+}
+
+// Release is a noop for diff account iterators as there are no held resources.
+func (it *diffAccountIterator) Release() {}
+
+// diskAccountIterator is an account iterator that steps over the live accounts
+// contained within a disk layer.
+type diskAccountIterator struct {
+ layer *diskLayer
+ it ethdb.Iterator
+}
+
+// AccountIterator creates an account iterator over a disk layer.
+func (dl *diskLayer) AccountIterator(seek common.Hash) AccountIterator {
+ pos := common.TrimRightZeroes(seek[:])
+ return &diskAccountIterator{
+ layer: dl,
+ it: dl.diskdb.NewIterator(rawdb.SnapshotAccountPrefix, pos),
+ }
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (it *diskAccountIterator) Next() bool {
+ // If the iterator was already exhausted, don't bother
+ if it.it == nil {
+ return false
+ }
+ // Try to advance the iterator and release it if we reached the end
+ for {
+ if !it.it.Next() || !bytes.HasPrefix(it.it.Key(), rawdb.SnapshotAccountPrefix) {
+ it.it.Release()
+ it.it = nil
+ return false
+ }
+ if len(it.it.Key()) == len(rawdb.SnapshotAccountPrefix)+common.HashLength {
+ break
+ }
+ }
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+//
+// A diff layer is immutable after creation content wise and can always be fully
+// iterated without error, so this method always returns nil.
+func (it *diskAccountIterator) Error() error {
+ return it.it.Error()
+}
+
+// Hash returns the hash of the account the iterator is currently at.
+func (it *diskAccountIterator) Hash() common.Hash {
+ return common.BytesToHash(it.it.Key())
+}
+
+// Account returns the RLP encoded slim account the iterator is currently at.
+func (it *diskAccountIterator) Account() []byte {
+ return it.it.Value()
+}
+
+// Release releases the database snapshot held during iteration.
+func (it *diskAccountIterator) Release() {
+ // The iterator is auto-released on exhaustion, so make sure it's still alive
+ if it.it != nil {
+ it.it.Release()
+ it.it = nil
+ }
+}
diff --git a/core/state/snapshot/iterator_binary.go b/core/state/snapshot/iterator_binary.go
new file mode 100644
index 000000000..d8df968ea
--- /dev/null
+++ b/core/state/snapshot/iterator_binary.go
@@ -0,0 +1,115 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+
+ "github.com/tomochain/tomochain/common"
+)
+
+// binaryAccountIterator is a simplistic iterator to step over the accounts in
+// a snapshot, which may or may npt be composed of multiple layers. Performance
+// wise this iterator is slow, it's meant for cross validating the fast one,
+type binaryAccountIterator struct {
+ a *diffAccountIterator
+ b AccountIterator
+ aDone bool
+ bDone bool
+ k common.Hash
+ fail error
+}
+
+// newBinaryAccountIterator creates a simplistic account iterator to step over
+// all the accounts in a slow, but eaily verifiable way.
+func (dl *diffLayer) newBinaryAccountIterator() AccountIterator {
+ parent, ok := dl.parent.(*diffLayer)
+ if !ok {
+ // parent is the disk layer
+ return dl.AccountIterator(common.Hash{})
+ }
+ l := &binaryAccountIterator{
+ a: dl.AccountIterator(common.Hash{}).(*diffAccountIterator),
+ b: parent.newBinaryAccountIterator(),
+ }
+ l.aDone = !l.a.Next()
+ l.bDone = !l.b.Next()
+ return l
+}
+
+// Next steps the iterator forward one element, returning false if exhausted,
+// or an error if iteration failed for some reason (e.g. root being iterated
+// becomes stale and garbage collected).
+func (it *binaryAccountIterator) Next() bool {
+ if it.aDone && it.bDone {
+ return false
+ }
+ nextB := it.b.Hash()
+first:
+ nextA := it.a.Hash()
+ if it.aDone {
+ it.bDone = !it.b.Next()
+ it.k = nextB
+ return true
+ }
+ if it.bDone {
+ it.aDone = !it.a.Next()
+ it.k = nextA
+ return true
+ }
+ if diff := bytes.Compare(nextA[:], nextB[:]); diff < 0 {
+ it.aDone = !it.a.Next()
+ it.k = nextA
+ return true
+ } else if diff == 0 {
+ // Now we need to advance one of them
+ it.aDone = !it.a.Next()
+ goto first
+ }
+ it.bDone = !it.b.Next()
+ it.k = nextB
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+func (it *binaryAccountIterator) Error() error {
+ return it.fail
+}
+
+// Hash returns the hash of the account the iterator is currently at.
+func (it *binaryAccountIterator) Hash() common.Hash {
+ return it.k
+}
+
+// Account returns the RLP encoded slim account the iterator is currently at, or
+// nil if the iterated snapshot stack became stale (you can check Error after
+// to see if it failed or not).
+func (it *binaryAccountIterator) Account() []byte {
+ blob, err := it.a.layer.AccountRLP(it.k)
+ if err != nil {
+ it.fail = err
+ return nil
+ }
+ return blob
+}
+
+// Release recursively releases all the iterators in the stack.
+func (it *binaryAccountIterator) Release() {
+ it.a.Release()
+ it.b.Release()
+}
diff --git a/core/state/snapshot/iterator_fast.go b/core/state/snapshot/iterator_fast.go
new file mode 100644
index 000000000..afbe70c2b
--- /dev/null
+++ b/core/state/snapshot/iterator_fast.go
@@ -0,0 +1,302 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+
+ "github.com/tomochain/tomochain/common"
+)
+
+// weightedAccountIterator is an account iterator with an assigned weight. It is
+// used to prioritise which account is the correct one if multiple iterators find
+// the same one (modified in multiple consecutive blocks).
+type weightedAccountIterator struct {
+ it AccountIterator
+ priority int
+}
+
+// weightedAccountIterators is a set of iterators implementing the sort.Interface.
+type weightedAccountIterators []*weightedAccountIterator
+
+// Len implements sort.Interface, returning the number of active iterators.
+func (its weightedAccountIterators) Len() int { return len(its) }
+
+// Less implements sort.Interface, returning which of two iterators in the stack
+// is before the other.
+func (its weightedAccountIterators) Less(i, j int) bool {
+ // Order the iterators primarily by the account hashes
+ hashI := its[i].it.Hash()
+ hashJ := its[j].it.Hash()
+
+ switch bytes.Compare(hashI[:], hashJ[:]) {
+ case -1:
+ return true
+ case 1:
+ return false
+ }
+ // Same account in multiple layers, split by priority
+ return its[i].priority < its[j].priority
+}
+
+// Swap implements sort.Interface, swapping two entries in the iterator stack.
+func (its weightedAccountIterators) Swap(i, j int) {
+ its[i], its[j] = its[j], its[i]
+}
+
+// fastAccountIterator is a more optimized multi-layer iterator which maintains a
+// direct mapping of all iterators leading down to the bottom layer.
+type fastAccountIterator struct {
+ tree *Tree // Snapshot tree to reinitialize stale sub-iterators with
+ root common.Hash // Root hash to reinitialize stale sub-iterators through
+ curAccount []byte
+
+ iterators weightedAccountIterators
+ initiated bool
+ fail error
+}
+
+// newFastAccountIterator creates a new hierarhical account iterator with one
+// element per diff layer. The returned combo iterator can be used to walk over
+// the entire snapshot diff stack simultaneously.
+func newFastAccountIterator(tree *Tree, root common.Hash, seek common.Hash) (AccountIterator, error) {
+ snap := tree.Snapshot(root)
+ if snap == nil {
+ return nil, fmt.Errorf("unknown snapshot: %x", root)
+ }
+ fi := &fastAccountIterator{
+ tree: tree,
+ root: root,
+ }
+ current := snap.(snapshot)
+ for depth := 0; current != nil; depth++ {
+ fi.iterators = append(fi.iterators, &weightedAccountIterator{
+ it: current.AccountIterator(seek),
+ priority: depth,
+ })
+ current = current.Parent()
+ }
+ fi.init()
+ return fi, nil
+}
+
+// init walks over all the iterators and resolves any clashes between them, after
+// which it prepares the stack for step-by-step iteration.
+func (fi *fastAccountIterator) init() {
+ // Track which account hashes are iterators positioned on
+ var positioned = make(map[common.Hash]int)
+
+ // Position all iterators and track how many remain live
+ for i := 0; i < len(fi.iterators); i++ {
+ // Retrieve the first element and if it clashes with a previous iterator,
+ // advance either the current one or the old one. Repeat until nothing is
+ // clashing any more.
+ it := fi.iterators[i]
+ for {
+ // If the iterator is exhausted, drop it off the end
+ if !it.it.Next() {
+ it.it.Release()
+ last := len(fi.iterators) - 1
+
+ fi.iterators[i] = fi.iterators[last]
+ fi.iterators[last] = nil
+ fi.iterators = fi.iterators[:last]
+
+ i--
+ break
+ }
+ // The iterator is still alive, check for collisions with previous ones
+ hash := it.it.Hash()
+ if other, exist := positioned[hash]; !exist {
+ positioned[hash] = i
+ break
+ } else {
+ // Iterators collide, one needs to be progressed, use priority to
+ // determine which.
+ //
+ // This whole else-block can be avoided, if we instead
+ // do an initial priority-sort of the iterators. If we do that,
+ // then we'll only wind up here if a lower-priority (preferred) iterator
+ // has the same value, and then we will always just continue.
+ // However, it costs an extra sort, so it's probably not better
+ if fi.iterators[other].priority < it.priority {
+ // The 'it' should be progressed
+ continue
+ } else {
+ // The 'other' should be progressed, swap them
+ it = fi.iterators[other]
+ fi.iterators[other], fi.iterators[i] = fi.iterators[i], fi.iterators[other]
+ continue
+ }
+ }
+ }
+ }
+ // Re-sort the entire list
+ sort.Sort(fi.iterators)
+ fi.initiated = false
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (fi *fastAccountIterator) Next() bool {
+ if len(fi.iterators) == 0 {
+ return false
+ }
+ if !fi.initiated {
+ // Don't forward first time -- we had to 'Next' once in order to
+ // do the sorting already
+ fi.initiated = true
+ fi.curAccount = fi.iterators[0].it.Account()
+ if innerErr := fi.iterators[0].it.Error(); innerErr != nil {
+ fi.fail = innerErr
+ return false
+ }
+ if fi.curAccount != nil {
+ return true
+ }
+ // Implicit else: we've hit a nil-account, and need to fall through to the
+ // loop below to land on something non-nil
+ }
+ // If an account is deleted in one of the layers, the key will still be there,
+ // but the actual value will be nil. However, the iterator should not
+ // export nil-values (but instead simply omit the key), so we need to loop
+ // here until we either
+ // - get a non-nil value,
+ // - hit an error,
+ // - or exhaust the iterator
+ for {
+ if !fi.next(0) {
+ return false // exhausted
+ }
+ fi.curAccount = fi.iterators[0].it.Account()
+ if innerErr := fi.iterators[0].it.Error(); innerErr != nil {
+ fi.fail = innerErr
+ return false // error
+ }
+ if fi.curAccount != nil {
+ break // non-nil value found
+ }
+ }
+ return true
+}
+
+// next handles the next operation internally and should be invoked when we know
+// that two elements in the list may have the same value.
+//
+// For example, if the iterated hashes become [2,3,5,5,8,9,10], then we should
+// invoke next(3), which will call Next on elem 3 (the second '5') and will
+// cascade along the list, applying the same operation if needed.
+func (fi *fastAccountIterator) next(idx int) bool {
+ // If this particular iterator got exhausted, remove it and return true (the
+ // next one is surely not exhausted yet, otherwise it would have been removed
+ // already).
+ if it := fi.iterators[idx].it; !it.Next() {
+ it.Release()
+
+ fi.iterators = append(fi.iterators[:idx], fi.iterators[idx+1:]...)
+ return len(fi.iterators) > 0
+ }
+ // If there's noone left to cascade into, return
+ if idx == len(fi.iterators)-1 {
+ return true
+ }
+ // We next-ed the iterator at 'idx', now we may have to re-sort that element
+ var (
+ cur, next = fi.iterators[idx], fi.iterators[idx+1]
+ curHash, nextHash = cur.it.Hash(), next.it.Hash()
+ )
+ if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 {
+ // It is still in correct place
+ return true
+ } else if diff == 0 && cur.priority < next.priority {
+ // So still in correct place, but we need to iterate on the next
+ fi.next(idx + 1)
+ return true
+ }
+ // At this point, the iterator is in the wrong location, but the remaining
+ // list is sorted. Find out where to move the item.
+ clash := -1
+ index := sort.Search(len(fi.iterators), func(n int) bool {
+ // The iterator always advances forward, so anything before the old slot
+ // is known to be behind us, so just skip them altogether. This actually
+ // is an important clause since the sort order got invalidated.
+ if n < idx {
+ return false
+ }
+ if n == len(fi.iterators)-1 {
+ // Can always place an elem last
+ return true
+ }
+ nextHash := fi.iterators[n+1].it.Hash()
+ if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 {
+ return true
+ } else if diff > 0 {
+ return false
+ }
+ // The elem we're placing it next to has the same value,
+ // so whichever winds up on n+1 will need further iteraton
+ clash = n + 1
+
+ return cur.priority < fi.iterators[n+1].priority
+ })
+ fi.move(idx, index)
+ if clash != -1 {
+ fi.next(clash)
+ }
+ return true
+}
+
+// move advances an iterator to another position in the list.
+func (fi *fastAccountIterator) move(index, newpos int) {
+ elem := fi.iterators[index]
+ copy(fi.iterators[index:], fi.iterators[index+1:newpos+1])
+ fi.iterators[newpos] = elem
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+func (fi *fastAccountIterator) Error() error {
+ return fi.fail
+}
+
+// Hash returns the current key
+func (fi *fastAccountIterator) Hash() common.Hash {
+ return fi.iterators[0].it.Hash()
+}
+
+// Account returns the current key
+func (fi *fastAccountIterator) Account() []byte {
+ return fi.curAccount
+}
+
+// Release iterates over all the remaining live layer iterators and releases each
+// of thme individually.
+func (fi *fastAccountIterator) Release() {
+ for _, it := range fi.iterators {
+ it.it.Release()
+ }
+ fi.iterators = nil
+}
+
+// Debug is a convencience helper during testing
+func (fi *fastAccountIterator) Debug() {
+ for _, it := range fi.iterators {
+ fmt.Printf("[p=%v v=%v] ", it.priority, it.it.Hash()[0])
+ }
+ fmt.Println()
+}
diff --git a/core/state/snapshot/iterator_test.go b/core/state/snapshot/iterator_test.go
new file mode 100644
index 000000000..613bd9955
--- /dev/null
+++ b/core/state/snapshot/iterator_test.go
@@ -0,0 +1,658 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+)
+
+// TestAccountIteratorBasics tests some simple single-layer iteration
+func TestAccountIteratorBasics(t *testing.T) {
+ var (
+ destructs = make(map[common.Hash]struct{})
+ accounts = make(map[common.Hash][]byte)
+ storage = make(map[common.Hash]map[common.Hash][]byte)
+ )
+ // Fill up a parent
+ for i := 0; i < 100; i++ {
+ h := randomHash()
+ data := randomAccount()
+
+ accounts[h] = data
+ if rand.Intn(4) == 0 {
+ destructs[h] = struct{}{}
+ }
+ if rand.Intn(2) == 0 {
+ accStorage := make(map[common.Hash][]byte)
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ storage[h] = accStorage
+ }
+ }
+ // Add some (identical) layers on top
+ parent := newDiffLayer(emptyLayer(), common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage))
+ it := parent.AccountIterator(common.Hash{})
+ verifyIterator(t, 100, it)
+}
+
+type testIterator struct {
+ values []byte
+}
+
+func newTestIterator(values ...byte) *testIterator {
+ return &testIterator{values}
+}
+
+func (ti *testIterator) Seek(common.Hash) {
+ panic("implement me")
+}
+
+func (ti *testIterator) Next() bool {
+ ti.values = ti.values[1:]
+ return len(ti.values) > 0
+}
+
+func (ti *testIterator) Error() error {
+ return nil
+}
+
+func (ti *testIterator) Hash() common.Hash {
+ return common.BytesToHash([]byte{ti.values[0]})
+}
+
+func (ti *testIterator) Account() []byte {
+ return nil
+}
+
+func (ti *testIterator) Release() {}
+
+func TestFastIteratorBasics(t *testing.T) {
+ type testCase struct {
+ lists [][]byte
+ expKeys []byte
+ }
+ for i, tc := range []testCase{
+ {lists: [][]byte{{0, 1, 8}, {1, 2, 8}, {2, 9}, {4},
+ {7, 14, 15}, {9, 13, 15, 16}},
+ expKeys: []byte{0, 1, 2, 4, 7, 8, 9, 13, 14, 15, 16}},
+ {lists: [][]byte{{0, 8}, {1, 2, 8}, {7, 14, 15}, {8, 9},
+ {9, 10}, {10, 13, 15, 16}},
+ expKeys: []byte{0, 1, 2, 7, 8, 9, 10, 13, 14, 15, 16}},
+ } {
+ var iterators []*weightedAccountIterator
+ for i, data := range tc.lists {
+ it := newTestIterator(data...)
+ iterators = append(iterators, &weightedAccountIterator{it, i})
+
+ }
+ fi := &fastAccountIterator{
+ iterators: iterators,
+ initiated: false,
+ }
+ count := 0
+ for fi.Next() {
+ if got, exp := fi.Hash()[31], tc.expKeys[count]; exp != got {
+ t.Errorf("tc %d, [%d]: got %d exp %d", i, count, got, exp)
+ }
+ count++
+ }
+ }
+}
+
+func verifyIterator(t *testing.T, expCount int, it AccountIterator) {
+ t.Helper()
+
+ var (
+ count = 0
+ last = common.Hash{}
+ )
+ for it.Next() {
+ hash := it.Hash()
+ if bytes.Compare(last[:], hash[:]) >= 0 {
+ t.Errorf("wrong order: %x >= %x", last, hash)
+ }
+ if it.Account() == nil {
+ t.Errorf("iterator returned nil-value for hash %x", hash)
+ }
+ count++
+ }
+ if count != expCount {
+ t.Errorf("iterator count mismatch: have %d, want %d", count, expCount)
+ }
+ if err := it.Error(); err != nil {
+ t.Errorf("iterator failed: %v", err)
+ }
+}
+
+// TestAccountIteratorTraversal tests some simple multi-layer iteration.
+func TestAccountIteratorTraversal(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Stack three diff layers on top with various overlaps
+ snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+ randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
+
+ snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+ randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
+
+ snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+ randomAccountSet("0xcc", "0xf0", "0xff"), nil)
+
+ // Verify the single and multi-layer iterators
+ head := snaps.Snapshot(common.HexToHash("0x04"))
+
+ verifyIterator(t, 3, head.(snapshot).AccountIterator(common.Hash{}))
+ verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator())
+
+ it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
+ defer it.Release()
+
+ verifyIterator(t, 7, it)
+}
+
+// TestAccountIteratorTraversalValues tests some multi-layer iteration, where we
+// also expect the correct values to show up.
+func TestAccountIteratorTraversalValues(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Create a batch of account sets to seed subsequent layers with
+ var (
+ a = make(map[common.Hash][]byte)
+ b = make(map[common.Hash][]byte)
+ c = make(map[common.Hash][]byte)
+ d = make(map[common.Hash][]byte)
+ e = make(map[common.Hash][]byte)
+ f = make(map[common.Hash][]byte)
+ g = make(map[common.Hash][]byte)
+ h = make(map[common.Hash][]byte)
+ )
+ for i := byte(2); i < 0xff; i++ {
+ a[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 0, i))
+ if i > 20 && i%2 == 0 {
+ b[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 1, i))
+ }
+ if i%4 == 0 {
+ c[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 2, i))
+ }
+ if i%7 == 0 {
+ d[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 3, i))
+ }
+ if i%8 == 0 {
+ e[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 4, i))
+ }
+ if i > 50 || i < 85 {
+ f[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 5, i))
+ }
+ if i%64 == 0 {
+ g[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 6, i))
+ }
+ if i%128 == 0 {
+ h[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 7, i))
+ }
+ }
+ // Assemble a stack of snapshots from the account layers
+ snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, a, nil)
+ snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, b, nil)
+ snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, c, nil)
+ snaps.Update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil, d, nil)
+ snaps.Update(common.HexToHash("0x06"), common.HexToHash("0x05"), nil, e, nil)
+ snaps.Update(common.HexToHash("0x07"), common.HexToHash("0x06"), nil, f, nil)
+ snaps.Update(common.HexToHash("0x08"), common.HexToHash("0x07"), nil, g, nil)
+ snaps.Update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, h, nil)
+
+ it, _ := snaps.AccountIterator(common.HexToHash("0x09"), common.Hash{})
+ defer it.Release()
+
+ head := snaps.Snapshot(common.HexToHash("0x09"))
+ for it.Next() {
+ hash := it.Hash()
+ want, err := head.AccountRLP(hash)
+ if err != nil {
+ t.Fatalf("failed to retrieve expected account: %v", err)
+ }
+ if have := it.Account(); !bytes.Equal(want, have) {
+ t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want)
+ }
+ }
+}
+
+// This testcase is notorious, all layers contain the exact same 200 accounts.
+func TestAccountIteratorLargeTraversal(t *testing.T) {
+ // Create a custom account factory to recreate the same addresses
+ makeAccounts := func(num int) map[common.Hash][]byte {
+ accounts := make(map[common.Hash][]byte)
+ for i := 0; i < num; i++ {
+ h := common.Hash{}
+ binary.BigEndian.PutUint64(h[:], uint64(i+1))
+ accounts[h] = randomAccount()
+ }
+ return accounts
+ }
+ // Build up a large stack of snapshots
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ for i := 1; i < 128; i++ {
+ snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil)
+ }
+ // Iterate the entire stack and ensure everything is hit only once
+ head := snaps.Snapshot(common.HexToHash("0x80"))
+ verifyIterator(t, 200, head.(snapshot).AccountIterator(common.Hash{}))
+ verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator())
+
+ it, _ := snaps.AccountIterator(common.HexToHash("0x80"), common.Hash{})
+ defer it.Release()
+
+ verifyIterator(t, 200, it)
+}
+
+// TestAccountIteratorFlattening tests what happens when we
+// - have a live iterator on child C (parent C1 -> C2 .. CN)
+// - flattens C2 all the way into CN
+// - continues iterating
+func TestAccountIteratorFlattening(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Create a stack of diffs on top
+ snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+ randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
+
+ snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+ randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
+
+ snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+ randomAccountSet("0xcc", "0xf0", "0xff"), nil)
+
+ // Create an iterator and flatten the data from underneath it
+ it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
+ defer it.Release()
+
+ if err := snaps.Cap(common.HexToHash("0x04"), 1); err != nil {
+ t.Fatalf("failed to flatten snapshot stack: %v", err)
+ }
+ //verifyIterator(t, 7, it)
+}
+
+func TestAccountIteratorSeek(t *testing.T) {
+ // Create a snapshot stack with some initial data
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+ randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
+
+ snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+ randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
+
+ snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+ randomAccountSet("0xcc", "0xf0", "0xff"), nil)
+
+ // Construct various iterators and ensure their tranversal is correct
+ it, _ := snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xdd"))
+ defer it.Release()
+ verifyIterator(t, 3, it) // expected: ee, f0, ff
+
+ it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xaa"))
+ defer it.Release()
+ verifyIterator(t, 3, it) // expected: ee, f0, ff
+
+ it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xff"))
+ defer it.Release()
+ verifyIterator(t, 0, it) // expected: nothing
+
+ it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xbb"))
+ defer it.Release()
+ verifyIterator(t, 5, it) // expected: cc, dd, ee, f0, ff
+
+ it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xef"))
+ defer it.Release()
+ verifyIterator(t, 2, it) // expected: f0, ff
+
+ it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xf0"))
+ defer it.Release()
+ verifyIterator(t, 1, it) // expected: ff
+
+ it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xff"))
+ defer it.Release()
+ verifyIterator(t, 0, it) // expected: nothing
+}
+
+// TestIteratorDeletions tests that the iterator behaves correct when there are
+// deleted accounts (where the Account() value is nil). The iterator
+// should not output any accounts or nil-values for those cases.
+func TestIteratorDeletions(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Stack three diff layers on top with various overlaps
+ snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"),
+ nil, randomAccountSet("0x11", "0x22", "0x33"), nil)
+
+ deleted := common.HexToHash("0x22")
+ destructed := map[common.Hash]struct{}{
+ deleted: struct{}{},
+ }
+ snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"),
+ destructed, randomAccountSet("0x11", "0x33"), nil)
+
+ snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"),
+ nil, randomAccountSet("0x33", "0x44", "0x55"), nil)
+
+ // The output should be 11,33,44,55
+ it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
+ // Do a quick check
+ verifyIterator(t, 4, it)
+ it.Release()
+
+ // And a more detailed verification that we indeed do not see '0x22'
+ it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
+ defer it.Release()
+ for it.Next() {
+ hash := it.Hash()
+ if it.Account() == nil {
+ t.Errorf("iterator returned nil-value for hash %x", hash)
+ }
+ if hash == deleted {
+ t.Errorf("expected deleted elem %x to not be returned by iterator", deleted)
+ }
+ }
+}
+
+// BenchmarkAccountIteratorTraversal is a bit a bit notorious -- all layers contain the
+// exact same 200 accounts. That means that we need to process 2000 items, but
+// only spit out 200 values eventually.
+//
+// The value-fetching benchmark is easy on the binary iterator, since it never has to reach
+// down at any depth for retrieving the values -- all are on the toppmost layer
+//
+// BenchmarkAccountIteratorTraversal/binary_iterator_keys-6 2239 483674 ns/op
+// BenchmarkAccountIteratorTraversal/binary_iterator_values-6 2403 501810 ns/op
+// BenchmarkAccountIteratorTraversal/fast_iterator_keys-6 1923 677966 ns/op
+// BenchmarkAccountIteratorTraversal/fast_iterator_values-6 1741 649967 ns/op
+func BenchmarkAccountIteratorTraversal(b *testing.B) {
+ // Create a custom account factory to recreate the same addresses
+ makeAccounts := func(num int) map[common.Hash][]byte {
+ accounts := make(map[common.Hash][]byte)
+ for i := 0; i < num; i++ {
+ h := common.Hash{}
+ binary.BigEndian.PutUint64(h[:], uint64(i+1))
+ accounts[h] = randomAccount()
+ }
+ return accounts
+ }
+ // Build up a large stack of snapshots
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ for i := 1; i <= 100; i++ {
+ snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil)
+ }
+ // We call this once before the benchmark, so the creation of
+ // sorted accountlists are not included in the results.
+ head := snaps.Snapshot(common.HexToHash("0x65"))
+ head.(*diffLayer).newBinaryAccountIterator()
+
+ b.Run("binary iterator keys", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("binary iterator values", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ head.(*diffLayer).accountRLP(it.Hash(), 0)
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator keys", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ got++
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator values", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ got++
+ it.Account()
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+}
+
+// BenchmarkAccountIteratorLargeBaselayer is a pretty realistic benchmark, where
+// the baselayer is a lot larger than the upper layer.
+//
+// This is heavy on the binary iterator, which in most cases will have to
+// call recursively 100 times for the majority of the values
+//
+// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(keys)-6 514 1971999 ns/op
+// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(values)-6 61 18997492 ns/op
+// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(keys)-6 10000 114385 ns/op
+// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(values)-6 4047 296823 ns/op
+func BenchmarkAccountIteratorLargeBaselayer(b *testing.B) {
+ // Create a custom account factory to recreate the same addresses
+ makeAccounts := func(num int) map[common.Hash][]byte {
+ accounts := make(map[common.Hash][]byte)
+ for i := 0; i < num; i++ {
+ h := common.Hash{}
+ binary.BigEndian.PutUint64(h[:], uint64(i+1))
+ accounts[h] = randomAccount()
+ }
+ return accounts
+ }
+ // Build up a large stack of snapshots
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, makeAccounts(2000), nil)
+ for i := 2; i <= 100; i++ {
+ snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(20), nil)
+ }
+ // We call this once before the benchmark, so the creation of
+ // sorted accountlists are not included in the results.
+ head := snaps.Snapshot(common.HexToHash("0x65"))
+ head.(*diffLayer).newBinaryAccountIterator()
+
+ b.Run("binary iterator (keys)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("binary iterator (values)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ v := it.Hash()
+ head.(*diffLayer).accountRLP(v, 0)
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator (keys)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ got++
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator (values)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ it.Account()
+ got++
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+}
+
+/*
+func BenchmarkBinaryAccountIteration(b *testing.B) {
+ benchmarkAccountIteration(b, func(snap snapshot) AccountIterator {
+ return snap.(*diffLayer).newBinaryAccountIterator()
+ })
+}
+func BenchmarkFastAccountIteration(b *testing.B) {
+ benchmarkAccountIteration(b, newFastAccountIterator)
+}
+func benchmarkAccountIteration(b *testing.B, iterator func(snap snapshot) AccountIterator) {
+ // Create a diff stack and randomize the accounts across them
+ layers := make([]map[common.Hash][]byte, 128)
+ for i := 0; i < len(layers); i++ {
+ layers[i] = make(map[common.Hash][]byte)
+ }
+ for i := 0; i < b.N; i++ {
+ depth := rand.Intn(len(layers))
+ layers[depth][randomHash()] = randomAccount()
+ }
+ stack := snapshot(emptyLayer())
+ for _, layer := range layers {
+ stack = stack.Update(common.Hash{}, layer, nil, nil)
+ }
+ // Reset the timers and report all the stats
+ it := iterator(stack)
+ b.ResetTimer()
+ b.ReportAllocs()
+ for it.Next() {
+ }
+}
+*/
diff --git a/core/state/snapshot/journal.go b/core/state/snapshot/journal.go
new file mode 100644
index 000000000..0c0e3a960
--- /dev/null
+++ b/core/state/snapshot/journal.go
@@ -0,0 +1,243 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/rlp"
+ "github.com/tomochain/tomochain/trie"
+)
+
+// journalGenerator is a disk layer entry containing the generator progress marker.
+type journalGenerator struct {
+ Wiping bool // Whether the database was in progress of being wiped
+ Done bool // Whether the generator finished creating the snapshot
+ Marker []byte
+ Accounts uint64
+ Slots uint64
+ Storage uint64
+}
+
+// journalDestruct is an account deletion entry in a diffLayer's disk journal.
+type journalDestruct struct {
+ Hash common.Hash
+}
+
+// journalAccount is an account entry in a diffLayer's disk journal.
+type journalAccount struct {
+ Hash common.Hash
+ Blob []byte
+}
+
+// journalStorage is an account's storage map in a diffLayer's disk journal.
+type journalStorage struct {
+ Hash common.Hash
+ Keys []common.Hash
+ Vals [][]byte
+}
+
+// loadSnapshot loads a pre-existing state snapshot backed by a key-value store.
+func loadSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash) (snapshot, error) {
+ // Retrieve the block number and hash of the snapshot, failing if no snapshot
+ // is present in the database (or crashed mid-update).
+ baseRoot := rawdb.ReadSnapshotRoot(diskdb)
+ if baseRoot == (common.Hash{}) {
+ return nil, errors.New("missing or corrupted snapshot")
+ }
+ base := &diskLayer{
+ diskdb: diskdb,
+ triedb: triedb,
+ cache: fastcache.New(cache * 1024 * 1024),
+ root: baseRoot,
+ }
+ // Retrieve the journal, it must exist since even for 0 layer it stores whether
+ // we've already generated the snapshot or are in progress only
+ journal := rawdb.ReadSnapshotJournal(diskdb)
+ if len(journal) == 0 {
+ return nil, errors.New("missing or corrupted snapshot journal")
+ }
+ r := rlp.NewStream(bytes.NewReader(journal), 0)
+
+ // Read the snapshot generation progress for the disk layer
+ var generator journalGenerator
+ if err := r.Decode(&generator); err != nil {
+ return nil, fmt.Errorf("failed to load snapshot progress marker: %v", err)
+ }
+ // Load all the snapshot diffs from the journal
+ snapshot, err := loadDiffLayer(base, r)
+ if err != nil {
+ return nil, err
+ }
+ // Entire snapshot journal loaded, sanity check the head and return
+ // Journal doesn't exist, don't worry if it's not supposed to
+ if head := snapshot.Root(); head != root {
+ return nil, fmt.Errorf("head doesn't match snapshot: have %#x, want %#x", head, root)
+ }
+ // Everything loaded correctly, resume any suspended operations
+ if !generator.Done {
+ // Whether or not wiping was in progress, load any generator progress too
+ base.genMarker = generator.Marker
+ if base.genMarker == nil {
+ base.genMarker = []byte{}
+ }
+ base.genPending = make(chan struct{})
+ base.genAbort = make(chan chan *generatorStats)
+
+ var origin uint64
+ if len(generator.Marker) >= 8 {
+ origin = binary.BigEndian.Uint64(generator.Marker)
+ }
+ go base.generate(&generatorStats{
+ origin: origin,
+ start: time.Now(),
+ accounts: generator.Accounts,
+ slots: generator.Slots,
+ storage: common.StorageSize(generator.Storage),
+ })
+ }
+ return snapshot, nil
+}
+
+// loadDiffLayer reads the next sections of a snapshot journal, reconstructing a new
+// diff and verifying that it can be linked to the requested parent.
+func loadDiffLayer(parent snapshot, r *rlp.Stream) (snapshot, error) {
+ // Read the next diff journal entry
+ var root common.Hash
+ if err := r.Decode(&root); err != nil {
+ // The first read may fail with EOF, marking the end of the journal
+ if err == io.EOF {
+ return parent, nil
+ }
+ return nil, fmt.Errorf("load diff root: %v", err)
+ }
+ var destructs []journalDestruct
+ if err := r.Decode(&destructs); err != nil {
+ return nil, fmt.Errorf("load diff destructs: %v", err)
+ }
+ destructSet := make(map[common.Hash]struct{})
+ for _, entry := range destructs {
+ destructSet[entry.Hash] = struct{}{}
+ }
+ var accounts []journalAccount
+ if err := r.Decode(&accounts); err != nil {
+ return nil, fmt.Errorf("load diff accounts: %v", err)
+ }
+ accountData := make(map[common.Hash][]byte)
+ for _, entry := range accounts {
+ accountData[entry.Hash] = entry.Blob
+ }
+ var storage []journalStorage
+ if err := r.Decode(&storage); err != nil {
+ return nil, fmt.Errorf("load diff storage: %v", err)
+ }
+ storageData := make(map[common.Hash]map[common.Hash][]byte)
+ for _, entry := range storage {
+ slots := make(map[common.Hash][]byte)
+ for i, key := range entry.Keys {
+ slots[key] = entry.Vals[i]
+ }
+ storageData[entry.Hash] = slots
+ }
+ return loadDiffLayer(newDiffLayer(parent, root, destructSet, accountData, storageData), r)
+}
+
+// Journal writes the persistent layer generator stats into a buffer to be stored
+// in the database as the snapshot journal.
+func (dl *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) {
+ // If the snapshot is currently being generated, abort it
+ var stats *generatorStats
+ if dl.genAbort != nil {
+ abort := make(chan *generatorStats)
+ dl.genAbort <- abort
+
+ if stats = <-abort; stats != nil {
+ stats.Log("Journalling in-progress snapshot", dl.genMarker)
+ }
+ }
+ // Ensure the layer didn't get stale
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if dl.stale {
+ return common.Hash{}, ErrSnapshotStale
+ }
+ // Ensure the generator stats is written even if none was ran this cycle
+ journalProgress(dl.diskdb, dl.genMarker, stats)
+
+ log.Debug("Journalled disk layer", "root", dl.root)
+ return dl.root, nil
+}
+
+// Journal writes the memory layer contents into a buffer to be stored in the
+// database as the snapshot journal.
+func (dl *diffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) {
+ // Journal the parent first
+ base, err := dl.parent.Journal(buffer)
+ if err != nil {
+ return common.Hash{}, err
+ }
+ // Ensure the layer didn't get stale
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if dl.Stale() {
+ return common.Hash{}, ErrSnapshotStale
+ }
+ // Everything below was journalled, persist this layer too
+ if err := rlp.Encode(buffer, dl.root); err != nil {
+ return common.Hash{}, err
+ }
+ destructs := make([]journalDestruct, 0, len(dl.destructSet))
+ for hash := range dl.destructSet {
+ destructs = append(destructs, journalDestruct{Hash: hash})
+ }
+ if err := rlp.Encode(buffer, destructs); err != nil {
+ return common.Hash{}, err
+ }
+ accounts := make([]journalAccount, 0, len(dl.accountData))
+ for hash, blob := range dl.accountData {
+ accounts = append(accounts, journalAccount{Hash: hash, Blob: blob})
+ }
+ if err := rlp.Encode(buffer, accounts); err != nil {
+ return common.Hash{}, err
+ }
+ storage := make([]journalStorage, 0, len(dl.storageData))
+ for hash, slots := range dl.storageData {
+ keys := make([]common.Hash, 0, len(slots))
+ vals := make([][]byte, 0, len(slots))
+ for key, val := range slots {
+ keys = append(keys, key)
+ vals = append(vals, val)
+ }
+ storage = append(storage, journalStorage{Hash: hash, Keys: keys, Vals: vals})
+ }
+ if err := rlp.Encode(buffer, storage); err != nil {
+ return common.Hash{}, err
+ }
+ return base, nil
+}
diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go
new file mode 100644
index 000000000..82cc1addc
--- /dev/null
+++ b/core/state/snapshot/snapshot.go
@@ -0,0 +1,598 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package snapshot implements a journalled, dynamic state dump.
+package snapshot
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/metrics"
+ "github.com/tomochain/tomochain/trie"
+)
+
+var (
+ snapshotCleanAccountHitMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/hit", nil)
+ snapshotCleanAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/miss", nil)
+ snapshotCleanAccountInexMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/inex", nil)
+ snapshotCleanAccountReadMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/read", nil)
+ snapshotCleanAccountWriteMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/write", nil)
+
+ snapshotCleanStorageHitMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/hit", nil)
+ snapshotCleanStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/miss", nil)
+ snapshotCleanStorageInexMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/inex", nil)
+ snapshotCleanStorageReadMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/read", nil)
+ snapshotCleanStorageWriteMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/write", nil)
+
+ snapshotDirtyAccountHitMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/hit", nil)
+ snapshotDirtyAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/miss", nil)
+ snapshotDirtyAccountInexMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/inex", nil)
+ snapshotDirtyAccountReadMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/read", nil)
+ snapshotDirtyAccountWriteMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/write", nil)
+
+ snapshotDirtyStorageHitMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/hit", nil)
+ snapshotDirtyStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/miss", nil)
+ snapshotDirtyStorageInexMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/inex", nil)
+ snapshotDirtyStorageReadMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/read", nil)
+ snapshotDirtyStorageWriteMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/write", nil)
+
+ snapshotDirtyAccountHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/account/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015))
+ snapshotDirtyStorageHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/storage/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015))
+
+ snapshotFlushAccountItemMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/item", nil)
+ snapshotFlushAccountSizeMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/size", nil)
+ snapshotFlushStorageItemMeter = metrics.NewRegisteredMeter("state/snapshot/flush/storage/item", nil)
+ snapshotFlushStorageSizeMeter = metrics.NewRegisteredMeter("state/snapshot/flush/storage/size", nil)
+
+ snapshotBloomIndexTimer = metrics.NewRegisteredResettingTimer("state/snapshot/bloom/index", nil)
+ snapshotBloomErrorGauge = metrics.NewRegisteredGaugeFloat64("state/snapshot/bloom/error", nil)
+
+ snapshotBloomAccountTrueHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/truehit", nil)
+ snapshotBloomAccountFalseHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/falsehit", nil)
+ snapshotBloomAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/miss", nil)
+
+ snapshotBloomStorageTrueHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/truehit", nil)
+ snapshotBloomStorageFalseHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/falsehit", nil)
+ snapshotBloomStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/miss", nil)
+
+ // ErrSnapshotStale is returned from data accessors if the underlying snapshot
+ // layer had been invalidated due to the chain progressing forward far enough
+ // to not maintain the layer's original state.
+ ErrSnapshotStale = errors.New("snapshot stale")
+
+ // ErrNotCoveredYet is returned from data accessors if the underlying snapshot
+ // is being generated currently and the requested data item is not yet in the
+ // range of accounts covered.
+ ErrNotCoveredYet = errors.New("not covered yet")
+
+ // errSnapshotCycle is returned if a snapshot is attempted to be inserted
+ // that forms a cycle in the snapshot tree.
+ errSnapshotCycle = errors.New("snapshot cycle")
+)
+
+// Snapshot represents the functionality supported by a snapshot storage layer.
+type Snapshot interface {
+ // Root returns the root hash for which this snapshot was made.
+ Root() common.Hash
+
+ // Account directly retrieves the account associated with a particular hash in
+ // the snapshot slim data format.
+ Account(hash common.Hash) (*types.SlimAccount, error)
+
+ // AccountRLP directly retrieves the account RLP associated with a particular
+ // hash in the snapshot slim data format.
+ AccountRLP(hash common.Hash) ([]byte, error)
+
+ // Storage directly retrieves the storage data associated with a particular hash,
+ // within a particular account.
+ Storage(accountHash, storageHash common.Hash) ([]byte, error)
+}
+
+// snapshot is the internal version of the snapshot data layer that supports some
+// additional methods compared to the public API.
+type snapshot interface {
+ Snapshot
+
+ // Parent returns the subsequent layer of a snapshot, or nil if the base was
+ // reached.
+ //
+ // Note, the method is an internal helper to avoid type switching between the
+ // disk and diff layers. There is no locking involved.
+ Parent() snapshot
+
+ // Update creates a new layer on top of the existing snapshot diff tree with
+ // the specified data items.
+ //
+ // Note, the maps are retained by the method to avoid copying everything.
+ Update(blockRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer
+
+ // Journal commits an entire diff hierarchy to disk into a single journal entry.
+ // This is meant to be used during shutdown to persist the snapshot without
+ // flattening everything down (bad for reorgs).
+ Journal(buffer *bytes.Buffer) (common.Hash, error)
+
+ // Stale return whether this layer has become stale (was flattened across) or
+ // if it's still live.
+ Stale() bool
+
+ // AccountIterator creates an account iterator over an arbitrary layer.
+ AccountIterator(seek common.Hash) AccountIterator
+}
+
+// SnapshotTree is an Ethereum state snapshot tree. It consists of one persistent
+// base layer backed by a key-value store, on top of which arbitrarily many in-
+// memory diff layers are topped. The memory diffs can form a tree with branching,
+// but the disk layer is singleton and common to all. If a reorg goes deeper than
+// the disk layer, everything needs to be deleted.
+//
+// The goal of a state snapshot is twofold: to allow direct access to account and
+// storage data to avoid expensive multi-level trie lookups; and to allow sorted,
+// cheap iteration of the account/storage tries for sync aid.
+type Tree struct {
+ diskdb ethdb.KeyValueStore // Persistent database to store the snapshot
+ triedb *trie.Database // In-memory cache to access the trie through
+ cache int // Megabytes permitted to use for read caches
+ layers map[common.Hash]snapshot // Collection of all known layers
+ lock sync.RWMutex
+}
+
+// New attempts to load an already existing snapshot from a persistent key-value
+// store (with a number of memory layers from a journal), ensuring that the head
+// of the snapshot matches the expected one.
+//
+// If the snapshot is missing or inconsistent, the entirety is deleted and will
+// be reconstructed from scratch based on the tries in the key-value store, on a
+// background thread.
+func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash, async bool) *Tree {
+ // Create a new, empty snapshot tree
+ snap := &Tree{
+ diskdb: diskdb,
+ triedb: triedb,
+ cache: cache,
+ layers: make(map[common.Hash]snapshot),
+ }
+ if !async {
+ defer snap.waitBuild()
+ }
+ // Attempt to load a previously persisted snapshot and rebuild one if failed
+ head, err := loadSnapshot(diskdb, triedb, cache, root)
+ if err != nil {
+ log.Warn("Failed to load snapshot, regenerating", "err", err)
+ snap.Rebuild(root)
+ return snap
+ }
+ // Existing snapshot loaded, seed all the layers
+ for head != nil {
+ snap.layers[head.Root()] = head
+ head = head.Parent()
+ }
+ return snap
+}
+
+// waitBuild blocks until the snapshot finishes rebuilding. This method is meant
+// to be used by tests to ensure we're testing what we believe we are.
+func (t *Tree) waitBuild() {
+ // Find the rebuild termination channel
+ var done chan struct{}
+
+ t.lock.RLock()
+ for _, layer := range t.layers {
+ if layer, ok := layer.(*diskLayer); ok {
+ done = layer.genPending
+ break
+ }
+ }
+ t.lock.RUnlock()
+
+ // Wait until the snapshot is generated
+ if done != nil {
+ <-done
+ }
+}
+
+// Snapshot retrieves a snapshot belonging to the given block root, or nil if no
+// snapshot is maintained for that block.
+func (t *Tree) Snapshot(blockRoot common.Hash) Snapshot {
+ t.lock.RLock()
+ defer t.lock.RUnlock()
+
+ return t.layers[blockRoot]
+}
+
+// Update adds a new snapshot into the tree, if that can be linked to an existing
+// old parent. It is disallowed to insert a disk layer (the origin of all).
+func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) error {
+ // Reject noop updates to avoid self-loops in the snapshot tree. This is a
+ // special case that can only happen for Clique networks where empty blocks
+ // don't modify the state (0 block subsidy).
+ //
+ // Although we could silently ignore this internally, it should be the caller's
+ // responsibility to avoid even attempting to insert such a snapshot.
+ if blockRoot == parentRoot {
+ return errSnapshotCycle
+ }
+ // Generate a new snapshot on top of the parent
+ parent := t.Snapshot(parentRoot).(snapshot)
+ if parent == nil {
+ return fmt.Errorf("parent [%#x] snapshot missing", parentRoot)
+ }
+ snap := parent.Update(blockRoot, destructs, accounts, storage)
+
+ // Save the new snapshot for later
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ t.layers[snap.root] = snap
+ return nil
+}
+
+// Cap traverses downwards the snapshot tree from a head block hash until the
+// number of allowed layers are crossed. All layers beyond the permitted number
+// are flattened downwards.
+func (t *Tree) Cap(root common.Hash, layers int) error {
+ // Retrieve the head snapshot to cap from
+ snap := t.Snapshot(root)
+ if snap == nil {
+ return fmt.Errorf("snapshot [%#x] missing", root)
+ }
+ diff, ok := snap.(*diffLayer)
+ if !ok {
+ return fmt.Errorf("snapshot [%#x] is disk layer", root)
+ }
+ // Run the internal capping and discard all stale layers
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ // Flattening the bottom-most diff layer requires special casing since there's
+ // no child to rewire to the grandparent. In that case we can fake a temporary
+ // child for the capping and then remove it.
+ var persisted *diskLayer
+
+ switch layers {
+ case 0:
+ // If full commit was requested, flatten the diffs and merge onto disk
+ diff.lock.RLock()
+ base := diffToDisk(diff.flatten().(*diffLayer))
+ diff.lock.RUnlock()
+
+ // Replace the entire snapshot tree with the flat base
+ t.layers = map[common.Hash]snapshot{base.root: base}
+ return nil
+
+ case 1:
+ // If full flattening was requested, flatten the diffs but only merge if the
+ // memory limit was reached
+ var (
+ bottom *diffLayer
+ base *diskLayer
+ )
+ diff.lock.RLock()
+ bottom = diff.flatten().(*diffLayer)
+ if bottom.memory >= aggregatorMemoryLimit {
+ base = diffToDisk(bottom)
+ }
+ diff.lock.RUnlock()
+
+ // If all diff layers were removed, replace the entire snapshot tree
+ if base != nil {
+ t.layers = map[common.Hash]snapshot{base.root: base}
+ return nil
+ }
+ // Merge the new aggregated layer into the snapshot tree, clean stales below
+ t.layers[bottom.root] = bottom
+
+ default:
+ // Many layers requested to be retained, cap normally
+ persisted = t.cap(diff, layers)
+ }
+ // Remove any layer that is stale or links into a stale layer
+ children := make(map[common.Hash][]common.Hash)
+ for root, snap := range t.layers {
+ if diff, ok := snap.(*diffLayer); ok {
+ parent := diff.parent.Root()
+ children[parent] = append(children[parent], root)
+ }
+ }
+ var remove func(root common.Hash)
+ remove = func(root common.Hash) {
+ delete(t.layers, root)
+ for _, child := range children[root] {
+ remove(child)
+ }
+ delete(children, root)
+ }
+ for root, snap := range t.layers {
+ if snap.Stale() {
+ remove(root)
+ }
+ }
+ // If the disk layer was modified, regenerate all the cummulative blooms
+ if persisted != nil {
+ var rebloom func(root common.Hash)
+ rebloom = func(root common.Hash) {
+ if diff, ok := t.layers[root].(*diffLayer); ok {
+ diff.rebloom(persisted)
+ }
+ for _, child := range children[root] {
+ rebloom(child)
+ }
+ }
+ rebloom(persisted.root)
+ }
+ return nil
+}
+
+// cap traverses downwards the diff tree until the number of allowed layers are
+// crossed. All diffs beyond the permitted number are flattened downwards. If the
+// layer limit is reached, memory cap is also enforced (but not before).
+//
+// The method returns the new disk layer if diffs were persistend into it.
+func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer {
+ // Dive until we run out of layers or reach the persistent database
+ for ; layers > 2; layers-- {
+ // If we still have diff layers below, continue down
+ if parent, ok := diff.parent.(*diffLayer); ok {
+ diff = parent
+ } else {
+ // Diff stack too shallow, return without modifications
+ return nil
+ }
+ }
+ // We're out of layers, flatten anything below, stopping if it's the disk or if
+ // the memory limit is not yet exceeded.
+ switch parent := diff.parent.(type) {
+ case *diskLayer:
+ return nil
+
+ case *diffLayer:
+ // Flatten the parent into the grandparent. The flattening internally obtains a
+ // write lock on grandparent.
+ flattened := parent.flatten().(*diffLayer)
+ t.layers[flattened.root] = flattened
+
+ diff.lock.Lock()
+ defer diff.lock.Unlock()
+
+ diff.parent = flattened
+ if flattened.memory < aggregatorMemoryLimit {
+ // Accumulator layer is smaller than the limit, so we can abort, unless
+ // there's a snapshot being generated currently. In that case, the trie
+ // will move fron underneath the generator so we **must** merge all the
+ // partial data down into the snapshot and restart the generation.
+ if flattened.parent.(*diskLayer).genAbort == nil {
+ return nil
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unknown data layer: %T", parent))
+ }
+ // If the bottom-most layer is larger than our memory cap, persist to disk
+ bottom := diff.parent.(*diffLayer)
+
+ bottom.lock.RLock()
+ base := diffToDisk(bottom)
+ bottom.lock.RUnlock()
+
+ t.layers[base.root] = base
+ diff.parent = base
+ return base
+}
+
+// diffToDisk merges a bottom-most diff into the persistent disk layer underneath
+// it. The method will panic if called onto a non-bottom-most diff layer.
+func diffToDisk(bottom *diffLayer) *diskLayer {
+ var (
+ base = bottom.parent.(*diskLayer)
+ batch = base.diskdb.NewBatch()
+ stats *generatorStats
+ )
+ // If the disk layer is running a snapshot generator, abort it
+ if base.genAbort != nil {
+ abort := make(chan *generatorStats)
+ base.genAbort <- abort
+ stats = <-abort
+ }
+ // Start by temporarily deleting the current snapshot block marker. This
+ // ensures that in the case of a crash, the entire snapshot is invalidated.
+ rawdb.DeleteSnapshotRoot(batch)
+
+ // Mark the original base as stale as we're going to create a new wrapper
+ base.lock.Lock()
+ if base.stale {
+ panic("parent disk layer is stale") // we've committed into the same base from two children, boo
+ }
+ base.stale = true
+ base.lock.Unlock()
+
+ // Destroy all the destructed accounts from the database
+ for hash := range bottom.destructSet {
+ // Skip any account not covered yet by the snapshot
+ if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 {
+ continue
+ }
+ // Remove all storage slots
+ rawdb.DeleteAccountSnapshot(batch, hash)
+ base.cache.Set(hash[:], nil)
+
+ it := rawdb.IterateStorageSnapshots(base.diskdb, hash)
+ for it.Next() {
+ if key := it.Key(); len(key) == 65 { // TODO(karalabe): Yuck, we should move this into the iterator
+ batch.Delete(key)
+ base.cache.Del(key[1:])
+
+ snapshotFlushStorageItemMeter.Mark(1)
+ }
+ }
+ it.Release()
+ }
+ // Push all updated accounts into the database
+ for hash, data := range bottom.accountData {
+ // Skip any account not covered yet by the snapshot
+ if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 {
+ continue
+ }
+ // Push the account to disk
+ rawdb.WriteAccountSnapshot(batch, hash, data)
+ base.cache.Set(hash[:], data)
+ snapshotCleanAccountWriteMeter.Mark(int64(len(data)))
+
+ if batch.ValueSize() > ethdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to write account snapshot", "err", err)
+ }
+ batch.Reset()
+ }
+ snapshotFlushAccountItemMeter.Mark(1)
+ snapshotFlushAccountSizeMeter.Mark(int64(len(data)))
+ }
+ // Push all the storage slots into the database
+ for accountHash, storage := range bottom.storageData {
+ // Skip any account not covered yet by the snapshot
+ if base.genMarker != nil && bytes.Compare(accountHash[:], base.genMarker) > 0 {
+ continue
+ }
+ // Generation might be mid-account, track that case too
+ midAccount := base.genMarker != nil && bytes.Equal(accountHash[:], base.genMarker[:common.HashLength])
+
+ for storageHash, data := range storage {
+ // Skip any slot not covered yet by the snapshot
+ if midAccount && bytes.Compare(storageHash[:], base.genMarker[common.HashLength:]) > 0 {
+ continue
+ }
+ if len(data) > 0 {
+ rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, data)
+ base.cache.Set(append(accountHash[:], storageHash[:]...), data)
+ snapshotCleanStorageWriteMeter.Mark(int64(len(data)))
+ } else {
+ rawdb.DeleteStorageSnapshot(batch, accountHash, storageHash)
+ base.cache.Set(append(accountHash[:], storageHash[:]...), nil)
+ }
+ snapshotFlushStorageItemMeter.Mark(1)
+ snapshotFlushStorageSizeMeter.Mark(int64(len(data)))
+ }
+ if batch.ValueSize() > ethdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to write storage snapshot", "err", err)
+ }
+ batch.Reset()
+ }
+ }
+ // Update the snapshot block marker and write any remainder data
+ rawdb.WriteSnapshotRoot(batch, bottom.root)
+ if err := batch.Write(); err != nil {
+ log.Crit("Failed to write leftover snapshot", "err", err)
+ }
+ res := &diskLayer{
+ root: bottom.root,
+ cache: base.cache,
+ diskdb: base.diskdb,
+ triedb: base.triedb,
+ genMarker: base.genMarker,
+ genPending: base.genPending,
+ }
+ // If snapshot generation hasn't finished yet, port over all the starts and
+ // continue where the previous round left off.
+ //
+ // Note, the `base.genAbort` comparison is not used normally, it's checked
+ // to allow the tests to play with the marker without triggering this path.
+ if base.genMarker != nil && base.genAbort != nil {
+ res.genMarker = base.genMarker
+ res.genAbort = make(chan chan *generatorStats)
+ go res.generate(stats)
+ }
+ return res
+}
+
+// Journal commits an entire diff hierarchy to disk into a single journal entry.
+// This is meant to be used during shutdown to persist the snapshot without
+// flattening everything down (bad for reorgs).
+//
+// The method returns the root hash of the base layer that needs to be persisted
+// to disk as a trie too to allow continuing any pending generation op.
+func (t *Tree) Journal(root common.Hash) (common.Hash, error) {
+ // Retrieve the head snapshot to journal from var snap snapshot
+ snap := t.Snapshot(root)
+ if snap == nil {
+ return common.Hash{}, fmt.Errorf("snapshot [%#x] missing", root)
+ }
+ // Run the journaling
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ journal := new(bytes.Buffer)
+ base, err := snap.(snapshot).Journal(journal)
+ if err != nil {
+ return common.Hash{}, err
+ }
+ // Store the journal into the database and return
+ rawdb.WriteSnapshotJournal(t.diskdb, journal.Bytes())
+ return base, nil
+}
+
+// Rebuild wipes all available snapshot data from the persistent database and
+// discard all caches and diff layers. Afterwards, it starts a new snapshot
+// generator with the given root hash.
+func (t *Tree) Rebuild(root common.Hash) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ // Iterate over and mark all layers stale
+ for _, layer := range t.layers {
+ switch layer := layer.(type) {
+ case *diskLayer:
+ // If the base layer is generating, abort it and save
+ if layer.genAbort != nil {
+ abort := make(chan *generatorStats)
+ layer.genAbort <- abort
+ <-abort
+ }
+ // Layer should be inactive now, mark it as stale
+ layer.lock.Lock()
+ layer.stale = true
+ layer.lock.Unlock()
+
+ case *diffLayer:
+ // If the layer is a simple diff, simply mark as stale
+ layer.lock.Lock()
+ atomic.StoreUint32(&layer.stale, 1)
+ layer.lock.Unlock()
+
+ default:
+ panic(fmt.Sprintf("unknown layer type: %T", layer))
+ }
+ }
+ // Start generating a new snapshot from scratch on a backgroung thread. The
+ // generator will run a wiper first if there's not one running right now.
+ log.Info("Rebuilding state snapshot")
+ t.layers = map[common.Hash]snapshot{
+ root: generateSnapshot(t.diskdb, t.triedb, t.cache, root),
+ }
+}
+
+// AccountIterator creates a new account iterator for the specified root hash and
+// seeks to a starting account hash.
+func (t *Tree) AccountIterator(root common.Hash, seek common.Hash) (AccountIterator, error) {
+ return newFastAccountIterator(t, root, seek)
+}
diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go
new file mode 100644
index 000000000..35fe62c83
--- /dev/null
+++ b/core/state/snapshot/snapshot_test.go
@@ -0,0 +1,350 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "fmt"
+ "math/big"
+ "math/rand"
+ "testing"
+
+ "github.com/VictoriaMetrics/fastcache"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+// randomHash generates a random blob of data and returns it as a hash.
+func randomHash() common.Hash {
+ var hash common.Hash
+ if n, err := rand.Read(hash[:]); n != common.HashLength || err != nil {
+ panic(err)
+ }
+ return hash
+}
+
+// randomAccount generates a random account and returns it RLP encoded.
+func randomAccount() []byte {
+ root := randomHash()
+ a := types.SlimAccount{
+ Balance: big.NewInt(rand.Int63()),
+ Nonce: rand.Uint64(),
+ Root: root[:],
+ CodeHash: emptyCode[:],
+ }
+ data, _ := rlp.EncodeToBytes(a)
+ return data
+}
+
+// randomAccountSet generates a set of random accounts with the given strings as
+// the account address hashes.
+func randomAccountSet(hashes ...string) map[common.Hash][]byte {
+ accounts := make(map[common.Hash][]byte)
+ for _, hash := range hashes {
+ accounts[common.HexToHash(hash)] = randomAccount()
+ }
+ return accounts
+}
+
+// Tests that if a disk layer becomes stale, no active external references will
+// be returned with junk data. This version of the test flattens every diff layer
+// to check internal corner case around the bottom-most memory accumulator.
+func TestDiskLayerExternalInvalidationFullFlatten(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Retrieve a reference to the base and commit a diff on top
+ ref := snaps.Snapshot(base.root)
+
+ accounts := map[common.Hash][]byte{
+ common.HexToHash("0xa1"): randomAccount(),
+ }
+ if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if n := len(snaps.layers); n != 2 {
+ t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 2)
+ }
+ // Commit the diff layer onto the disk and ensure it's persisted
+ if err := snaps.Cap(common.HexToHash("0x02"), 0); err != nil {
+ t.Fatalf("failed to merge diff layer onto disk: %v", err)
+ }
+ // Since the base layer was modified, ensure that data retrieval on the external reference fail
+ if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
+ }
+ if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err)
+ }
+ if n := len(snaps.layers); n != 1 {
+ t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 1)
+ fmt.Println(snaps.layers)
+ }
+}
+
+// Tests that if a disk layer becomes stale, no active external references will
+// be returned with junk data. This version of the test retains the bottom diff
+// layer to check the usual mode of operation where the accumulator is retained.
+func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Retrieve a reference to the base and commit two diffs on top
+ ref := snaps.Snapshot(base.root)
+
+ accounts := map[common.Hash][]byte{
+ common.HexToHash("0xa1"): randomAccount(),
+ }
+ if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if n := len(snaps.layers); n != 3 {
+ t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3)
+ }
+ // Commit the diff layer onto the disk and ensure it's persisted
+ defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit)
+ aggregatorMemoryLimit = 0
+
+ if err := snaps.Cap(common.HexToHash("0x03"), 2); err != nil {
+ t.Fatalf("failed to merge diff layer onto disk: %v", err)
+ }
+ // Since the base layer was modified, ensure that data retrievald on the external reference fail
+ if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
+ }
+ if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err)
+ }
+ if n := len(snaps.layers); n != 2 {
+ t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2)
+ fmt.Println(snaps.layers)
+ }
+}
+
+// Tests that if a diff layer becomes stale, no active external references will
+// be returned with junk data. This version of the test flattens every diff layer
+// to check internal corner case around the bottom-most memory accumulator.
+func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Commit two diffs on top and retrieve a reference to the bottommost
+ accounts := map[common.Hash][]byte{
+ common.HexToHash("0xa1"): randomAccount(),
+ }
+ if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if n := len(snaps.layers); n != 3 {
+ t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3)
+ }
+ ref := snaps.Snapshot(common.HexToHash("0x02"))
+
+ // Flatten the diff layer into the bottom accumulator
+ if err := snaps.Cap(common.HexToHash("0x03"), 1); err != nil {
+ t.Fatalf("failed to flatten diff layer into accumulator: %v", err)
+ }
+ // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail
+ if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
+ }
+ if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err)
+ }
+ if n := len(snaps.layers); n != 2 {
+ t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2)
+ fmt.Println(snaps.layers)
+ }
+}
+
+// Tests that if a diff layer becomes stale, no active external references will
+// be returned with junk data. This version of the test retains the bottom diff
+// layer to check the usual mode of operation where the accumulator is retained.
+func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) {
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // Commit three diffs on top and retrieve a reference to the bottommost
+ accounts := map[common.Hash][]byte{
+ common.HexToHash("0xa1"): randomAccount(),
+ }
+ if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if err := snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, accounts, nil); err != nil {
+ t.Fatalf("failed to create a diff layer: %v", err)
+ }
+ if n := len(snaps.layers); n != 4 {
+ t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 4)
+ }
+ ref := snaps.Snapshot(common.HexToHash("0x02"))
+
+ // Doing a Cap operation with many allowed layers should be a no-op
+ exp := len(snaps.layers)
+ if err := snaps.Cap(common.HexToHash("0x04"), 2000); err != nil {
+ t.Fatalf("failed to flatten diff layer into accumulator: %v", err)
+ }
+ if got := len(snaps.layers); got != exp {
+ t.Errorf("layers modified, got %d exp %d", got, exp)
+ }
+ // Flatten the diff layer into the bottom accumulator
+ if err := snaps.Cap(common.HexToHash("0x04"), 2); err != nil {
+ t.Fatalf("failed to flatten diff layer into accumulator: %v", err)
+ }
+ // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail
+ if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
+ }
+ if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale {
+ t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err)
+ }
+ if n := len(snaps.layers); n != 3 {
+ t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 3)
+ fmt.Println(snaps.layers)
+ }
+}
+
+// TestPostCapBasicDataAccess tests some functionality regarding capping/flattening.
+func TestPostCapBasicDataAccess(t *testing.T) {
+ // setAccount is a helper to construct a random account entry and assign it to
+ // an account slot in a snapshot
+ setAccount := func(accKey string) map[common.Hash][]byte {
+ return map[common.Hash][]byte{
+ common.HexToHash(accKey): randomAccount(),
+ }
+ }
+ // Create a starting base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: common.HexToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ }
+ snaps := &Tree{
+ layers: map[common.Hash]snapshot{
+ base.root: base,
+ },
+ }
+ // The lowest difflayer
+ snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil)
+ snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil)
+ snaps.Update(common.HexToHash("0xb2"), common.HexToHash("0xa1"), nil, setAccount("0xb2"), nil)
+
+ snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil)
+ snaps.Update(common.HexToHash("0xb3"), common.HexToHash("0xb2"), nil, setAccount("0xb3"), nil)
+
+ // checkExist verifies if an account exiss in a snapshot
+ checkExist := func(layer *diffLayer, key string) error {
+ if data, _ := layer.Account(common.HexToHash(key)); data == nil {
+ return fmt.Errorf("expected %x to exist, got nil", common.HexToHash(key))
+ }
+ return nil
+ }
+ // shouldErr checks that an account access errors as expected
+ shouldErr := func(layer *diffLayer, key string) error {
+ if data, err := layer.Account(common.HexToHash(key)); err == nil {
+ return fmt.Errorf("expected error, got data %x", data)
+ }
+ return nil
+ }
+ // check basics
+ snap := snaps.Snapshot(common.HexToHash("0xb3")).(*diffLayer)
+
+ if err := checkExist(snap, "0xa1"); err != nil {
+ t.Error(err)
+ }
+ if err := checkExist(snap, "0xb2"); err != nil {
+ t.Error(err)
+ }
+ if err := checkExist(snap, "0xb3"); err != nil {
+ t.Error(err)
+ }
+ // Cap to a bad root should fail
+ if err := snaps.Cap(common.HexToHash("0x1337"), 0); err == nil {
+ t.Errorf("expected error, got none")
+ }
+ // Now, merge the a-chain
+ snaps.Cap(common.HexToHash("0xa3"), 0)
+
+ // At this point, a2 got merged into a1. Thus, a1 is now modified, and as a1 is
+ // the parent of b2, b2 should no longer be able to iterate into parent.
+
+ // These should still be accessible
+ if err := checkExist(snap, "0xb2"); err != nil {
+ t.Error(err)
+ }
+ if err := checkExist(snap, "0xb3"); err != nil {
+ t.Error(err)
+ }
+ // But these would need iteration into the modified parent
+ if err := shouldErr(snap, "0xa1"); err != nil {
+ t.Error(err)
+ }
+ if err := shouldErr(snap, "0xa2"); err != nil {
+ t.Error(err)
+ }
+ if err := shouldErr(snap, "0xa3"); err != nil {
+ t.Error(err)
+ }
+ // Now, merge it again, just for fun. It should now error, since a3
+ // is a disk layer
+ if err := snaps.Cap(common.HexToHash("0xa3"), 0); err == nil {
+ t.Error("expected error capping the disk layer, got none")
+ }
+}
diff --git a/core/state/snapshot/sort.go b/core/state/snapshot/sort.go
new file mode 100644
index 000000000..dc877911a
--- /dev/null
+++ b/core/state/snapshot/sort.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+
+ "github.com/tomochain/tomochain/common"
+)
+
+// hashes is a helper to implement sort.Interface.
+type hashes []common.Hash
+
+// Len is the number of elements in the collection.
+func (hs hashes) Len() int { return len(hs) }
+
+// Less reports whether the element with index i should sort before the element
+// with index j.
+func (hs hashes) Less(i, j int) bool { return bytes.Compare(hs[i][:], hs[j][:]) < 0 }
+
+// Swap swaps the elements with indexes i and j.
+func (hs hashes) Swap(i, j int) { hs[i], hs[j] = hs[j], hs[i] }
diff --git a/core/state/state_object.go b/core/state/state_object.go
index b03231e23..b0dc8da14 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -21,9 +21,12 @@ import (
"fmt"
"io"
"math/big"
+ "time"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/metrics"
"github.com/tomochain/tomochain/rlp"
)
@@ -31,23 +34,23 @@ var emptyCodeHash = crypto.Keccak256(nil)
type Code []byte
-func (self Code) String() string {
- return string(self) //strings.Join(Disassemble(self), " ")
+func (c Code) String() string {
+ return string(c) //strings.Join(Disassemble(c), " ")
}
type Storage map[common.Hash]common.Hash
-func (self Storage) String() (str string) {
- for key, value := range self {
+func (s Storage) String() (str string) {
+ for key, value := range s {
str += fmt.Sprintf("%X : %X\n", key, value)
}
return
}
-func (self Storage) Copy() Storage {
+func (s Storage) Copy() Storage {
cpy := make(Storage)
- for key, value := range self {
+ for key, value := range s {
cpy[key] = value
}
@@ -63,7 +66,7 @@ func (self Storage) Copy() Storage {
type stateObject struct {
address common.Address
addrHash common.Hash // hash of ethereum address of the account
- data Account
+ data types.StateAccount
db *StateDB
// DB error.
@@ -77,17 +80,17 @@ type stateObject struct {
trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded
- cachedStorage Storage // Storage entry cache to avoid duplicate reads
- dirtyStorage Storage // Storage entries that need to be flushed to disk
+ originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
+ pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
+ dirtyStorage Storage // Storage entries that have been modified in the current transaction execution
+ fakeStorage Storage // Fake storage which constructed by caller for debugging purpose.
// Cache flags.
- // When an object is marked suicided it will be delete from the trie
+ // When an object is marked suicided it will be deleted from the trie
// during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated
suicided bool
- touched bool
deleted bool
- onDirty func(addr common.Address) // Callback method to mark a state object newly dirty
}
// empty returns whether the account is considered empty.
@@ -95,231 +98,329 @@ func (s *stateObject) empty() bool {
return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash)
}
-// Account is the Ethereum consensus representation of accounts.
-// These objects are stored in the main account trie.
-type Account struct {
- Nonce uint64
- Balance *big.Int
- Root common.Hash // merkle root of the storage trie
- CodeHash []byte
-}
-
// newObject creates a state object.
-func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *stateObject {
+func newObject(db *StateDB, address common.Address, data *types.StateAccount) *stateObject {
if data.Balance == nil {
data.Balance = new(big.Int)
}
if data.CodeHash == nil {
data.CodeHash = emptyCodeHash
}
+ if data.Root == (common.Hash{}) {
+ data.Root = emptyRoot
+ }
return &stateObject{
- db: db,
- address: address,
- addrHash: crypto.Keccak256Hash(address[:]),
- data: data,
- cachedStorage: make(Storage),
- dirtyStorage: make(Storage),
- onDirty: onDirty,
+ db: db,
+ address: address,
+ addrHash: crypto.Keccak256Hash(address[:]),
+ data: *data,
+ originStorage: make(Storage),
+ pendingStorage: make(Storage),
+ dirtyStorage: make(Storage),
}
}
// EncodeRLP implements rlp.Encoder.
-func (c *stateObject) EncodeRLP(w io.Writer) error {
- return rlp.Encode(w, c.data)
+func (s *stateObject) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, s.data)
}
// setError remembers the first non-nil error it is called with.
-func (self *stateObject) setError(err error) {
- if self.dbErr == nil {
- self.dbErr = err
+func (s *stateObject) setError(err error) {
+ if s.dbErr == nil {
+ s.dbErr = err
}
}
-func (self *stateObject) markSuicided() {
- self.suicided = true
- if self.onDirty != nil {
- self.onDirty(self.Address())
- self.onDirty = nil
- }
+func (s *stateObject) markSuicided() {
+ s.suicided = true
}
-func (c *stateObject) touch() {
- c.db.journal = append(c.db.journal, touchChange{
- account: &c.address,
- prev: c.touched,
- prevDirty: c.onDirty == nil,
+func (s *stateObject) touch() {
+ s.db.journal.append(touchChange{
+ account: &s.address,
})
- if c.onDirty != nil {
- c.onDirty(c.Address())
- c.onDirty = nil
+ if s.address == ripemd {
+ // Explicitly put it in the dirty-cache, which is otherwise generated from
+ // flattened journals.
+ s.db.journal.dirty(s.address)
}
- c.touched = true
}
-func (c *stateObject) getTrie(db Database) Trie {
- if c.trie == nil {
+func (s *stateObject) getTrie(db Database) Trie {
+ if s.trie == nil {
var err error
- c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root)
+ s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root)
if err != nil {
- c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{})
- c.setError(fmt.Errorf("can't create storage trie: %v", err))
+ s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{})
+ s.setError(fmt.Errorf("can't create storage trie: %v", err))
}
}
- return c.trie
+ return s.trie
}
-func (self *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash {
- value := common.Hash{}
- // Load from DB in case it is missing.
- enc, err := self.getTrie(db).TryGet(key[:])
- if err != nil {
- self.setError(err)
- return common.Hash{}
+// GetState retrieves a value from the account storage trie.
+func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
+ // If the fake storage is set, only lookup the state here(in the debugging mode)
+ if s.fakeStorage != nil {
+ return s.fakeStorage[key]
}
- if len(enc) > 0 {
- _, content, _, err := rlp.Split(enc)
- if err != nil {
- self.setError(err)
- }
- value.SetBytes(content)
+ // If we have a dirty value for this state entry, return it
+ value, dirty := s.dirtyStorage[key]
+ if dirty {
+ return value
}
- return value
+ // Otherwise return the entry's original value
+ return s.GetCommittedState(db, key)
}
-func (self *stateObject) GetState(db Database, key common.Hash) common.Hash {
- value, exists := self.cachedStorage[key]
- if exists {
+// GetCommittedState retrieves a value from the committed account storage trie.
+func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash {
+ // If the fake storage is set, only lookup the state here(in the debugging mode)
+ if s.fakeStorage != nil {
+ return s.fakeStorage[key]
+ }
+ // If we have a pending write or clean cached, return that
+ if value, pending := s.pendingStorage[key]; pending {
return value
}
- // Load from DB in case it is missing.
- enc, err := self.getTrie(db).TryGet(key[:])
- if err != nil {
- self.setError(err)
- return common.Hash{}
+ if value, cached := s.originStorage[key]; cached {
+ return value
}
- if len(enc) > 0 {
- _, content, _, err := rlp.Split(enc)
- if err != nil {
- self.setError(err)
+ // If no live objects are available, attempt to use snapshots
+ var (
+ enc []byte
+ err error
+ value common.Hash
+ )
+ if s.db.snap != nil {
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.db.SnapshotStorageReads += time.Since(start) }(time.Now())
+ }
+ // If the object was destructed in *this* block (and potentially resurrected),
+ // the storage has been cleared out, and we should *not* consult the previous
+ // snapshot about any storage values. The only possible alternatives are:
+ // 1) resurrect happened, and new slot values were set -- those should
+ // have been handles via pendingStorage above.
+ // 2) we don't have new values, and can deliver empty response back
+ if _, destructed := s.db.snapDestructs[s.addrHash]; destructed {
+ return common.Hash{}
+ }
+ enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key[:]))
+ if len(enc) > 0 {
+ _, content, _, err := rlp.Split(enc)
+ if err != nil {
+ s.setError(err)
+ }
+ value.SetBytes(content)
}
- value.SetBytes(content)
}
- if (value != common.Hash{}) {
- self.cachedStorage[key] = value
+ // If snapshot unavailable or reading from it failed, load from the database
+ if s.db.snap == nil || err != nil {
+ start := time.Now()
+ val, err := s.getTrie(db).GetStorage(s.address, key.Bytes())
+ if metrics.EnabledExpensive {
+ s.db.StorageReads += time.Since(start)
+ }
+ if err != nil {
+ s.setError(err)
+ return common.Hash{}
+ }
+ value.SetBytes(val)
}
+ s.originStorage[key] = value
return value
}
// SetState updates a value in account storage.
-func (self *stateObject) SetState(db Database, key, value common.Hash) {
- self.db.journal = append(self.db.journal, storageChange{
- account: &self.address,
+func (s *stateObject) SetState(db Database, key, value common.Hash) {
+ // If the fake storage is set, put the temporary state update here.
+ if s.fakeStorage != nil {
+ s.fakeStorage[key] = value
+ return
+ }
+ // If the new value is the same as old, don't set
+ prev := s.GetState(db, key)
+ if prev == value {
+ return
+ }
+ // New value is different, update and journal the change
+ s.db.journal.append(storageChange{
+ account: &s.address,
key: key,
- prevalue: self.GetState(db, key),
+ prevalue: prev,
})
- self.setState(key, value)
+ s.setState(key, value)
}
-func (self *stateObject) setState(key, value common.Hash) {
- self.cachedStorage[key] = value
- self.dirtyStorage[key] = value
+// SetStorage replaces the entire state storage with the given one.
+//
+// After this function is called, all original state will be ignored and state
+// lookup only happens in the fake state storage.
+//
+// Note this function should only be used for debugging purpose.
+func (s *stateObject) SetStorage(storage map[common.Hash]common.Hash) {
+ // Allocate fake storage if it's nil.
+ if s.fakeStorage == nil {
+ s.fakeStorage = make(Storage)
+ }
+ for key, value := range storage {
+ s.fakeStorage[key] = value
+ }
+ // Don't bother journal since this function should only be used for
+ // debugging and the `fake` storage won't be committed to database.
+}
+
+func (s *stateObject) setState(key, value common.Hash) {
+ s.dirtyStorage[key] = value
+}
- if self.onDirty != nil {
- self.onDirty(self.Address())
- self.onDirty = nil
+// finalise moves all dirty storage slots into the pending area to be hashed or
+// committed later. It is invoked at the end of every transaction.
+func (s *stateObject) finalise() {
+ for key, value := range s.dirtyStorage {
+ s.pendingStorage[key] = value
+ }
+ if len(s.dirtyStorage) > 0 {
+ s.dirtyStorage = make(Storage)
}
}
// updateTrie writes cached storage modifications into the object's storage trie.
-func (self *stateObject) updateTrie(db Database) Trie {
- tr := self.getTrie(db)
- for key, value := range self.dirtyStorage {
- delete(self.dirtyStorage, key)
- if (value == common.Hash{}) {
- self.setError(tr.TryDelete(key[:]))
+// It will return nil if the trie has not been loaded and no changes have been made
+func (s *stateObject) updateTrie(db Database) Trie {
+ // Make sure all dirty slots are finalized into the pending storage area
+ s.finalise()
+ if len(s.pendingStorage) == 0 {
+ return s.trie
+ }
+ // Track the amount of time wasted on updating the storage trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now())
+ }
+ // Retrieve the snapshot storage map for the object
+ var storage map[common.Hash][]byte
+ if s.db.snap != nil {
+ // Retrieve the old storage map, if available, create a new one otherwise
+ storage = s.db.snapStorage[s.addrHash]
+ if storage == nil {
+ storage = make(map[common.Hash][]byte)
+ s.db.snapStorage[s.addrHash] = storage
+ }
+ }
+ // Insert all the pending updates into the trie
+ tr := s.getTrie(db)
+ for key, value := range s.pendingStorage {
+ // Skip noop changes, persist actual changes
+ if value == s.originStorage[key] {
continue
}
- // Encoding []byte cannot fail, ok to ignore the error.
- v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00"))
- self.setError(tr.TryUpdate(key[:], v))
+ s.originStorage[key] = value
+
+ var v []byte
+ if (value == common.Hash{}) {
+ s.setError(tr.DeleteStorage(s.address, key.Bytes()))
+ } else {
+ // Encoding []byte cannot fail, ok to ignore the error.
+ v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:]))
+ s.setError(tr.UpdateStorage(s.address, key.Bytes(), v))
+ }
+ // If state snapshotting is active, cache the data til commit
+ if storage != nil {
+ storage[crypto.Keccak256Hash(key[:])] = v // v will be nil if value is 0x00
+ }
+ }
+ if len(s.pendingStorage) > 0 {
+ s.pendingStorage = make(Storage)
}
return tr
}
// UpdateRoot sets the trie root to the current root hash of
-func (self *stateObject) updateRoot(db Database) {
- self.updateTrie(db)
- self.data.Root = self.trie.Hash()
+func (s *stateObject) updateRoot(db Database) {
+ // If nothing changed, don't bother with hashing anything
+ if s.updateTrie(db) == nil {
+ return
+ }
+ // Track the amount of time wasted on hashing the storage trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now())
+ }
+ s.data.Root = s.trie.Hash()
}
-// CommitTrie the storage trie of the object to dwb.
+// CommitTrie the storage trie of the object to db.
// This updates the trie root.
-func (self *stateObject) CommitTrie(db Database) error {
- self.updateTrie(db)
- if self.dbErr != nil {
- return self.dbErr
+func (s *stateObject) CommitTrie(db Database) error {
+ // If nothing changed, don't bother with hashing anything
+ if s.updateTrie(db) == nil {
+ return nil
+ }
+ if s.dbErr != nil {
+ return s.dbErr
+ }
+ // Track the amount of time wasted on committing the storage trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now())
}
- root, err := self.trie.Commit(nil)
+ root, err := s.trie.Commit(nil)
if err == nil {
- self.data.Root = root
+ s.data.Root = root
}
return err
}
// AddBalance removes amount from c's balance.
// It is used to add funds to the destination account of a transfer.
-func (c *stateObject) AddBalance(amount *big.Int) {
+func (s *stateObject) AddBalance(amount *big.Int) {
// EIP158: We must check emptiness for the objects such that the account
// clearing (0,0,0 objects) can take effect.
if amount.Sign() == 0 {
- if c.empty() {
- c.touch()
+ if s.empty() {
+ s.touch()
}
return
}
- c.SetBalance(new(big.Int).Add(c.Balance(), amount))
+ s.SetBalance(new(big.Int).Add(s.Balance(), amount))
}
// SubBalance removes amount from c's balance.
// It is used to remove funds from the origin account of a transfer.
-func (c *stateObject) SubBalance(amount *big.Int) {
+func (s *stateObject) SubBalance(amount *big.Int) {
if amount.Sign() == 0 {
return
}
- c.SetBalance(new(big.Int).Sub(c.Balance(), amount))
+ s.SetBalance(new(big.Int).Sub(s.Balance(), amount))
}
-func (self *stateObject) SetBalance(amount *big.Int) {
- self.db.journal = append(self.db.journal, balanceChange{
- account: &self.address,
- prev: new(big.Int).Set(self.data.Balance),
+func (s *stateObject) SetBalance(amount *big.Int) {
+ s.db.journal.append(balanceChange{
+ account: &s.address,
+ prev: new(big.Int).Set(s.data.Balance),
})
- self.setBalance(amount)
+ s.setBalance(amount)
}
-func (self *stateObject) setBalance(amount *big.Int) {
- self.data.Balance = amount
- if self.onDirty != nil {
- self.onDirty(self.Address())
- self.onDirty = nil
- }
+func (s *stateObject) setBalance(amount *big.Int) {
+ s.data.Balance = amount
}
-// Return the gas back to the origin. Used by the Virtual machine or Closures
-func (c *stateObject) ReturnGas(gas *big.Int) {}
+// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures
+func (s *stateObject) ReturnGas(gas *big.Int) {}
-func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject {
- stateObject := newObject(db, self.address, self.data, onDirty)
- if self.trie != nil {
- stateObject.trie = db.db.CopyTrie(self.trie)
+func (s *stateObject) deepCopy(db *StateDB) *stateObject {
+ stateObject := newObject(db, s.address, &s.data)
+ if s.trie != nil {
+ stateObject.trie = db.db.CopyTrie(s.trie)
}
- stateObject.code = self.code
- stateObject.dirtyStorage = self.dirtyStorage.Copy()
- stateObject.cachedStorage = self.dirtyStorage.Copy()
- stateObject.suicided = self.suicided
- stateObject.dirtyCode = self.dirtyCode
- stateObject.deleted = self.deleted
+ stateObject.code = s.code
+ stateObject.dirtyStorage = s.dirtyStorage.Copy()
+ stateObject.originStorage = s.originStorage.Copy()
+ stateObject.pendingStorage = s.pendingStorage.Copy()
+ stateObject.suicided = s.suicided
+ stateObject.dirtyCode = s.dirtyCode
+ stateObject.deleted = s.deleted
return stateObject
}
@@ -327,78 +428,70 @@ func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)
// Attribute accessors
//
-// Returns the address of the contract/account
-func (c *stateObject) Address() common.Address {
- return c.address
+// Address returns the address of the contract/account
+func (s *stateObject) Address() common.Address {
+ return s.address
}
// Code returns the contract code associated with this object, if any.
-func (self *stateObject) Code(db Database) []byte {
- if self.code != nil {
- return self.code
+func (s *stateObject) Code(db Database) []byte {
+ if s.code != nil {
+ return s.code
}
- if bytes.Equal(self.CodeHash(), emptyCodeHash) {
+ if bytes.Equal(s.CodeHash(), emptyCodeHash) {
return nil
}
- code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash()))
+ code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash()))
if err != nil {
- self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err))
+ s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
}
- self.code = code
+ s.code = code
return code
}
-func (self *stateObject) SetCode(codeHash common.Hash, code []byte) {
- prevcode := self.Code(self.db.db)
- self.db.journal = append(self.db.journal, codeChange{
- account: &self.address,
- prevhash: self.CodeHash(),
+func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
+ prevcode := s.Code(s.db.db)
+ s.db.journal.append(codeChange{
+ account: &s.address,
+ prevhash: s.CodeHash(),
prevcode: prevcode,
})
- self.setCode(codeHash, code)
+ s.setCode(codeHash, code)
}
-func (self *stateObject) setCode(codeHash common.Hash, code []byte) {
- self.code = code
- self.data.CodeHash = codeHash[:]
- self.dirtyCode = true
- if self.onDirty != nil {
- self.onDirty(self.Address())
- self.onDirty = nil
- }
+func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
+ s.code = code
+ s.data.CodeHash = codeHash[:]
+ s.dirtyCode = true
}
-func (self *stateObject) SetNonce(nonce uint64) {
- self.db.journal = append(self.db.journal, nonceChange{
- account: &self.address,
- prev: self.data.Nonce,
+func (s *stateObject) SetNonce(nonce uint64) {
+ s.db.journal.append(nonceChange{
+ account: &s.address,
+ prev: s.data.Nonce,
})
- self.setNonce(nonce)
+ s.setNonce(nonce)
}
-func (self *stateObject) setNonce(nonce uint64) {
- self.data.Nonce = nonce
- if self.onDirty != nil {
- self.onDirty(self.Address())
- self.onDirty = nil
- }
+func (s *stateObject) setNonce(nonce uint64) {
+ s.data.Nonce = nonce
}
-func (self *stateObject) CodeHash() []byte {
- return self.data.CodeHash
+func (s *stateObject) CodeHash() []byte {
+ return s.data.CodeHash
}
-func (self *stateObject) Balance() *big.Int {
- return self.data.Balance
+func (s *stateObject) Balance() *big.Int {
+ return s.data.Balance
}
-func (self *stateObject) Nonce() uint64 {
- return self.data.Nonce
+func (s *stateObject) Nonce() uint64 {
+ return s.data.Nonce
}
-// Never called, but must be present to allow stateObject to be used
+// Value is never called, but must be present to allow stateObject to be used
// as a vm.Account interface that also satisfies the vm.ContractRef
// interface. Interfaces are awesome.
-func (self *stateObject) Value() *big.Int {
+func (s *stateObject) Value() *big.Int {
panic("Value on stateObject should never be called")
}
diff --git a/core/state/state_test.go b/core/state/state_test.go
index 30cca6c36..17ecf3b19 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -18,14 +18,16 @@ package state
import (
"bytes"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"testing"
+ checker "gopkg.in/check.v1"
+
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/ethdb"
- checker "gopkg.in/check.v1"
+ "github.com/tomochain/tomochain/trie"
)
type StateSuite struct {
@@ -88,8 +90,9 @@ func (s *StateSuite) TestDump(c *checker.C) {
}
func (s *StateSuite) SetUpTest(c *checker.C) {
- s.db= rawdb.NewMemoryDatabase()
- s.state, _ = New(common.Hash{}, NewDatabase(s.db))
+ s.db = rawdb.NewMemoryDatabase()
+ tdb := NewDatabaseWithConfig(s.db, &trie.Config{Preimages: true})
+ s.state, _ = New(common.Hash{}, tdb, nil)
}
func (s *StateSuite) TestNull(c *checker.C) {
@@ -135,7 +138,7 @@ func (s *StateSuite) TestSnapshotEmpty(c *checker.C) {
// printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) {
db := rawdb.NewMemoryDatabase()
- state, _ := New(common.Hash{}, NewDatabase(db))
+ state, _ := New(common.Hash{}, NewDatabase(db), nil)
stateobjaddr0 := toAddr([]byte("so0"))
stateobjaddr1 := toAddr([]byte("so1"))
@@ -210,24 +213,30 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
}
- if len(so1.cachedStorage) != len(so0.cachedStorage) {
- t.Errorf("Storage size mismatch: have %d, want %d", len(so1.cachedStorage), len(so0.cachedStorage))
+ if len(so1.dirtyStorage) != len(so0.dirtyStorage) {
+ t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage))
}
- for k, v := range so1.cachedStorage {
- if so0.cachedStorage[k] != v {
- t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.cachedStorage[k], v)
+ for k, v := range so1.dirtyStorage {
+ if so0.dirtyStorage[k] != v {
+ t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v)
}
}
- for k, v := range so0.cachedStorage {
- if so1.cachedStorage[k] != v {
- t.Errorf("Storage key %x mismatch: have %v, want none.", k, v)
+ for k, v := range so0.dirtyStorage {
+ if so1.dirtyStorage[k] != v {
+ t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v)
}
}
-
- if so0.suicided != so1.suicided {
- t.Fatalf("suicided mismatch: have %v, want %v", so0.suicided, so1.suicided)
+ if len(so1.originStorage) != len(so0.originStorage) {
+ t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage))
+ }
+ for k, v := range so1.originStorage {
+ if so0.originStorage[k] != v {
+ t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v)
+ }
}
- if so0.deleted != so1.deleted {
- t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted)
+ for k, v := range so0.originStorage {
+ if so1.originStorage[k] != v {
+ t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v)
+ }
}
}
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 7a3357b3e..f264ede6d 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -22,11 +22,13 @@ import (
"math/big"
"sort"
"sync"
+ "time"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/state/snapshot"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
- "github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/metrics"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/trie"
)
@@ -37,6 +39,9 @@ type revision struct {
}
var (
+ // emptyRoot is the known root hash of an empty trie.
+ emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+
// emptyState is the known hash of an empty state trie entry.
emptyState = crypto.Keccak256Hash(nil)
@@ -44,7 +49,7 @@ var (
emptyCode = crypto.Keccak256Hash(nil)
)
-// StateDBs within the ethereum protocol are used to store anything
+// StateDB within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve:
// * Contracts
@@ -53,6 +58,12 @@ type StateDB struct {
db Database
trie Trie
+ snaps *snapshot.Tree
+ snap snapshot.Snapshot
+ snapDestructs map[common.Hash]struct{}
+ snapAccounts map[common.Hash][]byte
+ snapStorage map[common.Hash]map[common.Hash][]byte
+
// This map holds 'live' objects, which will get modified while processing a state transition.
stateObjects map[common.Address]*stateObject
stateObjectsDirty map[common.Address]struct{}
@@ -76,144 +87,179 @@ type StateDB struct {
// Journal of state modifications. This is the backbone of
// Snapshot and RevertToSnapshot.
- journal journal
+ journal *journal
validRevisions []revision
nextRevisionId int
+ // Measurements gathered during execution for debugging purposes
+ AccountReads time.Duration
+ AccountHashes time.Duration
+ AccountUpdates time.Duration
+ AccountCommits time.Duration
+ StorageReads time.Duration
+ StorageHashes time.Duration
+ StorageUpdates time.Duration
+ StorageCommits time.Duration
+ SnapshotAccountReads time.Duration
+ SnapshotStorageReads time.Duration
+ SnapshotCommits time.Duration
+
lock sync.Mutex
}
-func (self *StateDB) SubRefund(gas uint64) {
- self.journal = append(self.journal, refundChange{
- prev: self.refund})
- if gas > self.refund {
- panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, self.refund))
+func (s *StateDB) SubRefund(gas uint64) {
+ s.journal.append(refundChange{
+ prev: s.refund})
+ if gas > s.refund {
+ panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund))
}
- self.refund -= gas
+ s.refund -= gas
}
-func (self *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
+ stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.GetCommittedState(self.db, hash)
+ return stateObject.GetCommittedState(s.db, hash)
}
return common.Hash{}
}
-// Create a new state from a given trie.
-func New(root common.Hash, db Database) (*StateDB, error) {
+// New creates a new state from a given trie.
+func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) {
tr, err := db.OpenTrie(root)
if err != nil {
return nil, err
}
- return &StateDB{
+ sdb := &StateDB{
db: db,
trie: tr,
+ snaps: snaps,
stateObjects: make(map[common.Address]*stateObject),
stateObjectsDirty: make(map[common.Address]struct{}),
logs: make(map[common.Hash][]*types.Log),
preimages: make(map[common.Hash][]byte),
- }, nil
+ journal: newJournal(),
+ }
+ if sdb.snaps != nil {
+ sdb.snap = sdb.snaps.Snapshot(root)
+ }
+ if sdb.snaps != nil {
+ if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil {
+ sdb.snapDestructs = make(map[common.Hash]struct{})
+ sdb.snapAccounts = make(map[common.Hash][]byte)
+ sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
+ }
+ }
+ return sdb, nil
}
// setError remembers the first non-nil error it is called with.
-func (self *StateDB) setError(err error) {
- if self.dbErr == nil {
- self.dbErr = err
+func (s *StateDB) setError(err error) {
+ if s.dbErr == nil {
+ s.dbErr = err
}
}
-func (self *StateDB) Error() error {
- return self.dbErr
+func (s *StateDB) Error() error {
+ return s.dbErr
}
// Reset clears out all ephemeral state objects from the state db, but keeps
// the underlying state trie to avoid reloading data for the next operations.
-func (self *StateDB) Reset(root common.Hash) error {
- tr, err := self.db.OpenTrie(root)
+func (s *StateDB) Reset(root common.Hash) error {
+ tr, err := s.db.OpenTrie(root)
if err != nil {
return err
}
- self.trie = tr
- self.stateObjects = make(map[common.Address]*stateObject)
- self.stateObjectsDirty = make(map[common.Address]struct{})
- self.thash = common.Hash{}
- self.bhash = common.Hash{}
- self.txIndex = 0
- self.logs = make(map[common.Hash][]*types.Log)
- self.logSize = 0
- self.preimages = make(map[common.Hash][]byte)
- self.clearJournalAndRefund()
+ s.trie = tr
+ s.stateObjects = make(map[common.Address]*stateObject)
+ s.stateObjectsDirty = make(map[common.Address]struct{})
+ s.thash = common.Hash{}
+ s.bhash = common.Hash{}
+ s.txIndex = 0
+ s.logs = make(map[common.Hash][]*types.Log)
+ s.logSize = 0
+ s.preimages = make(map[common.Hash][]byte)
+ s.clearJournalAndRefund()
+
+ if s.snaps != nil {
+ s.snapAccounts, s.snapDestructs, s.snapStorage = nil, nil, nil
+ if s.snap = s.snaps.Snapshot(root); s.snap != nil {
+ s.snapDestructs = make(map[common.Hash]struct{})
+ s.snapAccounts = make(map[common.Hash][]byte)
+ s.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
+ }
+ }
return nil
}
-func (self *StateDB) AddLog(log *types.Log) {
- self.journal = append(self.journal, addLogChange{txhash: self.thash})
+func (s *StateDB) AddLog(log *types.Log) {
+ s.journal.append(addLogChange{txhash: s.thash})
- log.TxHash = self.thash
- log.BlockHash = self.bhash
- log.TxIndex = uint(self.txIndex)
- log.Index = self.logSize
- self.logs[self.thash] = append(self.logs[self.thash], log)
- self.logSize++
+ log.TxHash = s.thash
+ log.BlockHash = s.bhash
+ log.TxIndex = uint(s.txIndex)
+ log.Index = s.logSize
+ s.logs[s.thash] = append(s.logs[s.thash], log)
+ s.logSize++
}
-func (self *StateDB) GetLogs(hash common.Hash) []*types.Log {
- return self.logs[hash]
+func (s *StateDB) GetLogs(hash common.Hash) []*types.Log {
+ return s.logs[hash]
}
-func (self *StateDB) Logs() []*types.Log {
+func (s *StateDB) Logs() []*types.Log {
var logs []*types.Log
- for _, lgs := range self.logs {
+ for _, lgs := range s.logs {
logs = append(logs, lgs...)
}
return logs
}
// AddPreimage records a SHA3 preimage seen by the VM.
-func (self *StateDB) AddPreimage(hash common.Hash, preimage []byte) {
- if _, ok := self.preimages[hash]; !ok {
- self.journal = append(self.journal, addPreimageChange{hash: hash})
+func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) {
+ if _, ok := s.preimages[hash]; !ok {
+ s.journal.append(addPreimageChange{hash: hash})
pi := make([]byte, len(preimage))
copy(pi, preimage)
- self.preimages[hash] = pi
+ s.preimages[hash] = pi
}
}
// Preimages returns a list of SHA3 preimages that have been submitted.
-func (self *StateDB) Preimages() map[common.Hash][]byte {
- return self.preimages
+func (s *StateDB) Preimages() map[common.Hash][]byte {
+ return s.preimages
}
-func (self *StateDB) AddRefund(gas uint64) {
- self.journal = append(self.journal, refundChange{prev: self.refund})
- self.refund += gas
+func (s *StateDB) AddRefund(gas uint64) {
+ s.journal.append(refundChange{prev: s.refund})
+ s.refund += gas
}
// Exist reports whether the given account address exists in the state.
// Notably this also returns true for suicided accounts.
-func (self *StateDB) Exist(addr common.Address) bool {
- return self.getStateObject(addr) != nil
+func (s *StateDB) Exist(addr common.Address) bool {
+ return s.getStateObject(addr) != nil
}
// Empty returns whether the state object is either non-existent
// or empty according to the EIP161 specification (balance = nonce = code = 0)
-func (self *StateDB) Empty(addr common.Address) bool {
- so := self.getStateObject(addr)
+func (s *StateDB) Empty(addr common.Address) bool {
+ so := s.getStateObject(addr)
return so == nil || so.empty()
}
-// Retrieve the balance from the given address or 0 if object not found
-func (self *StateDB) GetBalance(addr common.Address) *big.Int {
- stateObject := self.getStateObject(addr)
+// GetBalance retrieves the balance from the given address or 0 if object not found
+func (s *StateDB) GetBalance(addr common.Address) *big.Int {
+ stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.Balance()
}
return common.Big0
}
-func (self *StateDB) GetNonce(addr common.Address) uint64 {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) GetNonce(addr common.Address) uint64 {
+ stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.Nonce()
}
@@ -221,63 +267,63 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 {
return 0
}
-func (self *StateDB) GetCode(addr common.Address) []byte {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) GetCode(addr common.Address) []byte {
+ stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.Code(self.db)
+ return stateObject.Code(s.db)
}
return nil
}
-func (self *StateDB) GetCodeSize(addr common.Address) int {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) GetCodeSize(addr common.Address) int {
+ stateObject := s.getStateObject(addr)
if stateObject == nil {
return 0
}
if stateObject.code != nil {
return len(stateObject.code)
}
- size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash()))
+ size, err := s.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash()))
if err != nil {
- self.setError(err)
+ s.setError(err)
}
return size
}
-func (self *StateDB) GetCodeHash(addr common.Address) common.Hash {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) GetCodeHash(addr common.Address) common.Hash {
+ stateObject := s.getStateObject(addr)
if stateObject == nil {
return common.Hash{}
}
return common.BytesToHash(stateObject.CodeHash())
}
-func (self *StateDB) GetState(addr common.Address, bhash common.Hash) common.Hash {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) GetState(addr common.Address, bhash common.Hash) common.Hash {
+ stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.GetState(self.db, bhash)
+ return stateObject.GetState(s.db, bhash)
}
return common.Hash{}
}
// Database retrieves the low level database supporting the lower level trie ops.
-func (self *StateDB) Database() Database {
- return self.db
+func (s *StateDB) Database() Database {
+ return s.db
}
// StorageTrie returns the storage trie of an account.
// The return value is a copy and is nil for non-existent accounts.
-func (self *StateDB) StorageTrie(addr common.Address) Trie {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) StorageTrie(addr common.Address) Trie {
+ stateObject := s.getStateObject(addr)
if stateObject == nil {
return nil
}
- cpy := stateObject.deepCopy(self, nil)
- return cpy.updateTrie(self.db)
+ cpy := stateObject.deepCopy(s)
+ return cpy.updateTrie(s.db)
}
-func (self *StateDB) HasSuicided(addr common.Address) bool {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) HasSuicided(addr common.Address) bool {
+ stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.suicided
}
@@ -289,46 +335,46 @@ func (self *StateDB) HasSuicided(addr common.Address) bool {
*/
// AddBalance adds amount to the account associated with addr.
-func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) {
- stateObject := self.GetOrNewStateObject(addr)
+func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
+ stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.AddBalance(amount)
}
}
// SubBalance subtracts amount from the account associated with addr.
-func (self *StateDB) SubBalance(addr common.Address, amount *big.Int) {
- stateObject := self.GetOrNewStateObject(addr)
+func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) {
+ stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SubBalance(amount)
}
}
-func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) {
- stateObject := self.GetOrNewStateObject(addr)
+func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) {
+ stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetBalance(amount)
}
}
-func (self *StateDB) SetNonce(addr common.Address, nonce uint64) {
- stateObject := self.GetOrNewStateObject(addr)
+func (s *StateDB) SetNonce(addr common.Address, nonce uint64) {
+ stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetNonce(nonce)
}
}
-func (self *StateDB) SetCode(addr common.Address, code []byte) {
- stateObject := self.GetOrNewStateObject(addr)
+func (s *StateDB) SetCode(addr common.Address, code []byte) {
+ stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetCode(crypto.Keccak256Hash(code), code)
}
}
-func (self *StateDB) SetState(addr common.Address, key, value common.Hash) {
- stateObject := self.GetOrNewStateObject(addr)
+func (s *StateDB) SetState(addr common.Address, key, value common.Hash) {
+ stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
- stateObject.SetState(self.db, key, value)
+ stateObject.SetState(s.db, key, value)
}
}
@@ -337,12 +383,12 @@ func (self *StateDB) SetState(addr common.Address, key, value common.Hash) {
//
// The account's state object is still available until the state is committed,
// getStateObject will return a non-nil account after Suicide.
-func (self *StateDB) Suicide(addr common.Address) bool {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) Suicide(addr common.Address) bool {
+ stateObject := s.getStateObject(addr)
if stateObject == nil {
return false
}
- self.journal = append(self.journal, suicideChange{
+ s.journal.append(suicideChange{
account: &addr,
prev: stateObject.suicided,
prevbalance: new(big.Int).Set(stateObject.Balance()),
@@ -358,34 +404,43 @@ func (self *StateDB) Suicide(addr common.Address) bool {
//
// updateStateObject writes the given object to the trie.
-func (self *StateDB) updateStateObject(stateObject *stateObject) {
+func (s *StateDB) updateStateObject(stateObject *stateObject) {
addr := stateObject.Address()
- data, err := rlp.EncodeToBytes(stateObject)
- if err != nil {
- panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err))
+ if err := s.trie.UpdateAccount(addr, &stateObject.data); err != nil {
+ s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err))
}
- self.setError(self.trie.TryUpdate(addr[:], data))
+
+ // If state snapshotting is active, cache the data til commit. Note, this
+ // update mechanism is not symmetric to the deletion, because whereas it is
+ // enough to track account updates at commit time, deletions need tracking
+ // at transaction boundary level to ensure we capture state clearing.
+ if s.snap != nil {
+ s.snapAccounts[stateObject.addrHash] = types.SlimAccountRLP(stateObject.data)
+ }
+
}
// deleteStateObject removes the given object from the state trie.
-func (self *StateDB) deleteStateObject(stateObject *stateObject) {
+func (s *StateDB) deleteStateObject(stateObject *stateObject) {
stateObject.deleted = true
addr := stateObject.Address()
- self.setError(self.trie.TryDelete(addr[:]))
+ if err := s.trie.DeleteAccount(addr); err != nil {
+ s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err))
+ }
}
// DeleteAddress removes the address from the state trie.
-func (self *StateDB) DeleteAddress(addr common.Address) {
- stateObject := self.getStateObject(addr)
+func (s *StateDB) DeleteAddress(addr common.Address) {
+ stateObject := s.getStateObject(addr)
if stateObject != nil && !stateObject.deleted {
- self.deleteStateObject(stateObject)
+ s.deleteStateObject(stateObject)
}
}
// Retrieve a state object given my the address. Returns nil if not found.
-func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) {
+func (s *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) {
// Prefer 'live' objects.
- if obj := self.stateObjects[addr]; obj != nil {
+ if obj := s.stateObjects[addr]; obj != nil {
if obj.deleted {
return nil
}
@@ -393,53 +448,95 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje
}
// Load the object from the database.
- enc, err := self.trie.TryGet(addr[:])
- if len(enc) == 0 {
- self.setError(err)
+ data, err := s.trie.GetAccount(addr)
+ if err != nil {
+ s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err))
return nil
}
- var data Account
- if err := rlp.DecodeBytes(enc, &data); err != nil {
- log.Error("Failed to decode state object", "addr", addr, "err", err)
+ if data == nil {
return nil
}
// Insert into the live set.
- obj := newObject(self, addr, data, self.MarkStateObjectDirty)
- self.setStateObject(obj)
+ obj := newObject(s, addr, data)
+ s.setStateObject(obj)
+ return obj
+}
+
+// getDeletedStateObject is similar to getStateObject, but instead of returning
+// nil for a deleted state object, it returns the actual object with the deleted
+// flag set. This is needed by the state journal to revert to the correct s-
+// destructed object instead of wiping all knowledge about the state object.
+func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
+ // Prefer live objects if any is available
+ if obj := s.stateObjects[addr]; obj != nil {
+ return obj
+ }
+ // If no live objects are available, attempt to use snapshots
+ var (
+ data *types.StateAccount
+ err error
+ )
+ if s.snap != nil {
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now())
+ }
+ var acc *types.SlimAccount
+ if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil {
+ if acc == nil {
+ return nil
+ }
+ data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash
+ if len(data.CodeHash) == 0 {
+ data.CodeHash = emptyCodeHash
+ }
+ data.Root = common.BytesToHash(acc.Root)
+ if data.Root == (common.Hash{}) {
+ data.Root = emptyRoot
+ }
+ }
+ }
+ // If snapshot unavailable or reading from it failed, load from the database
+ if s.snap == nil || err != nil {
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now())
+ }
+ data, err = s.trie.GetAccount(addr)
+ if err != nil {
+ s.setError(err)
+ return nil
+ }
+ }
+ // Insert into the live set
+ obj := newObject(s, addr, data)
+ s.setStateObject(obj)
return obj
}
-func (self *StateDB) setStateObject(object *stateObject) {
- self.stateObjects[object.Address()] = object
+func (s *StateDB) setStateObject(object *stateObject) {
+ s.stateObjects[object.Address()] = object
}
-// Retrieve a state object or create a new state object if nil.
-func (self *StateDB) GetOrNewStateObject(addr common.Address) *stateObject {
- stateObject := self.getStateObject(addr)
+// GetOrNewStateObject retrieves a state object or create a new state object if nil.
+func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject {
+ stateObject := s.getStateObject(addr)
if stateObject == nil || stateObject.deleted {
- stateObject, _ = self.createObject(addr)
+ stateObject, _ = s.createObject(addr)
}
return stateObject
}
-// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly
-// state object cache iteration to find a handful of modified ones.
-func (self *StateDB) MarkStateObjectDirty(addr common.Address) {
- self.stateObjectsDirty[addr] = struct{}{}
-}
-
// createObject creates a new state object. If there is an existing account with
// the given address, it is overwritten and returned as the second return value.
-func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) {
- prev = self.getStateObject(addr)
- newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty)
+func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) {
+ prev = s.getStateObject(addr)
+ newobj = newObject(s, addr, &types.StateAccount{})
newobj.setNonce(0) // sets the object to dirty
if prev == nil {
- self.journal = append(self.journal, createObjectChange{account: &addr})
+ s.journal.append(createObjectChange{account: &addr})
} else {
- self.journal = append(self.journal, resetObjectChange{prev: prev})
+ s.journal.append(resetObjectChange{prev: prev})
}
- self.setStateObject(newobj)
+ s.setStateObject(newobj)
return newobj, prev
}
@@ -449,34 +546,43 @@ func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObjec
// CreateAccount is called during the EVM CREATE operation. The situation might arise that
// a contract does the following:
//
-// 1. sends funds to sha(account ++ (nonce + 1))
-// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
+// 1. sends funds to sha(account ++ (nonce + 1))
+// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
//
// Carrying over the balance ensures that Ether doesn't disappear.
-func (self *StateDB) CreateAccount(addr common.Address) {
- new, prev := self.createObject(addr)
+func (s *StateDB) CreateAccount(addr common.Address) {
+ new, prev := s.createObject(addr)
if prev != nil {
new.setBalance(prev.data.Balance)
}
}
-func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error {
- so := db.getStateObject(addr)
+func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error {
+ so := s.getStateObject(addr)
if so == nil {
return nil
}
+ tr := so.getTrie(s.db)
+ trieIt := tr.NodeIterator(nil)
+ it := trie.NewIterator(trieIt)
- // When iterating over the storage check the cache first
- for h, value := range so.cachedStorage {
- cb(h, value)
- }
-
- it := trie.NewIterator(so.getTrie(db.db).NodeIterator(nil))
for it.Next() {
- // ignore cached values
- key := common.BytesToHash(db.trie.GetKey(it.Key))
- if _, ok := so.cachedStorage[key]; !ok {
- cb(key, common.BytesToHash(it.Value))
+ key := common.BytesToHash(s.trie.GetKey(it.Key))
+ if value, dirty := so.dirtyStorage[key]; dirty {
+ if !cb(key, value) {
+ return nil
+ }
+ continue
+ }
+
+ if len(it.Value) > 0 {
+ _, content, _, err := rlp.Split(it.Value)
+ if err != nil {
+ return err
+ }
+ if !cb(key, common.BytesToHash(content)) {
+ return nil
+ }
}
}
return nil
@@ -484,81 +590,91 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common
// Copy creates a deep, independent copy of the state.
// Snapshots of the copied state cannot be applied to the copy.
-func (self *StateDB) Copy() *StateDB {
- self.lock.Lock()
- defer self.lock.Unlock()
+func (s *StateDB) Copy() *StateDB {
+ s.lock.Lock()
+ defer s.lock.Unlock()
// Copy all the basic fields, initialize the memory ones
state := &StateDB{
- db: self.db,
- trie: self.db.CopyTrie(self.trie),
- stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)),
- stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)),
- refund: self.refund,
- logs: make(map[common.Hash][]*types.Log, len(self.logs)),
- logSize: self.logSize,
+ db: s.db,
+ trie: s.db.CopyTrie(s.trie),
+ stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)),
+ stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)),
+ refund: s.refund,
+ logs: make(map[common.Hash][]*types.Log, len(s.logs)),
+ logSize: s.logSize,
preimages: make(map[common.Hash][]byte),
+ journal: newJournal(),
}
// Copy the dirty states, logs, and preimages
- for addr := range self.stateObjectsDirty {
- state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty)
- state.stateObjectsDirty[addr] = struct{}{}
+ for addr := range s.journal.dirties {
+ // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527),
+ // and in the Finalise-method, there is a case where an object is in the journal but not
+ // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for
+ // nil
+ if object, exist := s.stateObjects[addr]; exist {
+ // Even though the original object is dirty, we are not copying the journal,
+ // so we need to make sure that any side effect the journal would have caused
+ // during a commit (or similar op) is already applied to the copy.
+ state.stateObjects[addr] = object.deepCopy(state)
+ state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits
+ }
}
- for hash, logs := range self.logs {
+ for hash, logs := range s.logs {
state.logs[hash] = make([]*types.Log, len(logs))
copy(state.logs[hash], logs)
}
- for hash, preimage := range self.preimages {
+ for hash, preimage := range s.preimages {
state.preimages[hash] = preimage
}
return state
}
// Snapshot returns an identifier for the current revision of the state.
-func (self *StateDB) Snapshot() int {
- id := self.nextRevisionId
- self.nextRevisionId++
- self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)})
+func (s *StateDB) Snapshot() int {
+ id := s.nextRevisionId
+ s.nextRevisionId++
+ s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()})
return id
}
// RevertToSnapshot reverts all state changes made since the given revision.
-func (self *StateDB) RevertToSnapshot(revid int) {
+func (s *StateDB) RevertToSnapshot(revid int) {
// Find the snapshot in the stack of valid snapshots.
- idx := sort.Search(len(self.validRevisions), func(i int) bool {
- return self.validRevisions[i].id >= revid
+ idx := sort.Search(len(s.validRevisions), func(i int) bool {
+ return s.validRevisions[i].id >= revid
})
- if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid {
+ if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid {
panic(fmt.Errorf("revision id %v cannot be reverted", revid))
}
- snapshot := self.validRevisions[idx].journalIndex
-
- // Replay the journal to undo changes.
- for i := len(self.journal) - 1; i >= snapshot; i-- {
- self.journal[i].undo(self)
- }
- self.journal = self.journal[:snapshot]
+ snapshot := s.validRevisions[idx].journalIndex
- // Remove invalidated snapshots from the stack.
- self.validRevisions = self.validRevisions[:idx]
+ // Replay the journal to undo changes and remove invalidated snapshots
+ s.journal.revert(s, snapshot)
+ s.validRevisions = s.validRevisions[:idx]
}
// GetRefund returns the current value of the refund counter.
-func (self *StateDB) GetRefund() uint64 {
- return self.refund
+func (s *StateDB) GetRefund() uint64 {
+ return s.refund
}
// Finalise finalises the state by removing the self destructed objects
// and clears the journal as well as the refunds.
func (s *StateDB) Finalise(deleteEmptyObjects bool) {
- for addr := range s.stateObjectsDirty {
- stateObject := s.stateObjects[addr]
+ for addr := range s.journal.dirties {
+ stateObject, exist := s.stateObjects[addr]
+ if !exist {
+ continue
+ }
+
if stateObject.suicided || (deleteEmptyObjects && stateObject.empty()) {
s.deleteStateObject(stateObject)
} else {
stateObject.updateRoot(s.db)
s.updateStateObject(stateObject)
}
+ s.stateObjectsDirty[addr] = struct{}{}
}
// Invalidate journal because reverting across transactions is not allowed.
s.clearJournalAndRefund()
@@ -574,10 +690,10 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
// Prepare sets the current transaction hash and index and block hash which is
// used when the EVM emits new state logs.
-func (self *StateDB) Prepare(thash, bhash common.Hash, ti int) {
- self.thash = thash
- self.bhash = bhash
- self.txIndex = ti
+func (s *StateDB) Prepare(thash, bhash common.Hash, ti int) {
+ s.thash = thash
+ s.bhash = bhash
+ s.txIndex = ti
}
// DeleteSuicides flags the suicided objects for deletion so that it
@@ -602,7 +718,7 @@ func (s *StateDB) DeleteSuicides() {
}
func (s *StateDB) clearJournalAndRefund() {
- s.journal = nil
+ s.journal = newJournal()
s.validRevisions = s.validRevisions[:0]
s.refund = 0
}
@@ -611,6 +727,10 @@ func (s *StateDB) clearJournalAndRefund() {
func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) {
defer s.clearJournalAndRefund()
+ for addr := range s.journal.dirties {
+ s.stateObjectsDirty[addr] = struct{}{}
+ }
+
// Commit objects to the trie.
for addr, stateObject := range s.stateObjects {
_, isDirty := s.stateObjectsDirty[addr]
@@ -636,7 +756,7 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error)
}
// Write trie changes.
root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error {
- var account Account
+ var account types.StateAccount
if err := rlp.DecodeBytes(leaf, &account); err != nil {
return nil
}
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 865ee073b..085d8f7c2 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -20,7 +20,6 @@ import (
"bytes"
"encoding/binary"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math"
"math/big"
"math/rand"
@@ -29,6 +28,8 @@ import (
"testing"
"testing/quick"
+ "github.com/tomochain/tomochain/core/rawdb"
+
check "gopkg.in/check.v1"
"github.com/tomochain/tomochain/common"
@@ -40,7 +41,7 @@ import (
func TestUpdateLeaks(t *testing.T) {
// Create an empty state database
db := rawdb.NewMemoryDatabase()
- state, _ := New(common.Hash{}, NewDatabase(db))
+ state, _ := New(common.Hash{}, NewDatabase(db), nil)
// Update it with some accounts
for i := byte(0); i < 255; i++ {
@@ -70,8 +71,8 @@ func TestIntermediateLeaks(t *testing.T) {
// Create two state databases, one transitioning to the final state, the other final from the beginning
transDb := rawdb.NewMemoryDatabase()
finalDb := rawdb.NewMemoryDatabase()
- transState, _ := New(common.Hash{}, NewDatabase(transDb))
- finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
+ transState, _ := New(common.Hash{}, NewDatabase(transDb), nil)
+ finalState, _ := New(common.Hash{}, NewDatabase(finalDb), nil)
modify := func(state *StateDB, addr common.Address, i, tweak byte) {
state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
@@ -129,7 +130,7 @@ func TestIntermediateLeaks(t *testing.T) {
func TestCopy(t *testing.T) {
// Create a random state test to copy and modify "independently"
db := rawdb.NewMemoryDatabase()
- orig, _ := New(common.Hash{}, NewDatabase(db))
+ orig, _ := New(common.Hash{}, NewDatabase(db), nil)
for i := byte(0); i < 255; i++ {
obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
@@ -341,7 +342,7 @@ func (test *snapshotTest) run() bool {
// Run all actions and create snapshots.
var (
db = rawdb.NewMemoryDatabase()
- state, _ = New(common.Hash{}, NewDatabase(db))
+ state, _ = New(common.Hash{}, NewDatabase(db), nil)
snapshotRevs = make([]int, len(test.snapshots))
sindex = 0
)
@@ -355,7 +356,7 @@ func (test *snapshotTest) run() bool {
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
- checkstate, _ := New(common.Hash{}, state.Database())
+ checkstate, _ := New(common.Hash{}, state.Database(), nil)
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
@@ -415,15 +416,21 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
func (s *StateSuite) TestTouchDelete(c *check.C) {
s.state.GetOrNewStateObject(common.Address{})
root, _ := s.state.Commit(false)
- s.state.Reset(root)
+ s.state, _ = New(root, s.state.db, s.state.snaps)
snapshot := s.state.Snapshot()
s.state.AddBalance(common.Address{}, new(big.Int))
- if len(s.state.stateObjectsDirty) != 1 {
+ if len(s.state.journal.dirties) != 1 {
+ c.Fatal("expected one dirty state object")
+ }
+ if s.state.journal.dirties[common.Address{}] != 1 {
c.Fatal("expected one dirty state object")
}
s.state.RevertToSnapshot(snapshot)
- if len(s.state.stateObjectsDirty) != 0 {
+ if len(s.state.journal.dirties) != 0 {
+ c.Fatal("expected no dirty state object")
+ }
+ if s.state.journal.dirties[common.Address{}] != 0 {
c.Fatal("expected no dirty state object")
}
}
diff --git a/core/state/sync.go b/core/state/sync.go
index 95f29b287..e26281c7d 100644
--- a/core/state/sync.go
+++ b/core/state/sync.go
@@ -20,6 +20,7 @@ import (
"bytes"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/trie"
@@ -29,7 +30,7 @@ import (
func NewStateSync(root common.Hash, database ethdb.KeyValueReader, bloom *trie.SyncBloom) *trie.Sync {
var syncer *trie.Sync
callback := func(leaf []byte, parent common.Hash) error {
- var obj Account
+ var obj types.StateAccount
if err := rlp.Decode(bytes.NewReader(leaf), &obj); err != nil {
return err
}
diff --git a/core/state/sync_test.go b/core/state/sync_test.go
index 19fefb654..69c6491f0 100644
--- a/core/state/sync_test.go
+++ b/core/state/sync_test.go
@@ -41,7 +41,7 @@ type testAccount struct {
func makeTestState() (Database, common.Hash, []*testAccount) {
// Create an empty state
db := NewDatabase(rawdb.NewMemoryDatabase())
- state, _ := New(common.Hash{}, db)
+ state, _ := New(common.Hash{}, db, nil)
// Fill it with some arbitrary data
accounts := []*testAccount{}
@@ -72,7 +72,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) {
// account array.
func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) {
// Check root availability and state contents
- state, err := New(root, NewDatabase(db))
+ state, err := New(root, NewDatabase(db), nil)
if err != nil {
t.Fatalf("failed to create state trie at %x: %v", root, err)
}
@@ -113,7 +113,7 @@ func checkStateConsistency(db ethdb.Database, root common.Hash) error {
if _, err := db.Get(root.Bytes()); err != nil {
return nil // Consider a non existent state consistent.
}
- state, err := New(root, NewDatabase(db))
+ state, err := New(root, NewDatabase(db), nil)
if err != nil {
return err
}
diff --git a/core/state_processor.go b/core/state_processor.go
index 035c15f2b..b0aeb04b3 100644
--- a/core/state_processor.go
+++ b/core/state_processor.go
@@ -18,9 +18,6 @@ package core
import (
"fmt"
-
- "github.com/tomochain/tomochain/tomox/tradingstate"
- "github.com/tomochain/tomochain/log"
"math/big"
"runtime"
"strings"
@@ -33,7 +30,9 @@ import (
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
)
// StateProcessor is a basic Processor, which takes care of transitioning
@@ -243,7 +242,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]*
balanceFee = value
}
}
- msg, err := tx.AsMessage(types.MakeSigner(config, header.Number), balanceFee, header.Number)
+ msg, err := TransactionToMessage(tx, types.MakeSigner(config, header.Number), balanceFee, header.Number)
if err != nil {
return nil, 0, err, false
}
@@ -391,7 +390,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]*
blockMap[9147453] = "0x3538a544021c07869c16b764424c5987409cba48"
blockMap[9147459] = "0xe187cf86c2274b1f16e8225a7da9a75aba4f1f5f"
- addrFrom := msg.From().Hex()
+ addrFrom := msg.From.Hex()
currentBlockNumber := header.Number.Int64()
if addr, ok := blockMap[currentBlockNumber]; ok {
@@ -408,7 +407,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]*
// End Bypass blacklist address
// Apply the transaction to the current state (included in the env)
- _, gas, failed, err := ApplyMessage(vmenv, msg, gp, coinbaseOwner)
+ result, err := ApplyMessage(vmenv, msg, gp, coinbaseOwner)
if err != nil {
return nil, 0, err, false
@@ -420,24 +419,24 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]*
} else {
root = statedb.IntermediateRoot(config.IsEIP158(header.Number)).Bytes()
}
- *usedGas += gas
+ *usedGas += result.UsedGas
// Create a new receipt for the transaction, storing the intermediate root and gas used by the tx
// based on the eip phase, we're passing wether the root touch-delete accounts.
- receipt := types.NewReceipt(root, failed, *usedGas)
+ receipt := types.NewReceipt(root, result.Failed(), *usedGas)
receipt.TxHash = tx.Hash()
- receipt.GasUsed = gas
+ receipt.GasUsed = result.UsedGas
// if the transaction created a contract, store the creation address in the receipt.
- if msg.To() == nil {
+ if msg.To == nil {
receipt.ContractAddress = crypto.CreateAddress(vmenv.Context.Origin, tx.Nonce())
}
// Set the receipt logs and create a bloom for filtering
receipt.Logs = statedb.GetLogs(tx.Hash())
receipt.Bloom = types.CreateBloom(types.Receipts{receipt})
- if balanceFee != nil && failed {
- state.PayFeeWithTRC21TxFail(statedb, msg.From(), *tx.To())
+ if balanceFee != nil && result.Failed() {
+ state.PayFeeWithTRC21TxFail(statedb, msg.From, *tx.To())
}
- return receipt, gas, err, balanceFee != nil
+ return receipt, result.UsedGas, err, balanceFee != nil
}
func ApplySignTransaction(config *params.ChainConfig, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64) (*types.Receipt, uint64, error, bool) {
diff --git a/core/state_transition.go b/core/state_transition.go
index 9a2b07924..d6b63f4ba 100644
--- a/core/state_transition.go
+++ b/core/state_transition.go
@@ -18,12 +18,13 @@ package core
import (
"errors"
+ "fmt"
"math"
"math/big"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
- "github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/params"
)
@@ -42,15 +43,17 @@ The state transitioning model does all all the necessary work to work out a vali
3) Create a new state object if the recipient is \0*32
4) Value transfer
== If contract creation ==
- 4a) Attempt to run transaction data
- 4b) If valid, use result as code for the new state object
+
+ 4a) Attempt to run transaction data
+ 4b) If valid, use result as code for the new state object
+
== end ==
5) Run Script section
6) Derive new state root
*/
type StateTransition struct {
gp *GasPool
- msg Message
+ msg *Message
gas uint64
gasPrice *big.Int
initialGas uint64
@@ -60,20 +63,56 @@ type StateTransition struct {
evm *vm.EVM
}
-// Message represents a message sent to a contract.
-type Message interface {
- From() common.Address
- //FromFrontier() (common.Address, error)
- To() *common.Address
+// A Message contains the data derived from a single transaction that is relevant to state
+// processing.
+type Message struct {
+ To *common.Address
+ From common.Address
+ Nonce uint64
+ Value *big.Int
+ GasLimit uint64
+ GasPrice *big.Int
+ Data []byte
+ BalanceTokenFee *big.Int
+
+ // When SkipAccountChecks is true, the message nonce is not checked against the
+ // account nonce in state. It also disables checking that the sender is an EOA.
+ // This field will be set to true for operations like RPC eth_call.
+ SkipAccountChecks bool
+}
+
+// message no matter the execution itself is successful or not.
+type ExecutionResult struct {
+ UsedGas uint64 // Total used gas but include the refunded gas
+ Err error // Any error encountered during the execution(listed in core/vm/errors.go)
+ ReturnData []byte // Returned data from evm(function result or data supplied with revert opcode)
+}
+
+// Unwrap returns the internal evm error which allows us for further
+// analysis outside.
+func (result *ExecutionResult) Unwrap() error {
+ return result.Err
+}
- GasPrice() *big.Int
- Gas() uint64
- Value() *big.Int
+// Failed returns the indicator whether the execution is successful or not
+func (result *ExecutionResult) Failed() bool { return result.Err != nil }
- Nonce() uint64
- CheckNonce() bool
- Data() []byte
- BalanceTokenFee() *big.Int
+// Return is a helper function to help caller distinguish between revert reason
+// and function return. Return returns the data after execution if no error occurs.
+func (result *ExecutionResult) Return() []byte {
+ if result.Err != nil {
+ return nil
+ }
+ return common.CopyBytes(result.ReturnData)
+}
+
+// Revert returns the concrete revert reason if the execution is aborted by `REVERT`
+// opcode. Note the reason can be nil if no data supplied with revert opcode.
+func (result *ExecutionResult) Revert() []byte {
+ if result.Err != vm.ErrExecutionReverted {
+ return nil
+ }
+ return common.CopyBytes(result.ReturnData)
}
// IntrinsicGas computes the 'intrinsic gas' for a message with the given data.
@@ -96,13 +135,13 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error)
}
// Make sure we don't exceed uint64 for all data combinations
if (math.MaxUint64-gas)/params.TxDataNonZeroGas < nz {
- return 0, vm.ErrOutOfGas
+ return 0, ErrGasUintOverflow
}
gas += nz * params.TxDataNonZeroGas
z := uint64(len(data)) - nz
if (math.MaxUint64-gas)/params.TxDataZeroGas < z {
- return 0, vm.ErrOutOfGas
+ return 0, ErrGasUintOverflow
}
gas += z * params.TxDataZeroGas
}
@@ -110,18 +149,42 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error)
}
// NewStateTransition initialises and returns a new state transition object.
-func NewStateTransition(evm *vm.EVM, msg Message, gp *GasPool) *StateTransition {
+func NewStateTransition(evm *vm.EVM, msg *Message, gp *GasPool) *StateTransition {
return &StateTransition{
gp: gp,
evm: evm,
msg: msg,
- gasPrice: msg.GasPrice(),
- value: msg.Value(),
- data: msg.Data(),
+ gasPrice: msg.GasPrice,
+ value: msg.Value,
+ data: msg.Data,
state: evm.StateDB,
}
}
+// TransactionToMessage converts a transaction into a Message.
+func TransactionToMessage(tx *types.Transaction, s types.Signer, balanceFee *big.Int, number *big.Int) (*Message, error) {
+ msg := &Message{
+ Nonce: tx.Nonce(),
+ GasLimit: tx.Gas(),
+ GasPrice: new(big.Int).Set(tx.GasPrice()),
+ To: tx.To(),
+ Value: tx.Value(),
+ Data: tx.Data(),
+ SkipAccountChecks: false,
+ BalanceTokenFee: balanceFee,
+ }
+ var err error
+ msg.From, err = types.Sender(s, tx)
+ if balanceFee != nil {
+ if number.Cmp(common.TIPTRC21Fee) > 0 {
+ msg.GasPrice = common.TRC21GasPrice
+ } else {
+ msg.GasPrice = common.TRC21GasPriceBefore
+ }
+ }
+ return msg, err
+}
+
// ApplyMessage computes the new state by applying the given message
// against the old state within the environment.
//
@@ -129,12 +192,12 @@ func NewStateTransition(evm *vm.EVM, msg Message, gp *GasPool) *StateTransition
// the gas used (which includes gas refunds) and an error if it failed. An error always
// indicates a core error meaning that the message would always fail for that particular
// state and would never be accepted within a block.
-func ApplyMessage(evm *vm.EVM, msg Message, gp *GasPool, owner common.Address) ([]byte, uint64, bool, error) {
+func ApplyMessage(evm *vm.EVM, msg *Message, gp *GasPool, owner common.Address) (*ExecutionResult, error) {
return NewStateTransition(evm, msg, gp).TransitionDb(owner)
}
func (st *StateTransition) from() vm.AccountRef {
- f := st.msg.From()
+ f := st.msg.From
if !st.state.Exist(f) {
st.state.CreateAccount(f)
}
@@ -142,14 +205,14 @@ func (st *StateTransition) from() vm.AccountRef {
}
func (st *StateTransition) balanceTokenFee() *big.Int {
- return st.msg.BalanceTokenFee()
+ return st.msg.BalanceTokenFee
}
func (st *StateTransition) to() vm.AccountRef {
if st.msg == nil {
return vm.AccountRef{}
}
- to := st.msg.To()
+ to := st.msg.To
if to == nil {
return vm.AccountRef{} // contract creation
}
@@ -161,22 +224,13 @@ func (st *StateTransition) to() vm.AccountRef {
return reference
}
-func (st *StateTransition) useGas(amount uint64) error {
- if st.gas < amount {
- return vm.ErrOutOfGas
- }
- st.gas -= amount
-
- return nil
-}
-
func (st *StateTransition) buyGas() error {
var (
state = st.state
balanceTokenFee = st.balanceTokenFee()
from = st.from()
)
- mgval := new(big.Int).Mul(new(big.Int).SetUint64(st.msg.Gas()), st.gasPrice)
+ mgval := new(big.Int).Mul(new(big.Int).SetUint64(st.msg.GasLimit), st.gasPrice)
if balanceTokenFee == nil {
if state.GetBalance(from.Address()).Cmp(mgval) < 0 {
return errInsufficientBalanceForGas
@@ -184,12 +238,12 @@ func (st *StateTransition) buyGas() error {
} else if balanceTokenFee.Cmp(mgval) < 0 {
return errInsufficientBalanceForGas
}
- if err := st.gp.SubGas(st.msg.Gas()); err != nil {
+ if err := st.gp.SubGas(st.msg.GasLimit); err != nil {
return err
}
- st.gas += st.msg.Gas()
+ st.gas += st.msg.GasLimit
- st.initialGas = st.msg.Gas()
+ st.initialGas = st.msg.GasLimit
if balanceTokenFee == nil {
state.SubBalance(from.Address(), mgval)
}
@@ -197,72 +251,95 @@ func (st *StateTransition) buyGas() error {
}
func (st *StateTransition) preCheck() error {
+ // Only check transactions that are not fake
msg := st.msg
- sender := st.from()
-
- // Make sure this transaction's nonce is correct
- if msg.CheckNonce() {
- nonce := st.state.GetNonce(sender.Address())
- if nonce < msg.Nonce() {
- return ErrNonceTooHigh
- } else if nonce > msg.Nonce() {
- return ErrNonceTooLow
+ if !msg.SkipAccountChecks {
+ // Make sure this transaction's nonce is correct.
+ stNonce := st.state.GetNonce(msg.From)
+ if msgNonce := msg.Nonce; stNonce < msgNonce {
+ return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooHigh,
+ msg.From.Hex(), msgNonce, stNonce)
+ } else if stNonce > msgNonce {
+ return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooLow,
+ msg.From.Hex(), msgNonce, stNonce)
+ } else if stNonce+1 < stNonce {
+ return fmt.Errorf("%w: address %v, nonce: %d", ErrNonceMax,
+ msg.From.Hex(), stNonce)
+ }
+ // Make sure the sender is an EOA
+ codeHash := st.state.GetCodeHash(msg.From)
+ if codeHash != (common.Hash{}) && codeHash != types.EmptyCodeHash {
+ return fmt.Errorf("%w: address %v, codehash: %s", ErrSenderNoEOA,
+ msg.From.Hex(), codeHash)
}
}
+
return st.buyGas()
}
// TransitionDb will transition the state by applying the current message and
-// returning the result including the the used gas. It returns an error if it
-// failed. An error indicates a consensus issue.
-func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedGas uint64, failed bool, err error) {
- if err = st.preCheck(); err != nil {
- return
+// returning the evm execution result with following fields.
+//
+// - used gas:
+// total gas used (including gas being refunded)
+// - returndata:
+// the returned data from evm
+// - concrete execution error:
+// various **EVM** error which aborts the execution,
+// e.g. ErrOutOfGas, ErrExecutionReverted
+//
+// However if any consensus issue encountered, return the error directly with
+// nil evm execution result.
+func (st *StateTransition) TransitionDb(owner common.Address) (*ExecutionResult, error) {
+ // First check this message satisfies all consensus rules before
+ // applying the message. The rules include these clauses
+ //
+ // 1. the nonce of the message caller is correct
+ // 2. caller has enough balance to cover transaction fee(gaslimit * gasprice)
+ // 3. the amount of gas required is available in the block
+ // 4. the purchased gas is enough to cover intrinsic usage
+ // 5. there is no overflow when calculating intrinsic gas
+ // 6. caller has enough balance to cover asset transfer for **topmost** call
+
+ // Check clauses 1-3, buy gas if everything is correct
+ if err := st.preCheck(); err != nil {
+ return nil, err
}
msg := st.msg
sender := st.from() // err checked in preCheck
homestead := st.evm.ChainConfig().IsHomestead(st.evm.BlockNumber)
- contractCreation := msg.To() == nil
+ contractCreation := msg.To == nil
- // Pay intrinsic gas
+ // Check clauses 4-5, substract intrinsic gas if everything is correct
gas, err := IntrinsicGas(st.data, contractCreation, homestead)
if err != nil {
- return nil, 0, false, err
+ return nil, err
}
- if err = st.useGas(gas); err != nil {
- return nil, 0, false, err
+ if st.gas < gas {
+ return nil, ErrIntrinsicGas
+ }
+ st.gas -= gas
+
+ // check clause 6
+ if msg.Value.Sign() > 0 && !st.evm.CanTransfer(st.state, msg.From, msg.Value) {
+ return nil, ErrInsufficientFundsForTransfer
}
var (
- evm = st.evm
- // vm errors do not effect consensus and are therefor
- // not assigned to err, except for insufficient balance
- // error.
+ ret []byte
vmerr error
)
// for debugging purpose
// TODO: clean it after fixing the issue https://github.com/tomochain/tomochain/issues/401
- var contractAction string
nonce := uint64(1)
if contractCreation {
- ret, _, st.gas, vmerr = evm.Create(sender, st.data, st.gas, st.value)
- contractAction = "contract creation"
+ ret, _, st.gas, vmerr = st.evm.Create(sender, st.data, st.gas, st.value)
} else {
// Increment the nonce for the next transaction
nonce = st.state.GetNonce(sender.Address()) + 1
st.state.SetNonce(sender.Address(), nonce)
- ret, st.gas, vmerr = evm.Call(sender, st.to().Address(), st.data, st.gas, st.value)
- contractAction = "contract call"
- }
- if vmerr != nil {
- log.Debug("VM returned with error", "action", contractAction, "contract address", st.to().Address(), "gas", st.gas, "gasPrice", st.gasPrice, "nonce", nonce, "err", vmerr)
- // The only possible consensus-error would be if there wasn't
- // sufficient balance to make the transfer happen. The first
- // balance transfer may never fail.
- if vmerr == vm.ErrInsufficientBalance {
- return nil, 0, false, vmerr
- }
+ ret, st.gas, vmerr = st.evm.Call(sender, st.to().Address(), st.data, st.gas, st.value)
}
st.refundGas()
@@ -274,7 +351,11 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG
st.state.AddBalance(st.evm.Coinbase, new(big.Int).Mul(new(big.Int).SetUint64(st.gasUsed()), st.gasPrice))
}
- return ret, st.gasUsed(), vmerr != nil, err
+ return &ExecutionResult{
+ UsedGas: st.gasUsed(),
+ Err: vmerr,
+ ReturnData: ret,
+ }, err
}
func (st *StateTransition) refundGas() {
diff --git a/core/token_validator.go b/core/token_validator.go
index 485ff05c5..61324f272 100644
--- a/core/token_validator.go
+++ b/core/token_validator.go
@@ -17,7 +17,11 @@ package core
import (
"fmt"
- ethereum "github.com/tomochain/tomochain"
+ "math/big"
+ "math/rand"
+ "strings"
+
+ tomochain "github.com/tomochain/tomochain"
"github.com/tomochain/tomochain/accounts/abi"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
@@ -25,9 +29,6 @@ import (
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/log"
- "math/big"
- "math/rand"
- "strings"
)
const (
@@ -38,7 +39,7 @@ const (
// callmsg implements core.Message to allow passing it as a transaction simulator.
type callmsg struct {
- ethereum.CallMsg
+ tomochain.CallMsg
}
func (m callmsg) From() common.Address { return m.CallMsg.From }
@@ -52,7 +53,7 @@ func (m callmsg) Data() []byte { return m.CallMsg.Data }
func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee }
type SimulatedBackend interface {
- CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error)
+ CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error)
}
// GetTokenAbi return token abi
@@ -72,22 +73,22 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA
}
fakeCaller := common.HexToAddress("0x0000000000000000000000000000000000000001")
statedb.SetBalance(fakeCaller, common.BasePrice)
- msg := ethereum.CallMsg{To: &contractAddr, Data: input, From: fakeCaller}
+ msg := tomochain.CallMsg{To: &contractAddr, Data: input, From: fakeCaller}
result, err := CallContractWithState(msg, chain, statedb)
if err != nil {
return nil, err
}
var unpackResult interface{}
- err = abi.Unpack(&unpackResult, method, result)
+ err = abi.UnpackIntoInterface(&unpackResult, method, result)
if err != nil {
return nil, err
}
return unpackResult, nil
}
-//FIXME: please use copyState for this function
+// FIXME: please use copyState for this function
// CallContractWithState executes a contract call at the given state.
-func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) {
+func CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) {
// Ensure message is initialized properly.
call.GasPrice = big.NewInt(0)
@@ -98,11 +99,19 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext,
call.Value = new(big.Int)
}
// Execute the call.
- msg := callmsg{call}
+ msg := &Message{
+ To: call.To,
+ From: call.From,
+ Value: call.Value,
+ GasLimit: call.Gas,
+ GasPrice: call.GasPrice,
+ Data: call.Data,
+ SkipAccountChecks: false,
+ }
feeCapacity := state.GetTRC21FeeCapacityFromState(statedb)
- if msg.To() != nil {
- if value, ok := feeCapacity[*msg.To()]; ok {
- msg.CallMsg.BalanceTokenFee = value
+ if msg.To != nil {
+ if value, ok := feeCapacity[*msg.To]; ok {
+ msg.BalanceTokenFee = value
}
}
evmContext := NewEVMContext(msg, chain.CurrentHeader(), chain, nil)
@@ -111,11 +120,11 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext,
vmenv := vm.NewEVM(evmContext, statedb, nil, chain.Config(), vm.Config{})
gaspool := new(GasPool).AddGas(1000000)
owner := common.Address{}
- rval, _, _, err := NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner)
+ result, err := NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner)
if err != nil {
return nil, err
}
- return rval, err
+ return result.Return(), err
}
// make sure that balance of token is at slot 0
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index 4c6a31111..adfd81314 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -19,8 +19,6 @@ package core
import (
"crypto/ecdsa"
"fmt"
- "github.com/tomochain/tomochain/consensus"
- "github.com/tomochain/tomochain/core/rawdb"
"io/ioutil"
"math/big"
"math/rand"
@@ -29,11 +27,14 @@ import (
"time"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/consensus"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/trie"
)
// testTxPoolConfig is a transaction pool configuration without stateful disk
@@ -70,7 +71,7 @@ func (bc *testBlockChain) Config() *params.ChainConfig {
func (bc *testBlockChain) CurrentBlock() *types.Block {
return types.NewBlock(&types.Header{
GasLimit: bc.gasLimit,
- }, nil, nil, nil)
+ }, nil, nil, nil, new(trie.StackTrie))
}
func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block {
@@ -96,7 +97,7 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec
func setupTxPool() (*TxPool, *ecdsa.PrivateKey) {
diskdb := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
key, _ := crypto.GenerateKey()
@@ -176,7 +177,7 @@ func (c *testChain) State() (*state.StateDB, error) {
stdb := c.statedb
if *c.trigger {
db := rawdb.NewMemoryDatabase()
- c.statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
+ c.statedb, _ = state.New(common.Hash{}, state.NewDatabase(db), nil)
// simulate that the new head block included tx0 and tx1
c.statedb.SetNonce(c.address, 2)
c.statedb.SetBalance(c.address, new(big.Int).SetUint64(params.Ether))
@@ -195,7 +196,7 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) {
db = rawdb.NewMemoryDatabase()
key, _ = crypto.GenerateKey()
address = crypto.PubkeyToAddress(key.PublicKey)
- statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ = state.New(common.Hash{}, state.NewDatabase(db), nil)
trigger = false
)
@@ -355,7 +356,7 @@ func TestTransactionChainFork(t *testing.T) {
addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() {
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
statedb.AddBalance(addr, big.NewInt(100000000000000))
pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)}
@@ -385,7 +386,7 @@ func TestTransactionDoubleNonce(t *testing.T) {
addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() {
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
statedb.AddBalance(addr, big.NewInt(100000000000000))
pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)}
@@ -576,7 +577,7 @@ func TestTransactionPostponing(t *testing.T) {
// Create the pool to test the postponing with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain)
@@ -792,7 +793,7 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) {
// Create the pool to test the limit enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
config := testTxPoolConfig
@@ -872,8 +873,10 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) {
//
// This logic should not hold for local transactions, unless the local tracking
// mechanism is disabled.
-func TestTransactionQueueTimeLimiting(t *testing.T) { testTransactionQueueTimeLimiting(t, false) }
-func TestTransactionQueueTimeLimitingNoLocals(t *testing.T) { testTransactionQueueTimeLimiting(t, true) }
+func TestTransactionQueueTimeLimiting(t *testing.T) { testTransactionQueueTimeLimiting(t, false) }
+func TestTransactionQueueTimeLimitingNoLocals(t *testing.T) {
+ testTransactionQueueTimeLimiting(t, true)
+}
func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) {
common.MinGasPrice = big.NewInt(0)
@@ -883,7 +886,7 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) {
// Create the pool to test the non-expiration enforcement
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
config := testTxPoolConfig
@@ -981,8 +984,10 @@ func TestTransactionPendingLimiting(t *testing.T) {
// Tests that the transaction limits are enforced the same way irrelevant whether
// the transactions are added one by one or in batches.
-func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) }
-func TestTransactionPendingLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 0) }
+func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) }
+func TestTransactionPendingLimitingEquivalency(t *testing.T) {
+ testTransactionLimitingEquivalency(t, 0)
+}
func testTransactionLimitingEquivalency(t *testing.T, origin uint64) {
t.Parallel()
@@ -1038,7 +1043,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) {
// Create the pool to test the limit enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
config := testTxPoolConfig
@@ -1085,7 +1090,7 @@ func TestTransactionCapClearsFromAll(t *testing.T) {
// Create the pool to test the limit enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
config := testTxPoolConfig
@@ -1120,7 +1125,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) {
// Create the pool to test the limit enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
config := testTxPoolConfig
@@ -1170,7 +1175,7 @@ func TestTransactionPoolRepricing(t *testing.T) {
// Create the pool to test the pricing enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain)
@@ -1292,7 +1297,7 @@ func TestTransactionPoolRepricingKeepsLocals(t *testing.T) {
// Create the pool to test the pricing enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain)
@@ -1355,7 +1360,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) {
// Create the pool to test the pricing enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
config := testTxPoolConfig
@@ -1457,7 +1462,7 @@ func TestTransactionReplacement(t *testing.T) {
// Create the pool to test the pricing enforcement with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain)
@@ -1552,7 +1557,7 @@ func testTransactionJournaling(t *testing.T, nolocals bool) {
// Create the original pool to inject transaction into the journal
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
config := testTxPoolConfig
@@ -1651,7 +1656,7 @@ func TestTransactionStatusCheck(t *testing.T) {
// Create the pool to test the status retrievals with
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain)
diff --git a/core/types/block.go b/core/types/block.go
index a055ced14..66baecf2c 100644
--- a/core/types/block.go
+++ b/core/types/block.go
@@ -33,11 +33,6 @@ import (
"github.com/tomochain/tomochain/rlp"
)
-var (
- EmptyRootHash = DeriveSha(Transactions{})
- EmptyUncleHash = CalcUncleHash(nil)
-)
-
// A BlockNonce is a 64-bit hash which proves (combined with the
// mix-hash) that a sufficient amount of computation has been carried
// out on a block.
@@ -225,14 +220,14 @@ type storageblock struct {
// The values of TxHash, UncleHash, ReceiptHash and Bloom in header
// are ignored and set to values derived from the given txs, uncles
// and receipts.
-func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt) *Block {
+func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt, hasher Hasher) *Block {
b := &Block{header: CopyHeader(header), td: new(big.Int)}
// TODO: panic if len(txs) != len(receipts)
if len(txs) == 0 {
b.header.TxHash = EmptyRootHash
} else {
- b.header.TxHash = DeriveSha(Transactions(txs))
+ b.header.TxHash = DeriveSha(Transactions(txs), hasher)
b.transactions = make(Transactions, len(txs))
copy(b.transactions, txs)
}
@@ -240,7 +235,7 @@ func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*
if len(receipts) == 0 {
b.header.ReceiptHash = EmptyRootHash
} else {
- b.header.ReceiptHash = DeriveSha(Receipts(receipts))
+ b.header.ReceiptHash = DeriveSha(Receipts(receipts), hasher)
b.header.Bloom = CreateBloom(receipts)
}
diff --git a/core/types/block_test.go b/core/types/block_test.go
index 9b78b653c..e93ae02de 100644
--- a/core/types/block_test.go
+++ b/core/types/block_test.go
@@ -17,13 +17,15 @@
package types
import (
+ "bytes"
+ "hash"
"math/big"
+ "reflect"
"testing"
- "bytes"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/rlp"
- "reflect"
+ "golang.org/x/crypto/sha3"
)
// from bcValidBlockTest.json, "SimpleTx"
@@ -59,3 +61,38 @@ func TestBlockEncoding(t *testing.T) {
t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc)
}
}
+
+func TestUncleHash(t *testing.T) {
+ uncles := make([]*Header, 0)
+ h := CalcUncleHash(uncles)
+ exp := common.HexToHash("1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347")
+ if h != exp {
+ t.Fatalf("empty uncle hash is wrong, got %x != %x", h, exp)
+ }
+}
+
+var benchBuffer = bytes.NewBuffer(make([]byte, 0, 32000))
+
+// testHasher is the helper tool for transaction/receipt list hashing.
+// The original hasher is trie, in order to get rid of import cycle,
+// use the testing hasher instead.
+type testHasher struct {
+ hasher hash.Hash
+}
+
+func newHasher() *testHasher {
+ return &testHasher{hasher: sha3.NewLegacyKeccak256()}
+}
+
+func (h *testHasher) Reset() {
+ h.hasher.Reset()
+}
+
+func (h *testHasher) Update(key, val []byte) {
+ h.hasher.Write(key)
+ h.hasher.Write(val)
+}
+
+func (h *testHasher) Hash() common.Hash {
+ return common.BytesToHash(h.hasher.Sum(nil))
+}
diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go
index 2731c39cb..210ee26e0 100644
--- a/core/types/derive_sha.go
+++ b/core/types/derive_sha.go
@@ -21,21 +21,58 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/rlp"
- "github.com/tomochain/tomochain/trie"
)
+// DerivableList is the interface which can derive the hash.
type DerivableList interface {
Len() int
- GetRlp(i int) []byte
+ EncodeIndex(int, *bytes.Buffer)
}
-func DeriveSha(list DerivableList) common.Hash {
- keybuf := new(bytes.Buffer)
- trie := new(trie.Trie)
- for i := 0; i < list.Len(); i++ {
- keybuf.Reset()
- rlp.Encode(keybuf, uint(i))
- trie.Update(keybuf.Bytes(), list.GetRlp(i))
+// Hasher is the tool used to calculate the hash of derivable list.
+type Hasher interface {
+ Reset()
+ Update([]byte, []byte) error
+ Hash() common.Hash
+}
+
+func encodeForDerive(list DerivableList, i int, buf *bytes.Buffer) []byte {
+ buf.Reset()
+ list.EncodeIndex(i, buf)
+ // It's really unfortunate that we need to do perform this copy.
+ // StackTrie holds onto the values until Hash is called, so the values
+ // written to it must not alias.
+ return common.CopyBytes(buf.Bytes())
+}
+
+// DeriveSha creates the tree hashes of transactions, receipts, and withdrawals in a block header.
+func DeriveSha(list DerivableList, hasher Hasher) common.Hash {
+ hasher.Reset()
+
+ valueBuf := encodeBufferPool.Get().(*bytes.Buffer)
+ defer encodeBufferPool.Put(valueBuf)
+
+ // StackTrie requires values to be inserted in increasing hash order, which is not the
+ // order that `list` provides hashes in. This insertion sequence ensures that the
+ // order is correct.
+ //
+ // The error returned by hasher is omitted because hasher will produce an incorrect
+ // hash in case any error occurs.
+ var indexBuf []byte
+ for i := 1; i < list.Len() && i <= 0x7f; i++ {
+ indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
+ value := encodeForDerive(list, i, valueBuf)
+ hasher.Update(indexBuf, value)
+ }
+ if list.Len() > 0 {
+ indexBuf = rlp.AppendUint64(indexBuf[:0], 0)
+ value := encodeForDerive(list, 0, valueBuf)
+ hasher.Update(indexBuf, value)
+ }
+ for i := 0x80; i < list.Len(); i++ {
+ indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
+ value := encodeForDerive(list, i, valueBuf)
+ hasher.Update(indexBuf, value)
}
- return trie.Hash()
+ return hasher.Hash()
}
diff --git a/core/types/gen_header_rlp.go b/core/types/gen_header_rlp.go
new file mode 100644
index 000000000..1422cf6b1
--- /dev/null
+++ b/core/types/gen_header_rlp.go
@@ -0,0 +1,58 @@
+// Code generated by rlpgen. DO NOT EDIT.
+
+//go:build !norlpgen
+// +build !norlpgen
+
+package types
+
+import (
+ "io"
+
+ "github.com/tomochain/tomochain/rlp"
+)
+
+func (obj *Header) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ w.WriteBytes(obj.ParentHash[:])
+ w.WriteBytes(obj.UncleHash[:])
+ w.WriteBytes(obj.Coinbase[:])
+ w.WriteBytes(obj.Root[:])
+ w.WriteBytes(obj.TxHash[:])
+ w.WriteBytes(obj.ReceiptHash[:])
+ w.WriteBytes(obj.Bloom[:])
+ if obj.Difficulty == nil {
+ w.Write(rlp.EmptyString)
+ } else {
+ if obj.Difficulty.Sign() == -1 {
+ return rlp.ErrNegativeBigInt
+ }
+ w.WriteBigInt(obj.Difficulty)
+ }
+ if obj.Number == nil {
+ w.Write(rlp.EmptyString)
+ } else {
+ if obj.Number.Sign() == -1 {
+ return rlp.ErrNegativeBigInt
+ }
+ w.WriteBigInt(obj.Number)
+ }
+ w.WriteUint64(obj.GasLimit)
+ w.WriteUint64(obj.GasUsed)
+ if obj.Time == nil {
+ w.Write(rlp.EmptyString)
+ } else {
+ if obj.Time.Sign() == -1 {
+ return rlp.ErrNegativeBigInt
+ }
+ w.WriteBigInt(obj.Time)
+ }
+ w.WriteBytes(obj.Extra)
+ w.WriteBytes(obj.MixDigest[:])
+ w.WriteBytes(obj.Nonce[:])
+ w.WriteBytes(obj.Validators)
+ w.WriteBytes(obj.Validator)
+ w.WriteBytes(obj.Penalties)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
diff --git a/core/types/gen_log_json.go b/core/types/gen_log_json.go
index 759ff8814..ae61caf6b 100644
--- a/core/types/gen_log_json.go
+++ b/core/types/gen_log_json.go
@@ -12,6 +12,7 @@ import (
var _ = (*logMarshaling)(nil)
+// MarshalJSON marshals as JSON.
func (l Log) MarshalJSON() ([]byte, error) {
type Log struct {
Address common.Address `json:"address" gencodec:"required"`
@@ -37,6 +38,7 @@ func (l Log) MarshalJSON() ([]byte, error) {
return json.Marshal(&enc)
}
+// UnmarshalJSON unmarshals from JSON.
func (l *Log) UnmarshalJSON(input []byte) error {
type Log struct {
Address *common.Address `json:"address" gencodec:"required"`
diff --git a/core/types/gen_receipt_json.go b/core/types/gen_receipt_json.go
index ffc851f2d..03494c8a6 100644
--- a/core/types/gen_receipt_json.go
+++ b/core/types/gen_receipt_json.go
@@ -5,6 +5,7 @@ package types
import (
"encoding/json"
"errors"
+ "math/big"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
@@ -12,39 +13,50 @@ import (
var _ = (*receiptMarshaling)(nil)
+// MarshalJSON marshals as JSON.
func (r Receipt) MarshalJSON() ([]byte, error) {
type Receipt struct {
PostState hexutil.Bytes `json:"root"`
- Status hexutil.Uint `json:"status"`
+ Status hexutil.Uint64 `json:"status"`
CumulativeGasUsed hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"`
Bloom Bloom `json:"logsBloom" gencodec:"required"`
Logs []*Log `json:"logs" gencodec:"required"`
TxHash common.Hash `json:"transactionHash" gencodec:"required"`
ContractAddress common.Address `json:"contractAddress"`
GasUsed hexutil.Uint64 `json:"gasUsed" gencodec:"required"`
+ BlockHash common.Hash `json:"blockHash,omitempty"`
+ BlockNumber *hexutil.Big `json:"blockNumber,omitempty"`
+ TransactionIndex hexutil.Uint `json:"transactionIndex"`
}
var enc Receipt
enc.PostState = r.PostState
- enc.Status = hexutil.Uint(r.Status)
+ enc.Status = hexutil.Uint64(r.Status)
enc.CumulativeGasUsed = hexutil.Uint64(r.CumulativeGasUsed)
enc.Bloom = r.Bloom
enc.Logs = r.Logs
enc.TxHash = r.TxHash
enc.ContractAddress = r.ContractAddress
enc.GasUsed = hexutil.Uint64(r.GasUsed)
+ enc.BlockHash = r.BlockHash
+ enc.BlockNumber = (*hexutil.Big)(r.BlockNumber)
+ enc.TransactionIndex = hexutil.Uint(r.TransactionIndex)
return json.Marshal(&enc)
}
+// UnmarshalJSON unmarshals from JSON.
func (r *Receipt) UnmarshalJSON(input []byte) error {
type Receipt struct {
PostState *hexutil.Bytes `json:"root"`
- Status *hexutil.Uint `json:"status"`
+ Status *hexutil.Uint64 `json:"status"`
CumulativeGasUsed *hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"`
Bloom *Bloom `json:"logsBloom" gencodec:"required"`
Logs []*Log `json:"logs" gencodec:"required"`
TxHash *common.Hash `json:"transactionHash" gencodec:"required"`
ContractAddress *common.Address `json:"contractAddress"`
GasUsed *hexutil.Uint64 `json:"gasUsed" gencodec:"required"`
+ BlockHash *common.Hash `json:"blockHash,omitempty"`
+ BlockNumber *hexutil.Big `json:"blockNumber,omitempty"`
+ TransactionIndex *hexutil.Uint `json:"transactionIndex"`
}
var dec Receipt
if err := json.Unmarshal(input, &dec); err != nil {
@@ -54,7 +66,7 @@ func (r *Receipt) UnmarshalJSON(input []byte) error {
r.PostState = *dec.PostState
}
if dec.Status != nil {
- r.Status = uint(*dec.Status)
+ r.Status = uint64(*dec.Status)
}
if dec.CumulativeGasUsed == nil {
return errors.New("missing required field 'cumulativeGasUsed' for Receipt")
@@ -79,5 +91,14 @@ func (r *Receipt) UnmarshalJSON(input []byte) error {
return errors.New("missing required field 'gasUsed' for Receipt")
}
r.GasUsed = uint64(*dec.GasUsed)
+ if dec.BlockHash != nil {
+ r.BlockHash = *dec.BlockHash
+ }
+ if dec.BlockNumber != nil {
+ r.BlockNumber = (*big.Int)(dec.BlockNumber)
+ }
+ if dec.TransactionIndex != nil {
+ r.TransactionIndex = uint(*dec.TransactionIndex)
+ }
return nil
}
diff --git a/core/types/hashes.go b/core/types/hashes.go
new file mode 100644
index 000000000..35fc6dc9f
--- /dev/null
+++ b/core/types/hashes.go
@@ -0,0 +1,39 @@
+// Copyright 2023 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package types
+
+import (
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/crypto"
+)
+
+var (
+ // EmptyRootHash is the known root hash of an empty trie.
+ EmptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+
+ // EmptyUncleHash is the known hash of the empty uncle set.
+ EmptyUncleHash = rlpHash([]*Header(nil)) // 1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347
+
+ // EmptyCodeHash is the known hash of the empty EVM bytecode.
+ EmptyCodeHash = crypto.Keccak256Hash(nil) // c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470
+
+ // EmptyTxsHash is the known hash of the empty transaction set.
+ EmptyTxsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+
+ // EmptyReceiptsHash is the known hash of the empty receipt set.
+ EmptyReceiptsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+)
diff --git a/core/types/hashing.go b/core/types/hashing.go
new file mode 100644
index 000000000..8b9cb92b9
--- /dev/null
+++ b/core/types/hashing.go
@@ -0,0 +1,11 @@
+package types
+
+import (
+ "bytes"
+ "sync"
+)
+
+// encodeBufferPool holds temporary encoder buffers for DeriveSha and TX encoding.
+var encodeBufferPool = sync.Pool{
+ New: func() interface{} { return new(bytes.Buffer) },
+}
diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go
new file mode 100644
index 000000000..d2f2781a6
--- /dev/null
+++ b/core/types/hashing_test.go
@@ -0,0 +1,79 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package types_test
+
+import (
+ "math/big"
+ "testing"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/trie"
+)
+
+func BenchmarkDeriveSha200(b *testing.B) {
+ txs, err := genTxs(200)
+ if err != nil {
+ b.Fatal(err)
+ }
+ var exp common.Hash
+ var got common.Hash
+ b.Run("std_trie", func(b *testing.B) {
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp = types.DeriveSha(txs, tr)
+ }
+ })
+
+ b.Run("stack_trie", func(b *testing.B) {
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ got = types.DeriveSha(txs, trie.NewStackTrie(nil))
+ }
+ })
+ if got != exp {
+ b.Errorf("got %x exp %x", got, exp)
+ }
+}
+
+func genTxs(num uint64) (types.Transactions, error) {
+ key, err := crypto.HexToECDSA("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef")
+ if err != nil {
+ return nil, err
+ }
+ var addr = crypto.PubkeyToAddress(key.PublicKey)
+ newTx := func(i uint64) (*types.Transaction, error) {
+ signer := types.NewEIP155Signer(big.NewInt(18))
+ utx := types.NewTransaction(i, addr, new(big.Int), 0, new(big.Int).SetUint64(10000000), nil)
+ tx, err := types.SignTx(utx, signer, key)
+ return tx, err
+ }
+ var txs types.Transactions
+ for i := uint64(0); i < num; i++ {
+ tx, err := newTx(i)
+ if err != nil {
+ return nil, err
+ }
+ txs = append(txs, tx)
+ }
+ return txs, nil
+}
diff --git a/core/types/lending_transaction.go b/core/types/lending_transaction.go
index e33826829..715246111 100644
--- a/core/types/lending_transaction.go
+++ b/core/types/lending_transaction.go
@@ -17,6 +17,7 @@
package types
import (
+ "bytes"
"container/heap"
"errors"
"io"
@@ -319,10 +320,12 @@ func (s LendingTransactions) Len() int { return len(s) }
// Swap swaps the i'th and the j'th element in s.
func (s LendingTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-// GetRlp implements Rlpable and returns the i'th element of s in rlp.
-func (s LendingTransactions) GetRlp(i int) []byte {
- enc, _ := rlp.EncodeToBytes(s[i])
- return enc
+// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors
+// because we assume that *Transaction will only ever contain valid txs that were either
+// constructed by decoding or via public API in this package.
+func (s LendingTransactions) EncodeIndex(i int, w *bytes.Buffer) {
+ tx := s[i]
+ rlp.Encode(w, tx.data)
}
// LendingTxDifference returns a new set t which is the difference between a to b.
@@ -363,7 +366,7 @@ func (s *LendingTxByNonce) Pop() interface{} {
return x
}
-//LendingTransactionByNonce sort transaction by nonce
+// LendingTransactionByNonce sort transaction by nonce
type LendingTransactionByNonce struct {
txs map[common.Address]LendingTransactions
heads LendingTxByNonce
diff --git a/core/types/log.go b/core/types/log.go
index af8e515ea..93567b1e6 100644
--- a/core/types/log.go
+++ b/core/types/log.go
@@ -25,7 +25,7 @@ import (
"github.com/tomochain/tomochain/rlp"
)
-//go:generate gencodec -type Log -field-override logMarshaling -out gen_log_json.go
+//go:generate go run github.com/fjl/gencodec -type Log -field-override logMarshaling -out gen_log_json.go
// Log represents a contract log event. These events are generated by the LOG opcode and
// stored/indexed by the node.
@@ -63,6 +63,9 @@ type logMarshaling struct {
Index hexutil.Uint
}
+//go:generate go run ../../rlp/rlpgen -type rlpLog -out gen_log_rlp.go
+
+// rlpLog is used to RLP-encode both the consensus and storage formats.
type rlpLog struct {
Address common.Address
Topics []common.Hash
diff --git a/core/types/order_transaction.go b/core/types/order_transaction.go
index d51884e3f..e7150b991 100644
--- a/core/types/order_transaction.go
+++ b/core/types/order_transaction.go
@@ -17,6 +17,7 @@
package types
import (
+ "bytes"
"container/heap"
"errors"
"io"
@@ -250,10 +251,12 @@ func (s OrderTransactions) Len() int { return len(s) }
// Swap swaps the i'th and the j'th element in s.
func (s OrderTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-// GetRlp implements Rlpable and returns the i'th element of s in rlp.
-func (s OrderTransactions) GetRlp(i int) []byte {
- enc, _ := rlp.EncodeToBytes(s[i])
- return enc
+// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors
+// because we assume that *Transaction will only ever contain valid txs that were either
+// constructed by decoding or via public API in this package.
+func (s OrderTransactions) EncodeIndex(i int, w *bytes.Buffer) {
+ tx := s[i]
+ rlp.Encode(w, tx.data)
}
// OrderTxDifference returns a new set t which is the difference between a to b.
diff --git a/core/types/receipt.go b/core/types/receipt.go
index 3c55c1224..3235af79c 100644
--- a/core/types/receipt.go
+++ b/core/types/receipt.go
@@ -18,16 +18,20 @@ package types
import (
"bytes"
+ "errors"
"fmt"
"io"
+ "math/big"
"unsafe"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
)
-//go:generate gencodec -type Receipt -field-override receiptMarshaling -out gen_receipt_json.go
+//go:generate go run github.com/fjl/gencodec -type Receipt -field-override receiptMarshaling -out gen_receipt_json.go
var (
receiptStatusFailedRLP = []byte{}
@@ -36,17 +40,17 @@ var (
const (
// ReceiptStatusFailed is the status code of a transaction if execution failed.
- ReceiptStatusFailed = uint(0)
+ ReceiptStatusFailed = uint64(0)
// ReceiptStatusSuccessful is the status code of a transaction if execution succeeded.
- ReceiptStatusSuccessful = uint(1)
+ ReceiptStatusSuccessful = uint64(1)
)
// Receipt represents the results of a transaction.
type Receipt struct {
// Consensus fields
PostState []byte `json:"root"`
- Status uint `json:"status"`
+ Status uint64 `json:"status"`
CumulativeGasUsed uint64 `json:"cumulativeGasUsed" gencodec:"required"`
Bloom Bloom `json:"logsBloom" gencodec:"required"`
Logs []*Log `json:"logs" gencodec:"required"`
@@ -55,13 +59,21 @@ type Receipt struct {
TxHash common.Hash `json:"transactionHash" gencodec:"required"`
ContractAddress common.Address `json:"contractAddress"`
GasUsed uint64 `json:"gasUsed" gencodec:"required"`
+
+ // Inclusion information: These fields provide information about the inclusion of the
+ // transaction corresponding to this receipt.
+ BlockHash common.Hash `json:"blockHash,omitempty"`
+ BlockNumber *big.Int `json:"blockNumber,omitempty"`
+ TransactionIndex uint `json:"transactionIndex"`
}
type receiptMarshaling struct {
PostState hexutil.Bytes
- Status hexutil.Uint
+ Status hexutil.Uint64
CumulativeGasUsed hexutil.Uint64
GasUsed hexutil.Uint64
+ BlockNumber *hexutil.Big
+ TransactionIndex hexutil.Uint
}
// receiptRLP is the consensus encoding of a receipt.
@@ -72,7 +84,14 @@ type receiptRLP struct {
Logs []*Log
}
-type receiptStorageRLP struct {
+// StoredReceiptRLP is the storage encoding of a receipt.
+type StoredReceiptRLP struct {
+ PostStateOrStatus []byte
+ CumulativeGasUsed uint64
+ Logs []*Log
+}
+
+type legacyStoredReceiptRLP struct {
PostStateOrStatus []byte
CumulativeGasUsed uint64
Bloom Bloom
@@ -141,7 +160,6 @@ func (r *Receipt) statusEncoding() []byte {
// to approximate and limit the memory consumption of various caches.
func (r *Receipt) Size() common.StorageSize {
size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState))
-
size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{}))
for _, log := range r.Logs {
size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data))
@@ -163,50 +181,136 @@ type ReceiptForStorage Receipt
// EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt
// into an RLP stream.
-func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error {
- enc := &receiptStorageRLP{
- PostStateOrStatus: (*Receipt)(r).statusEncoding(),
- CumulativeGasUsed: r.CumulativeGasUsed,
- Bloom: r.Bloom,
- TxHash: r.TxHash,
- ContractAddress: r.ContractAddress,
- Logs: make([]*LogForStorage, len(r.Logs)),
- GasUsed: r.GasUsed,
- }
- for i, log := range r.Logs {
- enc.Logs[i] = (*LogForStorage)(log)
+func (r *ReceiptForStorage) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ outerList := w.List()
+ w.WriteBytes((*Receipt)(r).statusEncoding())
+ w.WriteUint64(r.CumulativeGasUsed)
+ logList := w.List()
+ for _, log := range r.Logs {
+ if err := rlp.Encode(w, log); err != nil {
+ return err
+ }
}
- return rlp.Encode(w, enc)
+ w.ListEnd(logList)
+ w.ListEnd(outerList)
+ return w.Flush()
}
// DecodeRLP implements rlp.Decoder, and loads both consensus and implementation
// fields of a receipt from an RLP stream.
func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error {
- var dec receiptStorageRLP
- if err := s.Decode(&dec); err != nil {
+ // Retrieve the entire receipt blob as we need to try multiple decoders
+ blob, err := s.Raw()
+ if err != nil {
+ return err
+ }
+ // Try decoding from the newest format for future proofness, then the older one
+ // for old nodes that just upgraded. V4 was an intermediate unreleased format so
+ // we do need to decode it, but it's not common (try last).
+ if err := decodeStoredReceiptRLP(r, blob); err == nil {
+ return nil
+ }
+ return decodeLegacyStoredReceiptRLP(r, blob)
+}
+
+func decodeStoredReceiptRLP(r *ReceiptForStorage, blob []byte) error {
+ var stored StoredReceiptRLP
+ if err := rlp.DecodeBytes(blob, &stored); err != nil {
return err
}
- if err := (*Receipt)(r).setStatus(dec.PostStateOrStatus); err != nil {
+ if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil {
return err
}
- // Assign the consensus fields
- r.CumulativeGasUsed, r.Bloom = dec.CumulativeGasUsed, dec.Bloom
- r.Logs = make([]*Log, len(dec.Logs))
- for i, log := range dec.Logs {
+ r.CumulativeGasUsed = stored.CumulativeGasUsed
+ r.Logs = stored.Logs
+ r.Bloom = CreateBloom(Receipts{(*Receipt)(r)})
+
+ return nil
+}
+
+func decodeLegacyStoredReceiptRLP(r *ReceiptForStorage, blob []byte) error {
+ var stored legacyStoredReceiptRLP
+ if err := rlp.DecodeBytes(blob, &stored); err != nil {
+ return err
+ }
+ if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil {
+ return err
+ }
+ r.CumulativeGasUsed = stored.CumulativeGasUsed
+ r.TxHash = stored.TxHash
+ r.ContractAddress = stored.ContractAddress
+ r.GasUsed = stored.GasUsed
+ r.Logs = make([]*Log, len(stored.Logs))
+ for i, log := range stored.Logs {
r.Logs[i] = (*Log)(log)
}
- // Assign the implementation fields
- r.TxHash, r.ContractAddress, r.GasUsed = dec.TxHash, dec.ContractAddress, dec.GasUsed
+ r.Bloom = CreateBloom(Receipts{(*Receipt)(r)})
+
return nil
}
-// Receipts is a wrapper around a Receipt array to implement DerivableList.
+// Receipts implements DerivableList for receipts.
type Receipts []*Receipt
// Len returns the number of receipts in this list.
func (r Receipts) Len() int { return len(r) }
-// GetRlp returns the RLP encoding of one receipt from the list.
+// EncodeIndex encodes the i'th receipt to w.
+func (rs Receipts) EncodeIndex(i int, w *bytes.Buffer) {
+ r := rs[i]
+ data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs}
+ rlp.Encode(w, data)
+}
+
+// DeriveFields fills the receipts with their computed fields based on consensus
+// data and contextual infos like containing block and transactions.
+func (rs Receipts) DeriveFields(config *params.ChainConfig, hash common.Hash, number uint64, txs []*Transaction) error {
+ signer := MakeSigner(config, new(big.Int).SetUint64(number))
+
+ logIndex := uint(0)
+ if len(txs) != len(rs) {
+ return errors.New("transaction and receipt count mismatch")
+ }
+ for i := 0; i < len(rs); i++ {
+ // The transaction type and hash can be retrieved from the transaction itself
+ rs[i].TxHash = txs[i].Hash()
+
+ // block location fields
+ rs[i].BlockHash = hash
+ rs[i].BlockNumber = new(big.Int).SetUint64(number)
+ rs[i].TransactionIndex = uint(i)
+
+ // The contract address can be derived from the transaction itself
+ if txs[i].To() == nil {
+ // Deriving the signer is expensive, only do if it's actually needed
+ from, _ := Sender(signer, txs[i])
+ rs[i].ContractAddress = crypto.CreateAddress(from, txs[i].Nonce())
+ } else {
+ rs[i].ContractAddress = common.Address{}
+ }
+
+ // The used gas can be calculated based on previous r
+ if i == 0 {
+ rs[i].GasUsed = rs[i].CumulativeGasUsed
+ } else {
+ rs[i].GasUsed = rs[i].CumulativeGasUsed - rs[i-1].CumulativeGasUsed
+ }
+
+ // The derived log fields can simply be set from the block and transaction
+ for j := 0; j < len(rs[i].Logs); j++ {
+ rs[i].Logs[j].BlockNumber = number
+ rs[i].Logs[j].BlockHash = hash
+ rs[i].Logs[j].TxHash = rs[i].TxHash
+ rs[i].Logs[j].TxIndex = uint(i)
+ rs[i].Logs[j].Index = logIndex
+ logIndex++
+ }
+ }
+ return nil
+}
+
+// GetRlp returns the RLP encoding of one receipt from the list..
func (r Receipts) GetRlp(i int) []byte {
bytes, err := rlp.EncodeToBytes(r[i])
if err != nil {
diff --git a/core/types/state_account.go b/core/types/state_account.go
new file mode 100644
index 000000000..01c552a04
--- /dev/null
+++ b/core/types/state_account.go
@@ -0,0 +1,103 @@
+package types
+
+import (
+ "bytes"
+ "math/big"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+// StateAccount is the Ethereum consensus representation of accounts.
+// These objects are stored in the main account trie.
+type StateAccount struct {
+ Nonce uint64
+ Balance *big.Int
+ Root common.Hash // merkle root of the storage trie
+ CodeHash []byte
+}
+
+// NewEmptyStateAccount constructs an empty state account.
+func NewEmptyStateAccount() *StateAccount {
+ return &StateAccount{
+ Balance: new(big.Int),
+ Root: EmptyRootHash,
+ CodeHash: EmptyCodeHash.Bytes(),
+ }
+}
+
+// Copy returns a deep-copied state account object.
+func (acct *StateAccount) Copy() *StateAccount {
+ var balance *big.Int
+ if acct.Balance != nil {
+ balance = new(big.Int).Set(acct.Balance)
+ }
+ return &StateAccount{
+ Nonce: acct.Nonce,
+ Balance: balance,
+ Root: acct.Root,
+ CodeHash: common.CopyBytes(acct.CodeHash),
+ }
+}
+
+// SlimAccount is a modified version of an Account, where the root is replaced
+// with a byte slice. This format can be used to represent full-consensus format
+// or slim format which replaces the empty root and code hash as nil byte slice.
+type SlimAccount struct {
+ Nonce uint64
+ Balance *big.Int
+ Root []byte // Nil if root equals to types.EmptyRootHash
+ CodeHash []byte // Nil if hash equals to types.EmptyCodeHash
+}
+
+// SlimAccountRLP encodes the state account in 'slim RLP' format.
+func SlimAccountRLP(account StateAccount) []byte {
+ slim := SlimAccount{
+ Nonce: account.Nonce,
+ Balance: account.Balance,
+ }
+ if account.Root != EmptyRootHash {
+ slim.Root = account.Root[:]
+ }
+ if !bytes.Equal(account.CodeHash, EmptyCodeHash[:]) {
+ slim.CodeHash = account.CodeHash
+ }
+ data, err := rlp.EncodeToBytes(slim)
+ if err != nil {
+ panic(err)
+ }
+ return data
+}
+
+// FullAccount decodes the data on the 'slim RLP' format and return
+// the consensus format account.
+func FullAccount(data []byte) (*StateAccount, error) {
+ var slim SlimAccount
+ if err := rlp.DecodeBytes(data, &slim); err != nil {
+ return nil, err
+ }
+ var account StateAccount
+ account.Nonce, account.Balance = slim.Nonce, slim.Balance
+
+ // Interpret the storage root and code hash in slim format.
+ if len(slim.Root) == 0 {
+ account.Root = EmptyRootHash
+ } else {
+ account.Root = common.BytesToHash(slim.Root)
+ }
+ if len(slim.CodeHash) == 0 {
+ account.CodeHash = EmptyCodeHash[:]
+ } else {
+ account.CodeHash = slim.CodeHash
+ }
+ return &account, nil
+}
+
+// FullAccountRLP converts data on the 'slim RLP' format into the full RLP-format.
+func FullAccountRLP(data []byte) ([]byte, error) {
+ account, err := FullAccount(data)
+ if err != nil {
+ return nil, err
+ }
+ return rlp.EncodeToBytes(account)
+}
diff --git a/core/types/transaction.go b/core/types/transaction.go
index cf546c442..d0d2ce215 100644
--- a/core/types/transaction.go
+++ b/core/types/transaction.go
@@ -17,6 +17,7 @@
package types
import (
+ "bytes"
"container/heap"
"errors"
"fmt"
@@ -242,34 +243,6 @@ func (tx *Transaction) Size() common.StorageSize {
return common.StorageSize(c)
}
-// AsMessage returns the transaction as a core.Message.
-//
-// AsMessage requires a signer to derive the sender.
-//
-// XXX Rename message to something less arbitrary?
-func (tx *Transaction) AsMessage(s Signer, balanceFee *big.Int, number *big.Int) (Message, error) {
- msg := Message{
- nonce: tx.data.AccountNonce,
- gasLimit: tx.data.GasLimit,
- gasPrice: new(big.Int).Set(tx.data.Price),
- to: tx.data.Recipient,
- amount: tx.data.Amount,
- data: tx.data.Payload,
- checkNonce: true,
- balanceTokenFee: balanceFee,
- }
- var err error
- msg.from, err = Sender(s, tx)
- if balanceFee != nil {
- if number.Cmp(common.TIPTRC21Fee) > 0 {
- msg.gasPrice = common.TRC21GasPrice
- } else {
- msg.gasPrice = common.TRC21GasPriceBefore
- }
- }
- return msg, err
-}
-
// WithSignature returns a new transaction with the given signature.
// This signature needs to be formatted as described in the yellow paper (v+27).
func (tx *Transaction) WithSignature(signer Signer, sig []byte) (*Transaction, error) {
@@ -523,15 +496,17 @@ type Transactions []*Transaction
// Len returns the length of s.
func (s Transactions) Len() int { return len(s) }
+// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors
+// because we assume that *Transaction will only ever contain valid txs that were either
+// constructed by decoding or via public API in this package.
+func (s Transactions) EncodeIndex(i int, w *bytes.Buffer) {
+ tx := s[i]
+ rlp.Encode(w, tx.data)
+}
+
// Swap swaps the i'th and the j'th element in s.
func (s Transactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-// GetRlp implements Rlpable and returns the i'th element of s in rlp.
-func (s Transactions) GetRlp(i int) []byte {
- enc, _ := rlp.EncodeToBytes(s[i])
- return enc
-}
-
// TxDifference returns a new set t which is the difference between a to b.
func TxDifference(a, b Transactions) (keep Transactions) {
keep = make(Transactions, 0, len(a))
@@ -680,45 +655,3 @@ func (t *TransactionsByPriceAndNonce) Shift() {
func (t *TransactionsByPriceAndNonce) Pop() {
heap.Pop(&t.heads)
}
-
-// Message is a fully derived transaction and implements core.Message
-//
-// NOTE: In a future PR this will be removed.
-type Message struct {
- to *common.Address
- from common.Address
- nonce uint64
- amount *big.Int
- gasLimit uint64
- gasPrice *big.Int
- data []byte
- checkNonce bool
- balanceTokenFee *big.Int
-}
-
-func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte, checkNonce bool, balanceTokenFee *big.Int) Message {
- if balanceTokenFee != nil {
- gasPrice = common.TRC21GasPrice
- }
- return Message{
- from: from,
- to: to,
- nonce: nonce,
- amount: amount,
- gasLimit: gasLimit,
- gasPrice: gasPrice,
- data: data,
- checkNonce: checkNonce,
- balanceTokenFee: balanceTokenFee,
- }
-}
-
-func (m Message) From() common.Address { return m.from }
-func (m Message) BalanceTokenFee() *big.Int { return m.balanceTokenFee }
-func (m Message) To() *common.Address { return m.to }
-func (m Message) GasPrice() *big.Int { return m.gasPrice }
-func (m Message) Value() *big.Int { return m.amount }
-func (m Message) Gas() uint64 { return m.gasLimit }
-func (m Message) Nonce() uint64 { return m.nonce }
-func (m Message) Data() []byte { return m.data }
-func (m Message) CheckNonce() bool { return m.checkNonce }
diff --git a/core/types/types_test.go b/core/types/types_test.go
new file mode 100644
index 000000000..03c29a159
--- /dev/null
+++ b/core/types/types_test.go
@@ -0,0 +1,111 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package types
+
+import (
+ "math/big"
+ "testing"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+type devnull struct{ len int }
+
+func (d *devnull) Write(p []byte) (int, error) {
+ d.len += len(p)
+ return len(p), nil
+}
+
+func BenchmarkEncodeRLP(b *testing.B) {
+ benchRLP(b, true)
+}
+
+func BenchmarkDecodeRLP(b *testing.B) {
+ benchRLP(b, false)
+}
+
+func benchRLP(b *testing.B, encode bool) {
+ key, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+ to := common.HexToAddress("0x00000000000000000000000000000000deadbeef")
+ signer := NewEIP155Signer(big.NewInt(1337))
+ tx := NewTransaction(1, to, big.NewInt(1), 1000000, big.NewInt(500), nil)
+ signedTx, err := SignTx(tx, signer, key)
+ if err != nil {
+ b.Fatal("cannot sign transaction for benchmarking")
+ }
+ for _, tc := range []struct {
+ name string
+ obj interface{}
+ }{
+ {
+ "header",
+ &Header{
+ Difficulty: big.NewInt(10000000000),
+ Number: big.NewInt(1000),
+ GasLimit: 8_000_000,
+ GasUsed: 8_000_000,
+ Time: big.NewInt(555),
+ Extra: make([]byte, 32),
+ },
+ },
+ {
+ "receipt-for-storage",
+ &ReceiptForStorage{
+ Status: ReceiptStatusSuccessful,
+ CumulativeGasUsed: 0x888888888,
+ Logs: make([]*Log, 0),
+ },
+ },
+ {
+ "receipt-full",
+ &Receipt{
+ Status: ReceiptStatusSuccessful,
+ CumulativeGasUsed: 0x888888888,
+ Logs: make([]*Log, 0),
+ },
+ },
+ {
+ "transaction",
+ signedTx,
+ },
+ } {
+ if encode {
+ b.Run(tc.name, func(b *testing.B) {
+ b.ReportAllocs()
+ var null = &devnull{}
+ for i := 0; i < b.N; i++ {
+ rlp.Encode(null, tc.obj)
+ }
+ b.SetBytes(int64(null.len / b.N))
+ })
+ } else {
+ data, _ := rlp.EncodeToBytes(tc.obj)
+ // Test decoding
+ b.Run(tc.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ if err := rlp.DecodeBytes(data, tc.obj); err != nil {
+ b.Fatal(err)
+ }
+ }
+ b.SetBytes(int64(len(data)))
+ })
+ }
+ }
+}
diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go
index ba31cf494..7e7df4f89 100644
--- a/core/vm/gas_table_test.go
+++ b/core/vm/gas_table_test.go
@@ -17,11 +17,12 @@
package vm
import (
- "github.com/tomochain/tomochain/core/rawdb"
"math"
"math/big"
"testing"
+ "github.com/tomochain/tomochain/core/rawdb"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/core/state"
@@ -81,7 +82,7 @@ func TestEIP2200(t *testing.T) {
for i, tt := range eip2200Tests {
address := common.BytesToAddress([]byte("contract"))
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
statedb.CreateAccount(address)
statedb.SetCode(address, hexutil.MustDecode(tt.input))
statedb.SetState(address, common.Hash{}, common.BytesToHash([]byte{tt.original}))
@@ -91,7 +92,7 @@ func TestEIP2200(t *testing.T) {
CanTransfer: func(StateDB, common.Address, *big.Int) bool { return true },
Transfer: func(StateDB, common.Address, common.Address, *big.Int) {},
}
- vmenv := NewEVM(vmctx, statedb, nil,params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}})
+ vmenv := NewEVM(vmctx, statedb, nil, params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}})
_, gas, err := vmenv.Call(AccountRef(common.Address{}), address, nil, tt.gaspool, new(big.Int))
if err != tt.failure {
diff --git a/core/vm/instructions.go b/core/vm/instructions.go
index 16f368585..ab962bd65 100644
--- a/core/vm/instructions.go
+++ b/core/vm/instructions.go
@@ -17,13 +17,13 @@
package vm
import (
- "github.com/tomochain/tomochain/params"
"math/big"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/core/types"
- "golang.org/x/crypto/sha3"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/params"
)
var (
@@ -381,7 +381,7 @@ func opSha3(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by
data := callContext.memory.GetPtr(offset.Int64(), size.Int64())
if interpreter.hasher == nil {
- interpreter.hasher = sha3.NewLegacyKeccak256().(keccakState)
+ interpreter.hasher = crypto.NewKeccakState()
} else {
interpreter.hasher.Reset()
}
@@ -513,16 +513,21 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx
// opExtCodeHash returns the code hash of a specified account.
// There are several cases when the function is called, while we can relay everything
// to `state.GetCodeHash` function to ensure the correctness.
-// (1) Caller tries to get the code hash of a normal contract account, state
+//
+// (1) Caller tries to get the code hash of a normal contract account, state
+//
// should return the relative code hash and set it as the result.
//
-// (2) Caller tries to get the code hash of a non-existent account, state should
+// (2) Caller tries to get the code hash of a non-existent account, state should
+//
// return common.Hash{} and zero will be set as the result.
//
-// (3) Caller tries to get the code hash for an account without contract code,
+// (3) Caller tries to get the code hash for an account without contract code,
+//
// state should return emptyCodeHash(0xc5d246...) as the result.
//
-// (4) Caller tries to get the code hash of a precompiled account, the result
+// (4) Caller tries to get the code hash of a precompiled account, the result
+//
// should be zero or emptyCodeHash.
//
// It is worth noting that in order to avoid unnecessary create and clean,
@@ -531,10 +536,12 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx
// If the precompile account is not transferred any amount on a private or
// customized chain, the return value will be zero.
//
-// (5) Caller tries to get the code hash for an account which is marked as suicided
+// (5) Caller tries to get the code hash for an account which is marked as suicided
+//
// in the current transaction, the code hash of this account should be returned.
//
-// (6) Caller tries to get the code hash for an account which is marked as deleted,
+// (6) Caller tries to get the code hash for an account which is marked as deleted,
+//
// this account should be regarded as a non-existent account and zero should be returned.
func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) {
slot := callContext.stack.peek()
diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go
index fc5b17a4f..36027be79 100644
--- a/core/vm/interpreter.go
+++ b/core/vm/interpreter.go
@@ -17,11 +17,11 @@
package vm
import (
- "hash"
"sync/atomic"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
+ "github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/log"
)
@@ -70,14 +70,6 @@ type callCtx struct {
contract *Contract
}
-// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports
-// Read to get a variable amount of data from the hash state. Read is faster than Sum
-// because it doesn't copy the internal state, but also modifies the internal state.
-type keccakState interface {
- hash.Hash
- Read([]byte) (int, error)
-}
-
// EVMInterpreter represents an EVM interpreter
type EVMInterpreter struct {
evm *EVM
@@ -85,8 +77,8 @@ type EVMInterpreter struct {
intPool *intPool
- hasher keccakState // Keccak256 hasher instance shared across opcodes
- hasherBuf common.Hash // Keccak256 hasher result array shared aross opcodes
+ hasher crypto.KeccakState // Keccak256 hasher instance shared across opcodes
+ hasherBuf common.Hash // Keccak256 hasher result array shared across opcodes
readOnly bool // Whether to throw on stateful modifications
returnData []byte // Last CALL's return data for subsequent reuse
diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go
index 683cad1d1..9a13d3d6f 100644
--- a/core/vm/runtime/runtime.go
+++ b/core/vm/runtime/runtime.go
@@ -17,11 +17,12 @@
package runtime
import (
- "github.com/tomochain/tomochain/core/rawdb"
"math"
"math/big"
"time"
+ "github.com/tomochain/tomochain/core/rawdb"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/vm"
@@ -100,7 +101,7 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) {
if cfg.State == nil {
db := rawdb.NewMemoryDatabase()
- cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db))
+ cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db), nil)
}
var (
address = common.BytesToAddress([]byte("contract"))
@@ -131,7 +132,7 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) {
if cfg.State == nil {
db := rawdb.NewMemoryDatabase()
- cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db))
+ cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db), nil)
}
var (
vmenv = NewEnv(cfg)
diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go
index e430c2b2a..0b95751d3 100644
--- a/core/vm/runtime/runtime_test.go
+++ b/core/vm/runtime/runtime_test.go
@@ -17,11 +17,12 @@
package runtime
import (
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"strings"
"testing"
+ "github.com/tomochain/tomochain/core/rawdb"
+
"github.com/tomochain/tomochain/accounts/abi"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
@@ -99,7 +100,7 @@ func TestExecute(t *testing.T) {
func TestCall(t *testing.T) {
db := rawdb.NewMemoryDatabase()
- state, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ state, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
address := common.HexToAddress("0x0a")
state.SetCode(address, []byte{
byte(vm.PUSH1), 10,
@@ -156,7 +157,7 @@ func BenchmarkCall(b *testing.B) {
func benchmarkEVM_Create(bench *testing.B, code string) {
var (
db = rawdb.NewMemoryDatabase()
- statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ = state.New(common.Hash{}, state.NewDatabase(db), nil)
sender = common.BytesToAddress([]byte("sender"))
receiver = common.BytesToAddress([]byte("receiver"))
)
diff --git a/crypto/crypto.go b/crypto/crypto.go
index 18386f85c..9154a5e9a 100644
--- a/crypto/crypto.go
+++ b/crypto/crypto.go
@@ -23,6 +23,7 @@ import (
"encoding/hex"
"errors"
"fmt"
+ "hash"
"io"
"io/ioutil"
"math/big"
@@ -30,38 +31,72 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
- "github.com/tomochain/tomochain/crypto/sha3"
"github.com/tomochain/tomochain/rlp"
+ "golang.org/x/crypto/sha3"
)
+// SignatureLength indicates the byte length required to carry a signature with recovery id.
+const SignatureLength = 64 + 1 // 64 bytes ECDSA signature + 1 byte recovery id
+
+// RecoveryIDOffset points to the byte offset within the signature that contains the recovery id.
+const RecoveryIDOffset = 64
+
+// DigestLength sets the signature digest exact length
+const DigestLength = 32
+
var (
secp256k1_N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16)
secp256k1_halfN = new(big.Int).Div(secp256k1_N, big.NewInt(2))
)
+var errInvalidPubkey = errors.New("invalid secp256k1 public key")
+
+// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports
+// Read to get a variable amount of data from the hash state. Read is faster than Sum
+// because it doesn't copy the internal state, but also modifies the internal state.
+type KeccakState interface {
+ hash.Hash
+ Read([]byte) (int, error)
+}
+
+// NewKeccakState creates a new KeccakState
+func NewKeccakState() KeccakState {
+ return sha3.NewLegacyKeccak256().(KeccakState)
+}
+
+// HashData hashes the provided data using the KeccakState and returns a 32 byte hash
+func HashData(kh KeccakState, data []byte) (h common.Hash) {
+ kh.Reset()
+ kh.Write(data)
+ kh.Read(h[:])
+ return h
+}
+
// Keccak256 calculates and returns the Keccak256 hash of the input data.
func Keccak256(data ...[]byte) []byte {
- d := sha3.NewKeccak256()
+ b := make([]byte, 32)
+ d := NewKeccakState()
for _, b := range data {
d.Write(b)
}
- return d.Sum(nil)
+ d.Read(b)
+ return b
}
// Keccak256Hash calculates and returns the Keccak256 hash of the input data,
// converting it to an internal Hash data structure.
func Keccak256Hash(data ...[]byte) (h common.Hash) {
- d := sha3.NewKeccak256()
+ d := NewKeccakState()
for _, b := range data {
d.Write(b)
}
- d.Sum(h[:0])
+ d.Read(h[:])
return h
}
// Keccak512 calculates and returns the Keccak512 hash of the input data.
func Keccak512(data ...[]byte) []byte {
- d := sha3.NewKeccak512()
+ d := sha3.NewLegacyKeccak512()
for _, b := range data {
d.Write(b)
}
@@ -128,6 +163,15 @@ func FromECDSA(priv *ecdsa.PrivateKey) []byte {
return math.PaddedBigBytes(priv.D, priv.Params().BitSize/8)
}
+// UnmarshalPubkey converts bytes to a secp256k1 public key.
+func UnmarshalPubkey(pub []byte) (*ecdsa.PublicKey, error) {
+ x, y := elliptic.Unmarshal(S256(), pub)
+ if x == nil {
+ return nil, errInvalidPubkey
+ }
+ return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}, nil
+}
+
func ToECDSAPub(pub []byte) *ecdsa.PublicKey {
if len(pub) == 0 {
return nil
diff --git a/eth/api.go b/eth/api.go
index 76a466a49..e885f6d60 100644
--- a/eth/api.go
+++ b/eth/api.go
@@ -28,6 +28,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/log"
@@ -343,7 +344,7 @@ func NewPrivateDebugAPI(config *params.ChainConfig, eth *Ethereum) *PrivateDebug
// Preimage is a debug API function that returns the preimage for a sha3 hash, if known.
func (api *PrivateDebugAPI) Preimage(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) {
- db := core.PreimageTable(api.eth.ChainDb())
+ db := rawdb.PreimageTable(api.eth.ChainDb())
return db.Get(hash.Bytes())
}
@@ -494,11 +495,10 @@ func (api *PublicEthereumAPI) ChainId() hexutil.Uint64 {
}
// GetOwner return masternode owner of the given coinbase address
-func (api *PublicEthereumAPI) GetOwnerByCoinbase(ctx context.Context, coinbase common.Address, blockNr rpc.BlockNumber) (common.Address, error) {
+func (api *PublicEthereumAPI) GetOwnerByCoinbase(ctx context.Context, coinbase common.Address, blockNr rpc.BlockNumber) (common.Address, error) {
statedb, _, err := api.e.ApiBackend.StateAndHeaderByNumber(ctx, blockNr)
if err != nil {
return common.Address{}, err
}
return statedb.GetOwner(coinbase), nil
}
-
diff --git a/eth/api_backend.go b/eth/api_backend.go
index 67554b448..ddd44455e 100644
--- a/eth/api_backend.go
+++ b/eth/api_backend.go
@@ -21,23 +21,19 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/tomochain/tomochain/tomox/tradingstate"
- "github.com/tomochain/tomochain/tomoxlending"
"io/ioutil"
"math/big"
"path/filepath"
- "github.com/tomochain/tomochain/tomox"
-
- "github.com/tomochain/tomochain/consensus/posv"
-
"github.com/tomochain/tomochain/accounts"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/consensus"
+ "github.com/tomochain/tomochain/consensus/posv"
"github.com/tomochain/tomochain/contracts"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
stateDatabase "github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
@@ -50,6 +46,9 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rpc"
+ "github.com/tomochain/tomochain/tomox"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
+ "github.com/tomochain/tomochain/tomoxlending"
)
// EthApiBackend implements ethapi.Backend for full nodes
@@ -117,11 +116,11 @@ func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t
}
func (b *EthApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) {
- return core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)), nil
+ return rawdb.GetBlockReceipts(b.eth.chainDb, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()), nil
}
func (b *EthApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) {
- receipts := core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash))
+ receipts := rawdb.GetBlockReceipts(b.eth.chainDb, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig())
if receipts == nil {
return nil, nil
}
@@ -136,8 +135,8 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
return b.eth.blockchain.GetTdByHash(blockHash)
}
-func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
- state.SetBalance(msg.From(), math.MaxBig256)
+func (b *EthApiBackend) GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
+ state.SetBalance(msg.From, math.MaxBig256)
vmError := func() error { return nil }
context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil)
diff --git a/eth/api_test.go b/eth/api_test.go
index f9f2fc43d..f0a48df52 100644
--- a/eth/api_test.go
+++ b/eth/api_test.go
@@ -17,10 +17,11 @@
package eth
import (
- "github.com/tomochain/tomochain/core/rawdb"
"reflect"
"testing"
+ "github.com/tomochain/tomochain/core/rawdb"
+
"github.com/davecgh/go-spew/spew"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core/state"
@@ -32,7 +33,7 @@ func TestStorageRangeAt(t *testing.T) {
// Create a state where account 0x010000... has a few storage entries.
var (
db = rawdb.NewMemoryDatabase()
- state, _ = state.New(common.Hash{}, state.NewDatabase(db))
+ state, _ = state.New(common.Hash{}, state.NewDatabase(db), nil)
addr = common.Address{0x01}
keys = []common.Hash{ // hashes of Keys of storage
common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"),
diff --git a/eth/api_tracer.go b/eth/api_tracer.go
index e1744dc2c..b94159727 100644
--- a/eth/api_tracer.go
+++ b/eth/api_tracer.go
@@ -21,7 +21,6 @@ import (
"context"
"errors"
"fmt"
- "github.com/tomochain/tomochain/tomox/tradingstate"
"io/ioutil"
"math/big"
"runtime"
@@ -31,6 +30,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
@@ -39,6 +39,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/rpc"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
"github.com/tomochain/tomochain/trie"
)
@@ -144,7 +145,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
return nil, fmt.Errorf("parent block #%d not found", number-1)
}
}
- statedb, err := state.New(start.Root(), database)
+ statedb, err := state.New(start.Root(), database, nil)
var tomoxState *tradingstate.TradingStateDB
if err != nil {
// If the starting state is missing, allow some number of blocks to be reexecuted
@@ -158,7 +159,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
if start == nil {
break
}
- if statedb, err = state.New(start.Root(), database); err == nil {
+ if statedb, err = state.New(start.Root(), database, nil); err == nil {
tomoxState, err = tradingstate.New(start.Root(), tradingstate.NewDatabase(api.eth.TomoX.GetLevelDB()))
if err == nil {
break
@@ -198,13 +199,13 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
feeCapacity := state.GetTRC21FeeCapacityFromState(task.statedb)
// Trace all the transactions contained within
for i, tx := range task.block.Transactions() {
- var balacne *big.Int
+ var balanceFee *big.Int
if tx.To() != nil {
if value, ok := feeCapacity[*tx.To()]; ok {
- balacne = value
+ balanceFee = value
}
}
- msg, _ := tx.AsMessage(signer, balacne, task.block.Number())
+ msg, _ := core.TransactionToMessage(tx, signer, balanceFee, task.block.Number())
vmctx := core.NewEVMContext(msg, task.block.Header(), api.eth.blockchain, nil)
res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config)
@@ -438,13 +439,13 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block,
// Fetch and execute the next transaction trace tasks
for task := range jobs {
feeCapacity := state.GetTRC21FeeCapacityFromState(task.statedb)
- var balacne *big.Int
+ var balanceFee *big.Int
if txs[task.index].To() != nil {
if value, ok := feeCapacity[*txs[task.index].To()]; ok {
- balacne = value
+ balanceFee = value
}
}
- msg, _ := txs[task.index].AsMessage(signer, balacne, block.Number())
+ msg, _ := core.TransactionToMessage(txs[task.index], signer, balanceFee, block.Number())
vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil)
res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config)
@@ -462,19 +463,19 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block,
for i, tx := range txs {
// Send the trace task over for execution
jobs <- &txTraceTask{statedb: statedb.Copy(), index: i}
- var balacne *big.Int
+ var balanceFee *big.Int
if tx.To() != nil {
if value, ok := feeCapacity[*tx.To()]; ok {
- balacne = value
+ balanceFee = value
}
}
// Generate the next state snapshot fast without tracing
- msg, _ := tx.AsMessage(signer, balacne, block.Number())
+ msg, _ := core.TransactionToMessage(tx, signer, balanceFee, block.Number())
vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil)
vmenv := vm.NewEVM(vmctx, statedb, tomoxState, api.config, vm.Config{})
owner := common.Address{}
- if _, _, _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas()), owner); err != nil {
+ if _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.GasLimit), owner); err != nil {
failed = err
break
}
@@ -513,7 +514,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
if block == nil {
break
}
- if statedb, err = state.New(block.Root(), database); err == nil {
+ if statedb, err = state.New(block.Root(), database, nil); err == nil {
tomoxState, err = tradingstate.New(block.Root(), tradingstate.NewDatabase(api.eth.TomoX.GetLevelDB()))
if err == nil {
break
@@ -567,14 +568,14 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
}
size, _ := database.TrieDB().Size()
log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", size)
- return statedb,tomoxState, nil
+ return statedb, tomoxState, nil
}
// TraceTransaction returns the structured logs created during the execution of EVM
// and returns them as a JSON object.
func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Hash, config *TraceConfig) (interface{}, error) {
// Retrieve the transaction and assemble its EVM context
- tx, blockHash, _, index := core.GetTransaction(api.eth.ChainDb(), hash)
+ tx, blockHash, _, index := rawdb.GetTransaction(api.eth.ChainDb(), hash)
if tx == nil {
return nil, fmt.Errorf("transaction %x not found", hash)
}
@@ -593,7 +594,7 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Ha
// traceTx configures a new tracer according to the provided configuration, and
// executes the given message in the provided environment. The return value will
// be tracer dependent.
-func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) {
+func (api *PrivateDebugAPI) traceTx(ctx context.Context, message *core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) {
// Assemble the structured logger or the JavaScript tracer
var (
tracer vm.Tracer
@@ -630,7 +631,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v
vmenv := vm.NewEVM(vmctx, statedb, nil, api.config, vm.Config{Debug: true, Tracer: tracer})
owner := common.Address{}
- ret, gas, failed, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas()), owner)
+ result, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.GasLimit), owner)
if err != nil {
return nil, fmt.Errorf("tracing failed: %v", err)
}
@@ -638,9 +639,9 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v
switch tracer := tracer.(type) {
case *vm.StructLogger:
return ðapi.ExecutionResult{
- Gas: gas,
- Failed: failed,
- ReturnValue: fmt.Sprintf("%x", ret),
+ Gas: result.UsedGas,
+ Failed: result.Failed(),
+ ReturnValue: fmt.Sprintf("%x", result.Return()),
StructLogs: ethapi.FormatLogs(tracer.StructLogs()),
}, nil
@@ -653,7 +654,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v
}
// computeTxEnv returns the execution environment of a certain transaction.
-func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, reexec uint64) (core.Message, vm.Context, *state.StateDB, error) {
+func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, reexec uint64) (*core.Message, vm.Context, *state.StateDB, error) {
// Create the parent state database
block := api.eth.blockchain.GetBlockByHash(blockHash)
if block == nil {
@@ -687,7 +688,7 @@ func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, ree
balanceFee = value
}
}
- msg, err := tx.AsMessage(types.MakeSigner(api.config, block.Header().Number), balanceFee, block.Number())
+ msg, err := core.TransactionToMessage(tx, types.MakeSigner(api.config, block.Header().Number), balanceFee, block.Number())
if err != nil {
return nil, vm.Context{}, nil, fmt.Errorf("tx %x failed: %v", tx.Hash(), err)
}
diff --git a/eth/backend.go b/eth/backend.go
index 412c67d23..1d3e51a84 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -18,6 +18,7 @@
package eth
import (
+ "bytes"
"errors"
"fmt"
"math/big"
@@ -27,18 +28,10 @@ import (
"sync/atomic"
"time"
- "github.com/tomochain/tomochain/tomoxlending"
-
- "github.com/tomochain/tomochain/accounts/abi/bind"
- "github.com/tomochain/tomochain/common/hexutil"
- "github.com/tomochain/tomochain/core/state"
- "github.com/tomochain/tomochain/eth/filters"
- "github.com/tomochain/tomochain/rlp"
-
- "bytes"
-
"github.com/tomochain/tomochain/accounts"
+ "github.com/tomochain/tomochain/accounts/abi/bind"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/consensus/posv"
@@ -46,11 +39,12 @@ import (
contractValidator "github.com/tomochain/tomochain/contracts/validator/contract"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
-
- //"github.com/tomochain/tomochain/core/state"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/eth/downloader"
+ "github.com/tomochain/tomochain/eth/filters"
"github.com/tomochain/tomochain/eth/gasprice"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/event"
@@ -60,8 +54,10 @@ import (
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
"github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/rpc"
"github.com/tomochain/tomochain/tomox"
+ "github.com/tomochain/tomochain/tomoxlending"
)
type LesServer interface {
@@ -125,6 +121,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi
if !config.SyncMode.IsValid() {
return nil, fmt.Errorf("invalid sync mode %d", config.SyncMode)
}
+
chainDb, err := CreateDB(ctx, config, "chaindata")
if err != nil {
return nil, err
@@ -160,15 +157,20 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi
log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkId)
if !config.SkipBcVersionCheck {
- bcVersion := core.GetBlockChainVersion(chainDb)
+ bcVersion := rawdb.GetBlockChainVersion(chainDb)
if bcVersion != core.BlockChainVersion && bcVersion != 0 {
return nil, fmt.Errorf("Blockchain DB version mismatch (%d / %d). Run geth upgradedb.\n", bcVersion, core.BlockChainVersion)
}
- core.WriteBlockChainVersion(chainDb, core.BlockChainVersion)
+ rawdb.WriteBlockChainVersion(chainDb, core.BlockChainVersion)
}
var (
vmConfig = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording}
- cacheConfig = &core.CacheConfig{Disabled: config.NoPruning, TrieNodeLimit: config.TrieCache, TrieTimeLimit: config.TrieTimeout}
+ cacheConfig = &core.CacheConfig{
+ Disabled: config.NoPruning,
+ TrieNodeLimit: config.TrieCache,
+ TrieTimeLimit: config.TrieTimeout,
+ SnapshotLimit: config.SnapshotCache,
+ }
)
if eth.chainConfig.Posv != nil {
c := eth.engine.(*posv.Posv)
@@ -187,7 +189,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi
if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
log.Warn("Rewinding chain to upgrade configuration", "err", compat)
eth.blockchain.SetHead(compat.RewindTo)
- core.WriteChainConfig(chainDb, genesisHash, chainConfig)
+ rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig)
}
eth.bloomIndexer.Start(eth.blockchain)
@@ -557,7 +559,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi
// Hook verifies masternodes set
c.HookVerifyMNs = func(header *types.Header, signers []common.Address) error {
number := header.Number.Int64()
- if number > 0 && number%common.EpocBlockRandomize == 0 {
+ if number > 0 && uint64(number)%common.EpocBlockRandomize == 0 {
start := time.Now()
validators, err := GetValidators(eth.blockchain, signers)
log.Debug("Time Calculated HookVerifyMNs ", "block", header.Number.Uint64(), "time", common.PrettyDuration(time.Since(start)))
diff --git a/eth/bloombits.go b/eth/bloombits.go
index abe8c5d67..39695f43e 100644
--- a/eth/bloombits.go
+++ b/eth/bloombits.go
@@ -17,13 +17,13 @@
package eth
import (
- "github.com/tomochain/tomochain/core/rawdb"
"time"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/bitutil"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/params"
@@ -61,8 +61,8 @@ func (eth *Ethereum) startBloomHandlers() {
task := <-request
task.Bitsets = make([][]byte, len(task.Sections))
for i, section := range task.Sections {
- head := core.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1)
- if compVector, err := core.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil {
+ head := rawdb.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1)
+ if compVector, err := rawdb.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil {
if blob, err := bitutil.DecompressBytes(compVector, int(params.BloomBitsBlocks)/8); err == nil {
task.Bitsets[i] = blob
} else {
@@ -108,7 +108,7 @@ func NewBloomIndexer(db ethdb.Database, size uint64) *core.ChainIndexer {
db: db,
size: size,
}
- table := rawdb.NewTable(db, string(core.BloomBitsIndexPrefix))
+ table := rawdb.NewTable(db, string(rawdb.BloomBitsIndexPrefix))
return core.NewChainIndexer(db, table, backend, size, bloomConfirms, bloomThrottling, "bloombits")
}
@@ -138,7 +138,7 @@ func (b *BloomIndexer) Commit() error {
if err != nil {
return err
}
- core.WriteBloomBits(batch, uint(i), b.section, b.head, bitutil.CompressBytes(bits))
+ rawdb.WriteBloomBits(batch, uint(i), b.section, b.head, bitutil.CompressBytes(bits))
}
return batch.Write()
}
diff --git a/eth/config.go b/eth/config.go
index a86f08456..8b62ab7e4 100644
--- a/eth/config.go
+++ b/eth/config.go
@@ -48,6 +48,7 @@ var DefaultConfig = Config{
DatabaseCache: 768,
TrieCache: 256,
TrieTimeout: 5 * time.Minute,
+ SnapshotCache: 256,
GasPrice: big.NewInt(0.25 * params.Shannon),
TxPool: core.DefaultTxPoolConfig,
@@ -93,6 +94,7 @@ type Config struct {
DatabaseCache int
TrieCache int
TrieTimeout time.Duration
+ SnapshotCache int
// Mining-related options
Etherbase common.Address `toml:",omitempty"`
diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go
index eba7fad77..f9faf2ff6 100644
--- a/eth/downloader/downloader.go
+++ b/eth/downloader/downloader.go
@@ -27,7 +27,7 @@ import (
"github.com/tomochain/tomochain"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/event"
@@ -225,7 +225,7 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, chain BlockC
stateCh: make(chan dataPack),
stateSyncStart: make(chan *stateSync),
syncStatsState: stateSyncStats{
- processed: core.GetTrieSyncProgress(stateDb),
+ processed: rawdb.GetTrieSyncProgress(stateDb),
},
trackStateReq: make(chan *stateReq),
}
@@ -975,22 +975,22 @@ func (d *Downloader) fetchReceipts(from uint64) error {
// various callbacks to handle the slight differences between processing them.
//
// The instrumentation parameters:
-// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer)
-// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers)
-// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`)
-// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed)
-// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping)
-// - pending: task callback for the number of requests still needing download (detect completion/non-completability)
-// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish)
-// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use)
-// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions)
-// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic)
-// - fetch: network callback to actually send a particular download request to a physical remote peer
-// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer)
-// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping)
-// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks
-// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping)
-// - kind: textual label of the type being downloaded to display in log mesages
+// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer)
+// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers)
+// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`)
+// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed)
+// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping)
+// - pending: task callback for the number of requests still needing download (detect completion/non-completability)
+// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish)
+// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use)
+// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions)
+// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic)
+// - fetch: network callback to actually send a particular download request to a physical remote peer
+// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer)
+// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping)
+// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks
+// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping)
+// - kind: textual label of the type being downloaded to display in log mesages
func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool,
expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, error),
fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int,
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index af39f9856..470819224 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -94,7 +94,7 @@ func newTester() *downloadTester {
peerChainTds: make(map[string]map[common.Hash]*big.Int),
peerMissingStates: make(map[string]map[common.Hash]bool),
}
- tester.stateDb= rawdb.NewMemoryDatabase()
+ tester.stateDb = rawdb.NewMemoryDatabase()
tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00})
tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester, nil, tester.dropPeer)
@@ -160,7 +160,7 @@ func (dl *downloadTester) makeChainFork(n, f int, parent *types.Block, parentRec
// Create the common suffix
hashes, headers, blocks, receipts := dl.makeChain(n-f, 0, parent, parentReceipts, false)
- // Create the forks, making the second heavyer if non balanced forks were requested
+ // Create the forks, making the second heavier if non balanced forks were requested
hashes1, headers1, blocks1, receipts1 := dl.makeChain(f, 1, blocks[hashes[0]], receipts[hashes[0]], false)
hashes1 = append(hashes1, hashes[1:]...)
@@ -663,12 +663,14 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
// Tests that simple synchronization against a canonical chain works correctly.
// In this test common ancestor lookup should be short circuited and not require
// binary searching.
-func TestCanonicalSynchronisation62(t *testing.T) { testCanonicalSynchronisation(t, 62, FullSync) }
-func TestCanonicalSynchronisation63Full(t *testing.T) { testCanonicalSynchronisation(t, 63, FullSync) }
-func TestCanonicalSynchronisation63Fast(t *testing.T) { testCanonicalSynchronisation(t, 63, FastSync) }
-func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) }
-func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) }
-func TestCanonicalSynchronisation64Light(t *testing.T) { testCanonicalSynchronisation(t, 64, LightSync) }
+func TestCanonicalSynchronisation62(t *testing.T) { testCanonicalSynchronisation(t, 62, FullSync) }
+func TestCanonicalSynchronisation63Full(t *testing.T) { testCanonicalSynchronisation(t, 63, FullSync) }
+func TestCanonicalSynchronisation63Fast(t *testing.T) { testCanonicalSynchronisation(t, 63, FastSync) }
+func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) }
+func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) }
+func TestCanonicalSynchronisation64Light(t *testing.T) {
+ testCanonicalSynchronisation(t, 64, LightSync)
+}
func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
t.Parallel()
@@ -1357,8 +1359,8 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) {
}
}
-//Tests that synchronisation progress (origin block number, current block number
-//and highest block number) is tracked and updated correctly.
+// Tests that synchronisation progress (origin block number, current block number
+// and highest block number) is tracked and updated correctly.
func TestSyncProgress62(t *testing.T) { testSyncProgress(t, 62, FullSync) }
func TestSyncProgress63Full(t *testing.T) { testSyncProgress(t, 63, FullSync) }
func TestSyncProgress63Fast(t *testing.T) { testSyncProgress(t, 63, FastSync) }
diff --git a/eth/downloader/fakepeer.go b/eth/downloader/fakepeer.go
index 4d7c5ac28..5858a0549 100644
--- a/eth/downloader/fakepeer.go
+++ b/eth/downloader/fakepeer.go
@@ -21,6 +21,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
)
@@ -126,7 +127,7 @@ func (p *FakePeer) RequestBodies(hashes []common.Hash) error {
uncles [][]*types.Header
)
for _, hash := range hashes {
- block := core.GetBlock(p.db, hash, p.hc.GetBlockNumber(hash))
+ block := rawdb.GetBlock(p.db, hash, p.hc.GetBlockNumber(hash))
txs = append(txs, block.Transactions())
uncles = append(uncles, block.Uncles())
@@ -140,7 +141,7 @@ func (p *FakePeer) RequestBodies(hashes []common.Hash) error {
func (p *FakePeer) RequestReceipts(hashes []common.Hash) error {
var receipts [][]*types.Receipt
for _, hash := range hashes {
- receipts = append(receipts, core.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash)))
+ receipts = append(receipts, rawdb.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash), p.hc.Config()))
}
p.dl.DeliverReceipts(p.id, receipts)
return nil
diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go
index 0ed4e75fa..43569da2d 100644
--- a/eth/downloader/queue.go
+++ b/eth/downloader/queue.go
@@ -29,6 +29,7 @@ import (
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/metrics"
+ "github.com/tomochain/tomochain/trie"
"gopkg.in/karalabe/cookiejar.v2/collections/prque"
)
@@ -767,7 +768,7 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi
defer q.lock.Unlock()
reconstruct := func(header *types.Header, index int, result *fetchResult) error {
- if types.DeriveSha(types.Transactions(txLists[index])) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash {
+ if types.DeriveSha(types.Transactions(txLists[index]), new(trie.StackTrie)) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash {
return errInvalidBody
}
result.Transactions = txLists[index]
@@ -785,7 +786,7 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int,
defer q.lock.Unlock()
reconstruct := func(header *types.Header, index int, result *fetchResult) error {
- if types.DeriveSha(types.Receipts(receiptList[index])) != header.ReceiptHash {
+ if types.DeriveSha(types.Receipts(receiptList[index]), new(trie.StackTrie)) != header.ReceiptHash {
return errInvalidReceipt
}
result.Receipts = receiptList[index]
diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go
index 3809a0c57..747c9f9cf 100644
--- a/eth/downloader/statesync.go
+++ b/eth/downloader/statesync.go
@@ -18,16 +18,16 @@ package downloader
import (
"fmt"
- "github.com/tomochain/tomochain/ethdb/memorydb"
"hash"
"sync"
"time"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/crypto/sha3"
"github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/ethdb/memorydb"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/trie"
)
@@ -470,6 +470,6 @@ func (s *stateSync) updateStats(written, duplicate, unexpected int, duration tim
log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "retry", len(s.tasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected)
}
if written > 0 {
- core.WriteTrieSyncProgress(s.d.stateDB, s.d.syncStatsState.processed)
+ rawdb.WriteTrieSyncProgress(s.d.stateDB, s.d.syncStatsState.processed)
}
}
diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go
index 65b15094d..d1bc108fd 100644
--- a/eth/fetcher/fetcher.go
+++ b/eth/fetcher/fetcher.go
@@ -19,14 +19,16 @@ package fetcher
import (
"errors"
- "github.com/hashicorp/golang-lru"
"math/rand"
"time"
+ lru "github.com/hashicorp/golang-lru"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/trie"
"gopkg.in/karalabe/cookiejar.v2/collections/prque"
)
@@ -468,7 +470,7 @@ func (f *Fetcher) loop() {
announce.time = task.time
// If the block is empty (header only), short circuit into the final import queue
- if header.TxHash == types.DeriveSha(types.Transactions{}) && header.UncleHash == types.CalcUncleHash([]*types.Header{}) {
+ if header.TxHash == types.EmptyRootHash && header.UncleHash == types.CalcUncleHash([]*types.Header{}) {
log.Trace("Block empty, skipping body retrieval", "peer", announce.origin, "number", header.Number, "hash", header.Hash())
block := types.NewBlockWithHeader(header)
@@ -530,7 +532,7 @@ func (f *Fetcher) loop() {
for hash, announce := range f.completing {
if f.queued[hash] == nil {
- txnHash := types.DeriveSha(types.Transactions(task.transactions[i]))
+ txnHash := types.DeriveSha(types.Transactions(task.transactions[i]), new(trie.StackTrie))
uncleHash := types.CalcUncleHash(task.uncles[i])
if txnHash == announce.header.TxHash && uncleHash == announce.header.UncleHash && announce.origin == task.peer {
diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go
index ab7e03aaa..951b2fcd6 100644
--- a/eth/fetcher/fetcher_test.go
+++ b/eth/fetcher/fetcher_test.go
@@ -18,7 +18,6 @@ package fetcher
import (
"errors"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"sync"
"sync/atomic"
@@ -28,9 +27,11 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/params"
+ "github.com/tomochain/tomochain/trie"
)
var (
@@ -38,7 +39,7 @@ var (
testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
testAddress = crypto.PubkeyToAddress(testKey.PublicKey)
genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000))
- unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil)
+ unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil, new(trie.StackTrie))
)
// makeChain creates a chain of n blocks starting at and including parent.
diff --git a/eth/filters/bench_test.go b/eth/filters/bench_test.go
index 3648a3db2..9822a85e4 100644
--- a/eth/filters/bench_test.go
+++ b/eth/filters/bench_test.go
@@ -20,14 +20,13 @@ import (
"bytes"
"context"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"testing"
"time"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/bitutil"
- "github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/event"
@@ -68,18 +67,18 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) {
benchDataDir := node.DefaultDataDir() + "/geth/chaindata"
fmt.Println("Running bloombits benchmark section size:", sectionSize)
- db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"")
+ db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "")
if err != nil {
b.Fatalf("error opening database at %v: %v", benchDataDir, err)
}
- head := core.GetHeadBlockHash(db)
+ head := rawdb.GetHeadBlockHash(db)
if head == (common.Hash{}) {
b.Fatalf("chain data not found at %v", benchDataDir)
}
clearBloomBits(db)
fmt.Println("Generating bloombits data...")
- headNum := core.GetBlockNumber(db, head)
+ headNum := rawdb.GetBlockNumber(db, head)
if headNum < sectionSize+512 {
b.Fatalf("not enough blocks for running a benchmark")
}
@@ -94,14 +93,14 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) {
}
var header *types.Header
for i := sectionIdx * sectionSize; i < (sectionIdx+1)*sectionSize; i++ {
- hash := core.GetCanonicalHash(db, i)
- header = core.GetHeader(db, hash, i)
+ hash := rawdb.GetCanonicalHash(db, i)
+ header = rawdb.GetHeader(db, hash, i)
if header == nil {
b.Fatalf("Error creating bloomBits data")
}
bc.AddBloom(uint(i-sectionIdx*sectionSize), header.Bloom)
}
- sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*sectionSize-1)
+ sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*sectionSize-1)
for i := 0; i < types.BloomBitLength; i++ {
data, err := bc.Bitset(uint(i))
if err != nil {
@@ -110,7 +109,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) {
comp := bitutil.CompressBytes(data)
dataSize += uint64(len(data))
compSize += uint64(len(comp))
- core.WriteBloomBits(db, uint(i), sectionIdx, sectionHead, comp)
+ rawdb.WriteBloomBits(db, uint(i), sectionIdx, sectionHead, comp)
}
//if sectionIdx%50 == 0 {
// fmt.Println(" section", sectionIdx, "/", cnt)
@@ -130,7 +129,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) {
for i := 0; i < benchFilterCnt; i++ {
if i%20 == 0 {
db.Close()
- db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"")
+ db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "")
backend = &testBackend{mux, db, cnt, new(event.Feed), new(event.Feed), new(event.Feed), new(event.Feed)}
}
var addr common.Address
@@ -148,7 +147,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) {
}
func forEachKey(db ethdb.Database, startPrefix, endPrefix []byte, fn func(key []byte)) {
- it := db.NewIterator(startPrefix,nil)
+ it := db.NewIterator(startPrefix, nil)
for it.Next() {
key := it.Key()
cmpLen := len(key)
@@ -176,15 +175,15 @@ func clearBloomBits(db ethdb.Database) {
func BenchmarkNoBloomBits(b *testing.B) {
benchDataDir := node.DefaultDataDir() + "/geth/chaindata"
fmt.Println("Running benchmark without bloombits")
- db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"")
+ db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "")
if err != nil {
b.Fatalf("error opening database at %v: %v", benchDataDir, err)
}
- head := core.GetHeadBlockHash(db)
+ head := rawdb.GetHeadBlockHash(db)
if head == (common.Hash{}) {
b.Fatalf("chain data not found at %v", benchDataDir)
}
- headNum := core.GetBlockNumber(db, head)
+ headNum := rawdb.GetBlockNumber(db, head)
clearBloomBits(db)
diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go
index 3d92fc1ac..75c3c5e41 100644
--- a/eth/filters/filter_system.go
+++ b/eth/filters/filter_system.go
@@ -28,6 +28,7 @@ import (
ethereum "github.com/tomochain/tomochain"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/rpc"
@@ -348,11 +349,11 @@ func (es *EventSystem) lightFilterNewHead(newHeader *types.Header, callBack func
for oldh.Hash() != newh.Hash() {
if oldh.Number.Uint64() >= newh.Number.Uint64() {
oldHeaders = append(oldHeaders, oldh)
- oldh = core.GetHeader(es.backend.ChainDb(), oldh.ParentHash, oldh.Number.Uint64()-1)
+ oldh = rawdb.GetHeader(es.backend.ChainDb(), oldh.ParentHash, oldh.Number.Uint64()-1)
}
if oldh.Number.Uint64() < newh.Number.Uint64() {
newHeaders = append(newHeaders, newh)
- newh = core.GetHeader(es.backend.ChainDb(), newh.ParentHash, newh.Number.Uint64()-1)
+ newh = rawdb.GetHeader(es.backend.ChainDb(), newh.ParentHash, newh.Number.Uint64()-1)
if newh == nil {
// happens when CHT syncing, nothing to do
newh = oldh
diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go
index d947a672a..077a9c41b 100644
--- a/eth/filters/filter_system_test.go
+++ b/eth/filters/filter_system_test.go
@@ -19,7 +19,6 @@ package filters
import (
"context"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"math/rand"
"reflect"
@@ -31,6 +30,7 @@ import (
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/event"
@@ -48,6 +48,10 @@ type testBackend struct {
chainFeed *event.Feed
}
+func (b *testBackend) ChainConfig() *params.ChainConfig {
+ return params.TestChainConfig
+}
+
func (b *testBackend) ChainDb() ethdb.Database {
return b.db
}
@@ -60,23 +64,23 @@ func (b *testBackend) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumbe
var hash common.Hash
var num uint64
if blockNr == rpc.LatestBlockNumber {
- hash = core.GetHeadBlockHash(b.db)
- num = core.GetBlockNumber(b.db, hash)
+ hash = rawdb.GetHeadBlockHash(b.db)
+ num = rawdb.GetBlockNumber(b.db, hash)
} else {
num = uint64(blockNr)
- hash = core.GetCanonicalHash(b.db, num)
+ hash = rawdb.GetCanonicalHash(b.db, num)
}
- return core.GetHeader(b.db, hash, num), nil
+ return rawdb.GetHeader(b.db, hash, num), nil
}
func (b *testBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) {
- number := core.GetBlockNumber(b.db, blockHash)
- return core.GetBlockReceipts(b.db, blockHash, number), nil
+ number := rawdb.GetBlockNumber(b.db, blockHash)
+ return rawdb.GetBlockReceipts(b.db, blockHash, number, b.ChainConfig()), nil
}
func (b *testBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) {
- number := core.GetBlockNumber(b.db, blockHash)
- receipts := core.GetBlockReceipts(b.db, blockHash, number)
+ number := rawdb.GetBlockNumber(b.db, blockHash)
+ receipts := rawdb.GetBlockReceipts(b.db, blockHash, number, b.ChainConfig())
logs := make([][]*types.Log, len(receipts))
for i, receipt := range receipts {
@@ -122,8 +126,8 @@ func (b *testBackend) ServiceFilter(ctx context.Context, session *bloombits.Matc
task.Bitsets = make([][]byte, len(task.Sections))
for i, section := range task.Sections {
if rand.Int()%4 != 0 { // Handle occasional missing deliveries
- head := core.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1)
- task.Bitsets[i], _ = core.GetBloomBits(b.db, task.Bit, section, head)
+ head := rawdb.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1)
+ task.Bitsets[i], _ = rawdb.GetBloomBits(b.db, task.Bit, section, head)
}
}
request <- task
diff --git a/eth/filters/filter_test.go b/eth/filters/filter_test.go
index bdfb6e37f..a5ddb00db 100644
--- a/eth/filters/filter_test.go
+++ b/eth/filters/filter_test.go
@@ -18,7 +18,6 @@ package filters
import (
"context"
- "github.com/tomochain/tomochain/core/rawdb"
"io/ioutil"
"math/big"
"os"
@@ -27,6 +26,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/event"
@@ -50,7 +50,7 @@ func BenchmarkFilters(b *testing.B) {
defer os.RemoveAll(dir)
var (
- db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"")
+ db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "")
mux = new(event.TypeMux)
txFeed = new(event.Feed)
rmLogsFeed = new(event.Feed)
@@ -84,14 +84,14 @@ func BenchmarkFilters(b *testing.B) {
}
})
for i, block := range chain {
- core.WriteBlock(db, block)
- if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
+ rawdb.WriteBlock(db, block)
+ if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
b.Fatalf("failed to insert block number: %v", err)
}
- if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil {
+ if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil {
b.Fatalf("failed to insert block number: %v", err)
}
- if err := core.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil {
+ if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil {
b.Fatal("error writing block receipts:", err)
}
}
@@ -115,7 +115,7 @@ func TestFilters(t *testing.T) {
defer os.RemoveAll(dir)
var (
- db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"")
+ db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "")
mux = new(event.TypeMux)
txFeed = new(event.Feed)
rmLogsFeed = new(event.Feed)
@@ -144,6 +144,7 @@ func TestFilters(t *testing.T) {
},
}
gen.AddUncheckedReceipt(receipt)
+ gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil))
case 2:
receipt := types.NewReceipt(nil, false, 0)
receipt.Logs = []*types.Log{
@@ -153,6 +154,7 @@ func TestFilters(t *testing.T) {
},
}
gen.AddUncheckedReceipt(receipt)
+ gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil))
case 998:
receipt := types.NewReceipt(nil, false, 0)
receipt.Logs = []*types.Log{
@@ -162,6 +164,7 @@ func TestFilters(t *testing.T) {
},
}
gen.AddUncheckedReceipt(receipt)
+ gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil))
case 999:
receipt := types.NewReceipt(nil, false, 0)
receipt.Logs = []*types.Log{
@@ -171,17 +174,19 @@ func TestFilters(t *testing.T) {
},
}
gen.AddUncheckedReceipt(receipt)
+ gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil))
}
})
+
for i, block := range chain {
- core.WriteBlock(db, block)
- if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
+ rawdb.WriteBlock(db, block)
+ if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil {
t.Fatalf("failed to insert block number: %v", err)
}
- if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil {
+ if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil {
t.Fatalf("failed to insert block number: %v", err)
}
- if err := core.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil {
+ if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil {
t.Fatal("error writing block receipts:", err)
}
}
diff --git a/eth/handler.go b/eth/handler.go
index eae95a9b1..2f377fe8e 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -38,7 +38,7 @@ import (
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
)
@@ -178,7 +178,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
NodeInfo: func() interface{} {
return manager.NodeInfo()
},
- PeerInfo: func(id discover.NodeID) interface{} {
+ PeerInfo: func(id enode.ID) interface{} {
if p := manager.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil {
return p.Info()
}
diff --git a/eth/handler_test.go b/eth/handler_test.go
index d8d2f0097..bee29ea90 100644
--- a/eth/handler_test.go
+++ b/eth/handler_test.go
@@ -17,13 +17,14 @@
package eth
import (
- "github.com/tomochain/tomochain/core/rawdb"
"math"
"math/big"
"math/rand"
"testing"
"time"
+ "github.com/tomochain/tomochain/core/rawdb"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
@@ -343,9 +344,9 @@ func testGetNodeData(t *testing.T, protocol int) {
// Fetch for now the entire chain db
hashes := []common.Hash{}
- it:=db.NewIterator(nil,nil)
+ it := db.NewIterator(nil, nil)
for it.Next() {
- key:=it.Key()
+ key := it.Key()
if len(key) == len(common.Hash{}) {
hashes = append(hashes, common.BytesToHash(key))
}
@@ -374,7 +375,7 @@ func testGetNodeData(t *testing.T, protocol int) {
}
accounts := []common.Address{testBank, acc1Addr, acc2Addr}
for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ {
- trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb))
+ trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil)
for j, acc := range accounts {
state, _ := pm.blockchain.State()
@@ -470,7 +471,7 @@ func testDAOChallenge(t *testing.T, localForked, remoteForked bool, timeout bool
var (
evmux = new(event.TypeMux)
pow = ethash.NewFaker()
- db = rawdb.NewMemoryDatabase()
+ db = rawdb.NewMemoryDatabase()
config = ¶ms.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked}
gspec = &core.Genesis{Config: config}
genesis = gspec.MustCommit(db)
diff --git a/eth/helper_test.go b/eth/helper_test.go
index 6ea65856f..a4da7286e 100644
--- a/eth/helper_test.go
+++ b/eth/helper_test.go
@@ -22,7 +22,6 @@ package eth
import (
"crypto/ecdsa"
"crypto/rand"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"sort"
"sync"
@@ -31,6 +30,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/crypto"
@@ -38,7 +38,7 @@ import (
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/params"
)
@@ -149,7 +149,7 @@ func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*te
app, net := p2p.MsgPipe()
// Generate a random id and create the peer
- var id discover.NodeID
+ var id enode.ID
rand.Read(id[:])
peer := pm.newPeer(version, p2p.NewPeer(id, name, nil), net)
diff --git a/eth/peer.go b/eth/peer.go
index 694267885..314c38458 100644
--- a/eth/peer.go
+++ b/eth/peer.go
@@ -24,6 +24,7 @@ import (
"time"
mapset "github.com/deckarep/golang-set"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/p2p"
@@ -38,10 +39,26 @@ var (
const (
maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
+ maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS)
maxKnownOrderTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
maxKnownLendingTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
- maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS)
- handshakeTimeout = 5 * time.Second
+
+ // maxQueuedTxs is the maximum number of transaction lists to queue up before
+ // dropping broadcasts. This is a sensitive number as a transaction list might
+ // contain a single transaction, or thousands.
+ maxQueuedTxs = 128
+
+ // maxQueuedProps is the maximum number of block propagations to queue up before
+ // dropping broadcasts. There's not much point in queueing stale blocks, so a few
+ // that might cover uncles should be enough.
+ maxQueuedProps = 4
+
+ // maxQueuedAnns is the maximum number of block announcements to queue up before
+ // dropping broadcasts. Similarly to block propagations, there's no point to queue
+ // above some healthy uncle limit, so use that.
+ maxQueuedAnns = 4
+
+ handshakeTimeout = 5 * time.Second
)
// PeerInfo represents a short summary of the Ethereum sub-protocol metadata known
@@ -52,12 +69,17 @@ type PeerInfo struct {
Head string `json:"head"` // SHA3 hash of the peer's best owned block
}
+// propEvent is a block propagation, waiting for its turn in the broadcast queue.
+type propEvent struct {
+ block *types.Block
+ td *big.Int
+}
+
type peer struct {
id string
*p2p.Peer
- rw p2p.MsgReadWriter
- pairRw p2p.MsgReadWriter
+ rw p2p.MsgReadWriter
version int // Protocol version negotiated
forkDrop *time.Timer // Timed connection dropper if forks aren't validated in time
@@ -66,27 +88,66 @@ type peer struct {
td *big.Int
lock sync.RWMutex
- knownTxs mapset.Set // Set of transaction hashes known to be known by this peer
- knownBlocks mapset.Set // Set of block hashes known to be known by this peer
- knownOrderTxs mapset.Set // Set of order transaction hashes known to be known by this peer
- knownLendingTxs mapset.Set // Set of lending transaction hashes known to be known by this peer
+ knownTxs mapset.Set // Set of transaction hashes known to be known by this peer
+ knownBlocks mapset.Set // Set of block hashes known to be known by this peer
+ knownOrderTxs mapset.Set // Set of order transaction hashes known to be known by this peer
+ knownLendingTxs mapset.Set // Set of lending transaction hashes known to be known by this peer
+ queuedTxs chan []*types.Transaction // Queue of transactions to broadcast to the peer
+ queuedProps chan *propEvent // Queue of blocks to broadcast to the peer
+ queuedAnns chan *types.Block // Queue of blocks to announce to the peer
+ term chan struct{} // Termination channel to stop the broadcaster
}
func newPeer(version int, p *p2p.Peer, rw p2p.MsgReadWriter) *peer {
- id := p.ID()
-
return &peer{
- Peer: p,
- rw: rw,
- version: version,
- id: fmt.Sprintf("%x", id[:8]),
- knownTxs: mapset.NewSet(),
- knownBlocks: mapset.NewSet(),
- knownOrderTxs: mapset.NewSet(),
- knownLendingTxs: mapset.NewSet(),
+ Peer: p,
+ rw: rw,
+ version: version,
+ id: fmt.Sprintf("%x", p.ID().Bytes()[:8]),
+ knownTxs: mapset.NewSet(),
+ knownBlocks: mapset.NewSet(),
+ queuedTxs: make(chan []*types.Transaction, maxQueuedTxs),
+ queuedProps: make(chan *propEvent, maxQueuedProps),
+ queuedAnns: make(chan *types.Block, maxQueuedAnns),
+ term: make(chan struct{}),
+ }
+}
+
+// broadcast is a write loop that multiplexes block propagations, announcements
+// and transaction broadcasts into the remote peer. The goal is to have an async
+// writer that does not lock up node internals.
+func (p *peer) broadcast() {
+ for {
+ select {
+ case txs := <-p.queuedTxs:
+ if err := p.SendTransactions(txs); err != nil {
+ return
+ }
+ p.Log().Trace("Broadcast transactions", "count", len(txs))
+
+ case prop := <-p.queuedProps:
+ if err := p.SendNewBlock(prop.block, prop.td); err != nil {
+ return
+ }
+ p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td)
+
+ case block := <-p.queuedAnns:
+ if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil {
+ return
+ }
+ p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash())
+
+ case <-p.term:
+ return
+ }
}
}
+// close signals the broadcast goroutine to terminate.
+func (p *peer) close() {
+ close(p.term)
+}
+
// Info gathers and returns a collection of metadata known about a peer.
func (p *peer) Info() *PeerInfo {
hash, td := p.Head()
@@ -184,6 +245,19 @@ func (p *peer) SendLendingTransactions(txs types.LendingTransactions) error {
return p2p.Send(p.rw, LendingTxMsg, txs)
}
+// AsyncSendTransactions queues list of transactions propagation to a remote
+// peer. If the peer's broadcast queue is full, the event is silently dropped.
+func (p *peer) AsyncSendTransactions(txs []*types.Transaction) {
+ select {
+ case p.queuedTxs <- txs:
+ for _, tx := range txs {
+ p.knownTxs.Add(tx.Hash())
+ }
+ default:
+ p.Log().Debug("Dropping transaction propagation", "count", len(txs))
+ }
+}
+
// SendNewBlockHashes announces the availability of a number of blocks through
// a hash notification.
func (p *peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error {
@@ -198,127 +272,102 @@ func (p *peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error
return p2p.Send(p.rw, NewBlockHashesMsg, request)
}
+// AsyncSendNewBlockHash queues the availability of a block for propagation to a
+// remote peer. If the peer's broadcast queue is full, the event is silently
+// dropped.
+func (p *peer) AsyncSendNewBlockHash(block *types.Block) {
+ select {
+ case p.queuedAnns <- block:
+ p.knownBlocks.Add(block.Hash())
+ default:
+ p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash())
+ }
+}
+
// SendNewBlock propagates an entire block to a remote peer.
func (p *peer) SendNewBlock(block *types.Block, td *big.Int) error {
p.knownBlocks.Add(block.Hash())
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, NewBlockMsg, []interface{}{block, td})
- } else {
- return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td})
+ return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td})
+}
+
+// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If
+// the peer's broadcast queue is full, the event is silently dropped.
+func (p *peer) AsyncSendNewBlock(block *types.Block, td *big.Int) {
+ select {
+ case p.queuedProps <- &propEvent{block: block, td: td}:
+ p.knownBlocks.Add(block.Hash())
+ default:
+ p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash())
}
}
// SendBlockHeaders sends a batch of block headers to the remote peer.
func (p *peer) SendBlockHeaders(headers []*types.Header) error {
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, BlockHeadersMsg, headers)
- } else {
- return p2p.Send(p.rw, BlockHeadersMsg, headers)
- }
+ return p2p.Send(p.rw, BlockHeadersMsg, headers)
}
// SendBlockBodies sends a batch of block contents to the remote peer.
func (p *peer) SendBlockBodies(bodies []*blockBody) error {
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, BlockBodiesMsg, blockBodiesData(bodies))
- } else {
- return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies))
- }
+ return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies))
}
// SendBlockBodiesRLP sends a batch of block contents to the remote peer from
// an already RLP encoded format.
func (p *peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error {
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, BlockBodiesMsg, bodies)
- } else {
- return p2p.Send(p.rw, BlockBodiesMsg, bodies)
- }
+ return p2p.Send(p.rw, BlockBodiesMsg, bodies)
}
// SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the
// hashes requested.
func (p *peer) SendNodeData(data [][]byte) error {
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, NodeDataMsg, data)
- } else {
- return p2p.Send(p.rw, NodeDataMsg, data)
- }
+ return p2p.Send(p.rw, NodeDataMsg, data)
}
// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the
// ones requested from an already RLP encoded format.
func (p *peer) SendReceiptsRLP(receipts []rlp.RawValue) error {
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, ReceiptsMsg, receipts)
- } else {
- return p2p.Send(p.rw, ReceiptsMsg, receipts)
- }
+ return p2p.Send(p.rw, ReceiptsMsg, receipts)
}
// RequestOneHeader is a wrapper around the header query functions to fetch a
// single header. It is used solely by the fetcher.
func (p *peer) RequestOneHeader(hash common.Hash) error {
p.Log().Debug("Fetching single header", "hash", hash)
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false})
- } else {
- return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false})
- }
+ return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false})
}
// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the
// specified header query, based on the hash of an origin block.
func (p *peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse)
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
- } else {
- return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
- }
+ return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
}
// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the
// specified header query, based on the number of an origin block.
func (p *peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse)
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
- } else {
- return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
- }
+ return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse})
}
// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes
// specified.
func (p *peer) RequestBodies(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of block bodies", "count", len(hashes))
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, GetBlockBodiesMsg, hashes)
- } else {
- return p2p.Send(p.rw, GetBlockBodiesMsg, hashes)
- }
+ return p2p.Send(p.rw, GetBlockBodiesMsg, hashes)
}
// RequestNodeData fetches a batch of arbitrary data from a node's known state
// data, corresponding to the specified hashes.
func (p *peer) RequestNodeData(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of state data", "count", len(hashes))
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, GetNodeDataMsg, hashes)
- } else {
- return p2p.Send(p.rw, GetNodeDataMsg, hashes)
- }
+ return p2p.Send(p.rw, GetNodeDataMsg, hashes)
}
// RequestReceipts fetches a batch of transaction receipts from a remote node.
func (p *peer) RequestReceipts(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of receipts", "count", len(hashes))
- if p.pairRw != nil {
- return p2p.Send(p.pairRw, GetReceiptsMsg, hashes)
- } else {
- return p2p.Send(p.rw, GetReceiptsMsg, hashes)
- }
+ return p2p.Send(p.rw, GetReceiptsMsg, hashes)
}
// Handshake executes the eth protocol handshake, negotiating version number,
@@ -406,7 +455,8 @@ func newPeerSet() *peerSet {
}
// Register injects a new peer into the working set, or returns an error if the
-// peer is already known.
+// peer is already known. If a new peer it registered, its broadcast loop is also
+// started.
func (ps *peerSet) Register(p *peer) error {
ps.lock.Lock()
defer ps.lock.Unlock()
@@ -414,16 +464,12 @@ func (ps *peerSet) Register(p *peer) error {
if ps.closed {
return errClosed
}
- if existPeer, ok := ps.peers[p.id]; ok {
- if existPeer.pairRw != nil {
- return errAlreadyRegistered
- }
- existPeer.PairPeer = p.Peer
- existPeer.pairRw = p.rw
- p.PairPeer = existPeer.Peer
- return p2p.ErrAddPairPeer
+ if _, ok := ps.peers[p.id]; ok {
+ return errAlreadyRegistered
}
ps.peers[p.id] = p
+ go p.broadcast()
+
return nil
}
@@ -433,10 +479,13 @@ func (ps *peerSet) Unregister(id string) error {
ps.lock.Lock()
defer ps.lock.Unlock()
- if _, ok := ps.peers[id]; !ok {
+ p, ok := ps.peers[id]
+ if !ok {
return errNotRegistered
}
delete(ps.peers, id)
+ p.close()
+
return nil
}
@@ -486,7 +535,7 @@ func (ps *peerSet) PeersWithoutTx(hash common.Hash) []*peer {
return list
}
-// PeersWithoutTx retrieves a list of peers that do not have a given transaction
+// OrderPeersWithoutTx retrieves a list of peers that do not have a given transaction
// in their set of known hashes.
func (ps *peerSet) OrderPeersWithoutTx(hash common.Hash) []*peer {
ps.lock.RLock()
diff --git a/eth/sync.go b/eth/sync.go
index a1224b3ca..ae95ad8d5 100644
--- a/eth/sync.go
+++ b/eth/sync.go
@@ -25,7 +25,7 @@ import (
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/eth/downloader"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
const (
@@ -64,7 +64,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) {
// the transactions in small packs to one peer at a time.
func (pm *ProtocolManager) txsyncLoop() {
var (
- pending = make(map[discover.NodeID]*txsync)
+ pending = make(map[enode.ID]*txsync)
sending = false // whether a send is active
pack = new(txsync) // the pack that is being sent
done = make(chan error, 1) // result of the send
diff --git a/eth/sync_test.go b/eth/sync_test.go
index 9b447f2a1..491a7513c 100644
--- a/eth/sync_test.go
+++ b/eth/sync_test.go
@@ -23,7 +23,7 @@ import (
"github.com/tomochain/tomochain/eth/downloader"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
// Tests that fast sync gets disabled as soon as a real block is successfully
@@ -42,8 +42,8 @@ func TestFastSyncDisabling(t *testing.T) {
// Sync up the two peers
io1, io2 := p2p.MsgPipe()
- go pmFull.handle(pmFull.newPeer(63, p2p.NewPeer(discover.NodeID{}, "empty", nil), io2))
- go pmEmpty.handle(pmEmpty.newPeer(63, p2p.NewPeer(discover.NodeID{}, "full", nil), io1))
+ go pmFull.handle(pmFull.newPeer(63, p2p.NewPeer(enode.ID{}, "empty", nil), io2))
+ go pmEmpty.handle(pmEmpty.newPeer(63, p2p.NewPeer(enode.ID{}, "full", nil), io1))
time.Sleep(250 * time.Millisecond)
pmEmpty.synchronise(pmEmpty.peers.BestPeer())
diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go
index 38d407517..2764f7034 100644
--- a/eth/tracers/tracers_test.go
+++ b/eth/tracers/tracers_test.go
@@ -20,17 +20,18 @@ import (
"crypto/ecdsa"
"crypto/rand"
"encoding/json"
- "github.com/tomochain/tomochain/core/rawdb"
"io/ioutil"
"math/big"
"path/filepath"
"reflect"
"strings"
"testing"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/crypto"
@@ -169,21 +170,21 @@ func TestPrestateTracerCreate2(t *testing.T) {
Balance: big.NewInt(500000000000000),
}
db := rawdb.NewMemoryDatabase()
- statedb := tests.MakePreState(db, alloc)
+ statedb := tests.MakePreState(db, alloc, false)
// Create the tracer, the EVM environment and run it
tracer, err := New("prestateTracer")
if err != nil {
t.Fatalf("failed to create call tracer: %v", err)
}
- evm := vm.NewEVM(context, statedb, nil, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tracer})
+ evm := vm.NewEVM(context, statedb, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer})
- msg, err := tx.AsMessage(signer, nil, nil)
+ msg, err := core.TransactionToMessage(tx, signer, nil, nil)
if err != nil {
t.Fatalf("failed to prepare transaction for tracing: %v", err)
}
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
- if _, _, _, err = st.TransitionDb(common.Address{}); err != nil {
+ if _, err = st.TransitionDb(common.Address{}); err != nil {
t.Fatalf("failed to execute transaction: %v", err)
}
// Retrieve the trace result and compare against the etalon
@@ -244,7 +245,7 @@ func TestCallTracer(t *testing.T) {
GasPrice: tx.GasPrice(),
}
db := rawdb.NewMemoryDatabase()
- statedb := tests.MakePreState(db, test.Genesis.Alloc)
+ statedb := tests.MakePreState(db, test.Genesis.Alloc, false)
// Create the tracer, the EVM environment and run it
tracer, err := New("callTracer")
@@ -253,12 +254,12 @@ func TestCallTracer(t *testing.T) {
}
evm := vm.NewEVM(context, statedb, nil, test.Genesis.Config, vm.Config{Debug: true, Tracer: tracer})
- msg, err := tx.AsMessage(signer, nil, common.Big0)
+ msg, err := core.TransactionToMessage(tx, signer, nil, common.Big0)
if err != nil {
t.Fatalf("failed to prepare transaction for tracing: %v", err)
}
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
- if _, _, _, err = st.TransitionDb(common.Address{}); err != nil {
+ if _, err = st.TransitionDb(common.Address{}); err != nil {
t.Fatalf("failed to execute transaction: %v", err)
}
// Retrieve the trace result and compare against the etalon
diff --git a/go.mod b/go.mod
index 15d820f80..6c1690598 100644
--- a/go.mod
+++ b/go.mod
@@ -4,44 +4,45 @@ go 1.19
require (
bazil.org/fuse v0.0.0-20180421153158-65cc252bf669
- github.com/VictoriaMetrics/fastcache v1.5.7
+ github.com/VictoriaMetrics/fastcache v1.6.0
github.com/aristanetworks/goarista v0.0.0-20191023202215-f096da5361bb
github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6
github.com/cespare/cp v1.1.1
github.com/davecgh/go-spew v1.1.1
github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea
- github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf
+ github.com/docker/docker v1.6.2
github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7
github.com/edsrzf/mmap-go v1.0.0
- github.com/fatih/color v1.6.0
+ github.com/fatih/color v1.7.0
github.com/gizak/termui v2.2.0+incompatible
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8
- github.com/go-stack/stack v1.8.0
- github.com/golang/protobuf v1.3.2
- github.com/golang/snappy v0.0.1
+ github.com/go-stack/stack v1.8.1
+ github.com/golang/protobuf v1.5.2
+ github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb
github.com/hashicorp/golang-lru v0.5.3
- github.com/huin/goupnp v1.0.0
+ github.com/holiman/uint256 v1.2.2
+ github.com/huin/goupnp v1.0.3
github.com/influxdata/influxdb v1.7.9
- github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458
+ github.com/jackpal/go-nat-pmp v1.0.2
github.com/julienschmidt/httprouter v1.3.0
github.com/karalabe/hid v1.0.0
- github.com/mattn/go-colorable v0.1.0
+ github.com/mattn/go-colorable v0.1.13
github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416
- github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c
+ github.com/olekukonko/tablewriter v0.0.5
github.com/pborman/uuid v1.2.0
github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7
- github.com/pkg/errors v0.8.1
+ github.com/pkg/errors v0.9.1
github.com/prometheus/prometheus v1.7.2-0.20170814170113-3101606756c5
github.com/rjeczalik/notify v0.9.2
- github.com/rs/cors v1.6.0
+ github.com/rs/cors v1.7.0
github.com/steakknife/bloomfilter v0.0.0-20180922174646-6819c0d2a570
- github.com/stretchr/testify v1.4.0
- github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d
- golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
- golang.org/x/net v0.0.0-20220722155237-a158d28d115b
- golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4
- golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f
- golang.org/x/tools v0.1.12
+ github.com/stretchr/testify v1.8.1
+ github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7
+ golang.org/x/crypto v0.1.0
+ golang.org/x/net v0.8.0
+ golang.org/x/sync v0.1.0
+ golang.org/x/sys v0.7.0
+ golang.org/x/tools v0.7.0
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
gopkg.in/karalabe/cookiejar.v2 v2.0.0-20150724131613-8dcd6a7f4951
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce
@@ -50,27 +51,30 @@ require (
)
require (
- github.com/cespare/xxhash/v2 v2.1.1 // indirect
+ github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/dlclark/regexp2 v1.7.0 // indirect
+ github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
- github.com/google/go-cmp v0.3.1 // indirect
+ github.com/google/go-cmp v0.5.9 // indirect
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
- github.com/google/uuid v1.0.0 // indirect
- github.com/kr/pretty v0.3.0 // indirect
+ github.com/google/uuid v1.3.0 // indirect
+ github.com/holiman/bloomfilter/v2 v2.0.3
+ github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e // indirect
github.com/maruel/ut v1.0.2 // indirect
- github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035 // indirect
- github.com/mattn/go-runewidth v0.0.4 // indirect
+ github.com/mattn/go-isatty v0.0.16 // indirect
+ github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect
github.com/naoina/go-stringutil v0.1.0 // indirect
github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
- github.com/rogpeppe/go-internal v1.6.1 // indirect
+ github.com/rogpeppe/go-internal v1.9.0 // indirect
github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 // indirect
- golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
- golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
- golang.org/x/text v0.3.8 // indirect
- gopkg.in/yaml.v2 v2.4.0 // indirect
- gotest.tools v2.2.0+incompatible // indirect
+ golang.org/x/mod v0.11.0 // indirect
+ golang.org/x/term v0.6.0 // indirect
+ golang.org/x/text v0.8.0 // indirect
+ golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect
+ google.golang.org/protobuf v1.28.1 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 2fff78c90..693c5630c 100644
--- a/go.sum
+++ b/go.sum
@@ -5,8 +5,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/DataDog/zstd v1.3.6-0.20190409195224-796139022798/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Shopify/sarama v1.23.1/go.mod h1:XLH1GYJnLVE0XCr6KdJGVJRTwY30moWNJ4sERjXX6fs=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
-github.com/VictoriaMetrics/fastcache v1.5.7 h1:4y6y0G8PRzszQUYIQHHssv/jgPHAb5qQuuDNdCbyAgw=
-github.com/VictoriaMetrics/fastcache v1.5.7/go.mod h1:ptDBkNMQI4RtmVo8VS/XwRY6RoTu1dAWCbrk+6WsEM8=
+github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o=
+github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
@@ -23,8 +23,9 @@ github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6 h1:Eey/GGQ/E5Xp1P2Ly
github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6/go.mod h1:Dmm/EzmjnCiweXmzRIAiUWCInVmPgjkzgv5k4tVyXiQ=
github.com/cespare/cp v1.1.1 h1:nCb6ZLdB7NRaqsm91JtQTAme2SKJzXVsdPIPkyJr1MU=
github.com/cespare/cp v1.1.1/go.mod h1:SOGHArjBr4JWaSDEVpWpo/hNg6RoKrls6Oh40hiwW+s=
-github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
+github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY=
github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic=
github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
@@ -38,8 +39,8 @@ github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea/go.mod h1:93vs
github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo=
github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
-github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf h1:sh8rkQZavChcmakYiSlqu2425CHyFXLZZnvm7PDpU8M=
-github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
+github.com/docker/docker v1.6.2 h1:HlFGsy+9/xrgMmhmN+NGhCc5SHGJ7I+kHosRR1xc/aI=
+github.com/docker/docker v1.6.2/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/dop251/goja v0.0.0-20211022113120-dc8c55024d06/go.mod h1:R9ET47fwRVRPZnOGvHxxhuZcbrMCuiqOz3Rlrh4KSnk=
github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7 h1:cVGkvrdHgyBkYeB6kMCaF5j2d9Bg4trgbIpcUrKrvk4=
github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7/go.mod h1:QMWlm50DNe14hD7t24KEqZuUdC9sOTy8W6XbCU1mlw4=
@@ -50,9 +51,12 @@ github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/edsrzf/mmap-go v1.0.0 h1:CEBF7HpRnUCSJgGUb5h1Gm7e3VkmVDrR8lvWVLtrOFw=
github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M=
-github.com/fatih/color v1.6.0 h1:66qjqZk8kalYAvDRtM1AdAJQI0tj4Wrue3Eq3B3pmFU=
-github.com/fatih/color v1.6.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
+github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
+github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
+github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
+github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
+github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
github.com/gizak/termui v2.2.0+incompatible h1:qvZU9Xll/Xd/Xr/YO+HfBKXhy8a8/94ao6vV9DSXzUE=
github.com/gizak/termui v2.2.0+incompatible/go.mod h1:PkJoWUt/zacQKysNfQtcw1RW+eK2SxkieVBtl+4ovLA=
@@ -63,40 +67,59 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
-github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
+github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw=
+github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
+github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
+github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
+github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
+github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
+github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
+github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
+github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
+github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk=
+github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
-github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
-github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
+github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8Bppgk=
github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
-github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
+github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao=
+github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA=
+github.com/holiman/uint256 v1.2.2 h1:TXKcSGc2WaxPD2+bmzAsVthL4+pEN0YwXcL5qED83vk=
+github.com/holiman/uint256 v1.2.2/go.mod h1:SC8Ryt4n+UBbPbIBKaG9zbbDlp4jOru9xFZmPzLUTxw=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
-github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo=
-github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc=
+github.com/huin/goupnp v1.0.3 h1:N8No57ls+MnjlB+JPiCVSOyy/ot7MJTqlo7rn+NYSqQ=
+github.com/huin/goupnp v1.0.3/go.mod h1:ZxNlw5WqJj6wSsRK5+YfflQGXYfccj5VgQsMNixHM7Y=
github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150/go.mod h1:PpLOETDnJ0o3iZrZfqZzyLl6l7F3c6L1oWn7OICBi6o=
github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/influxdata/influxdb v1.7.9 h1:uSeBTNO4rBkbp1Be5FKRsAmglM9nlx25TzVQRQt1An4=
github.com/influxdata/influxdb v1.7.9/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY=
github.com/influxdata/influxdb1-client v0.0.0-20190809212627-fc22c7df067e/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo=
-github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 h1:6OvNmYgJyexcZ3pYbTI9jWx5tHo1Dee/tWbLMfPe2TA=
-github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
+github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
+github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@@ -111,8 +134,9 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
-github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -123,13 +147,13 @@ github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e h1:e2z/lz9pvtRrE
github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e/go.mod h1:nty42YY5QByNC5MM7q/nj938VbgPU7avs45z6NClpxI=
github.com/maruel/ut v1.0.2 h1:mQTlQk3jubTbdTcza+hwoZQWhzcvE4L6K6RTtAFlA1k=
github.com/maruel/ut v1.0.2/go.mod h1:RV8PwPD9dd2KFlnlCc/DB2JVvkXmyaalfc5xvmSrRSs=
-github.com/mattn/go-colorable v0.1.0 h1:v2XXALHHh6zHfYTJ+cSkwtyffnaOyR1MXaA91mTrb8o=
-github.com/mattn/go-colorable v0.1.0/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
-github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035 h1:USWjF42jDCSEeikX/G1g40ZWnsPXN5WkZ4jMHZWyBK4=
-github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
+github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
-github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y=
-github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
+github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
+github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 h1:DpOJ2HYzCv8LZP15IdmG+YdwD2luVPHITV96TkirNBM=
github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo=
@@ -144,15 +168,19 @@ github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 h1:shk/vn9oCoOTmwcou
github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E=
github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77 h1:gKl78uP/I7JZ56OFtRf7nc4m1icV38hwV0In5pEGzeA=
github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77/go.mod h1:IuKpRQcYE1Tfu+oAQqaLisqDeXgjyyltCfsaoYN18NQ=
-github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c h1:1RHs3tNxjXGHeul8z2t6H2N2TlAqpKe5yryJztRx4Jk=
-github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo=
+github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78=
+github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
+github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
+github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=
github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
-github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=
+github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
+github.com/onsi/ginkgo v1.14.0 h1:2mOpI4JVVPBN+WQRa0WKH2eXR+Ey+uK4n7Zj0aYpIQA=
+github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
+github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
+github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE=
+github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/openconfig/gnmi v0.0.0-20190823184014-89b2bf29312c/go.mod h1:t+O9It+LKzfOAhKTT5O0ehDix+MTqbtT0T9t+7zzOvc=
github.com/openconfig/reference v0.0.0-20190727015836-8dfd928c9696/go.mod h1:ym2A+zigScwkSEb/cVQB0/ZMpU3rqiH6X7WRRsxgOGw=
github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g=
@@ -160,9 +188,10 @@ github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtP
github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 h1:oYW+YCJ1pachXTQmzR3rNLYGGz4g/UgFcjb28p/viDM=
github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0=
github.com/pierrec/lz4 v0.0.0-20190327172049-315a67e90e41/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc=
+github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
-github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -181,10 +210,11 @@ github.com/prometheus/prometheus v1.7.2-0.20170814170113-3101606756c5/go.mod h1:
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rjeczalik/notify v0.9.2 h1:MiTWrPj55mNDHEiIX5YUSKefw/+lCQVoAFmD6oQm5w8=
github.com/rjeczalik/notify v0.9.2/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM=
-github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
-github.com/rs/cors v1.6.0 h1:G9tHG9lebljV9mfp9SNPDL36nCDxmo3zTlAf1YgvzmI=
-github.com/rs/cors v1.6.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
+github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
+github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
+github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/steakknife/bloomfilter v0.0.0-20180922174646-6819c0d2a570 h1:gIlAHnH1vJb5vwEjIp5kBj/eu99p/bl0Ay2goiPe5xE=
@@ -193,12 +223,16 @@ github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 h1:njlZPzLwU639
github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3/go.mod h1:hpGUWaI9xL8pRQCTXQgocU38Qw1g0Us7n5PxxTwTCYU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
-github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d h1:gZZadD8H+fF+n9CmNhYL1Y0dJB+kLOmKd7FbPJLeGHs=
-github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d/go.mod h1:9OrXJhf154huy1nPWmuSrkgjPUtUNhA+Zmy+6AESzuA=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
+github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY=
+github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc=
github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161/go.mod h1:wM7WEvslTq+iOEAMDLSzhVuOt5BRZ05WirO+b09GHQU=
github.com/templexxx/xor v0.0.0-20181023030647-4e92f724b73b/go.mod h1:5XA7W9S6mni3h5uvOC75dA3m9CCCaS83lltmc0ukdi4=
github.com/tjfoc/gmsm v1.0.1/go.mod h1:XxO4hdhhrzAd+G4CjDqaOkd0hUzmtPR/d3EiBBMn/wc=
@@ -210,64 +244,100 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190404164418-38d8ce5564a5/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
+golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
+golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
+golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190912160710-24e19bdeb0f2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
+golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
+golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
+golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190801041406-cbf593c0f2f3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190912141932-bc967efca4b8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
+golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
+golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
-golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
+golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
+golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190912185636-87d9f09c5d89/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
+golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
+golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df h1:5Pf6pFKu98ODmgnpvkJ3kFUOQGGLIzLIkbzUHp47618=
+golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
+google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
+google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
+google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
+google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
+google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
+google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
+google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
+google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/bsm/ratelimit.v1 v1.0.0-20160220154919-db14e161995a/go.mod h1:KF9sEfUPAXdG8Oev9e99iLGnl2uJMjc5B+4y3O7x610=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -275,7 +345,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
-gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo=
gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q=
@@ -295,8 +364,11 @@ gopkg.in/urfave/cli.v1 v1.20.0 h1:NdAVW6RYxDif9DhDHaAortIu956m2c0v+09AZBPTbE0=
gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
-gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
-gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
diff --git a/internal/blocktest/test_hash.go b/internal/blocktest/test_hash.go
new file mode 100644
index 000000000..37d979e31
--- /dev/null
+++ b/internal/blocktest/test_hash.go
@@ -0,0 +1,60 @@
+// Copyright 2023 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package utesting provides a standalone replacement for package testing.
+//
+// This package exists because package testing cannot easily be embedded into a
+// standalone go program. It provides an API that mirrors the standard library
+// testing API.
+
+package blocktest
+
+import (
+ "hash"
+
+ "golang.org/x/crypto/sha3"
+
+ "github.com/tomochain/tomochain/common"
+)
+
+// testHasher is the helper tool for transaction/receipt list hashing.
+// The original hasher is trie, in order to get rid of import cycle,
+// use the testing hasher instead.
+type testHasher struct {
+ hasher hash.Hash
+}
+
+// NewHasher returns a new testHasher instance.
+func NewHasher() *testHasher {
+ return &testHasher{hasher: sha3.NewLegacyKeccak256()}
+}
+
+// Reset resets the hash state.
+func (h *testHasher) Reset() {
+ h.hasher.Reset()
+}
+
+// Update updates the hash state with the given key and value.
+func (h *testHasher) Update(key, val []byte) error {
+ h.hasher.Write(key)
+ h.hasher.Write(val)
+ return nil
+}
+
+// Hash returns the hash value.
+func (h *testHasher) Hash() common.Hash {
+ return common.BytesToHash(h.hasher.Sum(nil))
+}
diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go
index 33376a107..366de5e25 100644
--- a/internal/ethapi/api.go
+++ b/internal/ethapi/api.go
@@ -21,17 +21,15 @@ import (
"context"
"errors"
"fmt"
- "github.com/tomochain/tomochain/tomoxlending/lendingstate"
"math/big"
"sort"
"strings"
"time"
- "github.com/tomochain/tomochain/tomox/tradingstate"
-
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/util"
"github.com/tomochain/tomochain/accounts"
+ "github.com/tomochain/tomochain/accounts/abi"
"github.com/tomochain/tomochain/accounts/abi/bind"
"github.com/tomochain/tomochain/accounts/keystore"
"github.com/tomochain/tomochain/common"
@@ -41,6 +39,7 @@ import (
"github.com/tomochain/tomochain/consensus/posv"
contractValidator "github.com/tomochain/tomochain/contracts/validator/contract"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
@@ -50,6 +49,8 @@ import (
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/rpc"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
+ "github.com/tomochain/tomochain/tomoxlending/lendingstate"
)
const (
@@ -424,7 +425,8 @@ func (s *PrivateAccountAPI) SignTransaction(ctx context.Context, args SendTxArgs
// safely used to calculate a signature from.
//
// The hash is calulcated as
-// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}).
+//
+// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}).
//
// This gives context to the signed message and prevents signing of transactions.
func signHash(data []byte) []byte {
@@ -509,7 +511,7 @@ func (s *PublicBlockChainAPI) BlockNumber() *big.Int {
return header.Number
}
-// BlockNumber returns the block number of the chain head.
+// GetRewardByHash returns the block reward by block hash.
func (s *PublicBlockChainAPI) GetRewardByHash(hash common.Hash) map[string]map[string]map[string]*big.Int {
return s.b.GetRewardByHash(hash)
}
@@ -1025,17 +1027,21 @@ type CallArgs struct {
Data hexutil.Bytes `json:"data"`
}
-func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error) {
+func DoCall(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) (*core.ExecutionResult, error) {
defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now())
- statedb, header, err := s.b.StateAndHeaderByNumber(ctx, blockNr)
- if statedb == nil || err != nil {
- return nil, 0, false, err
+ state, header, err := b.StateAndHeaderByNumber(ctx, blockNr)
+ if state == nil || err != nil {
+ return nil, err
}
+
+ return doCall(ctx, b, args, state, header, timeout)
+}
+func doCall(ctx context.Context, b Backend, args CallArgs, state *state.StateDB, header *types.Header, timeout time.Duration) (*core.ExecutionResult, error) {
// Set sender address or use a default if none specified
addr := args.From
if addr == (common.Address{}) {
- if wallets := s.b.AccountManager().Wallets(); len(wallets) > 0 {
+ if wallets := b.AccountManager().Wallets(); len(wallets) > 0 {
if accounts := wallets[0].Accounts(); len(accounts) > 0 {
addr = accounts[0].Address
}
@@ -1052,7 +1058,17 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr
balanceTokenFee := big.NewInt(0).SetUint64(gas)
balanceTokenFee = balanceTokenFee.Mul(balanceTokenFee, gasPrice)
// Create new call message
- msg := types.NewMessage(addr, args.To, 0, args.Value.ToInt(), gas, gasPrice, args.Data, false, balanceTokenFee)
+ msg := &core.Message{
+ To: args.To,
+ From: addr,
+ Nonce: 0,
+ Value: args.Value.ToInt(),
+ GasLimit: gas,
+ GasPrice: gasPrice,
+ Data: args.Data,
+ BalanceTokenFee: balanceTokenFee,
+ SkipAccountChecks: true,
+ }
// Setup context so it may be cancelled the call has completed
// or, in case of unmetered gas, setup a context with a timeout.
@@ -1066,22 +1082,22 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr
// this makes sure resources are cleaned up.
defer cancel()
- block, err := s.b.BlockByNumber(ctx, blockNr)
+ block, err := b.BlockByNumber(ctx, rpc.BlockNumber(header.Number.Int64()))
if err != nil {
- return nil, 0, false, err
+ return nil, err
}
- author, err := s.b.GetEngine().Author(block.Header())
+ author, err := b.GetEngine().Author(block.Header())
if err != nil {
- return nil, 0, false, err
+ return nil, err
}
- tomoxState, err := s.b.TomoxService().GetTradingState(block, author)
+ tomoxState, err := b.TomoxService().GetTradingState(block, author)
if err != nil {
- return nil, 0, false, err
+ return nil, err
}
// Get a new instance of the EVM.
- evm, vmError, err := s.b.GetEVM(ctx, msg, statedb, tomoxState, header, vmCfg)
+ evm, vmError, err := b.GetEVM(ctx, msg, state, tomoxState, header, vm.Config{})
if err != nil {
- return nil, 0, false, err
+ return nil, err
}
// Wait for the context to be done and cancel the evm. Even if the
// EVM has finished, cancelling may be done (repeatedly)
@@ -1094,69 +1110,182 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr
// and apply the message.
gp := new(core.GasPool).AddGas(math.MaxUint64)
owner := common.Address{}
- res, gas, failed, err := core.ApplyMessage(evm, msg, gp, owner)
+ result, err := core.ApplyMessage(evm, msg, gp, owner)
if err := vmError(); err != nil {
- return nil, 0, false, err
+ return nil, err
+ }
+ return result, err
+}
+
+func newRevertError(result *core.ExecutionResult) *revertError {
+ reason, errUnpack := abi.UnpackRevert(result.Revert())
+ err := errors.New("execution reverted")
+ if errUnpack == nil {
+ err = fmt.Errorf("execution reverted: %v", reason)
}
- return res, gas, failed, err
+ return &revertError{
+ error: err,
+ reason: hexutil.Encode(result.Revert()),
+ }
+}
+
+// revertError is an API error that encompassas an EVM revertal with JSON error
+// code and a binary data blob.
+type revertError struct {
+ error
+ reason string // revert reason hex encoded
+}
+
+// ErrorCode returns the JSON error code for a revertal.
+// See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal
+func (e *revertError) ErrorCode() int {
+ return 3
+}
+
+// ErrorData returns the hex encoded revert reason.
+func (e *revertError) ErrorData() interface{} {
+ return e.reason
}
// Call executes the given transaction on the state for the given block number.
// It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values.
func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) {
- result, _, _, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second)
- return (hexutil.Bytes)(result), err
+ result, err := DoCall(ctx, s.b, args, blockNr, vm.Config{}, 5*time.Second)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(result.Revert()) > 0 {
+ return nil, newRevertError(result)
+ }
+ return result.Return(), result.Err
+}
+
+// executeEstimate is a helper that executes the transaction under a given gas limit and returns
+// true if the transaction fails for a reason that might be related to not enough gas. A non-nil
+// error means execution failed due to reasons unrelated to the gas limit.
+func executeEstimate(ctx context.Context, b Backend, args CallArgs, state *state.StateDB, header *types.Header, gasLimit uint64) (bool, *core.ExecutionResult, error) {
+ args.Gas = (hexutil.Uint64)(gasLimit)
+ result, err := doCall(ctx, b, args, state, header, 0)
+ if err != nil {
+ if errors.Is(err, core.ErrIntrinsicGas) {
+ return true, nil, nil // Special case, raise gas limit
+ }
+ return true, nil, err // Bail out
+ }
+ return result.Failed(), result, nil
}
-// EstimateGas returns an estimate of the amount of gas needed to execute the
-// given transaction against the current pending block.
-func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (hexutil.Uint64, error) {
- // Binary search the gas requirement, as it may be higher than the amount used
+// DoEstimateGas returns the lowest possible gas limit that allows the transaction to run
+// successfully at block `blockNrOrHash`. It returns error if the transaction would revert, or if
+// there are unexpected failures. The gas limit is capped by both `args.Gas` (if non-nil &
+// non-zero) and `gasCap` (if non-zero).
+func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumber) (hexutil.Uint64, error) {
+ // Binary search the gas limit, as it may need to be higher than the amount used
var (
- lo uint64 = params.TxGas - 1
- hi uint64
- cap uint64
+ lo uint64 // lowest-known gas limit where tx execution fails
+ hi uint64 // lowest-known gas limit where tx execution succeeds
)
+ // Determine the highest gas limit can be used during the estimation.
if uint64(args.Gas) >= params.TxGas {
hi = uint64(args.Gas)
} else {
- // Retrieve the current pending block to act as the gas ceiling
- block, err := s.b.BlockByNumber(ctx, rpc.LatestBlockNumber)
+ // Retrieve the block to act as the gas ceiling
+ block, err := b.BlockByNumber(ctx, blockNrOrHash)
if err != nil {
return 0, err
}
+ if block == nil {
+ return 0, errors.New("block not found")
+ }
hi = block.GasLimit()
}
- cap = hi
+ // Normalize the max fee per gas the call is willing to spend.
+ feeCap := args.GasPrice.ToInt()
- // Create a helper to check if a gas allowance results in an executable transaction
- executable := func(gas uint64) bool {
- args.Gas = hexutil.Uint64(gas)
+ state, header, err := b.StateAndHeaderByNumber(ctx, blockNrOrHash)
+ if state == nil || err != nil {
+ return 0, err
+ }
+
+ // Recap the highest gas limit with account's available balance.
+ if feeCap.BitLen() != 0 {
+ balance := state.GetBalance(args.From) // from can't be nil
+ available := new(big.Int).Set(balance)
+ if args.Value.ToInt().Cmp(available) >= 0 {
+ return 0, core.ErrInsufficientFundsForTransfer
+ }
+ available.Sub(available, args.Value.ToInt())
+ allowance := new(big.Int).Div(available, feeCap)
+
+ // If the allowance is larger than maximum uint64, skip checking
+ if allowance.IsUint64() && hi > allowance.Uint64() {
+ transfer := args.Value
+ log.Warn("Gas estimation capped by limited funds", "original", hi, "balance", balance,
+ "sent", transfer.ToInt(), "maxFeePerGas", feeCap, "fundable", allowance)
+ hi = allowance.Uint64()
+ }
+ }
- _, _, failed, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0)
- if err != nil || failed {
- return false
+ // We first execute the transaction at the highest allowable gas limit, since if this fails we
+ // can return error immediately.
+ failed, result, err := executeEstimate(ctx, b, args, state.Copy(), header, hi)
+ if err != nil {
+ return 0, err
+ }
+ if failed {
+ if result != nil && !errors.Is(result.Err, vm.ErrOutOfGas) {
+ if len(result.Revert()) > 0 {
+ return 0, newRevertError(result)
+ }
+ return 0, result.Err
}
- return true
+ return 0, fmt.Errorf("gas required exceeds allowance (%d)", hi)
}
- // Execute the binary search and hone in on an executable gas limit
+ // For almost any transaction, the gas consumed by the unconstrained execution above
+ // lower-bounds the gas limit required for it to succeed. One exception is those txs that
+ // explicitly check gas remaining in order to successfully execute within a given limit, but we
+ // probably don't want to return a lowest possible gas limit for these cases anyway.
+ lo = result.UsedGas - 1
+
+ // Binary search for the smallest gas limit that allows the tx to execute successfully.
for lo+1 < hi {
mid := (hi + lo) / 2
- if !executable(mid) {
+ if mid > lo*2 {
+ // Most txs don't need much higher gas limit than their gas used, and most txs don't
+ // require near the full block limit of gas, so the selection of where to bisect the
+ // range here is skewed to favor the low side.
+ mid = lo * 2
+ }
+ failed, _, err = executeEstimate(ctx, b, args, state.Copy(), header, mid)
+ if err != nil {
+ // This should not happen under normal conditions since if we make it this far the
+ // transaction had run without error at least once before.
+ log.Error("execution error in estimate gas", "err", err)
+ return 0, err
+ }
+ if failed {
lo = mid
} else {
hi = mid
}
}
- // Reject the transaction as invalid if it still fails at the highest allowance
- if hi == cap {
- if !executable(hi) {
- return 0, fmt.Errorf("gas required exceeds allowance or always failing transaction")
- }
- }
return hexutil.Uint64(hi), nil
}
+// EstimateGas returns the lowest possible gas limit that allows the transaction to run
+// successfully at block `blockNrOrHash`, or the latest block if `blockNrOrHash` is unspecified. It
+// returns error if the transaction would revert or if there are unexpected failures. The returned
+// value is capped by both `args.Gas` (if non-nil & non-zero) and the backend's RPCGasCap
+// configuration (if non-zero).
+func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumber) (hexutil.Uint64, error) {
+ bNrOrHash := rpc.LatestBlockNumber
+ if blockNrOrHash != nil {
+ bNrOrHash = *blockNrOrHash
+ }
+ return DoEstimateGas(ctx, s.b, args, bNrOrHash)
+}
+
// ExecutionResult groups all structured logs emitted by the EVM
// while replaying a transaction in debug mode as well as transaction
// execution status, the amount of gas used and the return value
@@ -1305,8 +1434,8 @@ func (s *PublicBlockChainAPI) findNearestSignedBlock(ctx context.Context, b *typ
}
/*
- findFinalityOfBlock return finality of a block
- Use blocksHashCache for to keep track - refer core/blockchain.go for more detail
+findFinalityOfBlock return finality of a block
+Use blocksHashCache for to keep track - refer core/blockchain.go for more detail
*/
func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types.Block, masternodes []common.Address) (uint, error) {
engine, _ := s.b.GetEngine().(*posv.Posv)
@@ -1371,7 +1500,7 @@ func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types.
}
/*
- Extract signers from block
+Extract signers from block
*/
func (s *PublicBlockChainAPI) getSigners(ctx context.Context, block *types.Block, engine *posv.Posv) ([]common.Address, error) {
var err error
@@ -1594,7 +1723,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr
// GetTransactionByHash returns the transaction for the given hash
func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, hash common.Hash) *RPCTransaction {
// Try to return an already finalized transaction
- if tx, blockHash, blockNumber, index := core.GetTransaction(s.b.ChainDb(), hash); tx != nil {
+ if tx, blockHash, blockNumber, index := rawdb.GetTransaction(s.b.ChainDb(), hash); tx != nil {
return newRPCTransaction(tx, blockHash, blockNumber, index)
}
// No finalized transaction, try to retrieve it from the pool
@@ -1610,7 +1739,7 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context,
var tx *types.Transaction
// Retrieve a finalized transaction, or a pooled otherwise
- if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil {
+ if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil {
if tx = s.b.GetPoolTransaction(hash); tx == nil {
// Transaction not found anywhere, abort
return nil, nil
@@ -1622,7 +1751,7 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context,
// GetTransactionReceipt returns the transaction receipt for the given transaction hash.
func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, hash common.Hash) (map[string]interface{}, error) {
- tx, blockHash, blockNumber, index := core.GetTransaction(s.b.ChainDb(), hash)
+ tx, blockHash, blockNumber, index := rawdb.GetTransaction(s.b.ChainDb(), hash)
if tx == nil {
return nil, nil
}
@@ -1867,7 +1996,7 @@ func (s *PublicTomoXTransactionPoolAPI) SendLendingRawTransaction(ctx context.Co
func (s *PublicTomoXTransactionPoolAPI) GetOrderTxMatchByHash(ctx context.Context, hash common.Hash) ([]*tradingstate.OrderItem, error) {
var tx *types.Transaction
orders := []*tradingstate.OrderItem{}
- if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil {
+ if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil {
if tx = s.b.GetPoolTransaction(hash); tx == nil {
return []*tradingstate.OrderItem{}, nil
}
@@ -2598,7 +2727,7 @@ func (s *PublicTomoXTransactionPoolAPI) GetBorrows(ctx context.Context, lendingT
// GetLendingTxMatchByHash returns lendingItems which have been processed at tx of the given txhash
func (s *PublicTomoXTransactionPoolAPI) GetLendingTxMatchByHash(ctx context.Context, hash common.Hash) ([]*lendingstate.LendingItem, error) {
var tx *types.Transaction
- if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil {
+ if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil {
if tx = s.b.GetPoolTransaction(hash); tx == nil {
return []*lendingstate.LendingItem{}, nil
}
@@ -2614,7 +2743,7 @@ func (s *PublicTomoXTransactionPoolAPI) GetLendingTxMatchByHash(ctx context.Cont
// GetLiquidatedTradesByTxHash returns trades which closed by TomoX protocol at the tx of the give hash
func (s *PublicTomoXTransactionPoolAPI) GetLiquidatedTradesByTxHash(ctx context.Context, hash common.Hash) (lendingstate.FinalizedResult, error) {
var tx *types.Transaction
- if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil {
+ if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil {
if tx = s.b.GetPoolTransaction(hash); tx == nil {
return lendingstate.FinalizedResult{}, nil
}
@@ -2965,7 +3094,8 @@ func GetSignersFromBlocks(b Backend, blockNumber uint64, blockHash common.Hash,
// GetStakerROI Estimate ROI for stakers using the last epoc reward
// then multiple by epoch per year, if the address is not masternode of last epoch - return 0
// Formular:
-// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100
+//
+// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100
func (s *PublicBlockChainAPI) GetStakerROI() float64 {
blockNumber := s.b.CurrentBlock().Number().Uint64()
lastCheckpointNumber := blockNumber - (blockNumber % s.b.ChainConfig().Posv.Epoch) - s.b.ChainConfig().Posv.Epoch // calculate for 2 epochs ago
@@ -2991,7 +3121,8 @@ func (s *PublicBlockChainAPI) GetStakerROI() float64 {
// GetStakerROIMasternode Estimate ROI for stakers of a specific masternode using the last epoc reward
// then multiple by epoch per year, if the address is not masternode of last epoch - return 0
// Formular:
-// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100
+//
+// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100
func (s *PublicBlockChainAPI) GetStakerROIMasternode(masternode common.Address) float64 {
votersReward := s.b.GetVotersRewards(masternode)
if votersReward == nil {
diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go
index 16edc3a17..9a197d4e2 100644
--- a/internal/ethapi/backend.go
+++ b/internal/ethapi/backend.go
@@ -19,12 +19,8 @@ package ethapi
import (
"context"
- "github.com/tomochain/tomochain/tomox/tradingstate"
- "github.com/tomochain/tomochain/tomoxlending"
"math/big"
- "github.com/tomochain/tomochain/tomox"
-
"github.com/tomochain/tomochain/accounts"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
@@ -38,6 +34,9 @@ import (
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rpc"
+ "github.com/tomochain/tomochain/tomox"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
+ "github.com/tomochain/tomochain/tomoxlending"
)
// Backend interface provides the common API services (that are provided by
@@ -61,7 +60,7 @@ type Backend interface {
GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error)
GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error)
GetTd(blockHash common.Hash) *big.Int
- GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error)
+ GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error)
SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription
SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription
SubscribeChainSideEvent(ch chan<- core.ChainSideEvent) event.Subscription
diff --git a/les/api_backend.go b/les/api_backend.go
index d8285da97..e49ee58c4 100644
--- a/les/api_backend.go
+++ b/les/api_backend.go
@@ -20,20 +20,17 @@ import (
"context"
"encoding/json"
"errors"
- "github.com/tomochain/tomochain/tomox/tradingstate"
- "github.com/tomochain/tomochain/tomoxlending"
"io/ioutil"
"math/big"
"path/filepath"
- "github.com/tomochain/tomochain/tomox"
-
"github.com/tomochain/tomochain/accounts"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
@@ -45,6 +42,9 @@ import (
"github.com/tomochain/tomochain/light"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rpc"
+ "github.com/tomochain/tomochain/tomox"
+ "github.com/tomochain/tomochain/tomox/tradingstate"
+ "github.com/tomochain/tomochain/tomoxlending"
)
type LesApiBackend struct {
@@ -94,19 +94,19 @@ func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t
}
func (b *LesApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) {
- return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash))
+ return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig())
}
func (b *LesApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) {
- return light.GetBlockLogs(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash))
+ return light.GetBlockLogs(ctx, b.eth.odr, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig())
}
func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int {
return b.eth.blockchain.GetTdByHash(blockHash)
}
-func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
- state.SetBalance(msg.From(), math.MaxBig256)
+func (b *LesApiBackend) GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
+ state.SetBalance(msg.From, math.MaxBig256)
context := core.NewEVMContext(msg, header, b.eth.blockchain, nil)
return vm.NewEVM(context, state, tomoxState, b.eth.chainConfig, vmCfg), state.Error, nil
}
diff --git a/les/backend.go b/les/backend.go
index 1a5cae11b..9cebbd40e 100644
--- a/les/backend.go
+++ b/les/backend.go
@@ -28,6 +28,7 @@ import (
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/core"
"github.com/tomochain/tomochain/core/bloombits"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/eth"
"github.com/tomochain/tomochain/eth/downloader"
@@ -122,7 +123,7 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
log.Warn("Rewinding chain to upgrade configuration", "err", compat)
leth.blockchain.SetHead(compat.RewindTo)
- core.WriteChainConfig(chainDb, genesisHash, chainConfig)
+ rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig)
}
leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
diff --git a/les/fetcher.go b/les/fetcher.go
index 7edfe808b..80568bc32 100644
--- a/les/fetcher.go
+++ b/les/fetcher.go
@@ -25,7 +25,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/mclock"
"github.com/tomochain/tomochain/consensus"
- "github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/light"
"github.com/tomochain/tomochain/log"
@@ -280,7 +280,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) {
// if one of root's children is canonical, keep it, delete other branches and root itself
var newRoot *fetcherTreeNode
for i, nn := range fp.root.children {
- if core.GetCanonicalHash(f.pm.chainDb, nn.number) == nn.hash {
+ if rawdb.GetCanonicalHash(f.pm.chainDb, nn.number) == nn.hash {
fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...)
nn.parent = nil
newRoot = nn
@@ -363,7 +363,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64) bo
//
// when syncing, just check if it is part of the known chain, there is nothing better we
// can do since we do not know the most recent block hash yet
- return core.GetCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && core.GetCanonicalHash(f.pm.chainDb, number) == hash
+ return rawdb.GetCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && rawdb.GetCanonicalHash(f.pm.chainDb, number) == hash
}
// requestAmount calculates the amount of headers to be downloaded starting
diff --git a/les/handler.go b/les/handler.go
index b426f7fdd..eb6d1886d 100644
--- a/les/handler.go
+++ b/les/handler.go
@@ -21,15 +21,14 @@ import (
"encoding/binary"
"errors"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
- "net"
"sync"
"time"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/eth/downloader"
@@ -38,8 +37,8 @@ import (
"github.com/tomochain/tomochain/light"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/p2p/discv5"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/trie"
@@ -164,8 +163,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protoco
var entry *poolEntry
peer := manager.newPeer(int(version), networkId, p, rw)
if manager.serverPool != nil {
- addr := p.RemoteAddr().(*net.TCPAddr)
- entry = manager.serverPool.connect(peer, addr.IP, uint16(addr.Port))
+ entry = manager.serverPool.connect(peer, peer.Node())
}
peer.poolEntry = entry
select {
@@ -187,7 +185,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protoco
NodeInfo: func() interface{} {
return manager.NodeInfo()
},
- PeerInfo: func(id discover.NodeID) interface{} {
+ PeerInfo: func(id enode.ID) interface{} {
if p := manager.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil {
return p.Info()
}
@@ -388,7 +386,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
}
if p.requestAnnounceType == announceTypeSigned {
- if err := req.checkSignature(p.pubKey); err != nil {
+ if err := req.checkSignature(p.ID()); err != nil {
p.Log().Trace("Invalid announcement signature", "err", err)
return err
}
@@ -529,7 +527,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
break
}
// Retrieve the requested block body, stopping if enough was found
- if data := core.GetBodyRLP(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 {
+ if data := rawdb.GetBodyRLP(pm.chainDb, hash, rawdb.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 {
bodies = append(bodies, data)
bytes += len(data)
}
@@ -580,7 +578,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
}
for _, req := range req.Reqs {
// Retrieve the requested state entry, stopping if enough was found
- if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
+ if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
statedb, err := pm.blockchain.State()
if err != nil {
continue
@@ -646,7 +644,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
break
}
// Retrieve the requested block's receipts, skipping if unknown to us
- results := core.GetBlockReceipts(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash))
+ results := rawdb.GetBlockReceipts(pm.chainDb, hash, rawdb.GetBlockNumber(pm.chainDb, hash), pm.chainConfig)
if results == nil {
if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
continue
@@ -706,7 +704,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
}
for _, req := range req.Reqs {
// Retrieve the requested state entry, stopping if enough was found
- if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
+ if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
statedb, err := pm.blockchain.State()
if err != nil {
continue
@@ -764,7 +762,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
if statedb == nil || req.BHash != lastBHash {
statedb, root, lastBHash = nil, common.Hash{}, req.BHash
- if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
+ if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
statedb, _ = pm.blockchain.State()
root = header.Root
}
@@ -860,7 +858,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
trieDb := trie.NewDatabase(rawdb.NewTable(pm.chainDb, light.ChtTablePrefix))
for _, req := range req.Reqs {
if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil {
- sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.CHTFrequencyServer-1)
+ sectionHead := rawdb.GetCanonicalHash(pm.chainDb, req.ChtNum*light.CHTFrequencyServer-1)
if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) {
trie, err := trie.New(root, trieDb)
if err != nil {
@@ -1095,18 +1093,18 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
}
// getAccount retrieves an account from the state based at root.
-func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) {
+func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (types.StateAccount, error) {
trie, err := trie.New(root, statedb.Database().TrieDB())
if err != nil {
- return state.Account{}, err
+ return types.StateAccount{}, err
}
- blob, err := trie.TryGet(hash[:])
+ blob, err := trie.Get(hash[:])
if err != nil {
- return state.Account{}, err
+ return types.StateAccount{}, err
}
- var account state.Account
+ var account types.StateAccount
if err = rlp.DecodeBytes(blob, &account); err != nil {
- return state.Account{}, err
+ return types.StateAccount{}, err
}
return account, nil
}
@@ -1115,10 +1113,10 @@ func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.
func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) {
switch id {
case htCanonical:
- sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.CHTFrequencyClient-1)
+ sectionHead := rawdb.GetCanonicalHash(pm.chainDb, (idx+1)*light.CHTFrequencyClient-1)
return light.GetChtV2Root(pm.chainDb, idx, sectionHead), light.ChtTablePrefix
case htBloomBits:
- sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1)
+ sectionHead := rawdb.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1)
return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix
}
return common.Hash{}, ""
@@ -1129,8 +1127,8 @@ func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte {
switch {
case req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8:
blockNum := binary.BigEndian.Uint64(req.Key)
- hash := core.GetCanonicalHash(pm.chainDb, blockNum)
- return core.GetHeaderRLP(pm.chainDb, hash, blockNum)
+ hash := rawdb.GetCanonicalHash(pm.chainDb, blockNum)
+ return rawdb.GetHeaderRLP(pm.chainDb, hash, blockNum)
}
return nil
}
@@ -1143,9 +1141,9 @@ func (pm *ProtocolManager) txStatus(hashes []common.Hash) []txStatus {
// If the transaction is unknown to the pool, try looking it up locally
if stat == core.TxStatusUnknown {
- if block, number, index := core.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) {
+ if block, number, index := rawdb.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) {
stats[i].Status = core.TxStatusIncluded
- stats[i].Lookup = &core.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index}
+ stats[i].Lookup = &rawdb.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index}
}
}
}
diff --git a/les/handler_test.go b/les/handler_test.go
index 225900dd5..2492a11f8 100644
--- a/les/handler_test.go
+++ b/les/handler_test.go
@@ -18,7 +18,6 @@ package les
import (
"encoding/binary"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"math/rand"
"testing"
@@ -27,6 +26,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/eth/downloader"
@@ -304,7 +304,7 @@ func testGetReceipt(t *testing.T, protocol int) {
block := bc.GetBlockByNumber(i)
hashes = append(hashes, block.Hash())
- receipts = append(receipts, core.GetBlockReceipts(db, block.Hash(), block.NumberU64()))
+ receipts = append(receipts, rawdb.GetBlockReceipts(db, block.Hash(), block.NumberU64(), pm.chainConfig))
}
// Send the hash request and verify the response
cost := peer.GetRequestCost(GetReceiptsMsg, len(hashes))
@@ -555,9 +555,9 @@ func TestTransactionStatusLes2(t *testing.T) {
}
// check if their status is included now
- block1hash := core.GetCanonicalHash(db, 1)
- test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}})
- test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}})
+ block1hash := rawdb.GetCanonicalHash(db, 1)
+ test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}})
+ test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}})
// create a reorg that rolls them back
gchain, _ = core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), ethash.NewFaker(), db, 2, func(i int, block *core.BlockGen) {})
diff --git a/les/helper_test.go b/les/helper_test.go
index 67a932b4e..4841efb6c 100644
--- a/les/helper_test.go
+++ b/les/helper_test.go
@@ -37,7 +37,7 @@ import (
"github.com/tomochain/tomochain/les/flowcontrol"
"github.com/tomochain/tomochain/light"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/params"
)
@@ -223,8 +223,8 @@ func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, sh
app, net := p2p.MsgPipe()
// Generate a random id and create the peer
- var id discover.NodeID
- rand.Read(id[:])
+ var id enode.ID
+ rand.Read(id.Bytes())
peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
@@ -260,8 +260,8 @@ func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer,
app, net := p2p.MsgPipe()
// Generate a random id and create the peer
- var id discover.NodeID
- rand.Read(id[:])
+ var id enode.ID
+ rand.Read(id.Bytes())
peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
peer2 := pm2.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), app)
diff --git a/les/odr_requests.go b/les/odr_requests.go
index e6e68e762..8bf12f6e8 100644
--- a/les/odr_requests.go
+++ b/les/odr_requests.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-// Package light implements on-demand retrieval capable state and chain objects
+// Package les implements on-demand retrieval capable state and chain objects
// for the Ethereum Light Client.
package les
@@ -24,7 +24,7 @@ import (
"fmt"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/ethdb"
@@ -110,11 +110,11 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error {
body := bodies[0]
// Retrieve our stored header and validate block content against it
- header := core.GetHeader(db, r.Hash, r.Number)
+ header := rawdb.GetHeader(db, r.Hash, r.Number)
if header == nil {
return errHeaderUnavailable
}
- if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions)) {
+ if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions), new(trie.StackTrie)) {
return errTxHashMismatch
}
if header.UncleHash != types.CalcUncleHash(body.Uncles) {
@@ -166,11 +166,11 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error {
receipt := receipts[0]
// Retrieve our stored header and validate receipt content against it
- header := core.GetHeader(db, r.Hash, r.Number)
+ header := rawdb.GetHeader(db, r.Hash, r.Number)
if header == nil {
return errHeaderUnavailable
}
- if header.ReceiptHash != types.DeriveSha(receipt) {
+ if header.ReceiptHash != types.DeriveSha(receipt, new(trie.StackTrie)) {
return errReceiptHashMismatch
}
// Validations passed, store and return
diff --git a/les/odr_test.go b/les/odr_test.go
index 3858e3402..1180b4bfa 100644
--- a/les/odr_test.go
+++ b/les/odr_test.go
@@ -19,7 +19,6 @@ package les
import (
"bytes"
"context"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
@@ -64,9 +64,9 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon
func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var receipts types.Receipts
if bc != nil {
- receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash))
+ receipts = rawdb.GetBlockReceipts(db, bhash, rawdb.GetBlockNumber(db, bhash), config)
} else {
- receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash))
+ receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, rawdb.GetBlockNumber(db, bhash), config)
}
if receipts == nil {
return nil
@@ -91,7 +91,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
for _, addr := range acc {
if bc != nil {
header := bc.GetHeaderByHash(bhash)
- st, err = state.New(header.Root, state.NewDatabase(db))
+ st, err = state.New(header.Root, state.NewDatabase(db), nil)
} else {
header := lc.GetHeaderByHash(bhash)
st = light.NewState(ctx, header, lc.Odr())
@@ -109,12 +109,6 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
//
//func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, odrContractCall) }
-type callmsg struct {
- types.Message
-}
-
-func (callmsg) CheckNonce() bool { return false }
-
func odrContractCall(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000")
@@ -123,7 +117,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
data[35] = byte(i)
if bc != nil {
header := bc.GetHeaderByHash(bhash)
- statedb, err := state.New(header.Root, state.NewDatabase(db))
+ statedb, err := state.New(header.Root, state.NewDatabase(db), nil)
if err == nil {
from := statedb.GetOrNewStateObject(testBankAddress)
@@ -133,16 +127,26 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
if value, ok := feeCapacity[testContractAddr]; ok {
balanceTokenFee = value
}
- msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee)}
-
+ fromAddr := from.Address()
+ msg := &core.Message{
+ To: &fromAddr,
+ From: testContractAddr,
+ Nonce: 0,
+ Value: new(big.Int),
+ GasLimit: 100000,
+ GasPrice: new(big.Int),
+ Data: data,
+ SkipAccountChecks: false,
+ BalanceTokenFee: balanceTokenFee,
+ }
context := core.NewEVMContext(msg, header, bc, nil)
vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{})
//vmenv := core.NewEnv(statedb, config, bc, msg, header, vm.Config{})
gp := new(core.GasPool).AddGas(math.MaxUint64)
owner := common.Address{}
- ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner)
- res = append(res, ret...)
+ ret, _ := core.ApplyMessage(vmenv, msg, gp, owner)
+ res = append(res, ret.Return()...)
}
} else {
header := lc.GetHeaderByHash(bhash)
@@ -153,14 +157,24 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
if value, ok := feeCapacity[testContractAddr]; ok {
balanceTokenFee = value
}
- msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee)}
+ msg := &core.Message{
+ To: &testBankAddress,
+ From: testContractAddr,
+ Nonce: 0,
+ Value: new(big.Int),
+ GasLimit: 100000,
+ GasPrice: new(big.Int),
+ Data: data,
+ SkipAccountChecks: false,
+ BalanceTokenFee: balanceTokenFee,
+ }
context := core.NewEVMContext(msg, header, lc, nil)
vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{})
gp := new(core.GasPool).AddGas(math.MaxUint64)
owner := common.Address{}
- ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner)
+ ret, _:= core.ApplyMessage(vmenv, msg, gp, owner)
if statedb.Error() == nil {
- res = append(res, ret...)
+ res = append(res, ret.Return()...)
}
}
}
@@ -190,7 +204,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) {
test := func(expFail uint64) {
for i := uint64(0); i <= pm.blockchain.CurrentHeader().Number.Uint64(); i++ {
- bhash := core.GetCanonicalHash(db, i)
+ bhash := rawdb.GetCanonicalHash(db, i)
b1 := fn(light.NoOdr, db, pm.chainConfig, pm.blockchain.(*core.BlockChain), nil, bhash)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
diff --git a/les/peer.go b/les/peer.go
index 2723003ec..ca91562f4 100644
--- a/les/peer.go
+++ b/les/peer.go
@@ -18,8 +18,6 @@
package les
import (
- "crypto/ecdsa"
- "encoding/binary"
"errors"
"fmt"
"math/big"
@@ -36,9 +34,10 @@ import (
)
var (
- errClosed = errors.New("peer set is closed")
- errAlreadyRegistered = errors.New("peer is already registered")
- errNotRegistered = errors.New("peer is not registered")
+ errClosed = errors.New("peer set is closed")
+ errAlreadyRegistered = errors.New("peer is already registered")
+ errNotRegistered = errors.New("peer is not registered")
+ errInvalidHelpTrieReq = errors.New("invalid help trie request")
)
const maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam)
@@ -51,7 +50,6 @@ const (
type peer struct {
*p2p.Peer
- pubKey *ecdsa.PublicKey
rw p2p.MsgReadWriter
@@ -80,11 +78,9 @@ type peer struct {
func newPeer(version int, network uint64, p *p2p.Peer, rw p2p.MsgReadWriter) *peer {
id := p.ID()
- pubKey, _ := id.Pubkey()
return &peer{
Peer: p,
- pubKey: pubKey,
rw: rw,
version: version,
network: network,
@@ -284,21 +280,21 @@ func (p *peer) RequestProofs(reqID, cost uint64, reqs []ProofReq) error {
}
// RequestHelperTrieProofs fetches a batch of HelperTrie merkle proofs from a remote node.
-func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, reqs []HelperTrieReq) error {
- p.Log().Debug("Fetching batch of HelperTrie proofs", "count", len(reqs))
+func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, data interface{}) error {
switch p.version {
case lpv1:
- reqsV1 := make([]ChtReq, len(reqs))
- for i, req := range reqs {
- if req.Type != htCanonical || req.AuxReq != auxHeader || len(req.Key) != 8 {
- return fmt.Errorf("Request invalid in LES/1 mode")
- }
- blockNum := binary.BigEndian.Uint64(req.Key)
- // convert HelperTrie request to old CHT request
- reqsV1[i] = ChtReq{ChtNum: (req.TrieIdx + 1) * (light.CHTFrequencyClient / light.CHTFrequencyServer), BlockNum: blockNum, FromLevel: req.FromLevel}
+ reqs, ok := data.([]ChtReq)
+ if !ok {
+ return errInvalidHelpTrieReq
}
- return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqsV1)
+ p.Log().Debug("Fetching batch of header proofs", "count", len(reqs))
+ return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqs)
case lpv2:
+ reqs, ok := data.([]HelperTrieReq)
+ if !ok {
+ return errInvalidHelpTrieReq
+ }
+ p.Log().Debug("Fetching batch of HelperTrie proofs", "count", len(reqs))
return sendRequest(p.rw, GetHelperTrieProofsMsg, reqID, cost, reqs)
default:
panic(nil)
@@ -545,9 +541,11 @@ func (ps *peerSet) notify(n peerSetNotify) {
func (ps *peerSet) Register(p *peer) error {
ps.lock.Lock()
if ps.closed {
+ ps.lock.Unlock()
return errClosed
}
if _, ok := ps.peers[p.id]; ok {
+ ps.lock.Unlock()
return errAlreadyRegistered
}
ps.peers[p.id] = p
diff --git a/les/protocol.go b/les/protocol.go
index 9ca62e73e..eabac6dab 100644
--- a/les/protocol.go
+++ b/les/protocol.go
@@ -18,9 +18,7 @@
package les
import (
- "bytes"
"crypto/ecdsa"
- "crypto/elliptic"
"errors"
"fmt"
"io"
@@ -28,8 +26,9 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/crypto"
- "github.com/tomochain/tomochain/crypto/secp256k1"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rlp"
)
@@ -147,22 +146,20 @@ func (a *announceData) sign(privKey *ecdsa.PrivateKey) {
}
// checkSignature verifies if the block announcement has a valid signature by the given pubKey
-func (a *announceData) checkSignature(pubKey *ecdsa.PublicKey) error {
+func (a *announceData) checkSignature(id enode.ID) error {
var sig []byte
if err := a.Update.decode().get("sign", &sig); err != nil {
return err
}
rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td})
- recPubkey, err := secp256k1.RecoverPubkey(crypto.Keccak256(rlp), sig)
+ recPubkey, err := crypto.SigToPub(crypto.Keccak256(rlp), sig)
if err != nil {
return err
}
- pbytes := elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y)
- if bytes.Equal(pbytes, recPubkey) {
+ if id == enode.PubkeyToIDV4(recPubkey) {
return nil
- } else {
- return errors.New("Wrong signature")
}
+ return errors.New("wrong signature")
}
type blockInfo struct {
@@ -224,6 +221,6 @@ type proofsData [][]rlp.RawValue
type txStatus struct {
Status core.TxStatus
- Lookup *core.TxLookupEntry `rlp:"nil"`
+ Lookup *rawdb.TxLookupEntry `rlp:"nil"`
Error string
}
diff --git a/les/request_test.go b/les/request_test.go
index 183128d83..2313e738a 100644
--- a/les/request_test.go
+++ b/les/request_test.go
@@ -18,12 +18,11 @@ package les
import (
"context"
- "github.com/tomochain/tomochain/core/rawdb"
"testing"
"time"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/eth"
"github.com/tomochain/tomochain/ethdb"
@@ -59,7 +58,7 @@ func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light
//func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) }
func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
- return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey}
+ return &light.TrieRequest{Id: light.StateTrieID(rawdb.GetHeader(db, bhash, rawdb.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey}
}
//func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) }
@@ -67,7 +66,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh
//func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) }
func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
- header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))
+ header := rawdb.GetHeader(db, bhash, rawdb.GetBlockNumber(db, bhash))
if header.Number.Uint64() < testContractDeployed {
return nil
}
@@ -100,7 +99,7 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) {
test := func(expFail uint64) {
for i := uint64(0); i <= pm.blockchain.CurrentHeader().Number.Uint64(); i++ {
- bhash := core.GetCanonicalHash(db, i)
+ bhash := rawdb.GetCanonicalHash(db, i)
if req := fn(ldb, bhash, i); req != nil {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
diff --git a/les/server.go b/les/server.go
index b56d2cad4..4705f599d 100644
--- a/les/server.go
+++ b/les/server.go
@@ -25,6 +25,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/eth"
"github.com/tomochain/tomochain/ethdb"
@@ -329,11 +330,11 @@ func (pm *ProtocolManager) blockLoop() {
header := ev.Block.Header()
hash := header.Hash()
number := header.Number.Uint64()
- td := core.GetTd(pm.chainDb, hash, number)
+ td := rawdb.GetTd(pm.chainDb, hash, number)
if td != nil && td.Cmp(lastBroadcastTd) > 0 {
var reorg uint64
if lastHead != nil {
- reorg = lastHead.Number.Uint64() - core.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64()
+ reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64()
}
lastHead = header
lastBroadcastTd = td
diff --git a/les/serverpool.go b/les/serverpool.go
index 313de65e9..93a37fc27 100644
--- a/les/serverpool.go
+++ b/les/serverpool.go
@@ -18,6 +18,7 @@
package les
import (
+ "crypto/ecdsa"
"fmt"
"io"
"math"
@@ -28,11 +29,12 @@ import (
"time"
"github.com/tomochain/tomochain/common/mclock"
+ "github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/p2p/discv5"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rlp"
)
@@ -73,7 +75,6 @@ const (
// and a short term value which is adjusted exponentially with a factor of
// pstatRecentAdjust with each dial/connection and also returned exponentially
// to the average with the time constant pstatReturnToMeanTC
- pstatRecentAdjust = 0.1
pstatReturnToMeanTC = time.Hour
// node address selection weight is dropped by a factor of exp(-addrFailDropLn) after
// each unsuccessful connection (restored after a successful one)
@@ -83,14 +84,31 @@ const (
responseScoreTC = time.Millisecond * 100
delayScoreTC = time.Second * 5
timeoutPow = 10
- // peerSelectMinWeight is added to calculated weights at request peer selection
- // to give poorly performing peers a little chance of coming back
- peerSelectMinWeight = 0.005
// initStatsWeight is used to initialize previously unknown peers with good
// statistics to give a chance to prove themselves
initStatsWeight = 1
)
+// connReq represents a request for peer connection.
+type connReq struct {
+ p *peer
+ node *enode.Node
+ result chan *poolEntry
+}
+
+// disconnReq represents a request for peer disconnection.
+type disconnReq struct {
+ entry *poolEntry
+ stopped bool
+ done chan struct{}
+}
+
+// registerReq represents a request for peer registration.
+type registerReq struct {
+ entry *poolEntry
+ done chan struct{}
+}
+
// serverPool implements a pool for storing and selecting newly discovered and already
// known light server nodes. It received discovered nodes, stores statistics about
// known nodes and takes care of always having enough good quality servers connected.
@@ -105,14 +123,17 @@ type serverPool struct {
topic discv5.Topic
discSetPeriod chan time.Duration
- discNodes chan *discv5.Node
+ discNodes chan *enode.Node
discLookups chan bool
- entries map[discover.NodeID]*poolEntry
- lock sync.Mutex
+ entries map[enode.ID]*poolEntry
timeout, enableRetry chan *poolEntry
adjustStats chan poolStatAdjust
+ connCh chan *connReq
+ disconnCh chan *disconnReq
+ registerCh chan *registerReq
+
knownQueue, newQueue poolEntryQueue
knownSelect, newSelect *weightedRandomSelect
knownSelected, newSelected int
@@ -125,10 +146,13 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup) *s
db: db,
quit: quit,
wg: wg,
- entries: make(map[discover.NodeID]*poolEntry),
+ entries: make(map[enode.ID]*poolEntry),
timeout: make(chan *poolEntry, 1),
adjustStats: make(chan poolStatAdjust, 100),
enableRetry: make(chan *poolEntry, 1),
+ connCh: make(chan *connReq),
+ disconnCh: make(chan *disconnReq),
+ registerCh: make(chan *registerReq),
knownSelect: newWeightedRandomSelect(),
newSelect: newWeightedRandomSelect(),
fastDiscover: true,
@@ -147,13 +171,28 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
if pool.server.DiscV5 != nil {
pool.discSetPeriod = make(chan time.Duration, 1)
- pool.discNodes = make(chan *discv5.Node, 100)
+ pool.discNodes = make(chan *enode.Node, 100)
pool.discLookups = make(chan bool, 100)
- go pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, pool.discNodes, pool.discLookups)
+ go pool.discoverNodes()
}
-
- go pool.eventLoop()
pool.checkDial()
+ go pool.eventLoop()
+}
+
+// discoverNodes wraps SearchTopic, converting result nodes to enode.Node.
+func (pool *serverPool) discoverNodes() {
+ ch := make(chan *discv5.Node)
+ go func() {
+ pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, ch, pool.discLookups)
+ close(ch)
+ }()
+ for n := range ch {
+ pubkey, err := decodePubkey64(n.ID[:])
+ if err != nil {
+ continue
+ }
+ pool.discNodes <- enode.NewV4(pubkey, n.IP, int(n.TCP), int(n.UDP))
+ }
}
// connect should be called upon any incoming connection. If the connection has been
@@ -161,84 +200,45 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
// Otherwise, the connection should be rejected.
// Note that whenever a connection has been accepted and a pool entry has been returned,
// disconnect should also always be called.
-func (pool *serverPool) connect(p *peer, ip net.IP, port uint16) *poolEntry {
- pool.lock.Lock()
- defer pool.lock.Unlock()
- entry := pool.entries[p.ID()]
- if entry == nil {
- entry = pool.findOrNewNode(p.ID(), ip, port)
- }
- p.Log().Debug("Connecting to new peer", "state", entry.state)
- if entry.state == psConnected || entry.state == psRegistered {
+func (pool *serverPool) connect(p *peer, node *enode.Node) *poolEntry {
+ log.Debug("Connect new entry", "enode", p.id)
+ req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)}
+ select {
+ case pool.connCh <- req:
+ case <-pool.quit:
return nil
}
- pool.connWg.Add(1)
- entry.peer = p
- entry.state = psConnected
- addr := &poolEntryAddress{
- ip: ip,
- port: port,
- lastSeen: mclock.Now(),
- }
- entry.lastConnected = addr
- entry.addr = make(map[string]*poolEntryAddress)
- entry.addr[addr.strKey()] = addr
- entry.addrSelect = *newWeightedRandomSelect()
- entry.addrSelect.update(addr)
- return entry
+ return <-req.result
}
// registered should be called after a successful handshake
func (pool *serverPool) registered(entry *poolEntry) {
- log.Debug("Registered new entry", "enode", entry.id)
- pool.lock.Lock()
- defer pool.lock.Unlock()
-
- entry.state = psRegistered
- entry.regTime = mclock.Now()
- if !entry.known {
- pool.newQueue.remove(entry)
- entry.known = true
+ log.Debug("Registered new entry", "enode", entry.node.ID())
+ req := ®isterReq{entry: entry, done: make(chan struct{})}
+ select {
+ case pool.registerCh <- req:
+ case <-pool.quit:
+ return
}
- pool.knownQueue.setLatest(entry)
- entry.shortRetry = shortRetryCnt
+ <-req.done
}
// disconnect should be called when ending a connection. Service quality statistics
// can be updated optionally (not updated if no registration happened, in this case
// only connection statistics are updated, just like in case of timeout)
func (pool *serverPool) disconnect(entry *poolEntry) {
- log.Debug("Disconnected old entry", "enode", entry.id)
- pool.lock.Lock()
- defer pool.lock.Unlock()
-
- if entry.state == psRegistered {
- connTime := mclock.Now() - entry.regTime
- connAdjust := float64(connTime) / float64(targetConnTime)
- if connAdjust > 1 {
- connAdjust = 1
- }
- stopped := false
- select {
- case <-pool.quit:
- stopped = true
- default:
- }
- if stopped {
- entry.connectStats.add(1, connAdjust)
- } else {
- entry.connectStats.add(connAdjust, 1)
- }
+ stopped := false
+ select {
+ case <-pool.quit:
+ stopped = true
+ default:
}
+ log.Debug("Disconnected old entry", "enode", entry.node.ID())
+ req := &disconnReq{entry: entry, stopped: stopped, done: make(chan struct{})}
- entry.state = psNotConnected
- if entry.knownSelected {
- pool.knownSelected--
- } else {
- pool.newSelected--
- }
- pool.setRetryDial(entry)
- pool.connWg.Done()
+ // Block until disconnection request is served.
+ pool.disconnCh <- req
+ <-req.done
}
const (
@@ -281,25 +281,51 @@ func (pool *serverPool) eventLoop() {
if pool.discSetPeriod != nil {
pool.discSetPeriod <- time.Millisecond * 100
}
+
+ // disconnect updates service quality statistics depending on the connection time
+ // and disconnection initiator.
+ disconnect := func(req *disconnReq, stopped bool) {
+ // Handle peer disconnection requests.
+ entry := req.entry
+ if entry.state == psRegistered {
+ connAdjust := float64(mclock.Now()-entry.regTime) / float64(targetConnTime)
+ if connAdjust > 1 {
+ connAdjust = 1
+ }
+ if stopped {
+ // disconnect requested by ourselves.
+ entry.connectStats.add(1, connAdjust)
+ } else {
+ // disconnect requested by server side.
+ entry.connectStats.add(connAdjust, 1)
+ }
+ }
+ entry.state = psNotConnected
+
+ if entry.knownSelected {
+ pool.knownSelected--
+ } else {
+ pool.newSelected--
+ }
+ pool.setRetryDial(entry)
+ pool.connWg.Done()
+ close(req.done)
+ }
+
for {
select {
case entry := <-pool.timeout:
- pool.lock.Lock()
if !entry.removed {
pool.checkDialTimeout(entry)
}
- pool.lock.Unlock()
case entry := <-pool.enableRetry:
- pool.lock.Lock()
if !entry.removed {
entry.delayedRetry = false
pool.updateCheckDial(entry)
}
- pool.lock.Unlock()
case adj := <-pool.adjustStats:
- pool.lock.Lock()
switch adj.adjustType {
case pseBlockDelay:
adj.entry.delayStats.add(float64(adj.time), 1)
@@ -309,13 +335,10 @@ func (pool *serverPool) eventLoop() {
case pseResponseTimeout:
adj.entry.timeoutStats.add(1, 1)
}
- pool.lock.Unlock()
case node := <-pool.discNodes:
- pool.lock.Lock()
- entry := pool.findOrNewNode(discover.NodeID(node.ID), node.IP, node.TCP)
+ entry := pool.findOrNewNode(node)
pool.updateCheckDial(entry)
- pool.lock.Unlock()
case conv := <-pool.discLookups:
if conv {
@@ -331,31 +354,82 @@ func (pool *serverPool) eventLoop() {
}
}
+ case req := <-pool.connCh:
+ // Handle peer connection requests.
+ entry := pool.entries[req.p.ID()]
+ if entry == nil {
+ entry = pool.findOrNewNode(req.node)
+ }
+ if entry.state == psConnected || entry.state == psRegistered {
+ req.result <- nil
+ continue
+ }
+ pool.connWg.Add(1)
+ entry.peer = req.p
+ entry.state = psConnected
+ addr := &poolEntryAddress{
+ ip: req.node.IP(),
+ port: uint16(req.node.TCP()),
+ lastSeen: mclock.Now(),
+ }
+ entry.lastConnected = addr
+ entry.addr = make(map[string]*poolEntryAddress)
+ entry.addr[addr.strKey()] = addr
+ entry.addrSelect = *newWeightedRandomSelect()
+ entry.addrSelect.update(addr)
+ req.result <- entry
+
+ case req := <-pool.registerCh:
+ // Handle peer registration requests.
+ entry := req.entry
+ entry.state = psRegistered
+ entry.regTime = mclock.Now()
+ if !entry.known {
+ pool.newQueue.remove(entry)
+ entry.known = true
+ }
+ pool.knownQueue.setLatest(entry)
+ entry.shortRetry = shortRetryCnt
+ close(req.done)
+
+ case req := <-pool.disconnCh:
+ // Handle peer disconnection requests.
+ disconnect(req, req.stopped)
+
case <-pool.quit:
if pool.discSetPeriod != nil {
close(pool.discSetPeriod)
}
- pool.connWg.Wait()
+
+ // Spawn a goroutine to close the disconnCh after all connections are disconnected.
+ go func() {
+ pool.connWg.Wait()
+ close(pool.disconnCh)
+ }()
+
+ // Handle all remaining disconnection requests before exit.
+ for req := range pool.disconnCh {
+ disconnect(req, true)
+ }
pool.saveNodes()
pool.wg.Done()
return
-
}
}
}
-func (pool *serverPool) findOrNewNode(id discover.NodeID, ip net.IP, port uint16) *poolEntry {
+func (pool *serverPool) findOrNewNode(node *enode.Node) *poolEntry {
now := mclock.Now()
- entry := pool.entries[id]
+ entry := pool.entries[node.ID()]
if entry == nil {
- log.Debug("Discovered new entry", "id", id)
+ log.Debug("Discovered new entry", "id", node.ID())
entry = &poolEntry{
- id: id,
+ node: node,
addr: make(map[string]*poolEntryAddress),
addrSelect: *newWeightedRandomSelect(),
shortRetry: shortRetryCnt,
}
- pool.entries[id] = entry
+ pool.entries[node.ID()] = entry
// initialize previously unknown peers with good statistics to give a chance to prove themselves
entry.connectStats.add(1, initStatsWeight)
entry.delayStats.add(0, initStatsWeight)
@@ -363,10 +437,7 @@ func (pool *serverPool) findOrNewNode(id discover.NodeID, ip net.IP, port uint16
entry.timeoutStats.add(0, initStatsWeight)
}
entry.lastDiscovered = now
- addr := &poolEntryAddress{
- ip: ip,
- port: port,
- }
+ addr := &poolEntryAddress{ip: node.IP(), port: uint16(node.TCP())}
if a, ok := entry.addr[addr.strKey()]; ok {
addr = a
} else {
@@ -393,12 +464,12 @@ func (pool *serverPool) loadNodes() {
return
}
for _, e := range list {
- log.Debug("Loaded server stats", "id", e.id, "fails", e.lastConnected.fails,
+ log.Debug("Loaded server stats", "id", e.node.ID(), "fails", e.lastConnected.fails,
"conn", fmt.Sprintf("%v/%v", e.connectStats.avg, e.connectStats.weight),
"delay", fmt.Sprintf("%v/%v", time.Duration(e.delayStats.avg), e.delayStats.weight),
"response", fmt.Sprintf("%v/%v", time.Duration(e.responseStats.avg), e.responseStats.weight),
"timeout", fmt.Sprintf("%v/%v", e.timeoutStats.avg, e.timeoutStats.weight))
- pool.entries[e.id] = e
+ pool.entries[e.node.ID()] = e
pool.knownQueue.setLatest(e)
pool.knownSelect.update((*knownEntry)(e))
}
@@ -424,7 +495,7 @@ func (pool *serverPool) removeEntry(entry *poolEntry) {
pool.newSelect.remove((*discoveredEntry)(entry))
pool.knownSelect.remove((*knownEntry)(entry))
entry.removed = true
- delete(pool.entries, entry.id)
+ delete(pool.entries, entry.node.ID())
}
// setRetryDial starts the timer which will enable dialing a certain node again
@@ -502,10 +573,10 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) {
pool.newSelected++
}
addr := entry.addrSelect.choose().(*poolEntryAddress)
- log.Debug("Dialing new peer", "lesaddr", entry.id.String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected)
+ log.Debug("Dialing new peer", "lesaddr", entry.node.ID().String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected)
entry.dialed = addr
go func() {
- pool.server.AddPeer(discover.NewNode(entry.id, addr.ip, addr.port, addr.port))
+ pool.server.AddPeer(entry.node)
select {
case <-pool.quit:
case <-time.After(dialTimeout):
@@ -523,7 +594,7 @@ func (pool *serverPool) checkDialTimeout(entry *poolEntry) {
if entry.state != psDialed {
return
}
- log.Debug("Dial timeout", "lesaddr", entry.id.String()+"@"+entry.dialed.strKey())
+ log.Debug("Dial timeout", "lesaddr", entry.node.ID().String()+"@"+entry.dialed.strKey())
entry.state = psNotConnected
if entry.knownSelected {
pool.knownSelected--
@@ -545,8 +616,9 @@ const (
// poolEntry represents a server node and stores its current state and statistics.
type poolEntry struct {
peer *peer
- id discover.NodeID
+ pubkey [64]byte // secp256k1 key of the node
addr map[string]*poolEntryAddress
+ node *enode.Node
lastConnected, dialed *poolEntryAddress
addrSelect weightedRandomSelect
@@ -563,23 +635,39 @@ type poolEntry struct {
shortRetry int
}
+// poolEntryEnc is the RLP encoding of poolEntry.
+type poolEntryEnc struct {
+ Pubkey []byte
+ IP net.IP
+ Port uint16
+ Fails uint
+ CStat, DStat, RStat, TStat poolStats
+}
+
func (e *poolEntry) EncodeRLP(w io.Writer) error {
- return rlp.Encode(w, []interface{}{e.id, e.lastConnected.ip, e.lastConnected.port, e.lastConnected.fails, &e.connectStats, &e.delayStats, &e.responseStats, &e.timeoutStats})
+ return rlp.Encode(w, &poolEntryEnc{
+ Pubkey: encodePubkey64(e.node.Pubkey()),
+ IP: e.lastConnected.ip,
+ Port: e.lastConnected.port,
+ Fails: e.lastConnected.fails,
+ CStat: e.connectStats,
+ DStat: e.delayStats,
+ RStat: e.responseStats,
+ TStat: e.timeoutStats,
+ })
}
func (e *poolEntry) DecodeRLP(s *rlp.Stream) error {
- var entry struct {
- ID discover.NodeID
- IP net.IP
- Port uint16
- Fails uint
- CStat, DStat, RStat, TStat poolStats
- }
+ var entry poolEntryEnc
if err := s.Decode(&entry); err != nil {
return err
}
+ pubkey, err := decodePubkey64(entry.Pubkey)
+ if err != nil {
+ return err
+ }
addr := &poolEntryAddress{ip: entry.IP, port: entry.Port, fails: entry.Fails, lastSeen: mclock.Now()}
- e.id = entry.ID
+ e.node = enode.NewV4(pubkey, entry.IP, int(entry.Port), int(entry.Port))
e.addr = make(map[string]*poolEntryAddress)
e.addr[addr.strKey()] = addr
e.addrSelect = *newWeightedRandomSelect()
@@ -594,6 +682,14 @@ func (e *poolEntry) DecodeRLP(s *rlp.Stream) error {
return nil
}
+func encodePubkey64(pub *ecdsa.PublicKey) []byte {
+ return crypto.FromECDSAPub(pub)[:1]
+}
+
+func decodePubkey64(b []byte) (*ecdsa.PublicKey, error) {
+ return crypto.UnmarshalPubkey(append([]byte{0x04}, b...))
+}
+
// discoveredEntry implements wrsItem
type discoveredEntry poolEntry
@@ -605,9 +701,8 @@ func (e *discoveredEntry) Weight() int64 {
t := time.Duration(mclock.Now() - e.lastDiscovered)
if t <= discoverExpireStart {
return 1000000000
- } else {
- return int64(1000000000 * math.Exp(-float64(t-discoverExpireStart)/float64(discoverExpireConst)))
}
+ return int64(1000000000 * math.Exp(-float64(t-discoverExpireStart)/float64(discoverExpireConst)))
}
// knownEntry implements wrsItem
diff --git a/les/sync.go b/les/sync.go
index 8e3cd47ca..993e96a58 100644
--- a/les/sync.go
+++ b/les/sync.go
@@ -20,7 +20,7 @@ import (
"context"
"time"
- "github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/eth/downloader"
"github.com/tomochain/tomochain/light"
)
@@ -61,7 +61,7 @@ func (pm *ProtocolManager) syncer() {
func (pm *ProtocolManager) needToSync(peerHead blockInfo) bool {
head := pm.blockchain.CurrentHeader()
- currentTd := core.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64())
+ currentTd := rawdb.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64())
return currentTd != nil && peerHead.Td.Cmp(currentTd) > 0
}
diff --git a/light/lightchain.go b/light/lightchain.go
index 6c9138977..42717f1ce 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -24,10 +24,12 @@ import (
"sync/atomic"
"time"
- "github.com/hashicorp/golang-lru"
+ lru "github.com/hashicorp/golang-lru"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
@@ -142,7 +144,7 @@ func (self *LightChain) Odr() OdrBackend {
// loadLastState loads the last known chain state from the database. This method
// assumes that the chain manager mutex is held.
func (self *LightChain) loadLastState() error {
- if head := core.GetHeadHeaderHash(self.chainDb); head == (common.Hash{}) {
+ if head := rawdb.GetHeadHeaderHash(self.chainDb); head == (common.Hash{}) {
// Corrupt or empty database, init from scratch
self.Reset()
} else {
@@ -189,10 +191,10 @@ func (bc *LightChain) ResetWithGenesisBlock(genesis *types.Block) {
defer bc.mu.Unlock()
// Prepare the genesis block and reinitialise the chain
- if err := core.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
+ if err := rawdb.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
log.Crit("Failed to write genesis block TD", "err", err)
}
- if err := core.WriteBlock(bc.chainDb, genesis); err != nil {
+ if err := rawdb.WriteBlock(bc.chainDb, genesis); err != nil {
log.Crit("Failed to write genesis block", "err", err)
}
bc.genesisBlock = genesis
diff --git a/light/lightchain_test.go b/light/lightchain_test.go
index 21836cc88..073efecb0 100644
--- a/light/lightchain_test.go
+++ b/light/lightchain_test.go
@@ -18,13 +18,13 @@ package light
import (
"context"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"testing"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/params"
@@ -123,8 +123,8 @@ func testHeaderChainImport(chain []*types.Header, lightchain *LightChain) error
}
// Manually insert the header into the database, but don't reorganize (allows subsequent testing)
lightchain.mu.Lock()
- core.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash)))
- core.WriteHeader(lightchain.chainDb, header)
+ rawdb.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash)))
+ rawdb.WriteHeader(lightchain.chainDb, header)
lightchain.mu.Unlock()
}
return nil
diff --git a/light/odr.go b/light/odr.go
index b5591fdd9..9fe919cb3 100644
--- a/light/odr.go
+++ b/light/odr.go
@@ -24,6 +24,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
)
@@ -112,7 +113,7 @@ type BlockRequest struct {
// StoreResult stores the retrieved data in local database
func (req *BlockRequest) StoreResult(db ethdb.Database) {
- core.WriteBodyRLP(db, req.Hash, req.Number, req.Rlp)
+ rawdb.WriteBodyRLP(db, req.Hash, req.Number, req.Rlp)
}
// ReceiptsRequest is the ODR request type for retrieving block bodies
@@ -125,7 +126,7 @@ type ReceiptsRequest struct {
// StoreResult stores the retrieved data in local database
func (req *ReceiptsRequest) StoreResult(db ethdb.Database) {
- core.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts)
+ rawdb.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts)
}
// ChtRequest is the ODR request type for state/storage trie entries
@@ -141,10 +142,10 @@ type ChtRequest struct {
// StoreResult stores the retrieved data in local database
func (req *ChtRequest) StoreResult(db ethdb.Database) {
// if there is a canonical hash, there is a header too
- core.WriteHeader(db, req.Header)
+ rawdb.WriteHeader(db, req.Header)
hash, num := req.Header.Hash(), req.Header.Number.Uint64()
- core.WriteTd(db, hash, num, req.Td)
- core.WriteCanonicalHash(db, hash, num)
+ rawdb.WriteTd(db, hash, num, req.Td)
+ rawdb.WriteCanonicalHash(db, hash, num)
}
// BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure
@@ -161,11 +162,11 @@ type BloomRequest struct {
// StoreResult stores the retrieved data in local database
func (req *BloomRequest) StoreResult(db ethdb.Database) {
for i, sectionIdx := range req.SectionIdxList {
- sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1)
+ sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1)
// if we don't have the canonical hash stored for this section head number, we'll still store it under
// a key with a zero sectionHead. GetBloomBits will look there too if we still don't have the canonical
// hash. In the unlikely case we've retrieved the section head hash since then, we'll just retrieve the
// bit vector again from the network.
- core.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i])
+ rawdb.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i])
}
}
diff --git a/light/odr_test.go b/light/odr_test.go
index 0c5fc7857..c4b9e1064 100644
--- a/light/odr_test.go
+++ b/light/odr_test.go
@@ -20,16 +20,16 @@ import (
"bytes"
"context"
"errors"
- "github.com/tomochain/tomochain/consensus"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"testing"
"time"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/math"
+ "github.com/tomochain/tomochain/consensus"
"github.com/tomochain/tomochain/consensus/ethash"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/core/vm"
@@ -43,7 +43,7 @@ import (
var (
testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey)
- testBankFunds = big.NewInt(100000000)
+ testBankFunds = big.NewInt(1_000_000_000_000_000_000)
acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee")
@@ -72,9 +72,12 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
}
switch req := req.(type) {
case *BlockRequest:
- req.Rlp = core.GetBodyRLP(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash))
+ req.Rlp = rawdb.GetBodyRLP(odr.sdb, req.Hash, rawdb.GetBlockNumber(odr.sdb, req.Hash))
case *ReceiptsRequest:
- req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash))
+ number := rawdb.GetBlockNumber(odr.sdb, req.Hash)
+ if number != rawdb.MissingNumber {
+ req.Receipts = rawdb.ReadRawReceipts(odr.sdb, req.Hash, number)
+ }
case *TrieRequest:
t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb))
nodes := NewNodeSet()
@@ -110,9 +113,13 @@ func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) }
func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
var receipts types.Receipts
if bc != nil {
- receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash))
+ if number := rawdb.GetBlockNumber(db, bhash); number != rawdb.MissingNumber {
+ if header := rawdb.GetHeader(db, bhash, number); header != nil {
+ receipts = rawdb.GetBlockReceipts(db, bhash, number, bc.Config())
+ }
+ }
} else {
- receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash))
+ receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, rawdb.GetBlockNumber(db, bhash), lc.Config())
}
if receipts == nil {
return nil, nil
@@ -133,7 +140,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc
st = NewState(ctx, header, lc.Odr())
} else {
header := bc.GetHeaderByHash(bhash)
- st, _ = state.New(header.Root, state.NewDatabase(db))
+ st, _ = state.New(header.Root, state.NewDatabase(db), nil)
}
var res []byte
@@ -148,7 +155,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc
func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) }
type callmsg struct {
- types.Message
+ core.Message
}
func (callmsg) CheckNonce() bool { return false }
@@ -173,7 +180,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain
} else {
chain = bc
header = bc.GetHeaderByHash(bhash)
- st, _ = state.New(header.Root, state.NewDatabase(db))
+ st, _ = state.New(header.Root, state.NewDatabase(db), nil)
}
// Perform read-only call.
@@ -183,13 +190,22 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain
if value, ok := feeCapacity[testContractAddr]; ok {
balanceTokenFee = value
}
- msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, false, balanceTokenFee)}
+ msg := &core.Message{
+ From: testBankAddress,
+ To: &testContractAddr,
+ Value: new(big.Int),
+ GasLimit: 1000000,
+ GasPrice: new(big.Int),
+ Data: data,
+ SkipAccountChecks: true,
+ BalanceTokenFee: balanceTokenFee,
+ }
context := core.NewEVMContext(msg, header, chain, nil)
vmenv := vm.NewEVM(context, st, nil, config, vm.Config{})
gp := new(core.GasPool).AddGas(math.MaxUint64)
owner := common.Address{}
- ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner)
- res = append(res, ret...)
+ ret, _ := core.ApplyMessage(vmenv, msg, gp, owner)
+ res = append(res, ret.Return()...)
if st.Error() != nil {
return res, st.Error()
}
@@ -202,17 +218,17 @@ func testChainGen(i int, block *core.BlockGen) {
switch i {
case 0:
// In block 1, the test bank sends account #1 some ether.
- tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey)
+ tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10_000_000_000_000_000), params.TxGas, nil, nil), signer, testBankKey)
block.AddTx(tx)
case 1:
// In block 2, the test bank sends some more ether to account #1.
// acc1Addr passes it on to account #2.
// acc1Addr creates a test contract.
- tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey)
+ tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1_000_000_000_000_000), params.TxGas, nil, nil), signer, testBankKey)
nonce := block.TxNonce(acc1Addr)
- tx2, _ := types.SignTx(types.NewTransaction(nonce, acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key)
+ tx2, _ := types.SignTx(types.NewTransaction(nonce, acc2Addr, big.NewInt(1_000_000_000_000_000), params.TxGas, nil, nil), signer, acc1Key)
nonce++
- tx3, _ := types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 1000000, big.NewInt(0), testContractCode), signer, acc1Key)
+ tx3, _ := types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 1000000, nil, testContractCode), signer, acc1Key)
testContractAddr = crypto.CreateAddress(acc1Addr, nonce)
block.AddTx(tx1)
block.AddTx(tx2)
@@ -240,9 +256,12 @@ func testChainGen(i int, block *core.BlockGen) {
func testChainOdr(t *testing.T, protocol int, fn odrTestFn) {
var (
- sdb = rawdb.NewMemoryDatabase()
- ldb = rawdb.NewMemoryDatabase()
- gspec = core.Genesis{Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}}
+ sdb = rawdb.NewMemoryDatabase()
+ ldb = rawdb.NewMemoryDatabase()
+ gspec = core.Genesis{
+ Config: params.TestChainConfig,
+ Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}},
+ }
genesis = gspec.MustCommit(sdb)
)
gspec.MustCommit(ldb)
@@ -268,7 +287,7 @@ func testChainOdr(t *testing.T, protocol int, fn odrTestFn) {
test := func(expFail int) {
for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ {
- bhash := core.GetCanonicalHash(sdb, i)
+ bhash := rawdb.GetCanonicalHash(sdb, i)
b1, err := fn(NoOdr, sdb, blockchain, nil, bhash)
if err != nil {
t.Fatalf("error in full-node test for block %d: %v", i, err)
diff --git a/light/odr_util.go b/light/odr_util.go
index 89a63eb2b..9d38c45b2 100644
--- a/light/odr_util.go
+++ b/light/odr_util.go
@@ -22,8 +22,10 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rlp"
)
@@ -31,10 +33,10 @@ var sha3_nil = crypto.Keccak256Hash(nil)
func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) {
db := odr.Database()
- hash := core.GetCanonicalHash(db, number)
+ hash := rawdb.GetCanonicalHash(db, number)
if (hash != common.Hash{}) {
// if there is a canonical hash, there is a header too
- header := core.GetHeader(db, hash, number)
+ header := rawdb.GetHeader(db, hash, number)
if header == nil {
panic("Canonical hash present but header not found")
}
@@ -47,14 +49,14 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ
)
if odr.ChtIndexer() != nil {
chtCount, sectionHeadNum, sectionHead = odr.ChtIndexer().Sections()
- canonicalHash := core.GetCanonicalHash(db, sectionHeadNum)
+ canonicalHash := rawdb.GetCanonicalHash(db, sectionHeadNum)
// if the CHT was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too
for chtCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) {
chtCount--
if chtCount > 0 {
sectionHeadNum = chtCount*CHTFrequencyClient - 1
sectionHead = odr.ChtIndexer().SectionHead(chtCount - 1)
- canonicalHash = core.GetCanonicalHash(db, sectionHeadNum)
+ canonicalHash = rawdb.GetCanonicalHash(db, sectionHeadNum)
}
}
}
@@ -69,7 +71,7 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ
}
func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (common.Hash, error) {
- hash := core.GetCanonicalHash(odr.Database(), number)
+ hash := rawdb.GetCanonicalHash(odr.Database(), number)
if (hash != common.Hash{}) {
return hash, nil
}
@@ -82,7 +84,7 @@ func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (commo
// GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding.
func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) {
- if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil {
+ if data := rawdb.GetBodyRLP(odr.Database(), hash, number); data != nil {
return data, nil
}
r := &BlockRequest{Hash: hash, Number: number}
@@ -111,7 +113,7 @@ func GetBody(ctx context.Context, odr OdrBackend, hash common.Hash, number uint6
// back from the stored header and body.
func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (*types.Block, error) {
// Retrieve the block header and body contents
- header := core.GetHeader(odr.Database(), hash, number)
+ header := rawdb.GetHeader(odr.Database(), hash, number)
if header == nil {
return nil, ErrNoHeader
}
@@ -125,9 +127,9 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint
// GetBlockReceipts retrieves the receipts generated by the transactions included
// in a block given by its hash.
-func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (types.Receipts, error) {
+func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64, config *params.ChainConfig) (types.Receipts, error) {
// Retrieve the potentially incomplete receipts from disk or network
- receipts := core.GetBlockReceipts(odr.Database(), hash, number)
+ receipts := rawdb.GetBlockReceipts(odr.Database(), hash, number, config)
if receipts == nil {
r := &ReceiptsRequest{Hash: hash, Number: number}
if err := odr.Retrieve(ctx, r); err != nil {
@@ -141,22 +143,22 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num
if err != nil {
return nil, err
}
- genesis := core.GetCanonicalHash(odr.Database(), 0)
- config, _ := core.GetChainConfig(odr.Database(), genesis)
+ genesis := rawdb.GetCanonicalHash(odr.Database(), 0)
+ config, _ := rawdb.GetChainConfig(odr.Database(), genesis)
if err := core.SetReceiptsData(config, block, receipts); err != nil {
return nil, err
}
- core.WriteBlockReceipts(odr.Database(), hash, number, receipts)
+ rawdb.WriteBlockReceipts(odr.Database(), hash, number, receipts)
}
return receipts, nil
}
// GetBlockLogs retrieves the logs generated by the transactions included in a
// block given by its hash.
-func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) ([][]*types.Log, error) {
+func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64, config *params.ChainConfig) ([][]*types.Log, error) {
// Retrieve the potentially incomplete receipts from disk or network
- receipts := core.GetBlockReceipts(odr.Database(), hash, number)
+ receipts := rawdb.GetBlockReceipts(odr.Database(), hash, number, config)
if receipts == nil {
r := &ReceiptsRequest{Hash: hash, Number: number}
if err := odr.Retrieve(ctx, r); err != nil {
@@ -187,24 +189,24 @@ func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxLi
)
if odr.BloomTrieIndexer() != nil {
bloomTrieCount, sectionHeadNum, sectionHead = odr.BloomTrieIndexer().Sections()
- canonicalHash := core.GetCanonicalHash(db, sectionHeadNum)
+ canonicalHash := rawdb.GetCanonicalHash(db, sectionHeadNum)
// if the BloomTrie was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too
for bloomTrieCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) {
bloomTrieCount--
if bloomTrieCount > 0 {
sectionHeadNum = bloomTrieCount*BloomTrieFrequency - 1
sectionHead = odr.BloomTrieIndexer().SectionHead(bloomTrieCount - 1)
- canonicalHash = core.GetCanonicalHash(db, sectionHeadNum)
+ canonicalHash = rawdb.GetCanonicalHash(db, sectionHeadNum)
}
}
}
for i, sectionIdx := range sectionIdxList {
- sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1)
+ sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1)
// if we don't have the canonical hash stored for this section head number, we'll still look for
// an entry with a zero sectionHead (we store it with zero section head too if we don't know it
// at the time of the retrieval)
- bloomBits, err := core.GetBloomBits(db, bitIdx, sectionIdx, sectionHead)
+ bloomBits, err := rawdb.GetBloomBits(db, bitIdx, sectionIdx, sectionHead)
if err == nil {
result[i] = bloomBits
} else {
diff --git a/light/postprocess.go b/light/postprocess.go
index 1e83a3cd7..22526d943 100644
--- a/light/postprocess.go
+++ b/light/postprocess.go
@@ -19,13 +19,13 @@ package light
import (
"encoding/binary"
"errors"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"time"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/bitutil"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/log"
@@ -162,7 +162,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) {
hash, num := header.Hash(), header.Number.Uint64()
c.lastHash = hash
- td := core.GetTd(c.diskdb, hash, num)
+ td := rawdb.GetTd(c.diskdb, hash, num)
if td == nil {
panic(nil)
}
@@ -273,7 +273,7 @@ func (b *BloomTrieIndexerBackend) Commit() error {
binary.BigEndian.PutUint64(encKey[2:10], b.section)
var decomp []byte
for j := uint64(0); j < b.bloomTrieRatio; j++ {
- data, err := core.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j])
+ data, err := rawdb.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j])
if err != nil {
return err
}
diff --git a/light/trie.go b/light/trie.go
index d247f145e..8d32392f4 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -26,11 +26,12 @@ import (
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/ethdb"
+ "github.com/tomochain/tomochain/rlp"
"github.com/tomochain/tomochain/trie"
)
func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB {
- state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr))
+ state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil)
return state
}
@@ -95,27 +96,74 @@ type odrTrie struct {
trie *trie.Trie
}
-func (t *odrTrie) TryGet(key []byte) ([]byte, error) {
+func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) {
key = crypto.Keccak256(key)
- var res []byte
+ var enc []byte
err := t.do(key, func() (err error) {
- res, err = t.trie.TryGet(key)
+ enc, err = t.trie.Get(key)
return err
})
- return res, err
+ if err != nil || len(enc) == 0 {
+ return nil, err
+ }
+ _, content, _, err := rlp.Split(enc)
+ return content, err
+}
+
+func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error) {
+ var (
+ enc []byte
+ key = crypto.Keccak256(address.Bytes())
+ )
+ err := t.do(key, func() (err error) {
+ enc, err = t.trie.Get(key)
+ return err
+ })
+ if err != nil || len(enc) == 0 {
+ return nil, err
+ }
+ acct := new(types.StateAccount)
+ if err := rlp.DecodeBytes(enc, acct); err != nil {
+ return nil, err
+ }
+ return acct, nil
+}
+
+func (t *odrTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error {
+ key := crypto.Keccak256(address.Bytes())
+ value, err := rlp.EncodeToBytes(acc)
+ if err != nil {
+ return fmt.Errorf("decoding error in account update: %w", err)
+ }
+ return t.do(key, func() error {
+ return t.trie.Update(key, value)
+ })
+}
+
+func (t *odrTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error {
+ return nil
}
-func (t *odrTrie) TryUpdate(key, value []byte) error {
+func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error {
key = crypto.Keccak256(key)
+ v, _ := rlp.EncodeToBytes(value)
return t.do(key, func() error {
- return t.trie.TryDelete(key)
+ return t.trie.Update(key, v)
})
}
-func (t *odrTrie) TryDelete(key []byte) error {
+func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error {
key = crypto.Keccak256(key)
return t.do(key, func() error {
- return t.trie.TryDelete(key)
+ return t.trie.Delete(key)
+ })
+}
+
+// DeleteAccount abstracts an account deletion from the trie.
+func (t *odrTrie) DeleteAccount(address common.Address) error {
+ key := crypto.Keccak256(address.Bytes())
+ return t.do(key, func() error {
+ return t.trie.Delete(key)
})
}
diff --git a/light/txpool.go b/light/txpool.go
index 7af86dbd6..d5bee2a3b 100644
--- a/light/txpool.go
+++ b/light/txpool.go
@@ -24,6 +24,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/core/types"
"github.com/tomochain/tomochain/ethdb"
@@ -74,10 +75,13 @@ type TxPool struct {
//
// Send instructs backend to forward new transactions
// NewHead notifies backend about a new head after processed by the tx pool,
-// including mined and rolled back transactions since the last event
+//
+// including mined and rolled back transactions since the last event
+//
// Discard notifies backend about transactions that should be discarded either
-// because they have been replaced by a re-send or because they have been mined
-// long ago and no rollback is expected
+//
+// because they have been replaced by a re-send or because they have been mined
+// long ago and no rollback is expected
type TxRelayBackend interface {
Send(txs types.Transactions)
NewHead(head common.Hash, mined []common.Hash, rollback []common.Hash)
@@ -180,10 +184,10 @@ func (pool *TxPool) checkMinedTxs(ctx context.Context, hash common.Hash, number
// If some transactions have been mined, write the needed data to disk and update
if list != nil {
// Retrieve all the receipts belonging to this block and write the loopup table
- if _, err := GetBlockReceipts(ctx, pool.odr, hash, number); err != nil { // ODR caches, ignore results
+ if _, err := GetBlockReceipts(ctx, pool.odr, hash, number, pool.config); err != nil { // ODR caches, ignore results
return err
}
- if err := core.WriteTxLookupEntries(pool.chainDb, block); err != nil {
+ if err := rawdb.WriteTxLookupEntries(pool.chainDb, block); err != nil {
return err
}
// Update the transaction pool's state
@@ -202,7 +206,7 @@ func (pool *TxPool) rollbackTxs(hash common.Hash, txc txStateChanges) {
if list, ok := pool.mined[hash]; ok {
for _, tx := range list {
txHash := tx.Hash()
- core.DeleteTxLookupEntry(pool.chainDb, txHash)
+ rawdb.DeleteTxLookupEntry(pool.chainDb, txHash)
pool.pending[txHash] = tx
txc.setState(txHash, false)
}
@@ -258,7 +262,7 @@ func (pool *TxPool) reorgOnNewHead(ctx context.Context, newHeader *types.Header)
idx2 := idx - txPermanent
if len(pool.mined) > 0 {
for i := pool.clearIdx; i < idx2; i++ {
- hash := core.GetCanonicalHash(pool.chainDb, i)
+ hash := rawdb.GetCanonicalHash(pool.chainDb, i)
if list, ok := pool.mined[hash]; ok {
hashes := make([]common.Hash, len(list))
for i, tx := range list {
diff --git a/metrics/metrics.go b/metrics/metrics.go
index dbb2727ec..3e315b19e 100644
--- a/metrics/metrics.go
+++ b/metrics/metrics.go
@@ -19,7 +19,12 @@ import (
//
// This global kill-switch helps quantify the observer effect and makes
// for less cluttered pprof profiles.
-var Enabled bool = false
+var Enabled = false
+
+// EnabledExpensive is a soft-flag meant for external packages to check if costly
+// metrics gathering is allowed or not. The goal is to separate standard metrics
+// for health monitoring and debug metrics that might impact runtime performance.
+var EnabledExpensive = false
// MetricsEnabledFlag is the CLI flag name to use to enable metrics collections.
const MetricsEnabledFlag = "metrics"
diff --git a/miner/worker.go b/miner/worker.go
index 995c40169..a8985a2a8 100644
--- a/miner/worker.go
+++ b/miner/worker.go
@@ -23,6 +23,7 @@ import (
"github.com/tomochain/tomochain/accounts"
"github.com/tomochain/tomochain/tomoxlending/lendingstate"
+ "github.com/tomochain/tomochain/trie"
"math/big"
"os"
@@ -204,6 +205,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) {
self.current.txs,
nil,
self.current.receipts,
+ new(trie.Trie),
), self.current.state.Copy()
}
return self.current.Block, self.current.state.Copy()
@@ -219,6 +221,7 @@ func (self *worker) pendingBlock() *types.Block {
self.current.txs,
nil,
self.current.receipts,
+ new(trie.Trie),
)
}
return self.current.Block
diff --git a/node/api.go b/node/api.go
index 23edbe2b3..25b67ac1b 100644
--- a/node/api.go
+++ b/node/api.go
@@ -26,7 +26,7 @@ import (
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/metrics"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rpc"
)
@@ -51,9 +51,9 @@ func (api *PrivateAdminAPI) AddPeer(url string) (bool, error) {
return false, ErrNodeStopped
}
// Try to add the url as a static peer and return
- node, err := discover.ParseNode(url)
+ node, err := enode.Parse(enode.ValidSchemes, url)
if err != nil {
- return false, fmt.Errorf("invalid enode: %v", err)
+ return false, fmt.Errorf("invalid enode url: %v, err %v", url, err)
}
server.AddPeer(node)
return true, nil
@@ -67,7 +67,7 @@ func (api *PrivateAdminAPI) RemovePeer(url string) (bool, error) {
return false, ErrNodeStopped
}
// Try to remove the url as a static peer and return
- node, err := discover.ParseNode(url)
+ node, err := enode.ParseV4(url)
if err != nil {
return false, fmt.Errorf("invalid enode: %v", err)
}
diff --git a/node/config.go b/node/config.go
index b8ad712fc..1eb4e528d 100644
--- a/node/config.go
+++ b/node/config.go
@@ -32,7 +32,7 @@ import (
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
const (
@@ -333,18 +333,18 @@ func (c *Config) NodeKey() *ecdsa.PrivateKey {
}
// StaticNodes returns a list of node enode URLs configured as static nodes.
-func (c *Config) StaticNodes() []*discover.Node {
+func (c *Config) StaticNodes() []*enode.Node {
return c.parsePersistentNodes(c.resolvePath(datadirStaticNodes))
}
// TrustedNodes returns a list of node enode URLs configured as trusted nodes.
-func (c *Config) TrustedNodes() []*discover.Node {
+func (c *Config) TrustedNodes() []*enode.Node {
return c.parsePersistentNodes(c.resolvePath(datadirTrustedNodes))
}
// parsePersistentNodes parses a list of discovery node URLs loaded from a .json
// file from within the data directory.
-func (c *Config) parsePersistentNodes(path string) []*discover.Node {
+func (c *Config) parsePersistentNodes(path string) []*enode.Node {
// Short circuit if no node config is present
if c.DataDir == "" {
return nil
@@ -359,12 +359,12 @@ func (c *Config) parsePersistentNodes(path string) []*discover.Node {
return nil
}
// Interpret the list as a discovery node array
- var nodes []*discover.Node
+ var nodes []*enode.Node
for _, url := range nodelist {
if url == "" {
continue
}
- node, err := discover.ParseNode(url)
+ node, err := enode.ParseV4(url)
if err != nil {
log.Error(fmt.Sprintf("Node URL %s: %v\n", url, err))
continue
diff --git a/p2p/dial.go b/p2p/dial.go
index 454d2198c..7f93a12be 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -18,14 +18,13 @@ package p2p
import (
"container/heap"
- "crypto/rand"
"errors"
"fmt"
"net"
"time"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/netutil"
)
@@ -50,7 +49,7 @@ const (
// NodeDialer is used to connect to nodes in the network, typically by using
// an underlying net.Dialer but also using net.Pipe in tests
type NodeDialer interface {
- Dial(*discover.Node) (net.Conn, error)
+ Dial(*enode.Node) (net.Conn, error)
}
// TCPDialer implements the NodeDialer interface by using a net.Dialer to
@@ -60,8 +59,8 @@ type TCPDialer struct {
}
// Dial creates a TCP connection to the node
-func (t TCPDialer) Dial(dest *discover.Node) (net.Conn, error) {
- addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)}
+func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) {
+ addr := &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()}
return t.Dialer.Dial("tcp", addr.String())
}
@@ -74,22 +73,22 @@ type dialstate struct {
netrestrict *netutil.Netlist
lookupRunning bool
- dialing map[discover.NodeID]connFlag
- lookupBuf []*discover.Node // current discovery lookup results
- randomNodes []*discover.Node // filled from Table
- static map[discover.NodeID]*dialTask
+ dialing map[enode.ID]connFlag
+ lookupBuf []*enode.Node // current discovery lookup results
+ randomNodes []*enode.Node // filled from Table
+ static map[enode.ID]*dialTask
hist *dialHistory
- start time.Time // time when the dialer was first used
- bootnodes []*discover.Node // default dials when there are no peers
+ start time.Time // time when the dialer was first used
+ bootnodes []*enode.Node // default dials when there are no peers
}
type discoverTable interface {
- Self() *discover.Node
+ Self() *enode.Node
Close()
- Resolve(target discover.NodeID) *discover.Node
- Lookup(target discover.NodeID) []*discover.Node
- ReadRandomNodes([]*discover.Node) int
+ Resolve(*enode.Node) *enode.Node
+ LookupRandom() []*enode.Node
+ ReadRandomNodes([]*enode.Node) int
}
// the dial history remembers recent dials.
@@ -97,7 +96,7 @@ type dialHistory []pastDial
// pastDial is an entry in the dial history.
type pastDial struct {
- id discover.NodeID
+ id enode.ID
exp time.Time
}
@@ -109,7 +108,7 @@ type task interface {
// fields cannot be accessed while the task is running.
type dialTask struct {
flags connFlag
- dest *discover.Node
+ dest *enode.Node
lastResolved time.Time
resolveDelay time.Duration
}
@@ -118,7 +117,7 @@ type dialTask struct {
// Only one discoverTask is active at any time.
// discoverTask.Do performs a random lookup.
type discoverTask struct {
- results []*discover.Node
+ results []*enode.Node
}
// A waitExpireTask is generated if there are no other tasks
@@ -127,15 +126,15 @@ type waitExpireTask struct {
time.Duration
}
-func newDialState(static []*discover.Node, bootnodes []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
+func newDialState(static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
netrestrict: netrestrict,
- static: make(map[discover.NodeID]*dialTask),
- dialing: make(map[discover.NodeID]connFlag),
- bootnodes: make([]*discover.Node, len(bootnodes)),
- randomNodes: make([]*discover.Node, maxdyn/2),
+ static: make(map[enode.ID]*dialTask),
+ dialing: make(map[enode.ID]connFlag),
+ bootnodes: make([]*enode.Node, len(bootnodes)),
+ randomNodes: make([]*enode.Node, maxdyn/2),
hist: new(dialHistory),
}
copy(s.bootnodes, bootnodes)
@@ -145,32 +144,32 @@ func newDialState(static []*discover.Node, bootnodes []*discover.Node, ntab disc
return s
}
-func (s *dialstate) addStatic(n *discover.Node) {
- // This overwites the task instead of updating an existing
+func (s *dialstate) addStatic(n *enode.Node) {
+ // This overwrites the task instead of updating an existing
// entry, giving users the opportunity to force a resolve operation.
- s.static[n.ID] = &dialTask{flags: staticDialedConn, dest: n}
+ s.static[n.ID()] = &dialTask{flags: staticDialedConn, dest: n}
}
-func (s *dialstate) removeStatic(n *discover.Node) {
+func (s *dialstate) removeStatic(n *enode.Node) {
// This removes a task so future attempts to connect will not be made.
- delete(s.static, n.ID)
+ delete(s.static, n.ID())
// This removes a previous dial timestamp so that application
// can force a server to reconnect with chosen peer immediately.
- s.hist.remove(n.ID)
+ s.hist.remove(n.ID())
}
-func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
+func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
if s.start.IsZero() {
s.start = now
}
var newtasks []task
- addDial := func(flag connFlag, n *discover.Node) bool {
+ addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil {
- log.Trace("Skipping dial candidate", "id", n.ID, "addr", &net.TCPAddr{IP: n.IP, Port: int(n.TCP)}, "err", err)
+ log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
return false
}
- s.dialing[n.ID] = flag
+ s.dialing[n.ID()] = flag
newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
return true
}
@@ -196,8 +195,8 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
err := s.checkDial(t.dest, peers)
switch err {
case errNotWhitelisted, errSelf:
- log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}, "err", err)
- delete(s.static, t.dest.ID)
+ log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
+ delete(s.static, t.dest.ID())
case nil:
s.dialing[id] = t.flags
newtasks = append(newtasks, t)
@@ -260,21 +259,18 @@ var (
errNotWhitelisted = errors.New("not contained in netrestrict whitelist")
)
-func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error {
- _, dialing := s.dialing[n.ID]
+func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
+ _, dialing := s.dialing[n.ID()]
switch {
case dialing:
return errAlreadyDialing
- case peers[n.ID] != nil:
- exitsPeer := peers[n.ID]
- if exitsPeer.PairPeer != nil {
- return errAlreadyConnected
- }
- case s.ntab != nil && n.ID == s.ntab.Self().ID:
+ case peers[n.ID()] != nil:
+ return errAlreadyConnected
+ case s.ntab != nil && n.ID() == s.ntab.Self().ID():
return errSelf
- case s.netrestrict != nil && !s.netrestrict.Contains(n.IP):
+ case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted
- case s.hist.contains(n.ID):
+ case s.hist.contains(n.ID()):
return errRecentlyDialed
}
return nil
@@ -283,8 +279,8 @@ func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer)
func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) {
case *dialTask:
- s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
- delete(s.dialing, t.dest.ID)
+ s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration))
+ delete(s.dialing, t.dest.ID())
case *discoverTask:
s.lookupRunning = false
s.lookupBuf = append(s.lookupBuf, t.results...)
@@ -303,26 +299,10 @@ func (t *dialTask) Do(srv *Server) {
// Try resolving the ID of static nodes if dialing failed.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
if t.resolve(srv) {
- err = t.dial(srv, t.dest)
+ t.dial(srv, t.dest)
}
}
}
- if err == nil {
- err = t.dial(srv, t.dest)
- if err != nil {
- // Try resolving the ID of static nodes if dialing failed.
- if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
- if t.resolve(srv) {
- err = t.dial(srv, t.dest)
- }
- }
- }
- if err == nil {
- log.Trace("Dial pair connection success", "task", t.dest)
- } else {
- log.Trace("Dial pair connection error", "task", t.dest, "err", err)
- }
- }
}
// resolve attempts to find the current endpoint for the destination
@@ -342,7 +322,7 @@ func (t *dialTask) resolve(srv *Server) bool {
if time.Since(t.lastResolved) < t.resolveDelay {
return false
}
- resolved := srv.ntab.Resolve(t.dest.ID)
+ resolved := srv.ntab.Resolve(t.dest)
t.lastResolved = time.Now()
if resolved == nil {
t.resolveDelay *= 2
@@ -355,7 +335,7 @@ func (t *dialTask) resolve(srv *Server) bool {
// The node was found.
t.resolveDelay = initialResolveDelay
t.dest = resolved
- log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)})
+ log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
return true
}
@@ -364,7 +344,7 @@ type dialError struct {
}
// dial performs the actual connection attempt.
-func (t *dialTask) dial(srv *Server, dest *discover.Node) error {
+func (t *dialTask) dial(srv *Server, dest *enode.Node) error {
fd, err := srv.Dialer.Dial(dest)
if err != nil {
return &dialError{err}
@@ -374,7 +354,8 @@ func (t *dialTask) dial(srv *Server, dest *discover.Node) error {
}
func (t *dialTask) String() string {
- return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP)
+ id := t.dest.ID()
+ return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
}
func (t *discoverTask) Do(srv *Server) {
@@ -386,9 +367,7 @@ func (t *discoverTask) Do(srv *Server) {
time.Sleep(next.Sub(now))
}
srv.lastLookup = time.Now()
- var target discover.NodeID
- rand.Read(target[:])
- t.results = srv.ntab.Lookup(target)
+ t.results = srv.ntab.LookupRandom()
}
func (t *discoverTask) String() string {
@@ -410,11 +389,11 @@ func (t waitExpireTask) String() string {
func (h dialHistory) min() pastDial {
return h[0]
}
-func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
+func (h *dialHistory) add(id enode.ID, exp time.Time) {
heap.Push(h, pastDial{id, exp})
}
-func (h *dialHistory) remove(id discover.NodeID) bool {
+func (h *dialHistory) remove(id enode.ID) bool {
for i, v := range *h {
if v.id == id {
heap.Remove(h, i)
@@ -423,7 +402,7 @@ func (h *dialHistory) remove(id discover.NodeID) bool {
}
return false
}
-func (h dialHistory) contains(id discover.NodeID) bool {
+func (h dialHistory) contains(id enode.ID) bool {
for _, v := range h {
if v.id == id {
return true
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
index 362f22c13..411f49a7c 100644
--- a/p2p/dial_test.go
+++ b/p2p/dial_test.go
@@ -24,7 +24,9 @@ import (
"time"
"github.com/davecgh/go-spew/spew"
- "github.com/tomochain/tomochain/p2p/discover"
+
+ "github.com/tomochain/tomochain/p2p/enode"
+ "github.com/tomochain/tomochain/p2p/enr"
"github.com/tomochain/tomochain/p2p/netutil"
)
@@ -48,10 +50,10 @@ func runDialTest(t *testing.T, test dialtest) {
vtime time.Time
running int
)
- pm := func(ps []*Peer) map[discover.NodeID]*Peer {
- m := make(map[discover.NodeID]*Peer)
+ pm := func(ps []*Peer) map[enode.ID]*Peer {
+ m := make(map[enode.ID]*Peer)
for _, p := range ps {
- m[p.rw.id] = p
+ m[p.ID()] = p
}
return m
}
@@ -69,6 +71,7 @@ func runDialTest(t *testing.T, test dialtest) {
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
}
+ t.Log("tasks:", spew.Sdump(new))
// Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second)
@@ -76,13 +79,13 @@ func runDialTest(t *testing.T, test dialtest) {
}
}
-type fakeTable []*discover.Node
+type fakeTable []*enode.Node
-func (t fakeTable) Self() *discover.Node { return new(discover.Node) }
-func (t fakeTable) Close() {}
-func (t fakeTable) Lookup(discover.NodeID) []*discover.Node { return nil }
-func (t fakeTable) Resolve(discover.NodeID) *discover.Node { return nil }
-func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, t) }
+func (t fakeTable) Self() *enode.Node { return new(enode.Node) }
+func (t fakeTable) Close() {}
+func (t fakeTable) LookupRandom() []*enode.Node { return nil }
+func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil }
+func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) }
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
@@ -92,63 +95,63 @@ func TestDialStateDynDial(t *testing.T) {
// A discovery query is launched.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
new: []task{&discoverTask{}},
},
// Dynamic dials are launched when it completes.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
- &discoverTask{results: []*discover.Node{
- {ID: uintID(2)}, // this one is already connected and not dialed.
- {ID: uintID(3)},
- {ID: uintID(4)},
- {ID: uintID(5)},
- {ID: uintID(6)}, // these are not tried because max dyn dials is 5
- {ID: uintID(7)}, // ...
+ &discoverTask{results: []*enode.Node{
+ newNode(uintID(2), nil), // this one is already connected and not dialed.
+ newNode(uintID(3), nil),
+ newNode(uintID(4), nil),
+ newNode(uintID(5), nil),
+ newNode(uintID(6), nil), // these are not tried because max dyn dials is 5
+ newNode(uintID(7), nil), // ...
}},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// Some of the dials complete but no new ones are launched yet because
// the sum of active dial count and dynamic peer count is == maxDynDials.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(3)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(4)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
},
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
},
},
// No new dial tasks are launched in the this round because
// maxDynDials has been reached.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(3)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(4)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
},
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&waitExpireTask{Duration: 14 * time.Second},
@@ -158,29 +161,31 @@ func TestDialStateDynDial(t *testing.T) {
// results from last discovery lookup are reused.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(3)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(4)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
+ },
+ new: []task{
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
},
- new: []task{},
},
// More peers (3,4) drop off and dial for ID 6 completes.
// The last query result from the discovery lookup is reused
// and a new one is spawned because more candidates are needed.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
},
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(6)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(7)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)},
+ &discoverTask{},
},
},
// Peer 7 is connected, but there still aren't enough dynamic peers
@@ -188,29 +193,29 @@ func TestDialStateDynDial(t *testing.T) {
// no new is started.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(7)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}},
},
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(7)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)},
},
},
// Finish the running node discovery with an empty set. A new lookup
// should be immediately requested.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(0)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(5)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(7)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}},
},
done: []task{
&discoverTask{},
},
new: []task{
- &waitExpireTask{Duration: 14 * time.Second},
+ &discoverTask{},
},
},
},
@@ -219,17 +224,17 @@ func TestDialStateDynDial(t *testing.T) {
// Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
func TestDialStateDynDialBootnode(t *testing.T) {
- bootnodes := []*discover.Node{
- {ID: uintID(1)},
- {ID: uintID(2)},
- {ID: uintID(3)},
+ bootnodes := []*enode.Node{
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
+ newNode(uintID(3), nil),
}
table := fakeTable{
- {ID: uintID(4)},
- {ID: uintID(5)},
- {ID: uintID(6)},
- {ID: uintID(7)},
- {ID: uintID(8)},
+ newNode(uintID(4), nil),
+ newNode(uintID(5), nil),
+ newNode(uintID(6), nil),
+ newNode(uintID(7), nil),
+ newNode(uintID(8), nil),
}
runDialTest(t, dialtest{
init: newDialState(nil, bootnodes, table, 5, nil),
@@ -237,16 +242,16 @@ func TestDialStateDynDialBootnode(t *testing.T) {
// 2 dynamic dials attempted, bootnodes pending fallback interval
{
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
&discoverTask{},
},
},
// No dials succeed, bootnodes still pending fallback interval
{
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, bootnodes still pending fallback interval
@@ -254,54 +259,51 @@ func TestDialStateDynDialBootnode(t *testing.T) {
// No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached
{
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, 2nd bootnode is attempted
{
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
},
},
// No dials succeed, 3rd bootnode is attempted
{
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
},
},
// No dials succeed, 1st bootnode is attempted again, expired random nodes retried
{
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// Random dial succeeds, no more bootnodes are attempted
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(4)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
},
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
- },
- new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
},
@@ -312,14 +314,14 @@ func TestDialStateDynDialFromTable(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
table := fakeTable{
- {ID: uintID(1)},
- {ID: uintID(2)},
- {ID: uintID(3)},
- {ID: uintID(4)},
- {ID: uintID(5)},
- {ID: uintID(6)},
- {ID: uintID(7)},
- {ID: uintID(8)},
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
+ newNode(uintID(3), nil),
+ newNode(uintID(4), nil),
+ newNode(uintID(5), nil),
+ newNode(uintID(6), nil),
+ newNode(uintID(7), nil),
+ newNode(uintID(8), nil),
}
runDialTest(t, dialtest{
@@ -328,67 +330,63 @@ func TestDialStateDynDialFromTable(t *testing.T) {
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
&discoverTask{},
},
},
// Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}},
- &discoverTask{results: []*discover.Node{
- {ID: uintID(10)},
- {ID: uintID(11)},
- {ID: uintID(12)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
+ &discoverTask{results: []*enode.Node{
+ newNode(uintID(10), nil),
+ newNode(uintID(11), nil),
+ newNode(uintID(12), nil),
}},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(10)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(11)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(12)}},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
+ &discoverTask{},
},
},
// Dialing nodes 3,4,5 fails. The dials from the lookup succeed.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(10)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(11)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(12)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
},
done: []task{
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(10)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(11)}},
- &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(12)}},
- },
- new: []task{
- &discoverTask{},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
},
},
// Waiting for expiry. No waitExpireTask is launched because the
// discovery query is still running.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(10)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(11)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(12)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
},
},
// Nodes 3,4 are not tried again because only the first two
@@ -396,30 +394,38 @@ func TestDialStateDynDialFromTable(t *testing.T) {
// already connected.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(10)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(11)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(12)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
},
},
},
})
}
+func newNode(id enode.ID, ip net.IP) *enode.Node {
+ var r enr.Record
+ if ip != nil {
+ r.Set(enr.IP(ip))
+ }
+ return enode.SignNull(&r, id)
+}
+
// This test checks that candidates that do not match the netrestrict list are not dialed.
func TestDialStateNetRestrict(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
table := fakeTable{
- {ID: uintID(1), IP: net.ParseIP("127.0.0.1")},
- {ID: uintID(2), IP: net.ParseIP("127.0.0.2")},
- {ID: uintID(3), IP: net.ParseIP("127.0.0.3")},
- {ID: uintID(4), IP: net.ParseIP("127.0.0.4")},
- {ID: uintID(5), IP: net.ParseIP("127.0.2.5")},
- {ID: uintID(6), IP: net.ParseIP("127.0.2.6")},
- {ID: uintID(7), IP: net.ParseIP("127.0.2.7")},
- {ID: uintID(8), IP: net.ParseIP("127.0.2.8")},
+ newNode(uintID(1), net.ParseIP("127.0.0.1")),
+ newNode(uintID(2), net.ParseIP("127.0.0.2")),
+ newNode(uintID(3), net.ParseIP("127.0.0.3")),
+ newNode(uintID(4), net.ParseIP("127.0.0.4")),
+ newNode(uintID(5), net.ParseIP("127.0.2.5")),
+ newNode(uintID(6), net.ParseIP("127.0.2.6")),
+ newNode(uintID(7), net.ParseIP("127.0.2.7")),
+ newNode(uintID(8), net.ParseIP("127.0.2.8")),
}
restrict := new(netutil.Netlist)
restrict.Add("127.0.2.0/24")
@@ -439,12 +445,12 @@ func TestDialStateNetRestrict(t *testing.T) {
// This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) {
- wantStatic := []*discover.Node{
- {ID: uintID(1)},
- {ID: uintID(2)},
- {ID: uintID(3)},
- {ID: uintID(4)},
- {ID: uintID(5)},
+ wantStatic := []*enode.Node{
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
+ newNode(uintID(3), nil),
+ newNode(uintID(4), nil),
+ newNode(uintID(5), nil),
}
runDialTest(t, dialtest{
@@ -454,70 +460,67 @@ func TestDialStateStaticDial(t *testing.T) {
// aren't yet connected.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No new tasks are launched in this round because all static
// nodes are either connected or still being dialed.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
- },
- new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
},
done: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
},
},
// No new dial tasks are launched because all static
// nodes are now connected.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(4)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}},
},
done: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(4)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(5)}},
+ &waitExpireTask{Duration: 14 * time.Second},
},
},
// Wait a round for dial history to expire, no new tasks should spawn.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(4)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}},
},
},
// If a static node is dropped, it should be immediately redialed,
// irrespective whether it was originally static or dynamic.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(3)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(5)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}},
+ },
+ new: []task{
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
},
- new: []task{},
},
},
})
@@ -525,9 +528,9 @@ func TestDialStateStaticDial(t *testing.T) {
// This test checks that static peers will be redialed immediately if they were re-added to a static list.
func TestDialStaticAfterReset(t *testing.T) {
- wantStatic := []*discover.Node{
- {ID: uintID(1)},
- {ID: uintID(2)},
+ wantStatic := []*enode.Node{
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
}
rounds := []round{
@@ -535,23 +538,22 @@ func TestDialStaticAfterReset(t *testing.T) {
{
peers: nil,
new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
},
},
// No new dial tasks, all peers are connected.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(1)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
},
new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ &waitExpireTask{Duration: 30 * time.Second},
},
},
}
@@ -563,19 +565,17 @@ func TestDialStaticAfterReset(t *testing.T) {
for _, n := range wantStatic {
dTest.init.removeStatic(n)
dTest.init.addStatic(n)
- delete(dTest.init.dialing, n.ID)
}
-
// without removing peers they will be considered recently dialed
runDialTest(t, dTest)
}
// This test checks that past dials are not retried for some time.
func TestDialStateCache(t *testing.T) {
- wantStatic := []*discover.Node{
- {ID: uintID(1)},
- {ID: uintID(2)},
- {ID: uintID(3)},
+ wantStatic := []*enode.Node{
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
+ newNode(uintID(3), nil),
}
runDialTest(t, dialtest{
@@ -586,53 +586,52 @@ func TestDialStateCache(t *testing.T) {
{
peers: nil,
new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
},
},
// No new tasks are launched in this round because all static
// nodes are either connected or still being dialed.
{
peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, id: uintID(1)}},
- {rw: &conn{flags: staticDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
- },
- new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}},
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
},
},
// A salvage task is launched to wait for node 3's history
// entry to expire.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
+ },
+ new: []task{
+ &waitExpireTask{Duration: 14 * time.Second},
},
},
// Still waiting for node 3's entry to expire in the cache.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
},
// The cache entry for node 3 has expired and is retried.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, id: uintID(1)}},
- {rw: &conn{flags: dynDialedConn, id: uintID(2)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
new: []task{
- &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}},
+ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
},
},
},
@@ -640,12 +639,12 @@ func TestDialStateCache(t *testing.T) {
}
func TestDialResolve(t *testing.T) {
- resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444)
+ resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved}
state := newDialState(nil, nil, table, 0, nil)
// Check that the task is generated with an incomplete ID.
- dest := discover.NewNode(uintID(1), nil, 0, 0)
+ dest := newNode(uintID(1), nil)
state.addStatic(dest)
tasks := state.newTasks(0, nil, time.Time{})
if !reflect.DeepEqual(tasks, []task{&dialTask{flags: staticDialedConn, dest: dest}}) {
@@ -656,7 +655,7 @@ func TestDialResolve(t *testing.T) {
config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}}
srv := &Server{ntab: table, Config: config}
tasks[0].Do(srv)
- if !reflect.DeepEqual(table.resolveCalls, []discover.NodeID{dest.ID}) {
+ if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)
}
@@ -684,25 +683,24 @@ next:
return true
}
-func uintID(i uint32) discover.NodeID {
- var id discover.NodeID
+func uintID(i uint32) enode.ID {
+ var id enode.ID
binary.BigEndian.PutUint32(id[:], i)
return id
}
// implements discoverTable for TestDialResolve
type resolveMock struct {
- resolveCalls []discover.NodeID
- answer *discover.Node
+ resolveCalls []*enode.Node
+ answer *enode.Node
}
-func (t *resolveMock) Resolve(id discover.NodeID) *discover.Node {
- t.resolveCalls = append(t.resolveCalls, id)
+func (t *resolveMock) Resolve(n *enode.Node) *enode.Node {
+ t.resolveCalls = append(t.resolveCalls, n)
return t.answer
}
-func (t *resolveMock) Self() *discover.Node { return new(discover.Node) }
-func (t *resolveMock) Close() {}
-func (t *resolveMock) Bootstrap([]*discover.Node) {}
-func (t *resolveMock) Lookup(discover.NodeID) []*discover.Node { return nil }
-func (t *resolveMock) ReadRandomNodes(buf []*discover.Node) int { return 0 }
+func (t *resolveMock) Self() *enode.Node { return new(enode.Node) }
+func (t *resolveMock) Close() {}
+func (t *resolveMock) LookupRandom() []*enode.Node { return nil }
+func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 }
diff --git a/p2p/discover/database.go b/p2p/discover/database.go
deleted file mode 100644
index 43a4ca37f..000000000
--- a/p2p/discover/database.go
+++ /dev/null
@@ -1,370 +0,0 @@
-// Copyright 2015 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-// Contains the node database, storing previously seen nodes and any collected
-// metadata about them for QoS purposes.
-
-package discover
-
-import (
- "bytes"
- "crypto/rand"
- "encoding/binary"
- "os"
- "sync"
- "time"
-
- "github.com/syndtr/goleveldb/leveldb"
- "github.com/syndtr/goleveldb/leveldb/errors"
- "github.com/syndtr/goleveldb/leveldb/iterator"
- "github.com/syndtr/goleveldb/leveldb/opt"
- "github.com/syndtr/goleveldb/leveldb/storage"
- "github.com/syndtr/goleveldb/leveldb/util"
- "github.com/tomochain/tomochain/crypto"
- "github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/rlp"
-)
-
-var (
- nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element.
- nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
- nodeDBCleanupCycle = time.Hour // Time period for running the expiration task.
-)
-
-// nodeDB stores all nodes we know about.
-type nodeDB struct {
- lvl *leveldb.DB // Interface to the database itself
- self NodeID // Own node id to prevent adding it into the database
- runner sync.Once // Ensures we can start at most one expirer
- quit chan struct{} // Channel to signal the expiring thread to stop
-}
-
-// Schema layout for the node database
-var (
- nodeDBVersionKey = []byte("version") // Version of the database to flush if changes
- nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with
-
- nodeDBDiscoverRoot = ":discover"
- nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping"
- nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong"
- nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail"
-)
-
-// newNodeDB creates a new node database for storing and retrieving infos about
-// known peers in the network. If no path is given, an in-memory, temporary
-// database is constructed.
-func newNodeDB(path string, version int, self NodeID) (*nodeDB, error) {
- if path == "" {
- return newMemoryNodeDB(self)
- }
- return newPersistentNodeDB(path, version, self)
-}
-
-// newMemoryNodeDB creates a new in-memory node database without a persistent
-// backend.
-func newMemoryNodeDB(self NodeID) (*nodeDB, error) {
- db, err := leveldb.Open(storage.NewMemStorage(), nil)
- if err != nil {
- return nil, err
- }
- return &nodeDB{
- lvl: db,
- self: self,
- quit: make(chan struct{}),
- }, nil
-}
-
-// newPersistentNodeDB creates/opens a leveldb backed persistent node database,
-// also flushing its contents in case of a version mismatch.
-func newPersistentNodeDB(path string, version int, self NodeID) (*nodeDB, error) {
- opts := &opt.Options{OpenFilesCacheCapacity: 5}
- db, err := leveldb.OpenFile(path, opts)
- if _, iscorrupted := err.(*errors.ErrCorrupted); iscorrupted {
- db, err = leveldb.RecoverFile(path, nil)
- }
- if err != nil {
- return nil, err
- }
- // The nodes contained in the cache correspond to a certain protocol version.
- // Flush all nodes if the version doesn't match.
- currentVer := make([]byte, binary.MaxVarintLen64)
- currentVer = currentVer[:binary.PutVarint(currentVer, int64(version))]
-
- blob, err := db.Get(nodeDBVersionKey, nil)
- switch err {
- case leveldb.ErrNotFound:
- // Version not found (i.e. empty cache), insert it
- if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil {
- db.Close()
- return nil, err
- }
-
- case nil:
- // Version present, flush if different
- if !bytes.Equal(blob, currentVer) {
- db.Close()
- if err = os.RemoveAll(path); err != nil {
- return nil, err
- }
- return newPersistentNodeDB(path, version, self)
- }
- }
- return &nodeDB{
- lvl: db,
- self: self,
- quit: make(chan struct{}),
- }, nil
-}
-
-// makeKey generates the leveldb key-blob from a node id and its particular
-// field of interest.
-func makeKey(id NodeID, field string) []byte {
- if bytes.Equal(id[:], nodeDBNilNodeID[:]) {
- return []byte(field)
- }
- return append(nodeDBItemPrefix, append(id[:], field...)...)
-}
-
-// splitKey tries to split a database key into a node id and a field part.
-func splitKey(key []byte) (id NodeID, field string) {
- // If the key is not of a node, return it plainly
- if !bytes.HasPrefix(key, nodeDBItemPrefix) {
- return NodeID{}, string(key)
- }
- // Otherwise split the id and field
- item := key[len(nodeDBItemPrefix):]
- copy(id[:], item[:len(id)])
- field = string(item[len(id):])
-
- return id, field
-}
-
-// fetchInt64 retrieves an integer instance associated with a particular
-// database key.
-func (db *nodeDB) fetchInt64(key []byte) int64 {
- blob, err := db.lvl.Get(key, nil)
- if err != nil {
- return 0
- }
- val, read := binary.Varint(blob)
- if read <= 0 {
- return 0
- }
- return val
-}
-
-// storeInt64 update a specific database entry to the current time instance as a
-// unix timestamp.
-func (db *nodeDB) storeInt64(key []byte, n int64) error {
- blob := make([]byte, binary.MaxVarintLen64)
- blob = blob[:binary.PutVarint(blob, n)]
-
- return db.lvl.Put(key, blob, nil)
-}
-
-// node retrieves a node with a given id from the database.
-func (db *nodeDB) node(id NodeID) *Node {
- blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil)
- if err != nil {
- return nil
- }
- node := new(Node)
- if err := rlp.DecodeBytes(blob, node); err != nil {
- log.Error("Failed to decode node RLP", "err", err)
- return nil
- }
- node.sha = crypto.Keccak256Hash(node.ID[:])
- return node
-}
-
-// updateNode inserts - potentially overwriting - a node into the peer database.
-func (db *nodeDB) updateNode(node *Node) error {
- blob, err := rlp.EncodeToBytes(node)
- if err != nil {
- return err
- }
- return db.lvl.Put(makeKey(node.ID, nodeDBDiscoverRoot), blob, nil)
-}
-
-// deleteNode deletes all information/keys associated with a node.
-func (db *nodeDB) deleteNode(id NodeID) error {
- deleter := db.lvl.NewIterator(util.BytesPrefix(makeKey(id, "")), nil)
- for deleter.Next() {
- if err := db.lvl.Delete(deleter.Key(), nil); err != nil {
- return err
- }
- }
- return nil
-}
-
-// ensureExpirer is a small helper method ensuring that the data expiration
-// mechanism is running. If the expiration goroutine is already running, this
-// method simply returns.
-//
-// The goal is to start the data evacuation only after the network successfully
-// bootstrapped itself (to prevent dumping potentially useful seed nodes). Since
-// it would require significant overhead to exactly trace the first successful
-// convergence, it's simpler to "ensure" the correct state when an appropriate
-// condition occurs (i.e. a successful bonding), and discard further events.
-func (db *nodeDB) ensureExpirer() {
- db.runner.Do(func() { go db.expirer() })
-}
-
-// expirer should be started in a go routine, and is responsible for looping ad
-// infinitum and dropping stale data from the database.
-func (db *nodeDB) expirer() {
- tick := time.NewTicker(nodeDBCleanupCycle)
- defer tick.Stop()
- for {
- select {
- case <-tick.C:
- if err := db.expireNodes(); err != nil {
- log.Error("Failed to expire nodedb items", "err", err)
- }
- case <-db.quit:
- return
- }
- }
-}
-
-// expireNodes iterates over the database and deletes all nodes that have not
-// been seen (i.e. received a pong from) for some allotted time.
-func (db *nodeDB) expireNodes() error {
- threshold := time.Now().Add(-nodeDBNodeExpiration)
-
- // Find discovered nodes that are older than the allowance
- it := db.lvl.NewIterator(nil, nil)
- defer it.Release()
-
- for it.Next() {
- // Skip the item if not a discovery node
- id, field := splitKey(it.Key())
- if field != nodeDBDiscoverRoot {
- continue
- }
- // Skip the node if not expired yet (and not self)
- if !bytes.Equal(id[:], db.self[:]) {
- if seen := db.bondTime(id); seen.After(threshold) {
- continue
- }
- }
- // Otherwise delete all associated information
- db.deleteNode(id)
- }
- return nil
-}
-
-// lastPing retrieves the time of the last ping packet send to a remote node,
-// requesting binding.
-func (db *nodeDB) lastPing(id NodeID) time.Time {
- return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0)
-}
-
-// updateLastPing updates the last time we tried contacting a remote node.
-func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
-}
-
-// bondTime retrieves the time of the last successful pong from remote node.
-func (db *nodeDB) bondTime(id NodeID) time.Time {
- return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
-}
-
-// hasBond reports whether the given node is considered bonded.
-func (db *nodeDB) hasBond(id NodeID) bool {
- return time.Since(db.bondTime(id)) < nodeDBNodeExpiration
-}
-
-// updateBondTime updates the last pong time of a node.
-func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
-}
-
-// findFails retrieves the number of findnode failures since bonding.
-func (db *nodeDB) findFails(id NodeID) int {
- return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails)))
-}
-
-// updateFindFails updates the number of findnode failures since bonding.
-func (db *nodeDB) updateFindFails(id NodeID, fails int) error {
- return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails))
-}
-
-// querySeeds retrieves random nodes to be used as potential seed nodes
-// for bootstrapping.
-func (db *nodeDB) querySeeds(n int, maxAge time.Duration) []*Node {
- var (
- now = time.Now()
- nodes = make([]*Node, 0, n)
- it = db.lvl.NewIterator(nil, nil)
- id NodeID
- )
- defer it.Release()
-
-seek:
- for seeks := 0; len(nodes) < n && seeks < n*5; seeks++ {
- // Seek to a random entry. The first byte is incremented by a
- // random amount each time in order to increase the likelihood
- // of hitting all existing nodes in very small databases.
- ctr := id[0]
- rand.Read(id[:])
- id[0] = ctr + id[0]%16
- it.Seek(makeKey(id, nodeDBDiscoverRoot))
-
- n := nextNode(it)
- if n == nil {
- id[0] = 0
- continue seek // iterator exhausted
- }
- if n.ID == db.self {
- continue seek
- }
- if now.Sub(db.bondTime(n.ID)) > maxAge {
- continue seek
- }
- for i := range nodes {
- if nodes[i].ID == n.ID {
- continue seek // duplicate
- }
- }
- nodes = append(nodes, n)
- }
- return nodes
-}
-
-// reads the next node record from the iterator, skipping over other
-// database entries.
-func nextNode(it iterator.Iterator) *Node {
- for end := false; !end; end = !it.Next() {
- id, field := splitKey(it.Key())
- if field != nodeDBDiscoverRoot {
- continue
- }
- var n Node
- if err := rlp.DecodeBytes(it.Value(), &n); err != nil {
- log.Warn("Failed to decode node RLP", "id", id, "err", err)
- continue
- }
- return &n
- }
- return nil
-}
-
-// close flushes and closes the database files.
-func (db *nodeDB) close() {
- close(db.quit)
- db.lvl.Close()
-}
diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go
deleted file mode 100644
index c4fa44d09..000000000
--- a/p2p/discover/database_test.go
+++ /dev/null
@@ -1,380 +0,0 @@
-// Copyright 2015 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package discover
-
-import (
- "bytes"
- "io/ioutil"
- "net"
- "os"
- "path/filepath"
- "reflect"
- "testing"
- "time"
-)
-
-var nodeDBKeyTests = []struct {
- id NodeID
- field string
- key []byte
-}{
- {
- id: NodeID{},
- field: "version",
- key: []byte{0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e}, // field
- },
- {
- id: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- field: ":discover",
- key: []byte{0x6e, 0x3a, // prefix
- 0x1d, 0xd9, 0xd6, 0x5c, 0x45, 0x52, 0xb5, 0xeb, // node id
- 0x43, 0xd5, 0xad, 0x55, 0xa2, 0xee, 0x3f, 0x56, //
- 0xc6, 0xcb, 0xc1, 0xc6, 0x4a, 0x5c, 0x8d, 0x65, //
- 0x9f, 0x51, 0xfc, 0xd5, 0x1b, 0xac, 0xe2, 0x43, //
- 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, //
- 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, //
- 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, //
- 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, //
- 0x3a, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x76, 0x65, 0x72, // field
- },
- },
-}
-
-func TestNodeDBKeys(t *testing.T) {
- for i, tt := range nodeDBKeyTests {
- if key := makeKey(tt.id, tt.field); !bytes.Equal(key, tt.key) {
- t.Errorf("make test %d: key mismatch: have 0x%x, want 0x%x", i, key, tt.key)
- }
- id, field := splitKey(tt.key)
- if !bytes.Equal(id[:], tt.id[:]) {
- t.Errorf("split test %d: id mismatch: have 0x%x, want 0x%x", i, id, tt.id)
- }
- if field != tt.field {
- t.Errorf("split test %d: field mismatch: have 0x%x, want 0x%x", i, field, tt.field)
- }
- }
-}
-
-var nodeDBInt64Tests = []struct {
- key []byte
- value int64
-}{
- {key: []byte{0x01}, value: 1},
- {key: []byte{0x02}, value: 2},
- {key: []byte{0x03}, value: 3},
-}
-
-func TestNodeDBInt64(t *testing.T) {
- db, _ := newNodeDB("", Version, NodeID{})
- defer db.close()
-
- tests := nodeDBInt64Tests
- for i := 0; i < len(tests); i++ {
- // Insert the next value
- if err := db.storeInt64(tests[i].key, tests[i].value); err != nil {
- t.Errorf("test %d: failed to store value: %v", i, err)
- }
- // Check all existing and non existing values
- for j := 0; j < len(tests); j++ {
- num := db.fetchInt64(tests[j].key)
- switch {
- case j <= i && num != tests[j].value:
- t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, tests[j].value)
- case j > i && num != 0:
- t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, 0)
- }
- }
- }
-}
-
-func TestNodeDBFetchStore(t *testing.T) {
- node := NewNode(
- MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{192, 168, 0, 1},
- 30303,
- 30303,
- )
- inst := time.Now()
- num := 314
-
- db, _ := newNodeDB("", Version, NodeID{})
- defer db.close()
-
- // Check fetch/store operations on a node ping object
- if stored := db.lastPing(node.ID); stored.Unix() != 0 {
- t.Errorf("ping: non-existing object: %v", stored)
- }
- if err := db.updateLastPing(node.ID, inst); err != nil {
- t.Errorf("ping: failed to update: %v", err)
- }
- if stored := db.lastPing(node.ID); stored.Unix() != inst.Unix() {
- t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
- }
- // Check fetch/store operations on a node pong object
- if stored := db.bondTime(node.ID); stored.Unix() != 0 {
- t.Errorf("pong: non-existing object: %v", stored)
- }
- if err := db.updateBondTime(node.ID, inst); err != nil {
- t.Errorf("pong: failed to update: %v", err)
- }
- if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() {
- t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
- }
- // Check fetch/store operations on a node findnode-failure object
- if stored := db.findFails(node.ID); stored != 0 {
- t.Errorf("find-node fails: non-existing object: %v", stored)
- }
- if err := db.updateFindFails(node.ID, num); err != nil {
- t.Errorf("find-node fails: failed to update: %v", err)
- }
- if stored := db.findFails(node.ID); stored != num {
- t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num)
- }
- // Check fetch/store operations on an actual node object
- if stored := db.node(node.ID); stored != nil {
- t.Errorf("node: non-existing object: %v", stored)
- }
- if err := db.updateNode(node); err != nil {
- t.Errorf("node: failed to update: %v", err)
- }
- if stored := db.node(node.ID); stored == nil {
- t.Errorf("node: not found")
- } else if !reflect.DeepEqual(stored, node) {
- t.Errorf("node: data mismatch: have %v, want %v", stored, node)
- }
-}
-
-var nodeDBSeedQueryNodes = []struct {
- node *Node
- pong time.Time
-}{
- // This one should not be in the result set because its last
- // pong time is too far in the past.
- {
- node: NewNode(
- MustHexID("0x84d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{127, 0, 0, 3},
- 30303,
- 30303,
- ),
- pong: time.Now().Add(-3 * time.Hour),
- },
- // This one shouldn't be in in the result set because its
- // nodeID is the local node's ID.
- {
- node: NewNode(
- MustHexID("0x57d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{127, 0, 0, 3},
- 30303,
- 30303,
- ),
- pong: time.Now().Add(-4 * time.Second),
- },
-
- // These should be in the result set.
- {
- node: NewNode(
- MustHexID("0x22d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{127, 0, 0, 1},
- 30303,
- 30303,
- ),
- pong: time.Now().Add(-2 * time.Second),
- },
- {
- node: NewNode(
- MustHexID("0x44d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{127, 0, 0, 2},
- 30303,
- 30303,
- ),
- pong: time.Now().Add(-3 * time.Second),
- },
- {
- node: NewNode(
- MustHexID("0xe2d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{127, 0, 0, 3},
- 30303,
- 30303,
- ),
- pong: time.Now().Add(-1 * time.Second),
- },
-}
-
-func TestNodeDBSeedQuery(t *testing.T) {
- db, _ := newNodeDB("", Version, nodeDBSeedQueryNodes[1].node.ID)
- defer db.close()
-
- // Insert a batch of nodes for querying
- for i, seed := range nodeDBSeedQueryNodes {
- if err := db.updateNode(seed.node); err != nil {
- t.Fatalf("node %d: failed to insert: %v", i, err)
- }
- if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to insert bondTime: %v", i, err)
- }
- }
-
- // Retrieve the entire batch and check for duplicates
- seeds := db.querySeeds(len(nodeDBSeedQueryNodes)*2, time.Hour)
- have := make(map[NodeID]struct{})
- for _, seed := range seeds {
- have[seed.ID] = struct{}{}
- }
- want := make(map[NodeID]struct{})
- for _, seed := range nodeDBSeedQueryNodes[2:] {
- want[seed.node.ID] = struct{}{}
- }
- if len(seeds) != len(want) {
- t.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(want))
- }
- for id := range have {
- if _, ok := want[id]; !ok {
- t.Errorf("extra seed: %v", id)
- }
- }
- for id := range want {
- if _, ok := have[id]; !ok {
- t.Errorf("missing seed: %v", id)
- }
- }
-}
-
-func TestNodeDBPersistency(t *testing.T) {
- root, err := ioutil.TempDir("", "nodedb-")
- if err != nil {
- t.Fatalf("failed to create temporary data folder: %v", err)
- }
- defer os.RemoveAll(root)
-
- var (
- testKey = []byte("somekey")
- testInt = int64(314)
- )
-
- // Create a persistent database and store some values
- db, err := newNodeDB(filepath.Join(root, "database"), Version, NodeID{})
- if err != nil {
- t.Fatalf("failed to create persistent database: %v", err)
- }
- if err := db.storeInt64(testKey, testInt); err != nil {
- t.Fatalf("failed to store value: %v.", err)
- }
- db.close()
-
- // Reopen the database and check the value
- db, err = newNodeDB(filepath.Join(root, "database"), Version, NodeID{})
- if err != nil {
- t.Fatalf("failed to open persistent database: %v", err)
- }
- if val := db.fetchInt64(testKey); val != testInt {
- t.Fatalf("value mismatch: have %v, want %v", val, testInt)
- }
- db.close()
-
- // Change the database version and check flush
- db, err = newNodeDB(filepath.Join(root, "database"), Version+1, NodeID{})
- if err != nil {
- t.Fatalf("failed to open persistent database: %v", err)
- }
- if val := db.fetchInt64(testKey); val != 0 {
- t.Fatalf("value mismatch: have %v, want %v", val, 0)
- }
- db.close()
-}
-
-var nodeDBExpirationNodes = []struct {
- node *Node
- pong time.Time
- exp bool
-}{
- {
- node: NewNode(
- MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{127, 0, 0, 1},
- 30303,
- 30303,
- ),
- pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute),
- exp: false,
- }, {
- node: NewNode(
- MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{127, 0, 0, 2},
- 30303,
- 30303,
- ),
- pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute),
- exp: true,
- },
-}
-
-func TestNodeDBExpiration(t *testing.T) {
- db, _ := newNodeDB("", Version, NodeID{})
- defer db.close()
-
- // Add all the test nodes and set their last pong time
- for i, seed := range nodeDBExpirationNodes {
- if err := db.updateNode(seed.node); err != nil {
- t.Fatalf("node %d: failed to insert: %v", i, err)
- }
- if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to update bondTime: %v", i, err)
- }
- }
- // Expire some of them, and check the rest
- if err := db.expireNodes(); err != nil {
- t.Fatalf("failed to expire nodes: %v", err)
- }
- for i, seed := range nodeDBExpirationNodes {
- node := db.node(seed.node.ID)
- if (node == nil && !seed.exp) || (node != nil && seed.exp) {
- t.Errorf("node %d: expiration mismatch: have %v, want %v", i, node, seed.exp)
- }
- }
-}
-
-func TestNodeDBSelfExpiration(t *testing.T) {
- // Find a node in the tests that shouldn't expire, and assign it as self
- var self NodeID
- for _, node := range nodeDBExpirationNodes {
- if !node.exp {
- self = node.node.ID
- break
- }
- }
- db, _ := newNodeDB("", Version, self)
- defer db.close()
-
- // Add all the test nodes and set their last pong time
- for i, seed := range nodeDBExpirationNodes {
- if err := db.updateNode(seed.node); err != nil {
- t.Fatalf("node %d: failed to insert: %v", i, err)
- }
- if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil {
- t.Fatalf("node %d: failed to update bondTime: %v", i, err)
- }
- }
- // Expire the nodes and make sure self has been evacuated too
- if err := db.expireNodes(); err != nil {
- t.Fatalf("failed to expire nodes: %v", err)
- }
- node := db.node(self)
- if node != nil {
- t.Errorf("self not evacuated")
- }
-}
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index 839f76227..9fe7bdb6d 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -18,415 +18,87 @@ package discover
import (
"crypto/ecdsa"
- "crypto/elliptic"
- "encoding/hex"
"errors"
- "fmt"
"math/big"
- "math/rand"
"net"
- "net/url"
- "regexp"
- "strconv"
- "strings"
"time"
- "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/crypto/secp256k1"
+ "github.com/tomochain/tomochain/p2p/enode"
)
-const NodeIDBits = 512
-
-// Node represents a host on the network.
+// node represents a host on the network.
// The fields of Node may not be modified.
-type Node struct {
- IP net.IP // len 4 for IPv4 or 16 for IPv6
- UDP, TCP uint16 // port numbers
- ID NodeID // the node's public key
-
- // This is a cached copy of sha3(ID) which is used for node
- // distance calculations. This is part of Node in order to make it
- // possible to write tests that need a node at a certain distance.
- // In those tests, the content of sha will not actually correspond
- // with ID.
- sha common.Hash
-
- // Time when the node was added to the table.
- addedAt time.Time
-}
-
-// NewNode creates a new node. It is mostly meant to be used for
-// testing purposes.
-func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
- if ipv4 := ip.To4(); ipv4 != nil {
- ip = ipv4
- }
- return &Node{
- IP: ip,
- UDP: udpPort,
- TCP: tcpPort,
- ID: id,
- sha: crypto.Keccak256Hash(id[:]),
- }
-}
-
-func (n *Node) addr() *net.UDPAddr {
- return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
-}
-
-// Incomplete returns true for nodes with no IP address.
-func (n *Node) Incomplete() bool {
- return n.IP == nil
-}
-
-// checks whether n is a valid complete node.
-func (n *Node) validateComplete() error {
- if n.Incomplete() {
- return errors.New("incomplete node")
- }
- if n.UDP == 0 {
- return errors.New("missing UDP port")
- }
- if n.TCP == 0 {
- return errors.New("missing TCP port")
- }
- if n.IP.IsMulticast() || n.IP.IsUnspecified() {
- return errors.New("invalid IP (multicast/unspecified)")
- }
- _, err := n.ID.Pubkey() // validate the key (on curve, etc.)
- return err
-}
-
-// The string representation of a Node is a URL.
-// Please see ParseNode for a description of the format.
-func (n *Node) String() string {
- u := url.URL{Scheme: "enode"}
- if n.Incomplete() {
- u.Host = fmt.Sprintf("%x", n.ID[:])
- } else {
- addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
- u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
- u.Host = addr.String()
- if n.UDP != n.TCP {
- u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
- }
- }
- return u.String()
-}
-
-var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
-
-// ParseNode parses a node designator.
-//
-// There are two basic forms of node designators
-// - incomplete nodes, which only have the public key (node ID)
-// - complete nodes, which contain the public key and IP/Port information
-//
-// For incomplete nodes, the designator must look like one of these
-//
-// enode://
-//
-//
-// For complete nodes, the node ID is encoded in the username portion
-// of the URL, separated from the host by an @ sign. The hostname can
-// only be given as an IP address, DNS domain names are not allowed.
-// The port in the host name section is the TCP listening port. If the
-// TCP and UDP (discovery) ports differ, the UDP port is specified as
-// query parameter "discport".
-//
-// In the following example, the node URL describes
-// a node with IP address 10.3.58.6, TCP listening port 30303
-// and UDP discovery port 30301.
-//
-// enode://@10.3.58.6:30303?discport=30301
-func ParseNode(rawurl string) (*Node, error) {
- if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
- id, err := HexID(m[1])
- if err != nil {
- return nil, fmt.Errorf("invalid node ID (%v)", err)
- }
- return NewNode(id, nil, 0, 0), nil
- }
- return parseComplete(rawurl)
+type node struct {
+ enode.Node
+ addedAt time.Time // time when the node was added to the table
}
-func parseComplete(rawurl string) (*Node, error) {
- var (
- id NodeID
- ip net.IP
- tcpPort, udpPort uint64
- )
- u, err := url.Parse(rawurl)
- if err != nil {
- return nil, err
- }
- if u.Scheme != "enode" {
- return nil, errors.New("invalid URL scheme, want \"enode\"")
- }
- // Parse the Node ID from the user portion.
- if u.User == nil {
- return nil, errors.New("does not contain node ID")
- }
- if id, err = HexID(u.User.String()); err != nil {
- return nil, fmt.Errorf("invalid node ID (%v)", err)
- }
- // Parse the IP address.
- host, port, err := net.SplitHostPort(u.Host)
- if err != nil {
- return nil, fmt.Errorf("invalid host: %v", err)
- }
- if ip = net.ParseIP(host); ip == nil {
- return nil, errors.New("invalid IP address")
- }
- // Ensure the IP is 4 bytes long for IPv4 addresses.
- if ipv4 := ip.To4(); ipv4 != nil {
- ip = ipv4
- }
- // Parse the port numbers.
- if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
- return nil, errors.New("invalid port")
- }
- udpPort = tcpPort
- qv := u.Query()
- if qv.Get("discport") != "" {
- udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
- if err != nil {
- return nil, errors.New("invalid discport in query")
- }
- }
- return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
-}
-
-// MustParseNode parses a node URL. It panics if the URL is not valid.
-func MustParseNode(rawurl string) *Node {
- n, err := ParseNode(rawurl)
- if err != nil {
- panic("invalid node URL: " + err.Error())
- }
- return n
-}
+type encPubkey [64]byte
-// MarshalText implements encoding.TextMarshaler.
-func (n *Node) MarshalText() ([]byte, error) {
- return []byte(n.String()), nil
+func encodePubkey(key *ecdsa.PublicKey) encPubkey {
+ var e encPubkey
+ math.ReadBits(key.X, e[:len(e)/2])
+ math.ReadBits(key.Y, e[len(e)/2:])
+ return e
}
-// UnmarshalText implements encoding.TextUnmarshaler.
-func (n *Node) UnmarshalText(text []byte) error {
- dec, err := ParseNode(string(text))
- if err == nil {
- *n = *dec
- }
- return err
-}
-
-// NodeID is a unique identifier for each node.
-// The node identifier is a marshaled elliptic curve public key.
-type NodeID [NodeIDBits / 8]byte
-
-// Bytes returns a byte slice representation of the NodeID
-func (n NodeID) Bytes() []byte {
- return n[:]
-}
-
-// NodeID prints as a long hexadecimal number.
-func (n NodeID) String() string {
- return fmt.Sprintf("%x", n[:])
-}
-
-// The Go syntax representation of a NodeID is a call to HexID.
-func (n NodeID) GoString() string {
- return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
-}
-
-// TerminalString returns a shortened hex string for terminal logging.
-func (n NodeID) TerminalString() string {
- return hex.EncodeToString(n[:8])
-}
-
-// MarshalText implements the encoding.TextMarshaler interface.
-func (n NodeID) MarshalText() ([]byte, error) {
- return []byte(hex.EncodeToString(n[:])), nil
-}
-
-// UnmarshalText implements the encoding.TextUnmarshaler interface.
-func (n *NodeID) UnmarshalText(text []byte) error {
- id, err := HexID(string(text))
- if err != nil {
- return err
- }
- *n = id
- return nil
-}
-
-// BytesID converts a byte slice to a NodeID
-func BytesID(b []byte) (NodeID, error) {
- var id NodeID
- if len(b) != len(id) {
- return id, fmt.Errorf("wrong length, want %d bytes", len(id))
- }
- copy(id[:], b)
- return id, nil
-}
-
-// MustBytesID converts a byte slice to a NodeID.
-// It panics if the byte slice is not a valid NodeID.
-func MustBytesID(b []byte) NodeID {
- id, err := BytesID(b)
- if err != nil {
- panic(err)
+func decodePubkey(e encPubkey) (*ecdsa.PublicKey, error) {
+ p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
+ half := len(e) / 2
+ p.X.SetBytes(e[:half])
+ p.Y.SetBytes(e[half:])
+ if !p.Curve.IsOnCurve(p.X, p.Y) {
+ return nil, errors.New("invalid secp256k1 curve point")
}
- return id
+ return p, nil
}
-// HexID converts a hex string to a NodeID.
-// The string may be prefixed with 0x.
-func HexID(in string) (NodeID, error) {
- var id NodeID
- b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
- if err != nil {
- return id, err
- } else if len(b) != len(id) {
- return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
- }
- copy(id[:], b)
- return id, nil
+func (e encPubkey) id() enode.ID {
+ return enode.ID(crypto.Keccak256Hash(e[:]))
}
-// MustHexID converts a hex string to a NodeID.
-// It panics if the string is not a valid NodeID.
-func MustHexID(in string) NodeID {
- id, err := HexID(in)
+// recoverNodeKey computes the public key used to sign the
+// given hash from the signature.
+func recoverNodeKey(hash, sig []byte) (key encPubkey, err error) {
+ pubkey, err := secp256k1.RecoverPubkey(hash, sig)
if err != nil {
- panic(err)
+ return key, err
}
- return id
+ copy(key[:], pubkey[1:])
+ return key, nil
}
-// PubkeyID returns a marshaled representation of the given public key.
-func PubkeyID(pub *ecdsa.PublicKey) NodeID {
- var id NodeID
- pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
- if len(pbytes)-1 != len(id) {
- panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
- }
- copy(id[:], pbytes[1:])
- return id
+func wrapNode(n *enode.Node) *node {
+ return &node{Node: *n}
}
-// Pubkey returns the public key represented by the node ID.
-// It returns an error if the ID is not a point on the curve.
-func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
- p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
- half := len(id) / 2
- p.X.SetBytes(id[:half])
- p.Y.SetBytes(id[half:])
- if !p.Curve.IsOnCurve(p.X, p.Y) {
- return nil, errors.New("id is invalid secp256k1 curve point")
+func wrapNodes(ns []*enode.Node) []*node {
+ result := make([]*node, len(ns))
+ for i, n := range ns {
+ result[i] = wrapNode(n)
}
- return p, nil
+ return result
}
-// recoverNodeID computes the public key used to sign the
-// given hash from the signature.
-func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
- pubkey, err := secp256k1.RecoverPubkey(hash, sig)
- if err != nil {
- return id, err
- }
- if len(pubkey)-1 != len(id) {
- return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
- }
- for i := range id {
- id[i] = pubkey[i+1]
- }
- return id, nil
+func unwrapNode(n *node) *enode.Node {
+ return &n.Node
}
-// distcmp compares the distances a->target and b->target.
-// Returns -1 if a is closer to target, 1 if b is closer to target
-// and 0 if they are equal.
-func distcmp(target, a, b common.Hash) int {
- for i := range target {
- da := a[i] ^ target[i]
- db := b[i] ^ target[i]
- if da > db {
- return 1
- } else if da < db {
- return -1
- }
+func unwrapNodes(ns []*node) []*enode.Node {
+ result := make([]*enode.Node, len(ns))
+ for i, n := range ns {
+ result[i] = unwrapNode(n)
}
- return 0
+ return result
}
-// table of leading zero counts for bytes [0..255]
-var lzcount = [256]int{
- 8, 7, 6, 6, 5, 5, 5, 5,
- 4, 4, 4, 4, 4, 4, 4, 4,
- 3, 3, 3, 3, 3, 3, 3, 3,
- 3, 3, 3, 3, 3, 3, 3, 3,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
+func (n *node) addr() *net.UDPAddr {
+ return &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
}
-// logdist returns the logarithmic distance between a and b, log2(a ^ b).
-func logdist(a, b common.Hash) int {
- lz := 0
- for i := range a {
- x := a[i] ^ b[i]
- if x == 0 {
- lz += 8
- } else {
- lz += lzcount[x]
- break
- }
- }
- return len(a)*8 - lz
-}
-
-// hashAtDistance returns a random hash such that logdist(a, b) == n
-func hashAtDistance(a common.Hash, n int) (b common.Hash) {
- if n == 0 {
- return a
- }
- // flip bit at position n, fill the rest with random bits
- b = a
- pos := len(a) - n/8 - 1
- bit := byte(0x01) << (byte(n%8) - 1)
- if bit == 0 {
- pos++
- bit = 0x80
- }
- b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
- for i := pos + 1; i < len(a); i++ {
- b[i] = byte(rand.Intn(255))
- }
- return b
+func (n *node) String() string {
+ return n.Node.String()
}
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go
deleted file mode 100644
index 8e3da2c2a..000000000
--- a/p2p/discover/node_test.go
+++ /dev/null
@@ -1,335 +0,0 @@
-// Copyright 2015 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package discover
-
-import (
- "bytes"
- "fmt"
- "math/big"
- "math/rand"
- "net"
- "reflect"
- "strings"
- "testing"
- "testing/quick"
- "time"
-
- "github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/crypto"
-)
-
-func ExampleNewNode() {
- id := MustHexID("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439")
-
- // Complete nodes contain UDP and TCP endpoints:
- n1 := NewNode(id, net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), 52150, 30303)
- fmt.Println("n1:", n1)
- fmt.Println("n1.Incomplete() ->", n1.Incomplete())
-
- // An incomplete node can be created by passing zero values
- // for all parameters except id.
- n2 := NewNode(id, nil, 0, 0)
- fmt.Println("n2:", n2)
- fmt.Println("n2.Incomplete() ->", n2.Incomplete())
-
- // Output:
- // n1: enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[2001:db8:3c4d:15::abcd:ef12]:30303?discport=52150
- // n1.Incomplete() -> false
- // n2: enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439
- // n2.Incomplete() -> true
-}
-
-var parseNodeTests = []struct {
- rawurl string
- wantError string
- wantResult *Node
-}{
- {
- rawurl: "http://foobar",
- wantError: `invalid URL scheme, want "enode"`,
- },
- {
- rawurl: "enode://01010101@123.124.125.126:3",
- wantError: `invalid node ID (wrong length, want 128 hex chars)`,
- },
- // Complete nodes with IP address.
- {
- rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
- wantError: `invalid IP address`,
- },
- //{
- // rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
- // wantError: `parse enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo: invalid port ":foo" after host`,
- //},
- {
- rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
- wantError: `invalid discport in query`,
- },
- {
- rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
- wantResult: NewNode(
- MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{0x7f, 0x0, 0x0, 0x1},
- 52150,
- 52150,
- ),
- },
- {
- rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
- wantResult: NewNode(
- MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.ParseIP("::"),
- 52150,
- 52150,
- ),
- },
- {
- rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[2001:db8:3c4d:15::abcd:ef12]:52150",
- wantResult: NewNode(
- MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.ParseIP("2001:db8:3c4d:15::abcd:ef12"),
- 52150,
- 52150,
- ),
- },
- {
- rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=22334",
- wantResult: NewNode(
- MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- net.IP{0x7f, 0x0, 0x0, 0x1},
- 22334,
- 52150,
- ),
- },
- // Incomplete nodes with no address.
- {
- rawurl: "1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439",
- wantResult: NewNode(
- MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- nil, 0, 0,
- ),
- },
- {
- rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439",
- wantResult: NewNode(
- MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- nil, 0, 0,
- ),
- },
- // Invalid URLs
- {
- rawurl: "01010101",
- wantError: `invalid node ID (wrong length, want 128 hex chars)`,
- },
- {
- rawurl: "enode://01010101",
- wantError: `invalid node ID (wrong length, want 128 hex chars)`,
- },
- {
- // This test checks that errors from url.Parse are handled.
- rawurl: "://foo",
- wantError: `parse ://foo: missing protocol scheme`,
- },
-}
-
-func TestParseNode(t *testing.T) {
- for _, test := range parseNodeTests {
- n, err := ParseNode(test.rawurl)
- if test.wantError != "" {
- if err == nil {
- t.Errorf("test %q:\n got nil error, expected %#q", test.rawurl, test.wantError)
- continue
- } else if err.Error() != test.wantError {
- t.Errorf("test %q:\n got error %#q, expected %#q", test.rawurl, err.Error(), test.wantError)
- continue
- }
- } else {
- if err != nil {
- t.Errorf("test %q:\n unexpected error: %v", test.rawurl, err)
- continue
- }
- if !reflect.DeepEqual(n, test.wantResult) {
- t.Errorf("test %q:\n result mismatch:\ngot: %#v, want: %#v", test.rawurl, n, test.wantResult)
- }
- }
- }
-}
-
-func TestNodeString(t *testing.T) {
- for i, test := range parseNodeTests {
- if test.wantError == "" && strings.HasPrefix(test.rawurl, "enode://") {
- str := test.wantResult.String()
- if str != test.rawurl {
- t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
- }
- }
- }
-}
-
-func TestHexID(t *testing.T) {
- ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
- id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
- id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
-
- if id1 != ref {
- t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
- }
- if id2 != ref {
- t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
- }
-}
-
-func TestNodeID_textEncoding(t *testing.T) {
- ref := NodeID{
- 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10,
- 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x20,
- 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x30,
- 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x40,
- 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x50,
- 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x60,
- 0x61, 0x62, 0x63, 0x64,
- }
- hex := "01020304050607080910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364"
-
- text, err := ref.MarshalText()
- if err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(text, []byte(hex)) {
- t.Fatalf("text encoding did not match\nexpected: %s\ngot: %s", hex, text)
- }
-
- id := new(NodeID)
- if err := id.UnmarshalText(text); err != nil {
- t.Fatal(err)
- }
- if *id != ref {
- t.Fatalf("text decoding did not match\nexpected: %s\ngot: %s", ref, id)
- }
-}
-
-func TestNodeID_recover(t *testing.T) {
- prv := newkey()
- hash := make([]byte, 32)
- sig, err := crypto.Sign(hash, prv)
- if err != nil {
- t.Fatalf("signing error: %v", err)
- }
-
- pub := PubkeyID(&prv.PublicKey)
- recpub, err := recoverNodeID(hash, sig)
- if err != nil {
- t.Fatalf("recovery error: %v", err)
- }
- if pub != recpub {
- t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
- }
-
- ecdsa, err := pub.Pubkey()
- if err != nil {
- t.Errorf("Pubkey error: %v", err)
- }
- if !reflect.DeepEqual(ecdsa, &prv.PublicKey) {
- t.Errorf("Pubkey mismatch:\n got: %#v\n want: %#v", ecdsa, &prv.PublicKey)
- }
-}
-
-func TestNodeID_pubkeyBad(t *testing.T) {
- ecdsa, err := NodeID{}.Pubkey()
- if err == nil {
- t.Error("expected error for zero ID")
- }
- if ecdsa != nil {
- t.Error("expected nil result")
- }
-}
-
-func TestNodeID_distcmp(t *testing.T) {
- distcmpBig := func(target, a, b common.Hash) int {
- tbig := new(big.Int).SetBytes(target[:])
- abig := new(big.Int).SetBytes(a[:])
- bbig := new(big.Int).SetBytes(b[:])
- return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
- }
- if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg()); err != nil {
- t.Error(err)
- }
-}
-
-// the random tests is likely to miss the case where they're equal.
-func TestNodeID_distcmpEqual(t *testing.T) {
- base := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
- x := common.Hash{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
- if distcmp(base, x, x) != 0 {
- t.Errorf("distcmp(base, x, x) != 0")
- }
-}
-
-func TestNodeID_logdist(t *testing.T) {
- logdistBig := func(a, b common.Hash) int {
- abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
- return new(big.Int).Xor(abig, bbig).BitLen()
- }
- if err := quick.CheckEqual(logdist, logdistBig, quickcfg()); err != nil {
- t.Error(err)
- }
-}
-
-// the random tests is likely to miss the case where they're equal.
-func TestNodeID_logdistEqual(t *testing.T) {
- x := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
- if logdist(x, x) != 0 {
- t.Errorf("logdist(x, x) != 0")
- }
-}
-
-func TestNodeID_hashAtDistance(t *testing.T) {
- // we don't use quick.Check here because its output isn't
- // very helpful when the test fails.
- cfg := quickcfg()
- for i := 0; i < cfg.MaxCount; i++ {
- a := gen(common.Hash{}, cfg.Rand).(common.Hash)
- dist := cfg.Rand.Intn(len(common.Hash{}) * 8)
- result := hashAtDistance(a, dist)
- actualdist := logdist(result, a)
-
- if dist != actualdist {
- t.Log("a: ", a)
- t.Log("result:", result)
- t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
- }
- }
-}
-
-func quickcfg() *quick.Config {
- return &quick.Config{
- MaxCount: 5000,
- Rand: rand.New(rand.NewSource(time.Now().Unix())),
- }
-}
-
-// TODO: The Generate method can be dropped when we require Go >= 1.5
-// because testing/quick learned to generate arrays in 1.5.
-
-func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
- var id NodeID
- m := rand.Intn(len(id))
- for i := len(id) - 1; i > m; i-- {
- id[i] = byte(rand.Uint32())
- }
- return reflect.ValueOf(id)
-}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 6fdd2cfd1..729ab1d0e 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -23,9 +23,9 @@
package discover
import (
+ "crypto/ecdsa"
crand "crypto/rand"
"encoding/binary"
- "errors"
"fmt"
mrand "math/rand"
"net"
@@ -36,13 +36,14 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/netutil"
)
const (
- alpha = 3 // Kademlia concurrency factor
- bucketSize = 200 // Kademlia bucket size
- maxReplacements = 10 // Size of per-bucket replacement list
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ maxReplacements = 10 // Size of per-bucket replacement list
// We keep buckets for the upper 1/15 of distances because
// it's very unlikely we'll ever encounter a node that's closer.
@@ -54,76 +55,56 @@ const (
bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
tableIPLimit, tableSubnet = 10, 24
- maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
- maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
-
- refreshInterval = 30 * time.Minute
- revalidateInterval = 10 * time.Second
- copyNodesInterval = 30 * time.Second
- seedMinTableTime = 5 * time.Minute
- seedCount = 30
- seedMaxAge = 5 * 24 * time.Hour
+ maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
+ refreshInterval = 30 * time.Minute
+ revalidateInterval = 10 * time.Second
+ copyNodesInterval = 30 * time.Second
+ seedMinTableTime = 5 * time.Minute
+ seedCount = 30
+ seedMaxAge = 5 * 24 * time.Hour
)
type Table struct {
mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance
- nursery []*Node // bootstrap nodes
+ nursery []*node // bootstrap nodes
rand *mrand.Rand // source of randomness, periodically reseeded
ips netutil.DistinctNetSet
- db *nodeDB // database of known nodes
+ db *enode.DB // database of known nodes
refreshReq chan chan struct{}
initDone chan struct{}
closeReq chan struct{}
closed chan struct{}
- bondmu sync.Mutex
- bonding map[NodeID]*bondproc
- bondslots chan struct{} // limits total number of active bonding processes
-
- nodeAddedHook func(*Node) // for testing
+ nodeAddedHook func(*node) // for testing
net transport
- self *Node // metadata of the local node
-}
-
-type bondproc struct {
- err error
- n *Node
- done chan struct{}
+ self *node // metadata of the local node
}
// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
type transport interface {
- ping(NodeID, *net.UDPAddr) error
- waitping(NodeID) error
- findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
+ ping(enode.ID, *net.UDPAddr) error
+ findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error)
close()
}
// bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries.
type bucket struct {
- entries []*Node // live entries, sorted by time of last contact
- replacements []*Node // recently seen nodes to be used if revalidation fails
+ entries []*node // live entries, sorted by time of last contact
+ replacements []*node // recently seen nodes to be used if revalidation fails
ips netutil.DistinctNetSet
}
-func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
- // If no node database was given, use an in-memory one
- db, err := newNodeDB(nodeDBPath, Version, ourID)
- if err != nil {
- return nil, err
- }
+func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
tab := &Table{
net: t,
db: db,
- self: NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)),
- bonding: make(map[NodeID]*bondproc),
- bondslots: make(chan struct{}, maxBondingPingPongs),
+ self: wrapNode(self),
refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}),
closeReq: make(chan struct{}),
@@ -134,20 +115,14 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
if err := tab.setFallbackNodes(bootnodes); err != nil {
return nil, err
}
- for i := 0; i < cap(tab.bondslots); i++ {
- tab.bondslots <- struct{}{}
- }
for i := range tab.buckets {
tab.buckets[i] = &bucket{
ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
}
}
tab.seedRand()
- tab.loadSeedNodes(false)
- // Start the background expiration goroutine after loading seeds so that the search for
- // seed nodes also considers older nodes that would otherwise be removed by the
- // expiration.
- tab.db.ensureExpirer()
+ tab.loadSeedNodes()
+
go tab.loop()
return tab, nil
}
@@ -162,15 +137,13 @@ func (tab *Table) seedRand() {
}
// Self returns the local node.
-// The returned node should not be modified by the caller.
-func (tab *Table) Self() *Node {
- return tab.self
+func (tab *Table) Self() *enode.Node {
+ return unwrapNode(tab.self)
}
-// ReadRandomNodes fills the given slice with random nodes from the
-// table. It will not write the same node more than once. The nodes in
-// the slice are copies and can be modified by the caller.
-func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
+// ReadRandomNodes fills the given slice with random nodes from the table. The results
+// are guaranteed to be unique for a single invocation, no node will appear twice.
+func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
if !tab.isInitDone() {
return 0
}
@@ -178,10 +151,10 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
defer tab.mutex.Unlock()
// Find all non-empty buckets and get a fresh slice of their entries.
- var buckets [][]*Node
- for _, b := range tab.buckets {
+ var buckets [][]*node
+ for _, b := range &tab.buckets {
if len(b.entries) > 0 {
- buckets = append(buckets, b.entries[:])
+ buckets = append(buckets, b.entries)
}
}
if len(buckets) == 0 {
@@ -196,7 +169,7 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
var i, j int
for ; i < len(buf); i, j = i+1, (j+1)%len(buckets) {
b := buckets[j]
- buf[i] = &(*b[0])
+ buf[i] = unwrapNode(b[0])
buckets[j] = b[1:]
if len(b) == 1 {
buckets = append(buckets[:j], buckets[j+1:]...)
@@ -221,20 +194,13 @@ func (tab *Table) Close() {
// setFallbackNodes sets the initial points of contact. These nodes
// are used to connect to the network if the table is empty and there
// are no known nodes in the database.
-func (tab *Table) setFallbackNodes(nodes []*Node) error {
+func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
for _, n := range nodes {
- if err := n.validateComplete(); err != nil {
- return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
+ if err := n.ValidateComplete(); err != nil {
+ return fmt.Errorf("bad bootstrap node %q: %v", n, err)
}
}
- tab.nursery = make([]*Node, 0, len(nodes))
- for _, n := range nodes {
- cpy := *n
- // Recompute cpy.sha because the node might not have been
- // created by NewNode or ParseNode.
- cpy.sha = crypto.Keccak256Hash(n.ID[:])
- tab.nursery = append(tab.nursery, &cpy)
- }
+ tab.nursery = wrapNodes(nodes)
return nil
}
@@ -250,47 +216,48 @@ func (tab *Table) isInitDone() bool {
// Resolve searches for a specific node with the given ID.
// It returns nil if the node could not be found.
-func (tab *Table) Resolve(targetID NodeID) *Node {
+func (tab *Table) Resolve(n *enode.Node) *enode.Node {
// If the node is present in the local table, no
// network interaction is required.
- hash := crypto.Keccak256Hash(targetID[:])
+ hash := n.ID()
tab.mutex.Lock()
cl := tab.closest(hash, 1)
tab.mutex.Unlock()
- if len(cl.entries) > 0 && cl.entries[0].ID == targetID {
- return cl.entries[0]
+ if len(cl.entries) > 0 && cl.entries[0].ID() == hash {
+ return unwrapNode(cl.entries[0])
}
// Otherwise, do a network lookup.
- result := tab.Lookup(targetID)
+ result := tab.lookup(encodePubkey(n.Pubkey()), true)
for _, n := range result {
- if n.ID == targetID {
- return n
+ if n.ID() == hash {
+ return unwrapNode(n)
}
}
return nil
}
-// Lookup performs a network search for nodes close
-// to the given target. It approaches the target by querying
-// nodes that are closer to it on each iteration.
-// The given target does not need to be an actual node
-// identifier.
-func (tab *Table) Lookup(targetID NodeID) []*Node {
- return tab.lookup(targetID, true)
+// LookupRandom finds random nodes in the network.
+func (tab *Table) LookupRandom() []*enode.Node {
+ var target encPubkey
+ crand.Read(target[:])
+ return unwrapNodes(tab.lookup(target, true))
}
-func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
+// lookup performs a network search for nodes close to the given target. It approaches the
+// target by querying nodes that are closer to it on each iteration. The given target does
+// not need to be an actual node identifier.
+func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
var (
- target = crypto.Keccak256Hash(targetID[:])
- asked = make(map[NodeID]bool)
- seen = make(map[NodeID]bool)
- reply = make(chan []*Node, alpha)
+ target = enode.ID(crypto.Keccak256Hash(targetKey[:]))
+ asked = make(map[enode.ID]bool)
+ seen = make(map[enode.ID]bool)
+ reply = make(chan []*node, alpha)
pendingQueries = 0
result *nodesByDistance
)
// don't query further if we hit ourself.
// unlikely to happen often in practice.
- asked[tab.self.ID] = true
+ asked[tab.self.ID()] = true
for {
tab.mutex.Lock()
@@ -312,25 +279,10 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
// ask the alpha closest nodes that we haven't asked yet
for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
n := result.entries[i]
- if !asked[n.ID] {
- asked[n.ID] = true
+ if !asked[n.ID()] {
+ asked[n.ID()] = true
pendingQueries++
- go func() {
- // Find potential neighbors to bond with
- r, err := tab.net.findnode(n.ID, n.addr(), targetID)
- if err != nil {
- // Bump the failure counter to detect and evacuate non-bonded entries
- fails := tab.db.findFails(n.ID) + 1
- tab.db.updateFindFails(n.ID, fails)
- log.Trace("Bumping findnode failure counter", "id", n.ID, "failcount", fails)
-
- if fails >= maxFindnodeFailures {
- log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails)
- tab.delete(n)
- }
- }
- reply <- tab.bondall(r)
- }()
+ go tab.findnode(n, targetKey, reply)
}
}
if pendingQueries == 0 {
@@ -339,8 +291,8 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
}
// wait for the next reply
for _, n := range <-reply {
- if n != nil && !seen[n.ID] {
- seen[n.ID] = true
+ if n != nil && !seen[n.ID()] {
+ seen[n.ID()] = true
result.push(n, bucketSize)
}
}
@@ -349,6 +301,29 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
return result.entries
}
+func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) {
+ fails := tab.db.FindFails(n.ID(), n.IP())
+ r, err := tab.net.findnode(n.ID(), n.addr(), targetKey)
+ if err != nil || len(r) == 0 {
+ fails++
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
+ log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
+ if fails >= maxFindnodeFailures {
+ log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
+ tab.delete(n)
+ }
+ } else if fails > 0 {
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1)
+ }
+
+ // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
+ // just remove those again during revalidation.
+ for _, n := range r {
+ tab.add(n)
+ }
+ reply <- r
+}
+
func (tab *Table) refresh() <-chan struct{} {
done := make(chan struct{})
select {
@@ -401,7 +376,7 @@ loop:
case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime())
case <-copyNodes.C:
- go tab.copyBondedNodes()
+ go tab.copyLiveNodes()
case <-tab.closeReq:
break loop
}
@@ -416,7 +391,6 @@ loop:
for _, ch := range waiting {
close(ch)
}
- tab.db.close()
close(tab.closed)
}
@@ -429,10 +403,14 @@ func (tab *Table) doRefresh(done chan struct{}) {
// Load nodes from the database and insert
// them. This should yield a few previously seen nodes that are
// (hopefully) still alive.
- tab.loadSeedNodes(true)
+ tab.loadSeedNodes()
// Run self lookup to discover new neighbor nodes.
- tab.lookup(tab.self.ID, false)
+ // We can only do this if we have a secp256k1 identity.
+ var key ecdsa.PublicKey
+ if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil {
+ tab.lookup(encodePubkey(&key), false)
+ }
// The Kademlia paper specifies that the bucket refresh should
// perform a lookup in the least recently used bucket. We cannot
@@ -441,22 +419,19 @@ func (tab *Table) doRefresh(done chan struct{}) {
// sha3 preimage that falls into a chosen bucket.
// We perform a few lookups with a random target instead.
for i := 0; i < 3; i++ {
- var target NodeID
+ var target encPubkey
crand.Read(target[:])
tab.lookup(target, false)
}
}
-func (tab *Table) loadSeedNodes(bond bool) {
- seeds := tab.db.querySeeds(seedCount, seedMaxAge)
+func (tab *Table) loadSeedNodes() {
+ seeds := wrapNodes(tab.db.QuerySeeds(seedCount, seedMaxAge))
seeds = append(seeds, tab.nursery...)
- if bond {
- seeds = tab.bondall(seeds)
- }
for i := range seeds {
seed := seeds[i]
- age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }}
- log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }}
+ log.Debug("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age)
tab.add(seed)
}
}
@@ -473,28 +448,28 @@ func (tab *Table) doRevalidate(done chan<- struct{}) {
}
// Ping the selected node and wait for a pong.
- err := tab.ping(last.ID, last.addr())
+ err := tab.net.ping(last.ID(), last.addr())
tab.mutex.Lock()
defer tab.mutex.Unlock()
b := tab.buckets[bi]
if err == nil {
// The node responded, move it to the front.
- log.Debug("Revalidated node", "b", bi, "id", last.ID)
+ log.Debug("Revalidated node", "b", bi, "id", last.ID())
b.bump(last)
return
}
// No reply received, pick a replacement or delete the node if there aren't
// any replacements.
if r := tab.replace(b, last); r != nil {
- log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
+ log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "r", r.ID(), "rip", r.IP())
} else {
- log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
+ log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP())
}
}
// nodeToRevalidate returns the last node in a random, non-empty bucket.
-func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
+func (tab *Table) nodeToRevalidate() (n *node, bi int) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
@@ -515,17 +490,17 @@ func (tab *Table) nextRevalidateTime() time.Duration {
return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
}
-// copyBondedNodes adds nodes from the table to the database if they have been in the table
+// copyLiveNodes adds nodes from the table to the database if they have been in the table
// longer then minTableTime.
-func (tab *Table) copyBondedNodes() {
+func (tab *Table) copyLiveNodes() {
tab.mutex.Lock()
defer tab.mutex.Unlock()
now := time.Now()
- for _, b := range tab.buckets {
+ for _, b := range &tab.buckets {
for _, n := range b.entries {
if now.Sub(n.addedAt) >= seedMinTableTime {
- tab.db.updateNode(n)
+ tab.db.UpdateNode(unwrapNode(n))
}
}
}
@@ -533,12 +508,12 @@ func (tab *Table) copyBondedNodes() {
// closest returns the n nodes in the table that are closest to the
// given id. The caller must hold tab.mutex.
-func (tab *Table) closest(target common.Hash, nresults int) *nodesByDistance {
+func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance {
// This is a very wasteful way to find the closest nodes but
// obviously correct. I believe that tree-based buckets would make
// this easier to implement efficiently.
close := &nodesByDistance{target: target}
- for _, b := range tab.buckets {
+ for _, b := range &tab.buckets {
for _, n := range b.entries {
close.push(n, nresults)
}
@@ -547,176 +522,76 @@ func (tab *Table) closest(target common.Hash, nresults int) *nodesByDistance {
}
func (tab *Table) len() (n int) {
- for _, b := range tab.buckets {
+ for _, b := range &tab.buckets {
n += len(b.entries)
}
return n
}
-// bondall bonds with all given nodes concurrently and returns
-// those nodes for which bonding has probably succeeded.
-func (tab *Table) bondall(nodes []*Node) (result []*Node) {
- rc := make(chan *Node, len(nodes))
- for i := range nodes {
- go func(n *Node) {
- nn, _ := tab.bond(false, n.ID, n.addr(), n.TCP)
- rc <- nn
- }(nodes[i])
- }
- for range nodes {
- if n := <-rc; n != nil {
- result = append(result, n)
- }
- }
- return result
-}
-
-// bond ensures the local node has a bond with the given remote node.
-// It also attempts to insert the node into the table if bonding succeeds.
-// The caller must not hold tab.mutex.
-//
-// A bond is must be established before sending findnode requests.
-// Both sides must have completed a ping/pong exchange for a bond to
-// exist. The total number of active bonding processes is limited in
-// order to restrain network use.
-//
-// bond is meant to operate idempotently in that bonding with a remote
-// node which still remembers a previously established bond will work.
-// The remote node will simply not send a ping back, causing waitping
-// to time out.
-//
-// If pinged is true, the remote node has just pinged us and one half
-// of the process can be skipped.
-func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
- if id == tab.self.ID {
- return nil, errors.New("is self")
- }
- if pinged && !tab.isInitDone() {
- return nil, errors.New("still initializing")
- }
- // Start bonding if we haven't seen this node for a while or if it failed findnode too often.
- node, fails := tab.db.node(id), tab.db.findFails(id)
- age := time.Since(tab.db.bondTime(id))
- var result error
- if fails > 0 || age > nodeDBNodeExpiration {
- log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
-
- tab.bondmu.Lock()
- w := tab.bonding[id]
- if w != nil {
- // Wait for an existing bonding process to complete.
- tab.bondmu.Unlock()
- <-w.done
- } else {
- // Register a new bonding process.
- w = &bondproc{done: make(chan struct{})}
- tab.bonding[id] = w
- tab.bondmu.Unlock()
- // Do the ping/pong. The result goes into w.
- tab.pingpong(w, pinged, id, addr, tcpPort)
- // Unregister the process after it's done.
- tab.bondmu.Lock()
- delete(tab.bonding, id)
- tab.bondmu.Unlock()
- }
- // Retrieve the bonding results
- result = w.err
- if result == nil {
- node = w.n
- }
- }
- // Add the node to the table even if the bonding ping/pong
- // fails. It will be relaced quickly if it continues to be
- // unresponsive.
- if node != nil {
- tab.add(node)
- tab.db.updateFindFails(id, 0)
- }
- return node, result
-}
-
-func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
- // Request a bonding slot to limit network usage
- <-tab.bondslots
- defer func() { tab.bondslots <- struct{}{} }()
-
- // Ping the remote side and wait for a pong.
- if w.err = tab.ping(id, addr); w.err != nil {
- close(w.done)
- return
- }
- if !pinged {
- // Give the remote node a chance to ping us before we start
- // sending findnode requests. If they still remember us,
- // waitping will simply time out.
- tab.net.waitping(id)
- }
- // Bonding succeeded, update the node database.
- w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
- close(w.done)
-}
-
-// ping a remote endpoint and wait for a reply, also updating the node
-// database accordingly.
-func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
- tab.db.updateLastPing(id, time.Now())
- if err := tab.net.ping(id, addr); err != nil {
- return err
- }
- tab.db.updateBondTime(id, time.Now())
- return nil
-}
-
// bucket returns the bucket for the given node ID hash.
-func (tab *Table) bucket(sha common.Hash) *bucket {
- d := logdist(tab.self.sha, sha)
+func (tab *Table) bucket(id enode.ID) *bucket {
+ d := enode.LogDist(tab.self.ID(), id)
if d <= bucketMinDistance {
return tab.buckets[0]
}
return tab.buckets[d-bucketMinDistance-1]
}
-// add attempts to add the given node its corresponding bucket. If the
-// bucket has space available, adding the node succeeds immediately.
-// Otherwise, the node is added if the least recently active node in
-// the bucket does not respond to a ping packet.
+// add attempts to add the given node to its corresponding bucket. If the bucket has space
+// available, adding the node succeeds immediately. Otherwise, the node is added if the
+// least recently active node in the bucket does not respond to a ping packet.
//
// The caller must not hold tab.mutex.
-func (tab *Table) add(new *Node) {
+func (tab *Table) add(n *node) {
+ if n.ID() == tab.self.ID() {
+ return
+ }
+
tab.mutex.Lock()
defer tab.mutex.Unlock()
-
- b := tab.bucket(new.sha)
- if !tab.bumpOrAdd(b, new) {
+ b := tab.bucket(n.ID())
+ if !tab.bumpOrAdd(b, n) {
// Node is not in table. Add it to the replacement list.
- tab.addReplacement(b, new)
+ tab.addReplacement(b, n)
+ }
+}
+
+// addThroughPing adds the given node to the table. Compared to plain
+// 'add' there is an additional safety measure: if the table is still
+// initializing the node is not added. This prevents an attack where the
+// table could be filled by just sending ping repeatedly.
+//
+// The caller must not hold tab.mutex.
+func (tab *Table) addThroughPing(n *node) {
+ if !tab.isInitDone() {
+ return
}
+ tab.add(n)
}
// stuff adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must not hold tab.mutex.
-func (tab *Table) stuff(nodes []*Node) {
+func (tab *Table) stuff(nodes []*node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, n := range nodes {
- if n.ID == tab.self.ID {
+ if n.ID() == tab.self.ID() {
continue // don't add self
}
- b := tab.bucket(n.sha)
+ b := tab.bucket(n.ID())
if len(b.entries) < bucketSize {
tab.bumpOrAdd(b, n)
}
}
}
-// delete removes an entry from the node table (used to evacuate
-// failed/non-bonded discovery peers).
-func (tab *Table) delete(node *Node) {
+// delete removes an entry from the node table. It is used to evacuate dead nodes.
+func (tab *Table) delete(node *node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
- tab.deleteInBucket(tab.bucket(node.sha), node)
+ tab.deleteInBucket(tab.bucket(node.ID()), node)
}
func (tab *Table) addIP(b *bucket, ip net.IP) bool {
@@ -743,27 +618,27 @@ func (tab *Table) removeIP(b *bucket, ip net.IP) {
b.ips.Remove(ip)
}
-func (tab *Table) addReplacement(b *bucket, n *Node) {
+func (tab *Table) addReplacement(b *bucket, n *node) {
for _, e := range b.replacements {
- if e.ID == n.ID {
+ if e.ID() == n.ID() {
return // already in list
}
}
- if !tab.addIP(b, n.IP) {
+ if !tab.addIP(b, n.IP()) {
return
}
- var removed *Node
+ var removed *node
b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
if removed != nil {
- tab.removeIP(b, removed.IP)
+ tab.removeIP(b, removed.IP())
}
}
// replace removes n from the replacement list and replaces 'last' with it if it is the
// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced
// with someone else or became active.
-func (tab *Table) replace(b *bucket, last *Node) *Node {
- if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID {
+func (tab *Table) replace(b *bucket, last *node) *node {
+ if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID() != last.ID() {
// Entry has moved, don't replace it.
return nil
}
@@ -775,15 +650,15 @@ func (tab *Table) replace(b *bucket, last *Node) *Node {
r := b.replacements[tab.rand.Intn(len(b.replacements))]
b.replacements = deleteNode(b.replacements, r)
b.entries[len(b.entries)-1] = r
- tab.removeIP(b, last.IP)
+ tab.removeIP(b, last.IP())
return r
}
// bump moves the given node to the front of the bucket entry list
// if it is contained in that list.
-func (b *bucket) bump(n *Node) bool {
+func (b *bucket) bump(n *node) bool {
for i := range b.entries {
- if b.entries[i].ID == n.ID {
+ if b.entries[i].ID() == n.ID() {
// move it to the front
copy(b.entries[1:], b.entries[:i])
b.entries[0] = n
@@ -795,11 +670,11 @@ func (b *bucket) bump(n *Node) bool {
// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't
// full. The return value is true if n is in the bucket.
-func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
+func (tab *Table) bumpOrAdd(b *bucket, n *node) bool {
if b.bump(n) {
return true
}
- if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
+ if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP()) {
return false
}
b.entries, _ = pushNode(b.entries, n, bucketSize)
@@ -811,13 +686,13 @@ func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
return true
}
-func (tab *Table) deleteInBucket(b *bucket, n *Node) {
+func (tab *Table) deleteInBucket(b *bucket, n *node) {
b.entries = deleteNode(b.entries, n)
- tab.removeIP(b, n.IP)
+ tab.removeIP(b, n.IP())
}
// pushNode adds n to the front of list, keeping at most max items.
-func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
+func pushNode(list []*node, n *node, max int) ([]*node, *node) {
if len(list) < max {
list = append(list, nil)
}
@@ -828,9 +703,9 @@ func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
}
// deleteNode removes n from list.
-func deleteNode(list []*Node, n *Node) []*Node {
+func deleteNode(list []*node, n *node) []*node {
for i := range list {
- if list[i].ID == n.ID {
+ if list[i].ID() == n.ID() {
return append(list[:i], list[i+1:]...)
}
}
@@ -840,14 +715,14 @@ func deleteNode(list []*Node, n *Node) []*Node {
// nodesByDistance is a list of nodes, ordered by
// distance to target.
type nodesByDistance struct {
- entries []*Node
- target common.Hash
+ entries []*node
+ target enode.ID
}
// push adds the given node to the list, keeping the total size below maxElems.
-func (h *nodesByDistance) push(n *Node, maxElems int) {
+func (h *nodesByDistance) push(n *node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool {
- return distcmp(h.target, h.entries[i].sha, n.sha) > 0
+ return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0
})
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index b81b0bfde..388baf6da 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -20,7 +20,6 @@ import (
"crypto/ecdsa"
"fmt"
"math/rand"
- "sync"
"net"
"reflect"
@@ -28,8 +27,9 @@ import (
"testing/quick"
"time"
- "github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/p2p/enode"
+ "github.com/tomochain/tomochain/p2p/enr"
)
func TestTable_pingReplace(t *testing.T) {
@@ -49,30 +49,27 @@ func TestTable_pingReplace(t *testing.T) {
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
transport := newPingRecorder()
- tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ tab, db := newTestTable(transport)
defer tab.Close()
+ defer db.Close()
- // Wait for init so bond is accepted.
<-tab.initDone
- // fill up the sender's bucket.
- pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
+ // Fill up the sender's bucket.
+ pingKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8")
+ pingSender := wrapNode(enode.NewV4(&pingKey.PublicKey, net.IP{}, 99, 99))
last := fillBucket(tab, pingSender)
- // this call to bond should replace the last node
- // in its bucket if the node is not responding.
- transport.dead[last.ID] = !lastInBucketIsResponding
- transport.dead[pingSender.ID] = !newNodeIsResponding
- tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
+ // Add the sender as if it just pinged us. Revalidate should replace the last node in
+ // its bucket if it is unresponsive. Revalidate again to ensure that
+ transport.dead[last.ID()] = !lastInBucketIsResponding
+ transport.dead[pingSender.ID()] = !newNodeIsResponding
+ tab.add(pingSender)
+ tab.doRevalidate(make(chan struct{}, 1))
tab.doRevalidate(make(chan struct{}, 1))
- // first ping goes to sender (bonding pingback)
- if !transport.pinged[pingSender.ID] {
- t.Error("table did not ping back sender")
- }
- if !transport.pinged[last.ID] {
- // second ping goes to oldest node in bucket
- // to see whether it is still alive.
+ if !transport.pinged[last.ID()] {
+ // Oldest node in bucket is pinged to see whether it is still alive.
t.Error("table did not ping last node in bucket")
}
@@ -82,14 +79,14 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
if !lastInBucketIsResponding && !newNodeIsResponding {
wantSize--
}
- if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
+ if l := len(tab.bucket(pingSender.ID()).entries); l != wantSize {
t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
}
- if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
+ if found := contains(tab.bucket(pingSender.ID()).entries, last.ID()); found != lastInBucketIsResponding {
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
}
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
- if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
+ if found := contains(tab.bucket(pingSender.ID()).entries, pingSender.ID()); found != wantNewEntry {
t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
}
}
@@ -102,9 +99,9 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
Values: func(args []reflect.Value, rand *rand.Rand) {
// generate a random list of nodes. this will be the content of the bucket.
n := rand.Intn(bucketSize-1) + 1
- nodes := make([]*Node, n)
+ nodes := make([]*node, n)
for i := range nodes {
- nodes[i] = nodeAtDistance(common.Hash{}, 200)
+ nodes[i] = nodeAtDistance(enode.ID{}, 200, intIP(200))
}
args[0] = reflect.ValueOf(nodes)
// generate random bump positions.
@@ -116,8 +113,8 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
},
}
- prop := func(nodes []*Node, bumps []int) (ok bool) {
- b := &bucket{entries: make([]*Node, len(nodes))}
+ prop := func(nodes []*node, bumps []int) (ok bool) {
+ b := &bucket{entries: make([]*node, len(nodes))}
copy(b.entries, nodes)
for i, pos := range bumps {
b.bump(b.entries[pos])
@@ -139,12 +136,12 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
// This checks that the table-wide IP limit is applied correctly.
func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder()
- tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ tab, db := newTestTable(transport)
defer tab.Close()
+ defer db.Close()
for i := 0; i < tableIPLimit+1; i++ {
- n := nodeAtDistance(tab.self.sha, i)
- n.IP = net.IP{172, 0, 1, byte(i)}
+ n := nodeAtDistance(tab.self.ID(), i, net.IP{172, 0, 1, byte(i)})
tab.add(n)
}
if tab.len() > tableIPLimit {
@@ -152,16 +149,16 @@ func TestTable_IPLimit(t *testing.T) {
}
}
-// This checks that the table-wide IP limit is applied correctly.
+// This checks that the per-bucket IP limit is applied correctly.
func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder()
- tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ tab, db := newTestTable(transport)
defer tab.Close()
+ defer db.Close()
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
- n := nodeAtDistance(tab.self.sha, d)
- n.IP = net.IP{172, 0, 1, byte(i)}
+ n := nodeAtDistance(tab.self.ID(), d, net.IP{172, 0, 1, byte(i)})
tab.add(n)
}
if tab.len() > bucketIPLimit {
@@ -169,70 +166,18 @@ func TestTable_BucketIPLimit(t *testing.T) {
}
}
-// fillBucket inserts nodes into the given bucket until
-// it is full. The node's IDs dont correspond to their
-// hashes.
-func fillBucket(tab *Table, n *Node) (last *Node) {
- ld := logdist(tab.self.sha, n.sha)
- b := tab.bucket(n.sha)
- for len(b.entries) < bucketSize {
- b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
- }
- return b.entries[bucketSize-1]
-}
-
-// nodeAtDistance creates a node for which logdist(base, n.sha) == ld.
-// The node's ID does not correspond to n.sha.
-func nodeAtDistance(base common.Hash, ld int) (n *Node) {
- n = new(Node)
- n.sha = hashAtDistance(base, ld)
- n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
- copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
- return n
-}
-
-type pingRecorder struct {
- mu sync.Mutex
- dead, pinged map[NodeID]bool
-}
-
-func newPingRecorder() *pingRecorder {
- return &pingRecorder{
- dead: make(map[NodeID]bool),
- pinged: make(map[NodeID]bool),
- }
-}
-
-func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
- return nil, nil
-}
-func (t *pingRecorder) close() {}
-func (t *pingRecorder) waitping(from NodeID) error {
- return nil // remote always pings
-}
-func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
- t.mu.Lock()
- defer t.mu.Unlock()
-
- t.pinged[toid] = true
- if t.dead[toid] {
- return errTimeout
- } else {
- return nil
- }
-}
-
func TestTable_closest(t *testing.T) {
t.Parallel()
test := func(test *closeTest) bool {
// for any node table, Target and N
transport := newPingRecorder()
- tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
+ tab, db := newTestTable(transport)
defer tab.Close()
+ defer db.Close()
tab.stuff(test.All)
- // check that doClosest(Target, N) returns nodes
+ // check that closest(Target, N) returns nodes
result := tab.closest(test.Target, test.N).entries
if hasDuplicates(result) {
t.Errorf("result contains duplicates")
@@ -258,15 +203,15 @@ func TestTable_closest(t *testing.T) {
// check that the result nodes have minimum distance to target.
for _, b := range tab.buckets {
for _, n := range b.entries {
- if contains(result, n.ID) {
+ if contains(result, n.ID()) {
continue // don't run the check below for nodes in result
}
- farthestResult := result[len(result)-1].sha
- if distcmp(test.Target, n.sha, farthestResult) < 0 {
+ farthestResult := result[len(result)-1].ID()
+ if enode.DistCmp(test.Target, n.ID(), farthestResult) < 0 {
t.Errorf("table contains node that is closer to target but it's not in result")
t.Logf(" Target: %v", test.Target)
t.Logf(" Farthest Result: %v", farthestResult)
- t.Logf(" ID: %v", n.ID)
+ t.Logf(" ID: %v", n.ID())
return false
}
}
@@ -283,25 +228,26 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
MaxCount: 200,
Rand: rand.New(rand.NewSource(time.Now().Unix())),
Values: func(args []reflect.Value, rand *rand.Rand) {
- args[0] = reflect.ValueOf(make([]*Node, rand.Intn(1000)))
+ args[0] = reflect.ValueOf(make([]*enode.Node, rand.Intn(1000)))
},
}
- test := func(buf []*Node) bool {
+ test := func(buf []*enode.Node) bool {
transport := newPingRecorder()
- tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
+ tab, db := newTestTable(transport)
defer tab.Close()
+ defer db.Close()
<-tab.initDone
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
- tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
+ tab.stuff([]*node{nodeAtDistance(tab.self.ID(), ld, intIP(ld))})
}
gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() {
t.Errorf("wrong number of nodes, got %d, want %d", gotN, tab.len())
return false
}
- if hasDuplicates(buf[:gotN]) {
+ if hasDuplicates(wrapNodes(buf[:gotN])) {
t.Errorf("result contains duplicates")
return false
}
@@ -313,302 +259,304 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
}
type closeTest struct {
- Self NodeID
- Target common.Hash
- All []*Node
+ Self enode.ID
+ Target enode.ID
+ All []*node
N int
}
func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
t := &closeTest{
- Self: gen(NodeID{}, rand).(NodeID),
- Target: gen(common.Hash{}, rand).(common.Hash),
+ Self: gen(enode.ID{}, rand).(enode.ID),
+ Target: gen(enode.ID{}, rand).(enode.ID),
N: rand.Intn(bucketSize),
}
- for _, id := range gen([]NodeID{}, rand).([]NodeID) {
- t.All = append(t.All, &Node{ID: id})
+ for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
+ n := enode.SignNull(new(enr.Record), id)
+ t.All = append(t.All, wrapNode(n))
}
return reflect.ValueOf(t)
}
-//func TestTable_Lookup(t *testing.T) {
-// bucketSizeTest := 16
-// self := nodeAtDistance(common.Hash{}, 0)
-// tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
-// defer tab.Close()
-//
-// // lookup on empty table returns no nodes
-// if results := tab.Lookup(lookupTestnet.target); len(results) > 0 {
-// t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
-// }
-// // seed table with initial node (otherwise lookup will terminate immediately)
-// seed := NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 0)
-// tab.stuff([]*Node{seed})
-//
-// results := tab.Lookup(lookupTestnet.target)
-// t.Logf("results:")
-// for _, e := range results {
-// t.Logf(" ld=%d, %x", logdist(lookupTestnet.targetSha, e.sha), e.sha[:])
-// }
-// if len(results) != bucketSizeTest {
-// t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSizeTest)
-// }
-// if hasDuplicates(results) {
-// t.Errorf("result set contains duplicate entries")
-// }
-// if !sortedByDistanceTo(lookupTestnet.targetSha, results) {
-// t.Errorf("result set not sorted by distance to target")
-// }
-// // TODO: check result nodes are actually closest
-//}
+func TestTable_Lookup(t *testing.T) {
+ tab, db := newTestTable(lookupTestnet)
+ defer tab.Close()
+ defer db.Close()
+
+ // lookup on empty table returns no nodes
+ if results := tab.lookup(lookupTestnet.target, false); len(results) > 0 {
+ t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
+ }
+ // seed table with initial node (otherwise lookup will terminate immediately)
+ seedKey, _ := decodePubkey(lookupTestnet.dists[256][0])
+ seed := wrapNode(enode.NewV4(seedKey, net.IP{}, 0, 256))
+ tab.stuff([]*node{seed})
+
+ results := tab.lookup(lookupTestnet.target, true)
+ t.Logf("results:")
+ for _, e := range results {
+ t.Logf(" ld=%d, %x", enode.LogDist(lookupTestnet.targetSha, e.ID()), e.ID().Bytes())
+ }
+ if len(results) != bucketSize {
+ t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
+ }
+ if hasDuplicates(results) {
+ t.Errorf("result set contains duplicate entries")
+ }
+ if !sortedByDistanceTo(lookupTestnet.targetSha, results) {
+ t.Errorf("result set not sorted by distance to target")
+ }
+ // TODO: check result nodes are actually closest
+}
// This is the test network for the Lookup test.
// The nodes were obtained by running testnet.mine with a random NodeID as target.
var lookupTestnet = &preminedTestnet{
- target: MustHexID("166aea4f556532c6d34e8b740e5d314af7e9ac0ca79833bd751d6b665f12dfd38ec563c363b32f02aef4a80b44fd3def94612d497b99cb5f17fd24de454927ec"),
- targetSha: common.Hash{0x5c, 0x94, 0x4e, 0xe5, 0x1c, 0x5a, 0xe9, 0xf7, 0x2a, 0x95, 0xec, 0xcb, 0x8a, 0xed, 0x3, 0x74, 0xee, 0xcb, 0x51, 0x19, 0xd7, 0x20, 0xcb, 0xea, 0x68, 0x13, 0xe8, 0xe0, 0xd6, 0xad, 0x92, 0x61},
- dists: [257][]NodeID{
+ target: hexEncPubkey("166aea4f556532c6d34e8b740e5d314af7e9ac0ca79833bd751d6b665f12dfd38ec563c363b32f02aef4a80b44fd3def94612d497b99cb5f17fd24de454927ec"),
+ targetSha: enode.HexID("5c944ee51c5ae9f72a95eccb8aed0374eecb5119d720cbea6813e8e0d6ad9261"),
+ dists: [257][]encPubkey{
240: {
- MustHexID("2001ad5e3e80c71b952161bc0186731cf5ffe942d24a79230a0555802296238e57ea7a32f5b6f18564eadc1c65389448481f8c9338df0a3dbd18f708cbc2cbcb"),
- MustHexID("6ba3f4f57d084b6bf94cc4555b8c657e4a8ac7b7baf23c6874efc21dd1e4f56b7eb2721e07f5242d2f1d8381fc8cae535e860197c69236798ba1ad231b105794"),
+ hexEncPubkey("2001ad5e3e80c71b952161bc0186731cf5ffe942d24a79230a0555802296238e57ea7a32f5b6f18564eadc1c65389448481f8c9338df0a3dbd18f708cbc2cbcb"),
+ hexEncPubkey("6ba3f4f57d084b6bf94cc4555b8c657e4a8ac7b7baf23c6874efc21dd1e4f56b7eb2721e07f5242d2f1d8381fc8cae535e860197c69236798ba1ad231b105794"),
},
244: {
- MustHexID("696ba1f0a9d55c59246f776600542a9e6432490f0cd78f8bb55a196918df2081a9b521c3c3ba48e465a75c10768807717f8f689b0b4adce00e1c75737552a178"),
+ hexEncPubkey("696ba1f0a9d55c59246f776600542a9e6432490f0cd78f8bb55a196918df2081a9b521c3c3ba48e465a75c10768807717f8f689b0b4adce00e1c75737552a178"),
},
246: {
- MustHexID("d6d32178bdc38416f46ffb8b3ec9e4cb2cfff8d04dd7e4311a70e403cb62b10be1b447311b60b4f9ee221a8131fc2cbd45b96dd80deba68a949d467241facfa8"),
- MustHexID("3ea3d04a43a3dfb5ac11cffc2319248cf41b6279659393c2f55b8a0a5fc9d12581a9d97ef5d8ff9b5abf3321a290e8f63a4f785f450dc8a672aba3ba2ff4fdab"),
- MustHexID("2fc897f05ae585553e5c014effd3078f84f37f9333afacffb109f00ca8e7a3373de810a3946be971cbccdfd40249f9fe7f322118ea459ac71acca85a1ef8b7f4"),
+ hexEncPubkey("d6d32178bdc38416f46ffb8b3ec9e4cb2cfff8d04dd7e4311a70e403cb62b10be1b447311b60b4f9ee221a8131fc2cbd45b96dd80deba68a949d467241facfa8"),
+ hexEncPubkey("3ea3d04a43a3dfb5ac11cffc2319248cf41b6279659393c2f55b8a0a5fc9d12581a9d97ef5d8ff9b5abf3321a290e8f63a4f785f450dc8a672aba3ba2ff4fdab"),
+ hexEncPubkey("2fc897f05ae585553e5c014effd3078f84f37f9333afacffb109f00ca8e7a3373de810a3946be971cbccdfd40249f9fe7f322118ea459ac71acca85a1ef8b7f4"),
},
247: {
- MustHexID("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"),
- MustHexID("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"),
- MustHexID("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"),
- MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"),
- MustHexID("8b58c6073dd98bbad4e310b97186c8f822d3a5c7d57af40e2136e88e315afd115edb27d2d0685a908cfe5aa49d0debdda6e6e63972691d6bd8c5af2d771dd2a9"),
- MustHexID("2cbb718b7dc682da19652e7d9eb4fefaf7b7147d82c1c2b6805edf77b85e29fde9f6da195741467ff2638dc62c8d3e014ea5686693c15ed0080b6de90354c137"),
- MustHexID("e84027696d3f12f2de30a9311afea8fbd313c2360daff52bb5fc8c7094d5295758bec3134e4eef24e4cdf377b40da344993284628a7a346eba94f74160998feb"),
- MustHexID("f1357a4f04f9d33753a57c0b65ba20a5d8777abbffd04e906014491c9103fb08590e45548d37aa4bd70965e2e81ddba94f31860348df01469eec8c1829200a68"),
- MustHexID("4ab0a75941b12892369b4490a1928c8ca52a9ad6d3dffbd1d8c0b907bc200fe74c022d011ec39b64808a39c0ca41f1d3254386c3e7733e7044c44259486461b6"),
- MustHexID("d45150a72dc74388773e68e03133a3b5f51447fe91837d566706b3c035ee4b56f160c878c6273394daee7f56cc398985269052f22f75a8057df2fe6172765354"),
+ hexEncPubkey("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"),
+ hexEncPubkey("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"),
+ hexEncPubkey("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"),
+ hexEncPubkey("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"),
+ hexEncPubkey("8b58c6073dd98bbad4e310b97186c8f822d3a5c7d57af40e2136e88e315afd115edb27d2d0685a908cfe5aa49d0debdda6e6e63972691d6bd8c5af2d771dd2a9"),
+ hexEncPubkey("2cbb718b7dc682da19652e7d9eb4fefaf7b7147d82c1c2b6805edf77b85e29fde9f6da195741467ff2638dc62c8d3e014ea5686693c15ed0080b6de90354c137"),
+ hexEncPubkey("e84027696d3f12f2de30a9311afea8fbd313c2360daff52bb5fc8c7094d5295758bec3134e4eef24e4cdf377b40da344993284628a7a346eba94f74160998feb"),
+ hexEncPubkey("f1357a4f04f9d33753a57c0b65ba20a5d8777abbffd04e906014491c9103fb08590e45548d37aa4bd70965e2e81ddba94f31860348df01469eec8c1829200a68"),
+ hexEncPubkey("4ab0a75941b12892369b4490a1928c8ca52a9ad6d3dffbd1d8c0b907bc200fe74c022d011ec39b64808a39c0ca41f1d3254386c3e7733e7044c44259486461b6"),
+ hexEncPubkey("d45150a72dc74388773e68e03133a3b5f51447fe91837d566706b3c035ee4b56f160c878c6273394daee7f56cc398985269052f22f75a8057df2fe6172765354"),
},
248: {
- MustHexID("6aadfce366a189bab08ac84721567483202c86590642ea6d6a14f37ca78d82bdb6509eb7b8b2f6f63c78ae3ae1d8837c89509e41497d719b23ad53dd81574afa"),
- MustHexID("a605ecfd6069a4cf4cf7f5840e5bc0ce10d23a3ac59e2aaa70c6afd5637359d2519b4524f56fc2ca180cdbebe54262f720ccaae8c1b28fd553c485675831624d"),
- MustHexID("29701451cb9448ca33fc33680b44b840d815be90146eb521641efbffed0859c154e8892d3906eae9934bfacee72cd1d2fa9dd050fd18888eea49da155ab0efd2"),
- MustHexID("3ed426322dee7572b08592e1e079f8b6c6b30e10e6243edd144a6a48fdbdb83df73a6e41b1143722cb82604f2203a32758610b5d9544f44a1a7921ba001528c1"),
- MustHexID("b2e2a2b7fdd363572a3256e75435fab1da3b16f7891a8bd2015f30995dae665d7eabfd194d87d99d5df628b4bbc7b04e5b492c596422dd8272746c7a1b0b8e4f"),
- MustHexID("0c69c9756162c593e85615b814ce57a2a8ca2df6c690b9c4e4602731b61e1531a3bbe3f7114271554427ffabea80ad8f36fa95a49fa77b675ae182c6ccac1728"),
- MustHexID("8d28be21d5a97b0876442fa4f5e5387f5bf3faad0b6f13b8607b64d6e448c0991ca28dd7fe2f64eb8eadd7150bff5d5666aa6ed868b84c71311f4ba9a38569dd"),
- MustHexID("2c677e1c64b9c9df6359348a7f5f33dc79e22f0177042486d125f8b6ca7f0dc756b1f672aceee5f1746bcff80aaf6f92a8dc0c9fbeb259b3fa0da060de5ab7e8"),
- MustHexID("3994880f94a8678f0cd247a43f474a8af375d2a072128da1ad6cae84a244105ff85e94fc7d8496f639468de7ee998908a91c7e33ef7585fff92e984b210941a1"),
- MustHexID("b45a9153c08d002a48090d15d61a7c7dad8c2af85d4ff5bd36ce23a9a11e0709bf8d56614c7b193bc028c16cbf7f20dfbcc751328b64a924995d47b41e452422"),
- MustHexID("057ab3a9e53c7a84b0f3fc586117a525cdd18e313f52a67bf31798d48078e325abe5cfee3f6c2533230cb37d0549289d692a29dd400e899b8552d4b928f6f907"),
- MustHexID("0ddf663d308791eb92e6bd88a2f8cb45e4f4f35bb16708a0e6ff7f1362aa6a73fedd0a1b1557fb3365e38e1b79d6918e2fae2788728b70c9ab6b51a3b94a4338"),
- MustHexID("f637e07ff50cc1e3731735841c4798411059f2023abcf3885674f3e8032531b0edca50fd715df6feb489b6177c345374d64f4b07d257a7745de393a107b013a5"),
- MustHexID("e24ec7c6eec094f63c7b3239f56d311ec5a3e45bc4e622a1095a65b95eea6fe13e29f3b6b7a2cbfe40906e3989f17ac834c3102dd0cadaaa26e16ee06d782b72"),
- MustHexID("b76ea1a6fd6506ef6e3506a4f1f60ed6287fff8114af6141b2ff13e61242331b54082b023cfea5b3083354a4fb3f9eb8be01fb4a518f579e731a5d0707291a6b"),
- MustHexID("9b53a37950ca8890ee349b325032d7b672cab7eced178d3060137b24ef6b92a43977922d5bdfb4a3409a2d80128e02f795f9dae6d7d99973ad0e23a2afb8442f"),
+ hexEncPubkey("6aadfce366a189bab08ac84721567483202c86590642ea6d6a14f37ca78d82bdb6509eb7b8b2f6f63c78ae3ae1d8837c89509e41497d719b23ad53dd81574afa"),
+ hexEncPubkey("a605ecfd6069a4cf4cf7f5840e5bc0ce10d23a3ac59e2aaa70c6afd5637359d2519b4524f56fc2ca180cdbebe54262f720ccaae8c1b28fd553c485675831624d"),
+ hexEncPubkey("29701451cb9448ca33fc33680b44b840d815be90146eb521641efbffed0859c154e8892d3906eae9934bfacee72cd1d2fa9dd050fd18888eea49da155ab0efd2"),
+ hexEncPubkey("3ed426322dee7572b08592e1e079f8b6c6b30e10e6243edd144a6a48fdbdb83df73a6e41b1143722cb82604f2203a32758610b5d9544f44a1a7921ba001528c1"),
+ hexEncPubkey("b2e2a2b7fdd363572a3256e75435fab1da3b16f7891a8bd2015f30995dae665d7eabfd194d87d99d5df628b4bbc7b04e5b492c596422dd8272746c7a1b0b8e4f"),
+ hexEncPubkey("0c69c9756162c593e85615b814ce57a2a8ca2df6c690b9c4e4602731b61e1531a3bbe3f7114271554427ffabea80ad8f36fa95a49fa77b675ae182c6ccac1728"),
+ hexEncPubkey("8d28be21d5a97b0876442fa4f5e5387f5bf3faad0b6f13b8607b64d6e448c0991ca28dd7fe2f64eb8eadd7150bff5d5666aa6ed868b84c71311f4ba9a38569dd"),
+ hexEncPubkey("2c677e1c64b9c9df6359348a7f5f33dc79e22f0177042486d125f8b6ca7f0dc756b1f672aceee5f1746bcff80aaf6f92a8dc0c9fbeb259b3fa0da060de5ab7e8"),
+ hexEncPubkey("3994880f94a8678f0cd247a43f474a8af375d2a072128da1ad6cae84a244105ff85e94fc7d8496f639468de7ee998908a91c7e33ef7585fff92e984b210941a1"),
+ hexEncPubkey("b45a9153c08d002a48090d15d61a7c7dad8c2af85d4ff5bd36ce23a9a11e0709bf8d56614c7b193bc028c16cbf7f20dfbcc751328b64a924995d47b41e452422"),
+ hexEncPubkey("057ab3a9e53c7a84b0f3fc586117a525cdd18e313f52a67bf31798d48078e325abe5cfee3f6c2533230cb37d0549289d692a29dd400e899b8552d4b928f6f907"),
+ hexEncPubkey("0ddf663d308791eb92e6bd88a2f8cb45e4f4f35bb16708a0e6ff7f1362aa6a73fedd0a1b1557fb3365e38e1b79d6918e2fae2788728b70c9ab6b51a3b94a4338"),
+ hexEncPubkey("f637e07ff50cc1e3731735841c4798411059f2023abcf3885674f3e8032531b0edca50fd715df6feb489b6177c345374d64f4b07d257a7745de393a107b013a5"),
+ hexEncPubkey("e24ec7c6eec094f63c7b3239f56d311ec5a3e45bc4e622a1095a65b95eea6fe13e29f3b6b7a2cbfe40906e3989f17ac834c3102dd0cadaaa26e16ee06d782b72"),
+ hexEncPubkey("b76ea1a6fd6506ef6e3506a4f1f60ed6287fff8114af6141b2ff13e61242331b54082b023cfea5b3083354a4fb3f9eb8be01fb4a518f579e731a5d0707291a6b"),
+ hexEncPubkey("9b53a37950ca8890ee349b325032d7b672cab7eced178d3060137b24ef6b92a43977922d5bdfb4a3409a2d80128e02f795f9dae6d7d99973ad0e23a2afb8442f"),
},
249: {
- MustHexID("675ae65567c3c72c50c73bc0fd4f61f202ea5f93346ca57b551de3411ccc614fad61cb9035493af47615311b9d44ee7a161972ee4d77c28fe1ec029d01434e6a"),
- MustHexID("8eb81408389da88536ae5800392b16ef5109d7ea132c18e9a82928047ecdb502693f6e4a4cdd18b54296caf561db937185731456c456c98bfe7de0baf0eaa495"),
- MustHexID("2adba8b1612a541771cb93a726a38a4b88e97b18eced2593eb7daf82f05a5321ca94a72cc780c306ff21e551a932fc2c6d791e4681907b5ceab7f084c3fa2944"),
- MustHexID("b1b4bfbda514d9b8f35b1c28961da5d5216fe50548f4066f69af3b7666a3b2e06eac646735e963e5c8f8138a2fb95af15b13b23ff00c6986eccc0efaa8ee6fb4"),
- MustHexID("d2139281b289ad0e4d7b4243c4364f5c51aac8b60f4806135de06b12b5b369c9e43a6eb494eab860d115c15c6fbb8c5a1b0e382972e0e460af395b8385363de7"),
- MustHexID("4a693df4b8fc5bdc7cec342c3ed2e228d7c5b4ab7321ddaa6cccbeb45b05a9f1d95766b4002e6d4791c2deacb8a667aadea6a700da28a3eea810a30395701bbc"),
- MustHexID("ab41611195ec3c62bb8cd762ee19fb182d194fd141f4a66780efbef4b07ce916246c022b841237a3a6b512a93431157edd221e854ed2a259b72e9c5351f44d0c"),
- MustHexID("68e8e26099030d10c3c703ae7045c0a48061fb88058d853b3e67880014c449d4311014da99d617d3150a20f1a3da5e34bf0f14f1c51fe4dd9d58afd222823176"),
- MustHexID("3fbcacf546fb129cd70fc48de3b593ba99d3c473798bc309292aca280320e0eacc04442c914cad5c4cf6950345ba79b0d51302df88285d4e83ee3fe41339eee7"),
- MustHexID("1d4a623659f7c8f80b6c3939596afdf42e78f892f682c768ad36eb7bfba402dbf97aea3a268f3badd8fe7636be216edf3d67ee1e08789ebbc7be625056bd7109"),
- MustHexID("a283c474ab09da02bbc96b16317241d0627646fcc427d1fe790b76a7bf1989ced90f92101a973047ae9940c92720dffbac8eff21df8cae468a50f72f9e159417"),
- MustHexID("dbf7e5ad7f87c3dfecae65d87c3039e14ed0bdc56caf00ce81931073e2e16719d746295512ff7937a15c3b03603e7c41a4f9df94fcd37bb200dd8f332767e9cb"),
- MustHexID("caaa070a26692f64fc77f30d7b5ae980d419b4393a0f442b1c821ef58c0862898b0d22f74a4f8c5d83069493e3ec0b92f17dc1fe6e4cd437c1ec25039e7ce839"),
- MustHexID("874cc8d1213beb65c4e0e1de38ef5d8165235893ac74ab5ea937c885eaab25c8d79dad0456e9fd3e9450626cac7e107b004478fb59842f067857f39a47cee695"),
- MustHexID("d94193f236105010972f5df1b7818b55846592a0445b9cdc4eaed811b8c4c0f7c27dc8cc9837a4774656d6b34682d6d329d42b6ebb55da1d475c2474dc3dfdf4"),
- MustHexID("edd9af6aded4094e9785637c28fccbd3980cbe28e2eb9a411048a23c2ace4bd6b0b7088a7817997b49a3dd05fc6929ca6c7abbb69438dbdabe65e971d2a794b2"),
+ hexEncPubkey("675ae65567c3c72c50c73bc0fd4f61f202ea5f93346ca57b551de3411ccc614fad61cb9035493af47615311b9d44ee7a161972ee4d77c28fe1ec029d01434e6a"),
+ hexEncPubkey("8eb81408389da88536ae5800392b16ef5109d7ea132c18e9a82928047ecdb502693f6e4a4cdd18b54296caf561db937185731456c456c98bfe7de0baf0eaa495"),
+ hexEncPubkey("2adba8b1612a541771cb93a726a38a4b88e97b18eced2593eb7daf82f05a5321ca94a72cc780c306ff21e551a932fc2c6d791e4681907b5ceab7f084c3fa2944"),
+ hexEncPubkey("b1b4bfbda514d9b8f35b1c28961da5d5216fe50548f4066f69af3b7666a3b2e06eac646735e963e5c8f8138a2fb95af15b13b23ff00c6986eccc0efaa8ee6fb4"),
+ hexEncPubkey("d2139281b289ad0e4d7b4243c4364f5c51aac8b60f4806135de06b12b5b369c9e43a6eb494eab860d115c15c6fbb8c5a1b0e382972e0e460af395b8385363de7"),
+ hexEncPubkey("4a693df4b8fc5bdc7cec342c3ed2e228d7c5b4ab7321ddaa6cccbeb45b05a9f1d95766b4002e6d4791c2deacb8a667aadea6a700da28a3eea810a30395701bbc"),
+ hexEncPubkey("ab41611195ec3c62bb8cd762ee19fb182d194fd141f4a66780efbef4b07ce916246c022b841237a3a6b512a93431157edd221e854ed2a259b72e9c5351f44d0c"),
+ hexEncPubkey("68e8e26099030d10c3c703ae7045c0a48061fb88058d853b3e67880014c449d4311014da99d617d3150a20f1a3da5e34bf0f14f1c51fe4dd9d58afd222823176"),
+ hexEncPubkey("3fbcacf546fb129cd70fc48de3b593ba99d3c473798bc309292aca280320e0eacc04442c914cad5c4cf6950345ba79b0d51302df88285d4e83ee3fe41339eee7"),
+ hexEncPubkey("1d4a623659f7c8f80b6c3939596afdf42e78f892f682c768ad36eb7bfba402dbf97aea3a268f3badd8fe7636be216edf3d67ee1e08789ebbc7be625056bd7109"),
+ hexEncPubkey("a283c474ab09da02bbc96b16317241d0627646fcc427d1fe790b76a7bf1989ced90f92101a973047ae9940c92720dffbac8eff21df8cae468a50f72f9e159417"),
+ hexEncPubkey("dbf7e5ad7f87c3dfecae65d87c3039e14ed0bdc56caf00ce81931073e2e16719d746295512ff7937a15c3b03603e7c41a4f9df94fcd37bb200dd8f332767e9cb"),
+ hexEncPubkey("caaa070a26692f64fc77f30d7b5ae980d419b4393a0f442b1c821ef58c0862898b0d22f74a4f8c5d83069493e3ec0b92f17dc1fe6e4cd437c1ec25039e7ce839"),
+ hexEncPubkey("874cc8d1213beb65c4e0e1de38ef5d8165235893ac74ab5ea937c885eaab25c8d79dad0456e9fd3e9450626cac7e107b004478fb59842f067857f39a47cee695"),
+ hexEncPubkey("d94193f236105010972f5df1b7818b55846592a0445b9cdc4eaed811b8c4c0f7c27dc8cc9837a4774656d6b34682d6d329d42b6ebb55da1d475c2474dc3dfdf4"),
+ hexEncPubkey("edd9af6aded4094e9785637c28fccbd3980cbe28e2eb9a411048a23c2ace4bd6b0b7088a7817997b49a3dd05fc6929ca6c7abbb69438dbdabe65e971d2a794b2"),
},
250: {
- MustHexID("53a5bd1215d4ab709ae8fdc2ced50bba320bced78bd9c5dc92947fb402250c914891786db0978c898c058493f86fc68b1c5de8a5cb36336150ac7a88655b6c39"),
- MustHexID("b7f79e3ab59f79262623c9ccefc8f01d682323aee56ffbe295437487e9d5acaf556a9c92e1f1c6a9601f2b9eb6b027ae1aeaebac71d61b9b78e88676efd3e1a3"),
- MustHexID("d374bf7e8d7ffff69cc00bebff38ef5bc1dcb0a8d51c1a3d70e61ac6b2e2d6617109254b0ac224354dfbf79009fe4239e09020c483cc60c071e00b9238684f30"),
- MustHexID("1e1eac1c9add703eb252eb991594f8f5a173255d526a855fab24ae57dc277e055bc3c7a7ae0b45d437c4f47a72d97eb7b126f2ba344ba6c0e14b2c6f27d4b1e6"),
- MustHexID("ae28953f63d4bc4e706712a59319c111f5ff8f312584f65d7436b4cd3d14b217b958f8486bad666b4481fe879019fb1f767cf15b3e3e2711efc33b56d460448a"),
- MustHexID("934bb1edf9c7a318b82306aca67feb3d6b434421fa275d694f0b4927afd8b1d3935b727fd4ff6e3d012e0c82f1824385174e8c6450ade59c2a43281a4b3446b6"),
- MustHexID("9eef3f28f70ce19637519a0916555bf76d26de31312ac656cf9d3e379899ea44e4dd7ffcce923b4f3563f8a00489a34bd6936db0cbb4c959d32c49f017e07d05"),
- MustHexID("82200872e8f871c48f1fad13daec6478298099b591bb3dbc4ef6890aa28ebee5860d07d70be62f4c0af85085a90ae8179ee8f937cf37915c67ea73e704b03ee7"),
- MustHexID("6c75a5834a08476b7fc37ff3dc2011dc3ea3b36524bad7a6d319b18878fad813c0ba76d1f4555cacd3890c865438c21f0e0aed1f80e0a157e642124c69f43a11"),
- MustHexID("995b873742206cb02b736e73a88580c2aacb0bd4a3c97a647b647bcab3f5e03c0e0736520a8b3600da09edf4248991fb01091ec7ff3ec7cdc8a1beae011e7aae"),
- MustHexID("c773a056594b5cdef2e850d30891ff0e927c3b1b9c35cd8e8d53a1017001e237468e1ece3ae33d612ca3e6abb0a9169aa352e9dcda358e5af2ad982b577447db"),
- MustHexID("2b46a5f6923f475c6be99ec6d134437a6d11f6bb4b4ac6bcd94572fa1092639d1c08aeefcb51f0912f0a060f71d4f38ee4da70ecc16010b05dd4a674aab14c3a"),
- MustHexID("af6ab501366debbaa0d22e20e9688f32ef6b3b644440580fd78de4fe0e99e2a16eb5636bbae0d1c259df8ddda77b35b9a35cbc36137473e9c68fbc9d203ba842"),
- MustHexID("c9f6f2dd1a941926f03f770695bda289859e85fabaf94baaae20b93e5015dc014ba41150176a36a1884adb52f405194693e63b0c464a6891cc9cc1c80d450326"),
- MustHexID("5b116f0751526868a909b61a30b0c5282c37df6925cc03ddea556ef0d0602a9595fd6c14d371f8ed7d45d89918a032dcd22be4342a8793d88fdbeb3ca3d75bd7"),
- MustHexID("50f3222fb6b82481c7c813b2172e1daea43e2710a443b9c2a57a12bd160dd37e20f87aa968c82ad639af6972185609d47036c0d93b4b7269b74ebd7073221c10"),
+ hexEncPubkey("53a5bd1215d4ab709ae8fdc2ced50bba320bced78bd9c5dc92947fb402250c914891786db0978c898c058493f86fc68b1c5de8a5cb36336150ac7a88655b6c39"),
+ hexEncPubkey("b7f79e3ab59f79262623c9ccefc8f01d682323aee56ffbe295437487e9d5acaf556a9c92e1f1c6a9601f2b9eb6b027ae1aeaebac71d61b9b78e88676efd3e1a3"),
+ hexEncPubkey("d374bf7e8d7ffff69cc00bebff38ef5bc1dcb0a8d51c1a3d70e61ac6b2e2d6617109254b0ac224354dfbf79009fe4239e09020c483cc60c071e00b9238684f30"),
+ hexEncPubkey("1e1eac1c9add703eb252eb991594f8f5a173255d526a855fab24ae57dc277e055bc3c7a7ae0b45d437c4f47a72d97eb7b126f2ba344ba6c0e14b2c6f27d4b1e6"),
+ hexEncPubkey("ae28953f63d4bc4e706712a59319c111f5ff8f312584f65d7436b4cd3d14b217b958f8486bad666b4481fe879019fb1f767cf15b3e3e2711efc33b56d460448a"),
+ hexEncPubkey("934bb1edf9c7a318b82306aca67feb3d6b434421fa275d694f0b4927afd8b1d3935b727fd4ff6e3d012e0c82f1824385174e8c6450ade59c2a43281a4b3446b6"),
+ hexEncPubkey("9eef3f28f70ce19637519a0916555bf76d26de31312ac656cf9d3e379899ea44e4dd7ffcce923b4f3563f8a00489a34bd6936db0cbb4c959d32c49f017e07d05"),
+ hexEncPubkey("82200872e8f871c48f1fad13daec6478298099b591bb3dbc4ef6890aa28ebee5860d07d70be62f4c0af85085a90ae8179ee8f937cf37915c67ea73e704b03ee7"),
+ hexEncPubkey("6c75a5834a08476b7fc37ff3dc2011dc3ea3b36524bad7a6d319b18878fad813c0ba76d1f4555cacd3890c865438c21f0e0aed1f80e0a157e642124c69f43a11"),
+ hexEncPubkey("995b873742206cb02b736e73a88580c2aacb0bd4a3c97a647b647bcab3f5e03c0e0736520a8b3600da09edf4248991fb01091ec7ff3ec7cdc8a1beae011e7aae"),
+ hexEncPubkey("c773a056594b5cdef2e850d30891ff0e927c3b1b9c35cd8e8d53a1017001e237468e1ece3ae33d612ca3e6abb0a9169aa352e9dcda358e5af2ad982b577447db"),
+ hexEncPubkey("2b46a5f6923f475c6be99ec6d134437a6d11f6bb4b4ac6bcd94572fa1092639d1c08aeefcb51f0912f0a060f71d4f38ee4da70ecc16010b05dd4a674aab14c3a"),
+ hexEncPubkey("af6ab501366debbaa0d22e20e9688f32ef6b3b644440580fd78de4fe0e99e2a16eb5636bbae0d1c259df8ddda77b35b9a35cbc36137473e9c68fbc9d203ba842"),
+ hexEncPubkey("c9f6f2dd1a941926f03f770695bda289859e85fabaf94baaae20b93e5015dc014ba41150176a36a1884adb52f405194693e63b0c464a6891cc9cc1c80d450326"),
+ hexEncPubkey("5b116f0751526868a909b61a30b0c5282c37df6925cc03ddea556ef0d0602a9595fd6c14d371f8ed7d45d89918a032dcd22be4342a8793d88fdbeb3ca3d75bd7"),
+ hexEncPubkey("50f3222fb6b82481c7c813b2172e1daea43e2710a443b9c2a57a12bd160dd37e20f87aa968c82ad639af6972185609d47036c0d93b4b7269b74ebd7073221c10"),
},
251: {
- MustHexID("9b8f702a62d1bee67bedfeb102eca7f37fa1713e310f0d6651cc0c33ea7c5477575289ccd463e5a2574a00a676a1fdce05658ba447bb9d2827f0ba47b947e894"),
- MustHexID("b97532eb83054ed054b4abdf413bb30c00e4205545c93521554dbe77faa3cfaa5bd31ef466a107b0b34a71ec97214c0c83919720142cddac93aa7a3e928d4708"),
- MustHexID("2f7a5e952bfb67f2f90b8441b5fadc9ee13b1dcde3afeeb3dd64bf937f86663cc5c55d1fa83952b5422763c7df1b7f2794b751c6be316ebc0beb4942e65ab8c1"),
- MustHexID("42c7483781727051a0b3660f14faf39e0d33de5e643702ae933837d036508ab856ce7eec8ec89c4929a4901256e5233a3d847d5d4893f91bcf21835a9a880fee"),
- MustHexID("873bae27bf1dc854408fba94046a53ab0c965cebe1e4e12290806fc62b88deb1f4a47f9e18f78fc0e7913a0c6e42ac4d0fc3a20cea6bc65f0c8a0ca90b67521e"),
- MustHexID("a7e3a370bbd761d413f8d209e85886f68bf73d5c3089b2dc6fa42aab1ecb5162635497eed95dee2417f3c9c74a3e76319625c48ead2e963c7de877cd4551f347"),
- MustHexID("528597534776a40df2addaaea15b6ff832ce36b9748a265768368f657e76d58569d9f30dbb91e91cf0ae7efe8f402f17aa0ae15f5c55051ba03ba830287f4c42"),
- MustHexID("461d8bd4f13c3c09031fdb84f104ed737a52f630261463ce0bdb5704259bab4b737dda688285b8444dbecaecad7f50f835190b38684ced5e90c54219e5adf1bc"),
- MustHexID("6ec50c0be3fd232737090fc0111caaf0bb6b18f72be453428087a11a97fd6b52db0344acbf789a689bd4f5f50f79017ea784f8fd6fe723ad6ae675b9e3b13e21"),
- MustHexID("12fc5e2f77a83fdcc727b79d8ae7fe6a516881138d3011847ee136b400fed7cfba1f53fd7a9730253c7aa4f39abeacd04f138417ba7fcb0f36cccc3514e0dab6"),
- MustHexID("4fdbe75914ccd0bce02101606a1ccf3657ec963e3b3c20239d5fec87673fe446d649b4f15f1fe1a40e6cfbd446dda2d31d40bb602b1093b8fcd5f139ba0eb46a"),
- MustHexID("3753668a0f6281e425ea69b52cb2d17ab97afbe6eb84cf5d25425bc5e53009388857640668fadd7c110721e6047c9697803bd8a6487b43bb343bfa32ebf24039"),
- MustHexID("2e81b16346637dec4410fd88e527346145b9c0a849dbf2628049ac7dae016c8f4305649d5659ec77f1e8a0fac0db457b6080547226f06283598e3740ad94849a"),
- MustHexID("802c3cc27f91c89213223d758f8d2ecd41135b357b6d698f24d811cdf113033a81c38e0bdff574a5c005b00a8c193dc2531f8c1fa05fa60acf0ab6f2858af09f"),
- MustHexID("fcc9a2e1ac3667026ff16192876d1813bb75abdbf39b929a92863012fe8b1d890badea7a0de36274d5c1eb1e8f975785532c50d80fd44b1a4b692f437303393f"),
- MustHexID("6d8b3efb461151dd4f6de809b62726f5b89e9b38e9ba1391967f61cde844f7528fecf821b74049207cee5a527096b31f3ad623928cd3ce51d926fa345a6b2951"),
+ hexEncPubkey("9b8f702a62d1bee67bedfeb102eca7f37fa1713e310f0d6651cc0c33ea7c5477575289ccd463e5a2574a00a676a1fdce05658ba447bb9d2827f0ba47b947e894"),
+ hexEncPubkey("b97532eb83054ed054b4abdf413bb30c00e4205545c93521554dbe77faa3cfaa5bd31ef466a107b0b34a71ec97214c0c83919720142cddac93aa7a3e928d4708"),
+ hexEncPubkey("2f7a5e952bfb67f2f90b8441b5fadc9ee13b1dcde3afeeb3dd64bf937f86663cc5c55d1fa83952b5422763c7df1b7f2794b751c6be316ebc0beb4942e65ab8c1"),
+ hexEncPubkey("42c7483781727051a0b3660f14faf39e0d33de5e643702ae933837d036508ab856ce7eec8ec89c4929a4901256e5233a3d847d5d4893f91bcf21835a9a880fee"),
+ hexEncPubkey("873bae27bf1dc854408fba94046a53ab0c965cebe1e4e12290806fc62b88deb1f4a47f9e18f78fc0e7913a0c6e42ac4d0fc3a20cea6bc65f0c8a0ca90b67521e"),
+ hexEncPubkey("a7e3a370bbd761d413f8d209e85886f68bf73d5c3089b2dc6fa42aab1ecb5162635497eed95dee2417f3c9c74a3e76319625c48ead2e963c7de877cd4551f347"),
+ hexEncPubkey("528597534776a40df2addaaea15b6ff832ce36b9748a265768368f657e76d58569d9f30dbb91e91cf0ae7efe8f402f17aa0ae15f5c55051ba03ba830287f4c42"),
+ hexEncPubkey("461d8bd4f13c3c09031fdb84f104ed737a52f630261463ce0bdb5704259bab4b737dda688285b8444dbecaecad7f50f835190b38684ced5e90c54219e5adf1bc"),
+ hexEncPubkey("6ec50c0be3fd232737090fc0111caaf0bb6b18f72be453428087a11a97fd6b52db0344acbf789a689bd4f5f50f79017ea784f8fd6fe723ad6ae675b9e3b13e21"),
+ hexEncPubkey("12fc5e2f77a83fdcc727b79d8ae7fe6a516881138d3011847ee136b400fed7cfba1f53fd7a9730253c7aa4f39abeacd04f138417ba7fcb0f36cccc3514e0dab6"),
+ hexEncPubkey("4fdbe75914ccd0bce02101606a1ccf3657ec963e3b3c20239d5fec87673fe446d649b4f15f1fe1a40e6cfbd446dda2d31d40bb602b1093b8fcd5f139ba0eb46a"),
+ hexEncPubkey("3753668a0f6281e425ea69b52cb2d17ab97afbe6eb84cf5d25425bc5e53009388857640668fadd7c110721e6047c9697803bd8a6487b43bb343bfa32ebf24039"),
+ hexEncPubkey("2e81b16346637dec4410fd88e527346145b9c0a849dbf2628049ac7dae016c8f4305649d5659ec77f1e8a0fac0db457b6080547226f06283598e3740ad94849a"),
+ hexEncPubkey("802c3cc27f91c89213223d758f8d2ecd41135b357b6d698f24d811cdf113033a81c38e0bdff574a5c005b00a8c193dc2531f8c1fa05fa60acf0ab6f2858af09f"),
+ hexEncPubkey("fcc9a2e1ac3667026ff16192876d1813bb75abdbf39b929a92863012fe8b1d890badea7a0de36274d5c1eb1e8f975785532c50d80fd44b1a4b692f437303393f"),
+ hexEncPubkey("6d8b3efb461151dd4f6de809b62726f5b89e9b38e9ba1391967f61cde844f7528fecf821b74049207cee5a527096b31f3ad623928cd3ce51d926fa345a6b2951"),
},
252: {
- MustHexID("f1ae93157cc48c2075dd5868fbf523e79e06caf4b8198f352f6e526680b78ff4227263de92612f7d63472bd09367bb92a636fff16fe46ccf41614f7a72495c2a"),
- MustHexID("587f482d111b239c27c0cb89b51dd5d574db8efd8de14a2e6a1400c54d4567e77c65f89c1da52841212080b91604104768350276b6682f2f961cdaf4039581c7"),
- MustHexID("e3f88274d35cefdaabdf205afe0e80e936cc982b8e3e47a84ce664c413b29016a4fb4f3a3ebae0a2f79671f8323661ed462bf4390af94c424dc8ace0c301b90f"),
- MustHexID("0ddc736077da9a12ba410dc5ea63cbcbe7659dd08596485b2bff3435221f82c10d263efd9af938e128464be64a178b7cd22e19f400d5802f4c9df54bf89f2619"),
- MustHexID("784aa34d833c6ce63fcc1279630113c3272e82c4ae8c126c5a52a88ac461b6baeed4244e607b05dc14e5b2f41c70a273c3804dea237f14f7a1e546f6d1309d14"),
- MustHexID("f253a2c354ee0e27cfcae786d726753d4ad24be6516b279a936195a487de4a59dbc296accf20463749ff55293263ed8c1b6365eecb248d44e75e9741c0d18205"),
- MustHexID("a1910b80357b3ad9b4593e0628922939614dc9056a5fbf477279c8b2c1d0b4b31d89a0c09d0d41f795271d14d3360ef08a3f821e65e7e1f56c07a36afe49c7c5"),
- MustHexID("f1168552c2efe541160f0909b0b4a9d6aeedcf595cdf0e9b165c97e3e197471a1ee6320e93389edfba28af6eaf10de98597ad56e7ab1b504ed762451996c3b98"),
- MustHexID("b0c8e5d2c8634a7930e1a6fd082e448c6cf9d2d8b7293558b59238815a4df926c286bf297d2049f14e8296a6eb3256af614ec1812c4f2bbe807673b58bf14c8c"),
- MustHexID("0fb346076396a38badc342df3679b55bd7f40a609ab103411fe45082c01f12ea016729e95914b2b5540e987ff5c9b133e85862648e7f36abdfd23100d248d234"),
- MustHexID("f736e0cc83417feaa280d9483f5d4d72d1b036cd0c6d9cbdeb8ac35ceb2604780de46dddaa32a378474e1d5ccdf79b373331c30c7911ade2ae32f98832e5de1f"),
- MustHexID("8b02991457602f42b38b342d3f2259ae4100c354b3843885f7e4e07bd644f64dab94bb7f38a3915f8b7f11d8e3f81c28e07a0078cf79d7397e38a7b7e0c857e2"),
- MustHexID("9221d9f04a8a184993d12baa91116692bb685f887671302999d69300ad103eb2d2c75a09d8979404c6dd28f12362f58a1a43619c493d9108fd47588a23ce5824"),
- MustHexID("652797801744dada833fff207d67484742eea6835d695925f3e618d71b68ec3c65bdd85b4302b2cdcb835ad3f94fd00d8da07e570b41bc0d2bcf69a8de1b3284"),
- MustHexID("d84f06fe64debc4cd0625e36d19b99014b6218375262cc2209202bdbafd7dffcc4e34ce6398e182e02fd8faeed622c3e175545864902dfd3d1ac57647cddf4c6"),
- MustHexID("d0ed87b294f38f1d741eb601020eeec30ac16331d05880fe27868f1e454446de367d7457b41c79e202eaf9525b029e4f1d7e17d85a55f83a557c005c68d7328a"),
+ hexEncPubkey("f1ae93157cc48c2075dd5868fbf523e79e06caf4b8198f352f6e526680b78ff4227263de92612f7d63472bd09367bb92a636fff16fe46ccf41614f7a72495c2a"),
+ hexEncPubkey("587f482d111b239c27c0cb89b51dd5d574db8efd8de14a2e6a1400c54d4567e77c65f89c1da52841212080b91604104768350276b6682f2f961cdaf4039581c7"),
+ hexEncPubkey("e3f88274d35cefdaabdf205afe0e80e936cc982b8e3e47a84ce664c413b29016a4fb4f3a3ebae0a2f79671f8323661ed462bf4390af94c424dc8ace0c301b90f"),
+ hexEncPubkey("0ddc736077da9a12ba410dc5ea63cbcbe7659dd08596485b2bff3435221f82c10d263efd9af938e128464be64a178b7cd22e19f400d5802f4c9df54bf89f2619"),
+ hexEncPubkey("784aa34d833c6ce63fcc1279630113c3272e82c4ae8c126c5a52a88ac461b6baeed4244e607b05dc14e5b2f41c70a273c3804dea237f14f7a1e546f6d1309d14"),
+ hexEncPubkey("f253a2c354ee0e27cfcae786d726753d4ad24be6516b279a936195a487de4a59dbc296accf20463749ff55293263ed8c1b6365eecb248d44e75e9741c0d18205"),
+ hexEncPubkey("a1910b80357b3ad9b4593e0628922939614dc9056a5fbf477279c8b2c1d0b4b31d89a0c09d0d41f795271d14d3360ef08a3f821e65e7e1f56c07a36afe49c7c5"),
+ hexEncPubkey("f1168552c2efe541160f0909b0b4a9d6aeedcf595cdf0e9b165c97e3e197471a1ee6320e93389edfba28af6eaf10de98597ad56e7ab1b504ed762451996c3b98"),
+ hexEncPubkey("b0c8e5d2c8634a7930e1a6fd082e448c6cf9d2d8b7293558b59238815a4df926c286bf297d2049f14e8296a6eb3256af614ec1812c4f2bbe807673b58bf14c8c"),
+ hexEncPubkey("0fb346076396a38badc342df3679b55bd7f40a609ab103411fe45082c01f12ea016729e95914b2b5540e987ff5c9b133e85862648e7f36abdfd23100d248d234"),
+ hexEncPubkey("f736e0cc83417feaa280d9483f5d4d72d1b036cd0c6d9cbdeb8ac35ceb2604780de46dddaa32a378474e1d5ccdf79b373331c30c7911ade2ae32f98832e5de1f"),
+ hexEncPubkey("8b02991457602f42b38b342d3f2259ae4100c354b3843885f7e4e07bd644f64dab94bb7f38a3915f8b7f11d8e3f81c28e07a0078cf79d7397e38a7b7e0c857e2"),
+ hexEncPubkey("9221d9f04a8a184993d12baa91116692bb685f887671302999d69300ad103eb2d2c75a09d8979404c6dd28f12362f58a1a43619c493d9108fd47588a23ce5824"),
+ hexEncPubkey("652797801744dada833fff207d67484742eea6835d695925f3e618d71b68ec3c65bdd85b4302b2cdcb835ad3f94fd00d8da07e570b41bc0d2bcf69a8de1b3284"),
+ hexEncPubkey("d84f06fe64debc4cd0625e36d19b99014b6218375262cc2209202bdbafd7dffcc4e34ce6398e182e02fd8faeed622c3e175545864902dfd3d1ac57647cddf4c6"),
+ hexEncPubkey("d0ed87b294f38f1d741eb601020eeec30ac16331d05880fe27868f1e454446de367d7457b41c79e202eaf9525b029e4f1d7e17d85a55f83a557c005c68d7328a"),
},
253: {
- MustHexID("ad4485e386e3cc7c7310366a7c38fb810b8896c0d52e55944bfd320ca294e7912d6c53c0a0cf85e7ce226e92491d60430e86f8f15cda0161ed71893fb4a9e3a1"),
- MustHexID("36d0e7e5b7734f98c6183eeeb8ac5130a85e910a925311a19c4941b1290f945d4fc3996b12ef4966960b6fa0fb29b1604f83a0f81bd5fd6398d2e1a22e46af0c"),
- MustHexID("7d307d8acb4a561afa23bdf0bd945d35c90245e26345ec3a1f9f7df354222a7cdcb81339c9ed6744526c27a1a0c8d10857e98df942fa433602facac71ac68a31"),
- MustHexID("d97bf55f88c83fae36232661af115d66ca600fc4bd6d1fb35ff9bb4dad674c02cf8c8d05f317525b5522250db58bb1ecafb7157392bf5aa61b178c61f098d995"),
- MustHexID("7045d678f1f9eb7a4613764d17bd5698796494d0bf977b16f2dbc272b8a0f7858a60805c022fc3d1fe4f31c37e63cdaca0416c0d053ef48a815f8b19121605e0"),
- MustHexID("14e1f21418d445748de2a95cd9a8c3b15b506f86a0acabd8af44bb968ce39885b19c8822af61b3dd58a34d1f265baec30e3ae56149dc7d2aa4a538f7319f69c8"),
- MustHexID("b9453d78281b66a4eac95a1546017111eaaa5f92a65d0de10b1122940e92b319728a24edf4dec6acc412321b1c95266d39c7b3a5d265c629c3e49a65fb022c09"),
- MustHexID("e8a49248419e3824a00d86af422f22f7366e2d4922b304b7169937616a01d9d6fa5abf5cc01061a352dc866f48e1fa2240dbb453d872b1d7be62bdfc1d5e248c"),
- MustHexID("bebcff24b52362f30e0589ee573ce2d86f073d58d18e6852a592fa86ceb1a6c9b96d7fb9ec7ed1ed98a51b6743039e780279f6bb49d0a04327ac7a182d9a56f6"),
- MustHexID("d0835e5a4291db249b8d2fca9f503049988180c7d247bedaa2cf3a1bad0a76709360a85d4f9a1423b2cbc82bb4d94b47c0cde20afc430224834c49fe312a9ae3"),
- MustHexID("6b087fe2a2da5e4f0b0f4777598a4a7fb66bf77dbd5bfc44e8a7eaa432ab585a6e226891f56a7d4f5ed11a7c57b90f1661bba1059590ca4267a35801c2802913"),
- MustHexID("d901e5bde52d1a0f4ddf010a686a53974cdae4ebe5c6551b3c37d6b6d635d38d5b0e5f80bc0186a2c7809dbf3a42870dd09643e68d32db896c6da8ba734579e7"),
- MustHexID("96419fb80efae4b674402bb969ebaab86c1274f29a83a311e24516d36cdf148fe21754d46c97688cdd7468f24c08b13e4727c29263393638a3b37b99ff60ebca"),
- MustHexID("7b9c1889ae916a5d5abcdfb0aaedcc9c6f9eb1c1a4f68d0c2d034fe79ac610ce917c3abc670744150fa891bfcd8ab14fed6983fca964de920aa393fa7b326748"),
- MustHexID("7a369b2b8962cc4c65900be046482fbf7c14f98a135bbbae25152c82ad168fb2097b3d1429197cf46d3ce9fdeb64808f908a489cc6019725db040060fdfe5405"),
- MustHexID("47bcae48288da5ecc7f5058dfa07cf14d89d06d6e449cb946e237aa6652ea050d9f5a24a65efdc0013ccf232bf88670979eddef249b054f63f38da9d7796dbd8"),
+ hexEncPubkey("ad4485e386e3cc7c7310366a7c38fb810b8896c0d52e55944bfd320ca294e7912d6c53c0a0cf85e7ce226e92491d60430e86f8f15cda0161ed71893fb4a9e3a1"),
+ hexEncPubkey("36d0e7e5b7734f98c6183eeeb8ac5130a85e910a925311a19c4941b1290f945d4fc3996b12ef4966960b6fa0fb29b1604f83a0f81bd5fd6398d2e1a22e46af0c"),
+ hexEncPubkey("7d307d8acb4a561afa23bdf0bd945d35c90245e26345ec3a1f9f7df354222a7cdcb81339c9ed6744526c27a1a0c8d10857e98df942fa433602facac71ac68a31"),
+ hexEncPubkey("d97bf55f88c83fae36232661af115d66ca600fc4bd6d1fb35ff9bb4dad674c02cf8c8d05f317525b5522250db58bb1ecafb7157392bf5aa61b178c61f098d995"),
+ hexEncPubkey("7045d678f1f9eb7a4613764d17bd5698796494d0bf977b16f2dbc272b8a0f7858a60805c022fc3d1fe4f31c37e63cdaca0416c0d053ef48a815f8b19121605e0"),
+ hexEncPubkey("14e1f21418d445748de2a95cd9a8c3b15b506f86a0acabd8af44bb968ce39885b19c8822af61b3dd58a34d1f265baec30e3ae56149dc7d2aa4a538f7319f69c8"),
+ hexEncPubkey("b9453d78281b66a4eac95a1546017111eaaa5f92a65d0de10b1122940e92b319728a24edf4dec6acc412321b1c95266d39c7b3a5d265c629c3e49a65fb022c09"),
+ hexEncPubkey("e8a49248419e3824a00d86af422f22f7366e2d4922b304b7169937616a01d9d6fa5abf5cc01061a352dc866f48e1fa2240dbb453d872b1d7be62bdfc1d5e248c"),
+ hexEncPubkey("bebcff24b52362f30e0589ee573ce2d86f073d58d18e6852a592fa86ceb1a6c9b96d7fb9ec7ed1ed98a51b6743039e780279f6bb49d0a04327ac7a182d9a56f6"),
+ hexEncPubkey("d0835e5a4291db249b8d2fca9f503049988180c7d247bedaa2cf3a1bad0a76709360a85d4f9a1423b2cbc82bb4d94b47c0cde20afc430224834c49fe312a9ae3"),
+ hexEncPubkey("6b087fe2a2da5e4f0b0f4777598a4a7fb66bf77dbd5bfc44e8a7eaa432ab585a6e226891f56a7d4f5ed11a7c57b90f1661bba1059590ca4267a35801c2802913"),
+ hexEncPubkey("d901e5bde52d1a0f4ddf010a686a53974cdae4ebe5c6551b3c37d6b6d635d38d5b0e5f80bc0186a2c7809dbf3a42870dd09643e68d32db896c6da8ba734579e7"),
+ hexEncPubkey("96419fb80efae4b674402bb969ebaab86c1274f29a83a311e24516d36cdf148fe21754d46c97688cdd7468f24c08b13e4727c29263393638a3b37b99ff60ebca"),
+ hexEncPubkey("7b9c1889ae916a5d5abcdfb0aaedcc9c6f9eb1c1a4f68d0c2d034fe79ac610ce917c3abc670744150fa891bfcd8ab14fed6983fca964de920aa393fa7b326748"),
+ hexEncPubkey("7a369b2b8962cc4c65900be046482fbf7c14f98a135bbbae25152c82ad168fb2097b3d1429197cf46d3ce9fdeb64808f908a489cc6019725db040060fdfe5405"),
+ hexEncPubkey("47bcae48288da5ecc7f5058dfa07cf14d89d06d6e449cb946e237aa6652ea050d9f5a24a65efdc0013ccf232bf88670979eddef249b054f63f38da9d7796dbd8"),
},
254: {
- MustHexID("099739d7abc8abd38ecc7a816c521a1168a4dbd359fa7212a5123ab583ffa1cf485a5fed219575d6475dbcdd541638b2d3631a6c7fce7474e7fe3cba1d4d5853"),
- MustHexID("c2b01603b088a7182d0cf7ef29fb2b04c70acb320fccf78526bf9472e10c74ee70b3fcfa6f4b11d167bd7d3bc4d936b660f2c9bff934793d97cb21750e7c3d31"),
- MustHexID("20e4d8f45f2f863e94b45548c1ef22a11f7d36f263e4f8623761e05a64c4572379b000a52211751e2561b0f14f4fc92dd4130410c8ccc71eb4f0e95a700d4ca9"),
- MustHexID("27f4a16cc085e72d86e25c98bd2eca173eaaee7565c78ec5a52e9e12b2211f35de81b5b45e9195de2ebfe29106742c59112b951a04eb7ae48822911fc1f9389e"),
- MustHexID("55db5ee7d98e7f0b1c3b9d5be6f2bc619a1b86c3cdd513160ad4dcf267037a5fffad527ac15d50aeb32c59c13d1d4c1e567ebbf4de0d25236130c8361f9aac63"),
- MustHexID("883df308b0130fc928a8559fe50667a0fff80493bc09685d18213b2db241a3ad11310ed86b0ef662b3ce21fc3d9aa7f3fc24b8d9afe17c7407e9afd3345ae548"),
- MustHexID("c7af968cc9bc8200c3ee1a387405f7563be1dce6710a3439f42ea40657d0eae9d2b3c16c42d779605351fcdece4da637b9804e60ca08cfb89aec32c197beffa6"),
- MustHexID("3e66f2b788e3ff1d04106b80597915cd7afa06c405a7ae026556b6e583dca8e05cfbab5039bb9a1b5d06083ffe8de5780b1775550e7218f5e98624bf7af9a0a8"),
- MustHexID("4fc7f53764de3337fdaec0a711d35d3a923e72fa65025444d12230b3552ed43d9b2d1ad08ccb11f2d50c58809e6dd74dde910e195294fca3b47ae5a3967cc479"),
- MustHexID("bafdfdcf6ccaa989436752fa97c77477b6baa7deb374b16c095492c529eb133e8e2f99e1977012b64767b9d34b2cf6d2048ed489bd822b5139b523f6a423167b"),
- MustHexID("7f5d78008a4312fe059104ce80202c82b8915c2eb4411c6b812b16f7642e57c00f2c9425121f5cbac4257fe0b3e81ef5dea97ea2dbaa98f6a8b6fd4d1e5980bb"),
- MustHexID("598c37fe78f922751a052f463aeb0cb0bc7f52b7c2a4cf2da72ec0931c7c32175d4165d0f8998f7320e87324ac3311c03f9382a5385c55f0407b7a66b2acd864"),
- MustHexID("f758c4136e1c148777a7f3275a76e2db0b2b04066fd738554ec398c1c6cc9fb47e14a3b4c87bd47deaeab3ffd2110514c3855685a374794daff87b605b27ee2e"),
- MustHexID("0307bb9e4fd865a49dcf1fe4333d1b944547db650ab580af0b33e53c4fef6c789531110fac801bbcbce21fc4d6f61b6d5b24abdf5b22e3030646d579f6dca9c2"),
- MustHexID("82504b6eb49bb2c0f91a7006ce9cefdbaf6df38706198502c2e06601091fc9dc91e4f15db3410d45c6af355bc270b0f268d3dff560f956985c7332d4b10bd1ed"),
- MustHexID("b39b5b677b45944ceebe76e76d1f051de2f2a0ec7b0d650da52135743e66a9a5dba45f638258f9a7545d9a790c7fe6d3fdf82c25425c7887323e45d27d06c057"),
+ hexEncPubkey("099739d7abc8abd38ecc7a816c521a1168a4dbd359fa7212a5123ab583ffa1cf485a5fed219575d6475dbcdd541638b2d3631a6c7fce7474e7fe3cba1d4d5853"),
+ hexEncPubkey("c2b01603b088a7182d0cf7ef29fb2b04c70acb320fccf78526bf9472e10c74ee70b3fcfa6f4b11d167bd7d3bc4d936b660f2c9bff934793d97cb21750e7c3d31"),
+ hexEncPubkey("20e4d8f45f2f863e94b45548c1ef22a11f7d36f263e4f8623761e05a64c4572379b000a52211751e2561b0f14f4fc92dd4130410c8ccc71eb4f0e95a700d4ca9"),
+ hexEncPubkey("27f4a16cc085e72d86e25c98bd2eca173eaaee7565c78ec5a52e9e12b2211f35de81b5b45e9195de2ebfe29106742c59112b951a04eb7ae48822911fc1f9389e"),
+ hexEncPubkey("55db5ee7d98e7f0b1c3b9d5be6f2bc619a1b86c3cdd513160ad4dcf267037a5fffad527ac15d50aeb32c59c13d1d4c1e567ebbf4de0d25236130c8361f9aac63"),
+ hexEncPubkey("883df308b0130fc928a8559fe50667a0fff80493bc09685d18213b2db241a3ad11310ed86b0ef662b3ce21fc3d9aa7f3fc24b8d9afe17c7407e9afd3345ae548"),
+ hexEncPubkey("c7af968cc9bc8200c3ee1a387405f7563be1dce6710a3439f42ea40657d0eae9d2b3c16c42d779605351fcdece4da637b9804e60ca08cfb89aec32c197beffa6"),
+ hexEncPubkey("3e66f2b788e3ff1d04106b80597915cd7afa06c405a7ae026556b6e583dca8e05cfbab5039bb9a1b5d06083ffe8de5780b1775550e7218f5e98624bf7af9a0a8"),
+ hexEncPubkey("4fc7f53764de3337fdaec0a711d35d3a923e72fa65025444d12230b3552ed43d9b2d1ad08ccb11f2d50c58809e6dd74dde910e195294fca3b47ae5a3967cc479"),
+ hexEncPubkey("bafdfdcf6ccaa989436752fa97c77477b6baa7deb374b16c095492c529eb133e8e2f99e1977012b64767b9d34b2cf6d2048ed489bd822b5139b523f6a423167b"),
+ hexEncPubkey("7f5d78008a4312fe059104ce80202c82b8915c2eb4411c6b812b16f7642e57c00f2c9425121f5cbac4257fe0b3e81ef5dea97ea2dbaa98f6a8b6fd4d1e5980bb"),
+ hexEncPubkey("598c37fe78f922751a052f463aeb0cb0bc7f52b7c2a4cf2da72ec0931c7c32175d4165d0f8998f7320e87324ac3311c03f9382a5385c55f0407b7a66b2acd864"),
+ hexEncPubkey("f758c4136e1c148777a7f3275a76e2db0b2b04066fd738554ec398c1c6cc9fb47e14a3b4c87bd47deaeab3ffd2110514c3855685a374794daff87b605b27ee2e"),
+ hexEncPubkey("0307bb9e4fd865a49dcf1fe4333d1b944547db650ab580af0b33e53c4fef6c789531110fac801bbcbce21fc4d6f61b6d5b24abdf5b22e3030646d579f6dca9c2"),
+ hexEncPubkey("82504b6eb49bb2c0f91a7006ce9cefdbaf6df38706198502c2e06601091fc9dc91e4f15db3410d45c6af355bc270b0f268d3dff560f956985c7332d4b10bd1ed"),
+ hexEncPubkey("b39b5b677b45944ceebe76e76d1f051de2f2a0ec7b0d650da52135743e66a9a5dba45f638258f9a7545d9a790c7fe6d3fdf82c25425c7887323e45d27d06c057"),
},
255: {
- MustHexID("5c4d58d46e055dd1f093f81ee60a675e1f02f54da6206720adee4dccef9b67a31efc5c2a2949c31a04ee31beadc79aba10da31440a1f9ff2a24093c63c36d784"),
- MustHexID("ea72161ffdd4b1e124c7b93b0684805f4c4b58d617ed498b37a145c670dbc2e04976f8785583d9c805ffbf343c31d492d79f841652bbbd01b61ed85640b23495"),
- MustHexID("51caa1d93352d47a8e531692a3612adac1e8ac68d0a200d086c1c57ae1e1a91aa285ab242e8c52ef9d7afe374c9485b122ae815f1707b875569d0433c1c3ce85"),
- MustHexID("c08397d5751b47bd3da044b908be0fb0e510d3149574dff7aeab33749b023bb171b5769990fe17469dbebc100bc150e798aeda426a2dcc766699a225fddd75c6"),
- MustHexID("0222c1c194b749736e593f937fad67ee348ac57287a15c7e42877aa38a9b87732a408bca370f812efd0eedbff13e6d5b854bf3ba1dec431a796ed47f32552b09"),
- MustHexID("03d859cd46ef02d9bfad5268461a6955426845eef4126de6be0fa4e8d7e0727ba2385b78f1a883a8239e95ebb814f2af8379632c7d5b100688eebc5841209582"),
- MustHexID("64d5004b7e043c39ff0bd10cb20094c287721d5251715884c280a612b494b3e9e1c64ba6f67614994c7d969a0d0c0295d107d53fc225d47c44c4b82852d6f960"),
- MustHexID("b0a5eefb2dab6f786670f35bf9641eefe6dd87fd3f1362bcab4aaa792903500ab23d88fae68411372e0813b057535a601d46e454323745a948017f6063a47b1f"),
- MustHexID("0cc6df0a3433d448b5684d2a3ffa9d1a825388177a18f44ad0008c7bd7702f1ec0fc38b83506f7de689c3b6ecb552599927e29699eed6bb867ff08f80068b287"),
- MustHexID("50772f7b8c03a4e153355fbbf79c8a80cf32af656ff0c7873c99911099d04a0dae0674706c357e0145ad017a0ade65e6052cb1b0d574fcd6f67da3eee0ace66b"),
- MustHexID("1ae37829c9ef41f8b508b82259ebac76b1ed900d7a45c08b7970f25d2d48ddd1829e2f11423a18749940b6dab8598c6e416cef0efd47e46e51f29a0bc65b37cd"),
- MustHexID("ba973cab31c2af091fc1644a93527d62b2394999e2b6ccbf158dd5ab9796a43d408786f1803ef4e29debfeb62fce2b6caa5ab2b24d1549c822a11c40c2856665"),
- MustHexID("bc413ad270dd6ea25bddba78f3298b03b8ba6f8608ac03d06007d4116fa78ef5a0cfe8c80155089382fc7a193243ee5500082660cb5d7793f60f2d7d18650964"),
- MustHexID("5a6a9ef07634d9eec3baa87c997b529b92652afa11473dfee41ef7037d5c06e0ddb9fe842364462d79dd31cff8a59a1b8d5bc2b810dea1d4cbbd3beb80ecec83"),
- MustHexID("f492c6ee2696d5f682f7f537757e52744c2ae560f1090a07024609e903d334e9e174fc01609c5a229ddbcac36c9d21adaf6457dab38a25bfd44f2f0ee4277998"),
- MustHexID("459e4db99298cb0467a90acee6888b08bb857450deac11015cced5104853be5adce5b69c740968bc7f931495d671a70cad9f48546d7cd203357fe9af0e8d2164"),
+ hexEncPubkey("5c4d58d46e055dd1f093f81ee60a675e1f02f54da6206720adee4dccef9b67a31efc5c2a2949c31a04ee31beadc79aba10da31440a1f9ff2a24093c63c36d784"),
+ hexEncPubkey("ea72161ffdd4b1e124c7b93b0684805f4c4b58d617ed498b37a145c670dbc2e04976f8785583d9c805ffbf343c31d492d79f841652bbbd01b61ed85640b23495"),
+ hexEncPubkey("51caa1d93352d47a8e531692a3612adac1e8ac68d0a200d086c1c57ae1e1a91aa285ab242e8c52ef9d7afe374c9485b122ae815f1707b875569d0433c1c3ce85"),
+ hexEncPubkey("c08397d5751b47bd3da044b908be0fb0e510d3149574dff7aeab33749b023bb171b5769990fe17469dbebc100bc150e798aeda426a2dcc766699a225fddd75c6"),
+ hexEncPubkey("0222c1c194b749736e593f937fad67ee348ac57287a15c7e42877aa38a9b87732a408bca370f812efd0eedbff13e6d5b854bf3ba1dec431a796ed47f32552b09"),
+ hexEncPubkey("03d859cd46ef02d9bfad5268461a6955426845eef4126de6be0fa4e8d7e0727ba2385b78f1a883a8239e95ebb814f2af8379632c7d5b100688eebc5841209582"),
+ hexEncPubkey("64d5004b7e043c39ff0bd10cb20094c287721d5251715884c280a612b494b3e9e1c64ba6f67614994c7d969a0d0c0295d107d53fc225d47c44c4b82852d6f960"),
+ hexEncPubkey("b0a5eefb2dab6f786670f35bf9641eefe6dd87fd3f1362bcab4aaa792903500ab23d88fae68411372e0813b057535a601d46e454323745a948017f6063a47b1f"),
+ hexEncPubkey("0cc6df0a3433d448b5684d2a3ffa9d1a825388177a18f44ad0008c7bd7702f1ec0fc38b83506f7de689c3b6ecb552599927e29699eed6bb867ff08f80068b287"),
+ hexEncPubkey("50772f7b8c03a4e153355fbbf79c8a80cf32af656ff0c7873c99911099d04a0dae0674706c357e0145ad017a0ade65e6052cb1b0d574fcd6f67da3eee0ace66b"),
+ hexEncPubkey("1ae37829c9ef41f8b508b82259ebac76b1ed900d7a45c08b7970f25d2d48ddd1829e2f11423a18749940b6dab8598c6e416cef0efd47e46e51f29a0bc65b37cd"),
+ hexEncPubkey("ba973cab31c2af091fc1644a93527d62b2394999e2b6ccbf158dd5ab9796a43d408786f1803ef4e29debfeb62fce2b6caa5ab2b24d1549c822a11c40c2856665"),
+ hexEncPubkey("bc413ad270dd6ea25bddba78f3298b03b8ba6f8608ac03d06007d4116fa78ef5a0cfe8c80155089382fc7a193243ee5500082660cb5d7793f60f2d7d18650964"),
+ hexEncPubkey("5a6a9ef07634d9eec3baa87c997b529b92652afa11473dfee41ef7037d5c06e0ddb9fe842364462d79dd31cff8a59a1b8d5bc2b810dea1d4cbbd3beb80ecec83"),
+ hexEncPubkey("f492c6ee2696d5f682f7f537757e52744c2ae560f1090a07024609e903d334e9e174fc01609c5a229ddbcac36c9d21adaf6457dab38a25bfd44f2f0ee4277998"),
+ hexEncPubkey("459e4db99298cb0467a90acee6888b08bb857450deac11015cced5104853be5adce5b69c740968bc7f931495d671a70cad9f48546d7cd203357fe9af0e8d2164"),
},
256: {
- MustHexID("a8593af8a4aef7b806b5197612017951bac8845a1917ca9a6a15dd6086d608505144990b245785c4cd2d67a295701c7aac2aa18823fb0033987284b019656268"),
- MustHexID("d2eebef914928c3aad77fc1b2a495f52d2294acf5edaa7d8a530b540f094b861a68fe8348a46a7c302f08ab609d85912a4968eacfea0740847b29421b4795d9e"),
- MustHexID("b14bfcb31495f32b650b63cf7d08492e3e29071fdc73cf2da0da48d4b191a70ba1a65f42ad8c343206101f00f8a48e8db4b08bf3f622c0853e7323b250835b91"),
- MustHexID("7feaee0d818c03eb30e4e0bf03ade0f3c21ca38e938a761aa1781cf70bda8cc5cd631a6cc53dd44f1d4a6d3e2dae6513c6c66ee50cb2f0e9ad6f7e319b309fd9"),
- MustHexID("4ca3b657b139311db8d583c25dd5963005e46689e1317620496cc64129c7f3e52870820e0ec7941d28809311df6db8a2867bbd4f235b4248af24d7a9c22d1232"),
- MustHexID("1181defb1d16851d42dd951d84424d6bd1479137f587fa184d5a8152be6b6b16ed08bcdb2c2ed8539bcde98c80c432875f9f724737c316a2bd385a39d3cab1d8"),
- MustHexID("d9dd818769fa0c3ec9f553c759b92476f082817252a04a47dc1777740b1731d280058c66f982812f173a294acf4944a85ba08346e2de153ba3ba41ce8a62cb64"),
- MustHexID("bd7c4f8a9e770aa915c771b15e107ca123d838762da0d3ffc53aa6b53e9cd076cffc534ec4d2e4c334c683f1f5ea72e0e123f6c261915ed5b58ac1b59f003d88"),
- MustHexID("3dd5739c73649d510456a70e9d6b46a855864a4a3f744e088fd8c8da11b18e4c9b5f2d7da50b1c147b2bae5ca9609ae01f7a3cdea9dce34f80a91d29cd82f918"),
- MustHexID("f0d7df1efc439b4bcc0b762118c1cfa99b2a6143a9f4b10e3c9465125f4c9fca4ab88a2504169bbcad65492cf2f50da9dd5d077c39574a944f94d8246529066b"),
- MustHexID("dd598b9ba441448e5fb1a6ec6c5f5aa9605bad6e223297c729b1705d11d05f6bfd3d41988b694681ae69bb03b9a08bff4beab5596503d12a39bffb5cd6e94c7c"),
- MustHexID("3fce284ac97e567aebae681b15b7a2b6df9d873945536335883e4bbc26460c064370537f323fd1ada828ea43154992d14ac0cec0940a2bd2a3f42ec156d60c83"),
- MustHexID("7c8dfa8c1311cb14fb29a8ac11bca23ecc115e56d9fcf7b7ac1db9066aa4eb39f8b1dabf46e192a65be95ebfb4e839b5ab4533fef414921825e996b210dd53bd"),
- MustHexID("cafa6934f82120456620573d7f801390ed5e16ed619613a37e409e44ab355ef755e83565a913b48a9466db786f8d4fbd590bfec474c2524d4a2608d4eafd6abd"),
- MustHexID("9d16600d0dd310d77045769fed2cb427f32db88cd57d86e49390c2ba8a9698cfa856f775be2013237226e7bf47b248871cf865d23015937d1edeb20db5e3e760"),
- MustHexID("17be6b6ba54199b1d80eff866d348ea11d8a4b341d63ad9a6681d3ef8a43853ac564d153eb2a8737f0afc9ab320f6f95c55aa11aaa13bbb1ff422fd16bdf8188"),
+ hexEncPubkey("a8593af8a4aef7b806b5197612017951bac8845a1917ca9a6a15dd6086d608505144990b245785c4cd2d67a295701c7aac2aa18823fb0033987284b019656268"),
+ hexEncPubkey("d2eebef914928c3aad77fc1b2a495f52d2294acf5edaa7d8a530b540f094b861a68fe8348a46a7c302f08ab609d85912a4968eacfea0740847b29421b4795d9e"),
+ hexEncPubkey("b14bfcb31495f32b650b63cf7d08492e3e29071fdc73cf2da0da48d4b191a70ba1a65f42ad8c343206101f00f8a48e8db4b08bf3f622c0853e7323b250835b91"),
+ hexEncPubkey("7feaee0d818c03eb30e4e0bf03ade0f3c21ca38e938a761aa1781cf70bda8cc5cd631a6cc53dd44f1d4a6d3e2dae6513c6c66ee50cb2f0e9ad6f7e319b309fd9"),
+ hexEncPubkey("4ca3b657b139311db8d583c25dd5963005e46689e1317620496cc64129c7f3e52870820e0ec7941d28809311df6db8a2867bbd4f235b4248af24d7a9c22d1232"),
+ hexEncPubkey("1181defb1d16851d42dd951d84424d6bd1479137f587fa184d5a8152be6b6b16ed08bcdb2c2ed8539bcde98c80c432875f9f724737c316a2bd385a39d3cab1d8"),
+ hexEncPubkey("d9dd818769fa0c3ec9f553c759b92476f082817252a04a47dc1777740b1731d280058c66f982812f173a294acf4944a85ba08346e2de153ba3ba41ce8a62cb64"),
+ hexEncPubkey("bd7c4f8a9e770aa915c771b15e107ca123d838762da0d3ffc53aa6b53e9cd076cffc534ec4d2e4c334c683f1f5ea72e0e123f6c261915ed5b58ac1b59f003d88"),
+ hexEncPubkey("3dd5739c73649d510456a70e9d6b46a855864a4a3f744e088fd8c8da11b18e4c9b5f2d7da50b1c147b2bae5ca9609ae01f7a3cdea9dce34f80a91d29cd82f918"),
+ hexEncPubkey("f0d7df1efc439b4bcc0b762118c1cfa99b2a6143a9f4b10e3c9465125f4c9fca4ab88a2504169bbcad65492cf2f50da9dd5d077c39574a944f94d8246529066b"),
+ hexEncPubkey("dd598b9ba441448e5fb1a6ec6c5f5aa9605bad6e223297c729b1705d11d05f6bfd3d41988b694681ae69bb03b9a08bff4beab5596503d12a39bffb5cd6e94c7c"),
+ hexEncPubkey("3fce284ac97e567aebae681b15b7a2b6df9d873945536335883e4bbc26460c064370537f323fd1ada828ea43154992d14ac0cec0940a2bd2a3f42ec156d60c83"),
+ hexEncPubkey("7c8dfa8c1311cb14fb29a8ac11bca23ecc115e56d9fcf7b7ac1db9066aa4eb39f8b1dabf46e192a65be95ebfb4e839b5ab4533fef414921825e996b210dd53bd"),
+ hexEncPubkey("cafa6934f82120456620573d7f801390ed5e16ed619613a37e409e44ab355ef755e83565a913b48a9466db786f8d4fbd590bfec474c2524d4a2608d4eafd6abd"),
+ hexEncPubkey("9d16600d0dd310d77045769fed2cb427f32db88cd57d86e49390c2ba8a9698cfa856f775be2013237226e7bf47b248871cf865d23015937d1edeb20db5e3e760"),
+ hexEncPubkey("17be6b6ba54199b1d80eff866d348ea11d8a4b341d63ad9a6681d3ef8a43853ac564d153eb2a8737f0afc9ab320f6f95c55aa11aaa13bbb1ff422fd16bdf8188"),
},
},
}
type preminedTestnet struct {
- target NodeID
- targetSha common.Hash // sha3(target)
- dists [hashBits + 1][]NodeID
+ target encPubkey
+ targetSha enode.ID // sha3(target)
+ dists [hashBits + 1][]encPubkey
}
-func (tn *preminedTestnet) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
+func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// current log distance is encoded in port number
// fmt.Println("findnode query at dist", toaddr.Port)
if toaddr.Port == 0 {
panic("query to node at distance 0")
}
- next := uint16(toaddr.Port) - 1
- var result []*Node
- for i, id := range tn.dists[toaddr.Port] {
- result = append(result, NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)))
+ next := toaddr.Port - 1
+ var result []*node
+ for i, ekey := range tn.dists[toaddr.Port] {
+ key, _ := decodePubkey(ekey)
+ node := wrapNode(enode.NewV4(key, net.ParseIP("127.0.0.1"), i, next))
+ result = append(result, node)
}
return result, nil
}
-func (*preminedTestnet) close() {}
-func (*preminedTestnet) waitping(from NodeID) error { return nil }
-func (*preminedTestnet) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil }
+func (*preminedTestnet) close() {}
+func (*preminedTestnet) waitping(from enode.ID) error { return nil }
+func (*preminedTestnet) ping(toid enode.ID, toaddr *net.UDPAddr) error { return nil }
// mine generates a testnet struct literal with nodes at
// various distances to the given target.
-func (n *preminedTestnet) mine(target NodeID) {
- n.target = target
- n.targetSha = crypto.Keccak256Hash(n.target[:])
+func (tn *preminedTestnet) mine(target encPubkey) {
+ tn.target = target
+ tn.targetSha = tn.target.id()
found := 0
for found < bucketSize*10 {
k := newkey()
- id := PubkeyID(&k.PublicKey)
- sha := crypto.Keccak256Hash(id[:])
- ld := logdist(n.targetSha, sha)
- if len(n.dists[ld]) < bucketSize {
- n.dists[ld] = append(n.dists[ld], id)
+ key := encodePubkey(&k.PublicKey)
+ ld := enode.LogDist(tn.targetSha, key.id())
+ if len(tn.dists[ld]) < bucketSize {
+ tn.dists[ld] = append(tn.dists[ld], key)
fmt.Println("found ID with ld", ld)
found++
}
}
fmt.Println("&preminedTestnet{")
- fmt.Printf(" target: %#v,\n", n.target)
- fmt.Printf(" targetSha: %#v,\n", n.targetSha)
- fmt.Printf(" dists: [%d][]NodeID{\n", len(n.dists))
- for ld, ns := range n.dists {
+ fmt.Printf(" target: %#v,\n", tn.target)
+ fmt.Printf(" targetSha: %#v,\n", tn.targetSha)
+ fmt.Printf(" dists: [%d][]encPubkey{\n", len(tn.dists))
+ for ld, ns := range tn.dists {
if len(ns) == 0 {
continue
}
- fmt.Printf(" %d: []NodeID{\n", ld)
+ fmt.Printf(" %d: []encPubkey{\n", ld)
for _, n := range ns {
- fmt.Printf(" MustHexID(\"%x\"),\n", n[:])
+ fmt.Printf(" hexEncPubkey(\"%x\"),\n", n[:])
}
fmt.Println(" },")
}
@@ -616,40 +564,6 @@ func (n *preminedTestnet) mine(target NodeID) {
fmt.Println("}")
}
-func hasDuplicates(slice []*Node) bool {
- seen := make(map[NodeID]bool)
- for i, e := range slice {
- if e == nil {
- panic(fmt.Sprintf("nil *Node at %d", i))
- }
- if seen[e.ID] {
- return true
- }
- seen[e.ID] = true
- }
- return false
-}
-
-func sortedByDistanceTo(distbase common.Hash, slice []*Node) bool {
- var last common.Hash
- for i, e := range slice {
- if i > 0 && distcmp(distbase, e.sha, last) < 0 {
- return false
- }
- last = e.sha
- }
- return true
-}
-
-func contains(ns []*Node, id NodeID) bool {
- for _, n := range ns {
- if n.ID == id {
- return true
- }
- }
- return false
-}
-
// gen wraps quick.Value so it's easier to use.
// it generates a random value of the given value's type.
func gen(typ interface{}, rand *rand.Rand) interface{} {
@@ -660,6 +574,13 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface()
}
+func quickcfg() *quick.Config {
+ return &quick.Config{
+ MaxCount: 5000,
+ Rand: rand.New(rand.NewSource(time.Now().Unix())),
+ }
+}
+
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go
new file mode 100644
index 000000000..fe55eb562
--- /dev/null
+++ b/p2p/discover/table_util_test.go
@@ -0,0 +1,167 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package discover
+
+import (
+ "crypto/ecdsa"
+ "encoding/hex"
+ "fmt"
+ "math/rand"
+ "net"
+ "sync"
+
+ "github.com/tomochain/tomochain/p2p/enode"
+ "github.com/tomochain/tomochain/p2p/enr"
+)
+
+func newTestTable(t transport) (*Table, *enode.DB) {
+ var r enr.Record
+ r.Set(enr.IP{0, 0, 0, 0})
+ n := enode.SignNull(&r, enode.ID{})
+ db, _ := enode.OpenDB("")
+ tab, _ := newTable(t, n, db, nil)
+ return tab, db
+}
+
+// nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld.
+func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node {
+ var r enr.Record
+ r.Set(enr.IP(ip))
+ return wrapNode(enode.SignNull(&r, idAtDistance(base, ld)))
+}
+
+// idAtDistance returns a random hash such that enode.LogDist(a, b) == n
+func idAtDistance(a enode.ID, n int) (b enode.ID) {
+ if n == 0 {
+ return a
+ }
+ // flip bit at position n, fill the rest with random bits
+ b = a
+ pos := len(a) - n/8 - 1
+ bit := byte(0x01) << (byte(n%8) - 1)
+ if bit == 0 {
+ pos++
+ bit = 0x80
+ }
+ b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
+ for i := pos + 1; i < len(a); i++ {
+ b[i] = byte(rand.Intn(255))
+ }
+ return b
+}
+
+func intIP(i int) net.IP {
+ return net.IP{byte(i), 0, 2, byte(i)}
+}
+
+// fillBucket inserts nodes into the given bucket until it is full.
+func fillBucket(tab *Table, n *node) (last *node) {
+ ld := enode.LogDist(tab.self.ID(), n.ID())
+ b := tab.bucket(n.ID())
+ for len(b.entries) < bucketSize {
+ b.entries = append(b.entries, nodeAtDistance(tab.self.ID(), ld, intIP(ld)))
+ }
+ return b.entries[bucketSize-1]
+}
+
+type pingRecorder struct {
+ mu sync.Mutex
+ dead, pinged map[enode.ID]bool
+}
+
+func newPingRecorder() *pingRecorder {
+ return &pingRecorder{
+ dead: make(map[enode.ID]bool),
+ pinged: make(map[enode.ID]bool),
+ }
+}
+
+func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
+ return nil, nil
+}
+
+func (t *pingRecorder) waitping(from enode.ID) error {
+ return nil // remote always pings
+}
+
+func (t *pingRecorder) ping(toid enode.ID, toaddr *net.UDPAddr) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.pinged[toid] = true
+ if t.dead[toid] {
+ return errTimeout
+ } else {
+ return nil
+ }
+}
+
+func (t *pingRecorder) close() {}
+
+func hasDuplicates(slice []*node) bool {
+ seen := make(map[enode.ID]bool)
+ for i, e := range slice {
+ if e == nil {
+ panic(fmt.Sprintf("nil *Node at %d", i))
+ }
+ if seen[e.ID()] {
+ return true
+ }
+ seen[e.ID()] = true
+ }
+ return false
+}
+
+func contains(ns []*node, id enode.ID) bool {
+ for _, n := range ns {
+ if n.ID() == id {
+ return true
+ }
+ }
+ return false
+}
+
+func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
+ var last enode.ID
+ for i, e := range slice {
+ if i > 0 && enode.DistCmp(distbase, e.ID(), last) < 0 {
+ return false
+ }
+ last = e.ID()
+ }
+ return true
+}
+
+func hexEncPubkey(h string) (ret encPubkey) {
+ b, err := hex.DecodeString(h)
+ if err != nil {
+ panic(err)
+ }
+ if len(b) != len(ret) {
+ panic("invalid length")
+ }
+ copy(ret[:], b)
+ return ret
+}
+
+func hexPubkey(h string) *ecdsa.PublicKey {
+ k, err := decodePubkey(hexEncPubkey(h))
+ if err != nil {
+ panic(err)
+ }
+ return k
+}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 051477cb5..3b73e2939 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -27,13 +27,12 @@ import (
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
"github.com/tomochain/tomochain/p2p/netutil"
"github.com/tomochain/tomochain/rlp"
)
-const Version = 4
-
// Errors
var (
errPacketTooSmall = errors.New("too small")
@@ -48,9 +47,9 @@ var (
// Timeouts
const (
- respTimeout = 500 * time.Millisecond
- sendTimeout = 500 * time.Millisecond
- expiration = 20 * time.Second
+ respTimeout = 500 * time.Millisecond
+ expiration = 20 * time.Second
+ bondExpiration = 24 * time.Hour
ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP
ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
@@ -63,7 +62,6 @@ const (
pongPacket
findnodePacket
neighborsPacket
- pingTomo
)
// RPC request structures
@@ -91,7 +89,7 @@ type (
// findnode is a query for nodes close to the given target.
findnode struct {
- Target NodeID // doesn't need to be an actual public key
+ Target encPubkey
Expiration uint64
// Ignore additional fields (for forward compatibility).
Rest []rlp.RawValue `rlp:"tail"`
@@ -109,7 +107,7 @@ type (
IP net.IP // len 4 for IPv4 or 16 for IPv6
UDP uint16 // for discovery protocol
TCP uint16 // for RLPx protocol
- ID NodeID
+ ID encPubkey
}
rpcEndpoint struct {
@@ -127,7 +125,7 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
}
-func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
+func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*node, error) {
if rn.UDP <= 1024 {
return nil, errors.New("low port")
}
@@ -137,17 +135,26 @@ func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
return nil, errors.New("not contained in netrestrict whitelist")
}
- n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
- err := n.validateComplete()
+ key, err := decodePubkey(rn.ID)
+ if err != nil {
+ return nil, err
+ }
+ n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP)))
+ err = n.ValidateComplete()
return n, err
}
-func nodeToRPC(n *Node) rpcNode {
- return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
+func nodeToRPC(n *node) rpcNode {
+ var key ecdsa.PublicKey
+ var ekey encPubkey
+ if err := n.Load((*enode.Secp256k1)(&key)); err == nil {
+ ekey = encodePubkey(&key)
+ }
+ return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())}
}
type packet interface {
- handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+ handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error
name() string
}
@@ -185,7 +192,7 @@ type udp struct {
// to all the callback functions for that node.
type pending struct {
// these fields must match in the reply.
- from NodeID
+ from enode.ID
ptype byte
// time when the request must complete
@@ -203,7 +210,7 @@ type pending struct {
}
type reply struct {
- from NodeID
+ from enode.ID
ptype byte
data interface{}
// loop indicates whether there was
@@ -226,7 +233,7 @@ type Config struct {
AnnounceAddr *net.UDPAddr // local address announced in the DHT
NodeDBPath string // if set, the node database is stored at this filesystem location
NetRestrict *netutil.Netlist // network whitelist
- Bootnodes []*Node // list of bootstrap nodes
+ Bootnodes []*enode.Node // list of bootstrap nodes
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
}
@@ -241,6 +248,16 @@ func ListenUDP(c conn, cfg Config) (*Table, error) {
}
func newUDP(c conn, cfg Config) (*Table, *udp, error) {
+ realaddr := c.LocalAddr().(*net.UDPAddr)
+ if cfg.AnnounceAddr != nil {
+ realaddr = cfg.AnnounceAddr
+ }
+ self := enode.NewV4(&cfg.PrivateKey.PublicKey, realaddr.IP, realaddr.Port, realaddr.Port)
+ db, err := enode.OpenDB(cfg.NodeDBPath)
+ if err != nil {
+ return nil, nil, err
+ }
+
udp := &udp{
conn: c,
priv: cfg.PrivateKey,
@@ -249,13 +266,9 @@ func newUDP(c conn, cfg Config) (*Table, *udp, error) {
gotreply: make(chan reply),
addpending: make(chan *pending),
}
- realaddr := c.LocalAddr().(*net.UDPAddr)
- if cfg.AnnounceAddr != nil {
- realaddr = cfg.AnnounceAddr
- }
// TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
- tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
+ tab, err := newTable(udp, self, db, cfg.Bootnodes)
if err != nil {
return nil, nil, err
}
@@ -269,36 +282,56 @@ func newUDP(c conn, cfg Config) (*Table, *udp, error) {
func (t *udp) close() {
close(t.closing)
t.conn.Close()
+ t.db.Close()
// TODO: wait for the loops to end.
}
// ping sends a ping message to the given node and waits for a reply.
-func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
+func (t *udp) ping(toid enode.ID, toaddr *net.UDPAddr) error {
+ return <-t.sendPing(toid, toaddr, nil)
+}
+
+// sendPing sends a ping message to the given node and invokes the callback
+// when the reply arrives.
+func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error {
req := &ping{
- Version: Version,
+ Version: 4,
From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
}
- packet, hash, err := encodePacket(t.priv, pingTomo, req)
+ packet, hash, err := encodePacket(t.priv, pingPacket, req)
if err != nil {
- return err
+ errc := make(chan error, 1)
+ errc <- err
+ return errc
}
errc := t.pending(toid, pongPacket, func(p interface{}) bool {
- return bytes.Equal(p.(*pong).ReplyTok, hash)
+ ok := bytes.Equal(p.(*pong).ReplyTok, hash)
+ if ok && callback != nil {
+ callback()
+ }
+ return ok
})
t.write(toaddr, req.name(), packet)
- return <-errc
+ return errc
}
-func (t *udp) waitping(from NodeID) error {
- return <-t.pending(from, pingTomo, func(interface{}) bool { return true })
+func (t *udp) waitping(from enode.ID) error {
+ return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
}
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
-func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
- nodes := make([]*Node, 0, bucketSize)
+func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
+ // If we haven't seen a ping from the destination node for a while, it won't remember
+ // our endpoint proof and reject findnode. Solicit a ping first.
+ if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration {
+ t.ping(toid, toaddr)
+ t.waitping(toid)
+ }
+
+ nodes := make([]*node, 0, bucketSize)
nreceived := 0
errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
reply := r.(*neighbors)
@@ -317,13 +350,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- err := <-errc
- return nodes, err
+ return nodes, <-errc
}
// pending adds a reply callback to the pending reply queue.
// see the documentation of type pending for a detailed explanation.
-func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
+func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error {
ch := make(chan error, 1)
p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
select {
@@ -335,7 +367,7 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
return ch
}
-func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
+func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
matched := make(chan bool, 1)
select {
case t.gotreply <- reply{from, ptype, req, matched}:
@@ -549,22 +581,23 @@ func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
return err
}
-func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
+func decodePacket(buf []byte) (packet, encPubkey, []byte, error) {
if len(buf) < headSize+1 {
- return nil, NodeID{}, nil, errPacketTooSmall
+ return nil, encPubkey{}, nil, errPacketTooSmall
}
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
shouldhash := crypto.Keccak256(buf[macSize:])
if !bytes.Equal(hash, shouldhash) {
- return nil, NodeID{}, nil, errBadHash
+ return nil, encPubkey{}, nil, errBadHash
}
- fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
+ fromKey, err := recoverNodeKey(crypto.Keccak256(buf[headSize:]), sig)
if err != nil {
- return nil, NodeID{}, hash, err
+ return nil, fromKey, hash, err
}
+
var req packet
switch ptype := sigdata[0]; ptype {
- case pingTomo:
+ case pingPacket:
req = new(ping)
case pongPacket:
req = new(pong)
@@ -573,68 +606,78 @@ func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
case neighborsPacket:
req = new(neighbors)
default:
- return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
+ return nil, fromKey, hash, fmt.Errorf("unknown type: %d", ptype)
}
s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0)
err = s.Decode(req)
- return req, fromID, hash, err
+ return req, fromKey, hash, err
}
-func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
+ key, err := decodePubkey(fromKey)
+ if err != nil {
+ return fmt.Errorf("invalid public key: %v", err)
+ }
t.send(from, pongPacket, &pong{
To: makeEndpoint(from, req.From.TCP),
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- if !t.handleReply(fromID, pingTomo, req) {
- // Note: we're ignoring the provided IP address right now
- go t.bond(true, fromID, from, req.From.TCP)
- }
+ n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
+ t.handleReply(n.ID(), pingPacket, req)
+ if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
+ t.sendPing(n.ID(), from, func() { t.addThroughPing(n) })
+ } else {
+ t.addThroughPing(n)
+ }
+ t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now())
return nil
}
-func (req *ping) name() string { return "PING TOMO/v4" }
+func (req *ping) name() string { return "PING/v4" }
-func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
+ fromID := fromKey.id()
if !t.handleReply(fromID, pongPacket, req) {
return errUnsolicitedReply
}
+ t.db.UpdateLastPongReceived(fromID, from.IP, time.Now())
return nil
}
func (req *pong) name() string { return "PONG/v4" }
-func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
- if !t.db.hasBond(fromID) {
- // No bond exists, we don't process the packet. This prevents
- // an attack vector where the discovery protocol could be used
- // to amplify traffic in a DDOS attack. A malicious actor
- // would send a findnode request with the IP address and UDP
- // port of the target as the source address. The recipient of
- // the findnode packet would then send a neighbors packet
- // (which is a much bigger packet than findnode) to the victim.
+ fromID := fromKey.id()
+ if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration {
+ // No endpoint proof pong exists, we don't process the packet. This prevents an
+ // attack vector where the discovery protocol could be used to amplify traffic in a
+ // DDOS attack. A malicious actor would send a findnode request with the IP address
+ // and UDP port of the target as the source address. The recipient of the findnode
+ // packet would then send a neighbors packet (which is a much bigger packet than
+ // findnode) to the victim.
return errUnknownNode
}
- target := crypto.Keccak256Hash(req.Target[:])
+ target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
t.mutex.Lock()
closest := t.closest(target, bucketSize).entries
t.mutex.Unlock()
- log.Trace("find neighbors ", "from", from, "fromID", fromID, "closest", len(closest))
+
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
for _, n := range closest {
- if netutil.CheckRelayIP(from.IP, n.IP) == nil {
+ if netutil.CheckRelayIP(from.IP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
if len(p.Nodes) == maxNeighbors {
@@ -651,11 +694,11 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
func (req *findnode) name() string { return "FINDNODE/v4" }
-func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
- if !t.handleReply(fromID, neighborsPacket, req) {
+ if !t.handleReply(fromKey.id(), neighborsPacket, req) {
return errUnsolicitedReply
}
return nil
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index b13a79658..e5fb32d08 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -36,6 +36,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rlp"
)
@@ -46,7 +47,7 @@ func init() {
// shared test variables
var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
- testTarget = NodeID{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}
+ testTarget = encPubkey{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}
testRemote = rpcEndpoint{IP: net.ParseIP("1.1.1.1").To4(), UDP: 1, TCP: 2}
testLocalAnnounced = rpcEndpoint{IP: net.ParseIP("2.2.2.2").To4(), UDP: 3, TCP: 4}
testLocal = rpcEndpoint{IP: net.ParseIP("3.3.3.3").To4(), UDP: 5, TCP: 6}
@@ -124,7 +125,7 @@ func TestUDP_packetErrors(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
- test.packetIn(errExpired, pingTomo, &ping{From: testRemote, To: testLocalAnnounced, Version: Version})
+ test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
@@ -136,7 +137,7 @@ func TestUDP_pingTimeout(t *testing.T) {
defer test.table.Close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
- toid := NodeID{1, 2, 3, 4}
+ toid := enode.ID{1, 2, 3, 4}
if err := test.udp.ping(toid, toaddr); err != errTimeout {
t.Error("expected timeout error, got", err)
}
@@ -220,8 +221,8 @@ func TestUDP_findnodeTimeout(t *testing.T) {
defer test.table.Close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
- toid := NodeID{1, 2, 3, 4}
- target := NodeID{4, 5, 6, 7}
+ toid := enode.ID{1, 2, 3, 4}
+ target := encPubkey{4, 5, 6, 7}
result, err := test.udp.findnode(toid, toaddr, target)
if err != errTimeout {
t.Error("expected timeout error, got", err)
@@ -232,35 +233,36 @@ func TestUDP_findnodeTimeout(t *testing.T) {
}
func TestUDP_findnode(t *testing.T) {
- bucketSizeTest := 16
test := newUDPTest(t)
defer test.table.Close()
// put a few nodes into the table. their exact
// distribution shouldn't matter much, although we need to
// take care not to overflow any bucket.
- targetHash := crypto.Keccak256Hash(testTarget[:])
- nodes := &nodesByDistance{target: targetHash}
- for i := 0; i < bucketSizeTest; i++ {
- nodes.push(nodeAtDistance(test.table.self.sha, i+2), bucketSizeTest)
+ nodes := &nodesByDistance{target: testTarget.id()}
+ for i := 0; i < bucketSize; i++ {
+ key := newkey()
+ n := wrapNode(enode.NewV4(&key.PublicKey, net.IP{10, 13, 0, 1}, 0, i))
+ nodes.push(n, bucketSize)
}
test.table.stuff(nodes.entries)
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
- test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now())
+ remoteID := encodePubkey(&test.remotekey.PublicKey).id()
+ test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now())
// check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
- expected := test.table.closest(targetHash, bucketSizeTest)
+ expected := test.table.closest(testTarget.id(), bucketSize)
- waitNeighbors := func(want []*Node) {
+ waitNeighbors := func(want []*node) {
test.waitPacketOut(func(p *neighbors) {
if len(p.Nodes) != len(want) {
- t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSizeTest)
+ t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
for i := range p.Nodes {
- if p.Nodes[i].ID != want[i].ID {
+ if p.Nodes[i].ID.id() != want[i].ID() {
t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
}
}
@@ -274,10 +276,13 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
+ rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
+ test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now())
+
// queue a pending findnode request
- resultc, errc := make(chan []*Node), make(chan error)
+ resultc, errc := make(chan []*node), make(chan error)
go func() {
- rid := PubkeyID(&test.remotekey.PublicKey)
+ rid := encodePubkey(&test.remotekey.PublicKey).id()
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
if err != nil && len(ns) == 0 {
errc <- err
@@ -295,11 +300,11 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
})
// send the reply as two packets.
- list := []*Node{
- MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304"),
- MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
- MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17"),
- MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
+ list := []*node{
+ wrapNode(enode.MustParseV4("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304")),
+ wrapNode(enode.MustParseV4("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303")),
+ wrapNode(enode.MustParseV4("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17")),
+ wrapNode(enode.MustParseV4("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303")),
}
rpclist := make([]rpcNode, len(list))
for i := range list {
@@ -324,12 +329,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
func TestUDP_successfulPing(t *testing.T) {
test := newUDPTest(t)
- added := make(chan *Node, 1)
- test.table.nodeAddedHook = func(n *Node) { added <- n }
+ added := make(chan *node, 1)
+ test.table.nodeAddedHook = func(n *node) { added <- n }
defer test.table.Close()
// The remote side sends a ping packet to initiate the exchange.
- go test.packetIn(nil, pingTomo, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp})
+ go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
// the ping is replied to.
test.waitPacketOut(func(p *pong) {
@@ -369,18 +374,18 @@ func TestUDP_successfulPing(t *testing.T) {
// pong packet.
select {
case n := <-added:
- rid := PubkeyID(&test.remotekey.PublicKey)
- if n.ID != rid {
- t.Errorf("node has wrong ID: got %v, want %v", n.ID, rid)
+ rid := encodePubkey(&test.remotekey.PublicKey).id()
+ if n.ID() != rid {
+ t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid)
}
- if !n.IP.Equal(test.remoteaddr.IP) {
- t.Errorf("node has wrong IP: got %v, want: %v", n.IP, test.remoteaddr.IP)
+ if !n.IP().Equal(test.remoteaddr.IP) {
+ t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.IP)
}
- if int(n.UDP) != test.remoteaddr.Port {
- t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP, test.remoteaddr.Port)
+ if int(n.UDP()) != test.remoteaddr.Port {
+ t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port)
}
- if n.TCP != testRemote.TCP {
- t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP, testRemote.TCP)
+ if n.TCP() != int(testRemote.TCP) {
+ t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP(), testRemote.TCP)
}
case <-time.After(2 * time.Second):
t.Errorf("node was not added within 2 seconds")
@@ -392,7 +397,7 @@ var testPackets = []struct {
wantPacket interface{}
}{
{
- input: "95a4d7d1909e6a58f115e9a451d47a8f016776a8874140366e702e33e85c7b4cd58a82ebece6acd0973342b66b9e716fece46b5c67a3560fc8624063dd15a310469de42ca599474b9d8cb6eb8dc41b0d5236539ea7ae10ef3c630cd94faefd800005ea04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a355",
+ input: "71dbda3a79554728d4f94411e42ee1f8b0d561c10e1e5f5893367948c6a7d70bb87b235fa28a77070271b6c164a2dce8c7e13a5739b53b5e96f2e5acb0e458a02902f5965d55ecbeb2ebb6cabb8b2b232896a36b737666c55265ad0a68412f250001ea04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a355",
wantPacket: &ping{
Version: 4,
From: rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544},
@@ -402,7 +407,7 @@ var testPackets = []struct {
},
},
{
- input: "57b1c182cc24e21e9297baa70d57a67ade498439123c968ffc048541addf9d463d1d25d10cf473a7f90a3efd6a070818097ebeaef58cd53843cb3af28acaee354272cfe7801b7fa7dbd8aa13309b6059fce877ad376c8dad7524dc34de626bd80105ec04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a3550102",
+ input: "e9614ccfd9fc3e74360018522d30e1419a143407ffcce748de3e22116b7e8dc92ff74788c0b6663aaa3d67d641936511c8f8d6ad8698b820a7cf9e1be7155e9a241f556658c55428ec0563514365799a4be2be5a685a80971ddcfa80cb422cdd0101ec04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a3550102",
wantPacket: &ping{
Version: 4,
From: rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544},
@@ -412,7 +417,7 @@ var testPackets = []struct {
},
},
{
- input: "e3e987421accd2c75967d4a7229c436c18760def054738d8d9669697ee4726cdc9949c51df3e90d795d33d3f57d508c4687913338f6eb9caa89873aaae9dd49a5473ade5ea452c4df9d1f842eadf03439dbc373c0de8b20b412b6760d7b479140105f83e82022bd79020010db83c4d001500000000abcdef12820cfa8215a8d79020010db885a308d313198a2e037073488208ae82823a8443b9a355c50102030405",
+ input: "577be4349c4dd26768081f58de4c6f375a7a22f3f7adda654d1428637412c3d7fe917cadc56d4e5e7ffae1dbe3efffb9849feb71b262de37977e7c7a44e677295680e9e38ab26bee2fcbae207fba3ff3d74069a50b902a82c9903ed37cc993c50001f83e82022bd79020010db83c4d001500000000abcdef12820cfa8215a8d79020010db885a308d313198a2e037073488208ae82823a8443b9a355c5010203040531b9019afde696e582a78fa8d95ea13ce3297d4afb8ba6433e4154caa5ac6431af1b80ba76023fa4090c408f6b4bc3701562c031041d4702971d102c9ab7fa5eed4cd6bab8f7af956f7d565ee1917084a95398b6a21eac920fe3dd1345ec0a7ef39367ee69ddf092cbfe5b93e5e568ebc491983c09c76d922dc3",
wantPacket: &ping{
Version: 555,
From: rpcEndpoint{net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), 3322, 5544},
@@ -433,7 +438,7 @@ var testPackets = []struct {
{
input: "c7c44041b9f7c7e41934417ebac9a8e1a4c6298f74553f2fcfdcae6ed6fe53163eb3d2b52e39fe91831b8a927bf4fc222c3902202027e5e9eb812195f95d20061ef5cd31d502e47ecb61183f74a504fe04c51e73df81f25c4d506b26db4517490103f84eb840ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f8443b9a35582999983999999280dc62cc8255c73471e0a61da0c89acdc0e035e260add7fc0c04ad9ebf3919644c91cb247affc82b69bd2ca235c71eab8e49737c937a2c396",
wantPacket: &findnode{
- Target: MustHexID("ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f"),
+ Target: hexEncPubkey("ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f"),
Expiration: 1136239445,
Rest: []rlp.RawValue{{0x82, 0x99, 0x99}, {0x83, 0x99, 0x99, 0x99}},
},
@@ -443,25 +448,25 @@ var testPackets = []struct {
wantPacket: &neighbors{
Nodes: []rpcNode{
{
- ID: MustHexID("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"),
+ ID: hexEncPubkey("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"),
IP: net.ParseIP("99.33.22.55").To4(),
UDP: 4444,
TCP: 4445,
},
{
- ID: MustHexID("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"),
+ ID: hexEncPubkey("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"),
IP: net.ParseIP("1.2.3.4").To4(),
UDP: 1,
TCP: 1,
},
{
- ID: MustHexID("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"),
+ ID: hexEncPubkey("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"),
IP: net.ParseIP("2001:db8:3c4d:15::abcd:ef12"),
UDP: 3333,
TCP: 3333,
},
{
- ID: MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"),
+ ID: hexEncPubkey("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"),
IP: net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
UDP: 999,
TCP: 1000,
@@ -475,13 +480,14 @@ var testPackets = []struct {
func TestForwardCompatibility(t *testing.T) {
testkey, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
- wantNodeID := PubkeyID(&testkey.PublicKey)
+ wantNodeKey := encodePubkey(&testkey.PublicKey)
+
for _, test := range testPackets {
input, err := hex.DecodeString(test.input)
if err != nil {
t.Fatalf("invalid hex: %s", test.input)
}
- packet, nodeid, _, err := decodePacket(input)
+ packet, nodekey, _, err := decodePacket(input)
if err != nil {
t.Errorf("did not accept packet %s\n%v", test.input, err)
continue
@@ -489,8 +495,8 @@ func TestForwardCompatibility(t *testing.T) {
if !reflect.DeepEqual(packet, test.wantPacket) {
t.Errorf("got %s\nwant %s", spew.Sdump(packet), spew.Sdump(test.wantPacket))
}
- if nodeid != wantNodeID {
- t.Errorf("got id %v\nwant id %v", nodeid, wantNodeID)
+ if nodekey != wantNodeKey {
+ t.Errorf("got id %v\nwant id %v", nodekey, wantNodeKey)
}
}
}
diff --git a/p2p/discv5/node_test.go b/p2p/discv5/node_test.go
index a28f29825..d0fa6880a 100644
--- a/p2p/discv5/node_test.go
+++ b/p2p/discv5/node_test.go
@@ -141,7 +141,7 @@ var parseNodeTests = []struct {
{
// This test checks that errors from url.Parse are handled.
rawurl: "://foo",
- wantError: `parse ://foo: missing protocol scheme`,
+ wantError: `parse "://foo": missing protocol scheme`,
},
}
diff --git a/p2p/enode/idscheme.go b/p2p/enode/idscheme.go
new file mode 100644
index 000000000..87981db5c
--- /dev/null
+++ b/p2p/enode/idscheme.go
@@ -0,0 +1,161 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "crypto/ecdsa"
+ "fmt"
+ "io"
+
+ "github.com/tomochain/tomochain/common/math"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/p2p/enr"
+ "github.com/tomochain/tomochain/rlp"
+ "golang.org/x/crypto/sha3"
+)
+
+// ValidSchemes is a List of known secure identity schemes.
+var ValidSchemes = enr.SchemeMap{
+ "v4": V4ID{},
+}
+
+// ValidSchemesForTesting is a List of identity schemes for testing.
+var ValidSchemesForTesting = enr.SchemeMap{
+ "v4": V4ID{},
+ "null": NullID{},
+}
+
+// V4ID is the "v4" identity scheme.
+type V4ID struct{}
+
+// SignV4 signs a record using the v4 scheme.
+func SignV4(r *enr.Record, privkey *ecdsa.PrivateKey) error {
+ // Copy r to avoid modifying it if signing fails.
+ cpy := *r
+ cpy.Set(enr.ID("v4"))
+ cpy.Set(Secp256k1(privkey.PublicKey))
+
+ h := sha3.NewLegacyKeccak256()
+ rlp.Encode(h, cpy.AppendElements(nil))
+ sig, err := crypto.Sign(h.Sum(nil), privkey)
+ if err != nil {
+ return err
+ }
+ sig = sig[:len(sig)-1] // remove v
+ if err = cpy.SetSig(V4ID{}, sig); err == nil {
+ *r = cpy
+ }
+ return err
+}
+
+func (V4ID) Verify(r *enr.Record, sig []byte) error {
+ var entry s256raw
+ if err := r.Load(&entry); err != nil {
+ return err
+ } else if len(entry) != 33 {
+ return fmt.Errorf("invalid public key")
+ }
+
+ h := sha3.NewLegacyKeccak256()
+ rlp.Encode(h, r.AppendElements(nil))
+ if !crypto.VerifySignature(entry, h.Sum(nil), sig) {
+ return enr.ErrInvalidSig
+ }
+ return nil
+}
+
+func (V4ID) NodeAddr(r *enr.Record) []byte {
+ var pubkey Secp256k1
+ err := r.Load(&pubkey)
+ if err != nil {
+ return nil
+ }
+ buf := make([]byte, 64)
+ math.ReadBits(pubkey.X, buf[:32])
+ math.ReadBits(pubkey.Y, buf[32:])
+ return crypto.Keccak256(buf)
+}
+
+// Secp256k1 is the "secp256k1" key, which holds a public key.
+type Secp256k1 ecdsa.PublicKey
+
+func (v Secp256k1) ENRKey() string { return "secp256k1" }
+
+// EncodeRLP implements rlp.Encoder.
+func (v Secp256k1) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, crypto.CompressPubkey((*ecdsa.PublicKey)(&v)))
+}
+
+// DecodeRLP implements rlp.Decoder.
+func (v *Secp256k1) DecodeRLP(s *rlp.Stream) error {
+ buf, err := s.Bytes()
+ if err != nil {
+ return err
+ }
+ pk, err := crypto.DecompressPubkey(buf)
+ if err != nil {
+ return err
+ }
+ *v = (Secp256k1)(*pk)
+ return nil
+}
+
+// s256raw is an unparsed secp256k1 public key entry.
+type s256raw []byte
+
+func (s256raw) ENRKey() string { return "secp256k1" }
+
+// v4CompatID is a weaker and insecure version of the "v4" scheme which only checks for the
+// presence of a secp256k1 public key, but doesn't verify the signature.
+type v4CompatID struct {
+ V4ID
+}
+
+func (v4CompatID) Verify(r *enr.Record, sig []byte) error {
+ var pubkey Secp256k1
+ return r.Load(&pubkey)
+}
+
+func signV4Compat(r *enr.Record, pubkey *ecdsa.PublicKey) {
+ r.Set((*Secp256k1)(pubkey))
+ if err := r.SetSig(v4CompatID{}, []byte{}); err != nil {
+ panic(err)
+ }
+}
+
+// NullID is the "null" ENR identity scheme. This scheme stores the node
+// ID in the record without any signature.
+type NullID struct{}
+
+func (NullID) Verify(r *enr.Record, sig []byte) error {
+ return nil
+}
+
+func (NullID) NodeAddr(r *enr.Record) []byte {
+ var id ID
+ r.Load(enr.WithEntry("nulladdr", &id))
+ return id[:]
+}
+
+func SignNull(r *enr.Record, id ID) *Node {
+ r.Set(enr.ID("null"))
+ r.Set(enr.WithEntry("nulladdr", id))
+ if err := r.SetSig(NullID{}, []byte{}); err != nil {
+ panic(err)
+ }
+ return &Node{r: *r, id: id}
+}
diff --git a/p2p/enode/idscheme_test.go b/p2p/enode/idscheme_test.go
new file mode 100644
index 000000000..8d7440f47
--- /dev/null
+++ b/p2p/enode/idscheme_test.go
@@ -0,0 +1,74 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "encoding/hex"
+ "math/big"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/p2p/enr"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+var (
+ privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+ pubkey = &privkey.PublicKey
+)
+
+func TestEmptyNodeID(t *testing.T) {
+ var r enr.Record
+ if addr := ValidSchemes.NodeAddr(&r); addr != nil {
+ t.Errorf("wrong address on empty record: got %v, want %v", addr, nil)
+ }
+
+ require.NoError(t, SignV4(&r, privkey))
+ expected := "a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7"
+ assert.Equal(t, expected, hex.EncodeToString(ValidSchemes.NodeAddr(&r)))
+}
+
+// Checks that failure to sign leaves the record unmodified.
+func TestSignError(t *testing.T) {
+ invalidKey := &ecdsa.PrivateKey{D: new(big.Int), PublicKey: *pubkey}
+
+ var r enr.Record
+ emptyEnc, _ := rlp.EncodeToBytes(&r)
+ if err := SignV4(&r, invalidKey); err == nil {
+ t.Fatal("expected error from SignV4")
+ }
+ newEnc, _ := rlp.EncodeToBytes(&r)
+ if !bytes.Equal(newEnc, emptyEnc) {
+ t.Fatal("record modified even though signing failed")
+ }
+}
+
+// TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key.
+func TestGetSetSecp256k1(t *testing.T) {
+ var r enr.Record
+ if err := SignV4(&r, privkey); err != nil {
+ t.Fatal(err)
+ }
+
+ var pk Secp256k1
+ require.NoError(t, r.Load(&pk))
+ assert.EqualValues(t, pubkey, &pk)
+}
diff --git a/p2p/enode/iter.go b/p2p/enode/iter.go
new file mode 100644
index 000000000..b8ab4a758
--- /dev/null
+++ b/p2p/enode/iter.go
@@ -0,0 +1,295 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "sync"
+ "time"
+)
+
+// Iterator represents a sequence of nodes. The Next method moves to the next node in the
+// sequence. It returns false when the sequence has ended or the iterator is closed. Close
+// may be called concurrently with Next and Node, and interrupts Next if it is blocked.
+type Iterator interface {
+ Next() bool // moves to next node
+ Node() *Node // returns current node
+ Close() // ends the iterator
+}
+
+// ReadNodes reads at most n nodes from the given iterator. The return value contains no
+// duplicates and no nil values. To prevent looping indefinitely for small repeating node
+// sequences, this function calls Next at most n times.
+func ReadNodes(it Iterator, n int) []*Node {
+ seen := make(map[ID]*Node, n)
+ for i := 0; i < n && it.Next(); i++ {
+ // Remove duplicates, keeping the node with higher seq.
+ node := it.Node()
+ prevNode, ok := seen[node.ID()]
+ if ok && prevNode.Seq() > node.Seq() {
+ continue
+ }
+ seen[node.ID()] = node
+ }
+ result := make([]*Node, 0, len(seen))
+ for _, node := range seen {
+ result = append(result, node)
+ }
+ return result
+}
+
+// IterNodes makes an iterator which runs through the given nodes once.
+func IterNodes(nodes []*Node) Iterator {
+ return &sliceIter{nodes: nodes, index: -1}
+}
+
+// CycleNodes makes an iterator which cycles through the given nodes indefinitely.
+func CycleNodes(nodes []*Node) Iterator {
+ return &sliceIter{nodes: nodes, index: -1, cycle: true}
+}
+
+type sliceIter struct {
+ mu sync.Mutex
+ nodes []*Node
+ index int
+ cycle bool
+}
+
+func (it *sliceIter) Next() bool {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ if len(it.nodes) == 0 {
+ return false
+ }
+ it.index++
+ if it.index == len(it.nodes) {
+ if it.cycle {
+ it.index = 0
+ } else {
+ it.nodes = nil
+ return false
+ }
+ }
+ return true
+}
+
+func (it *sliceIter) Node() *Node {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ if len(it.nodes) == 0 {
+ return nil
+ }
+ return it.nodes[it.index]
+}
+
+func (it *sliceIter) Close() {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ it.nodes = nil
+}
+
+// Filter wraps an iterator such that Next only returns nodes for which
+// the 'check' function returns true.
+func Filter(it Iterator, check func(*Node) bool) Iterator {
+ return &filterIter{it, check}
+}
+
+type filterIter struct {
+ Iterator
+ check func(*Node) bool
+}
+
+func (f *filterIter) Next() bool {
+ for f.Iterator.Next() {
+ if f.check(f.Node()) {
+ return true
+ }
+ }
+ return false
+}
+
+// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends
+// only when Close is called. Source iterators added via AddSource are removed from the
+// mix when they end.
+//
+// The distribution of nodes returned by Next is approximately fair, i.e. FairMix
+// attempts to draw from all sources equally often. However, if a certain source is slow
+// and doesn't return a node within the configured timeout, a node from any other source
+// will be returned.
+//
+// It's safe to call AddSource and Close concurrently with Next.
+type FairMix struct {
+ wg sync.WaitGroup
+ fromAny chan *Node
+ timeout time.Duration
+ cur *Node
+
+ mu sync.Mutex
+ closed chan struct{}
+ sources []*mixSource
+ last int
+}
+
+type mixSource struct {
+ it Iterator
+ next chan *Node
+ timeout time.Duration
+}
+
+// NewFairMix creates a mixer.
+//
+// The timeout specifies how long the mixer will wait for the next fairly-chosen source
+// before giving up and taking a node from any other source. A good way to set the timeout
+// is deciding how long you'd want to wait for a node on average. Passing a negative
+// timeout makes the mixer completely fair.
+func NewFairMix(timeout time.Duration) *FairMix {
+ m := &FairMix{
+ fromAny: make(chan *Node),
+ closed: make(chan struct{}),
+ timeout: timeout,
+ }
+ return m
+}
+
+// AddSource adds a source of nodes.
+func (m *FairMix) AddSource(it Iterator) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.closed == nil {
+ return
+ }
+ m.wg.Add(1)
+ source := &mixSource{it, make(chan *Node), m.timeout}
+ m.sources = append(m.sources, source)
+ go m.runSource(m.closed, source)
+}
+
+// Close shuts down the mixer and all current sources.
+// Calling this is required to release resources associated with the mixer.
+func (m *FairMix) Close() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.closed == nil {
+ return
+ }
+ for _, s := range m.sources {
+ s.it.Close()
+ }
+ close(m.closed)
+ m.wg.Wait()
+ close(m.fromAny)
+ m.sources = nil
+ m.closed = nil
+}
+
+// Next returns a node from a random source.
+func (m *FairMix) Next() bool {
+ m.cur = nil
+
+ for {
+ source := m.pickSource()
+ if source == nil {
+ return m.nextFromAny()
+ }
+
+ var timeout <-chan time.Time
+ if source.timeout >= 0 {
+ timer := time.NewTimer(source.timeout)
+ timeout = timer.C
+ defer timer.Stop()
+ }
+
+ select {
+ case n, ok := <-source.next:
+ if ok {
+ // Here, the timeout is reset to the configured value
+ // because the source delivered a node.
+ source.timeout = m.timeout
+ m.cur = n
+ return true
+ }
+ // This source has ended.
+ m.deleteSource(source)
+ case <-timeout:
+ // The selected source did not deliver a node within the timeout, so the
+ // timeout duration is halved for next time. This is supposed to improve
+ // latency with stuck sources.
+ source.timeout /= 2
+ return m.nextFromAny()
+ }
+ }
+}
+
+// Node returns the current node.
+func (m *FairMix) Node() *Node {
+ return m.cur
+}
+
+// nextFromAny is used when there are no sources or when the 'fair' choice
+// doesn't turn up a node quickly enough.
+func (m *FairMix) nextFromAny() bool {
+ n, ok := <-m.fromAny
+ if ok {
+ m.cur = n
+ }
+ return ok
+}
+
+// pickSource chooses the next source to read from, cycling through them in order.
+func (m *FairMix) pickSource() *mixSource {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if len(m.sources) == 0 {
+ return nil
+ }
+ m.last = (m.last + 1) % len(m.sources)
+ return m.sources[m.last]
+}
+
+// deleteSource deletes a source.
+func (m *FairMix) deleteSource(s *mixSource) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ for i := range m.sources {
+ if m.sources[i] == s {
+ copy(m.sources[i:], m.sources[i+1:])
+ m.sources[len(m.sources)-1] = nil
+ m.sources = m.sources[:len(m.sources)-1]
+ break
+ }
+ }
+}
+
+// runSource reads a single source in a loop.
+func (m *FairMix) runSource(closed chan struct{}, s *mixSource) {
+ defer m.wg.Done()
+ defer close(s.next)
+ for s.it.Next() {
+ n := s.it.Node()
+ select {
+ case s.next <- n:
+ case m.fromAny <- n:
+ case <-closed:
+ return
+ }
+ }
+}
diff --git a/p2p/enode/iter_test.go b/p2p/enode/iter_test.go
new file mode 100644
index 000000000..ae980345a
--- /dev/null
+++ b/p2p/enode/iter_test.go
@@ -0,0 +1,291 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "encoding/binary"
+ "runtime"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/tomochain/tomochain/p2p/enr"
+)
+
+func TestReadNodes(t *testing.T) {
+ nodes := ReadNodes(new(genIter), 10)
+ checkNodes(t, nodes, 10)
+}
+
+// This test checks that ReadNodes terminates when reading N nodes from an iterator
+// which returns less than N nodes in an endless cycle.
+func TestReadNodesCycle(t *testing.T) {
+ iter := &callCountIter{
+ Iterator: CycleNodes([]*Node{
+ testNode(0, 0),
+ testNode(1, 0),
+ testNode(2, 0),
+ }),
+ }
+ nodes := ReadNodes(iter, 10)
+ checkNodes(t, nodes, 3)
+ if iter.count != 10 {
+ t.Fatalf("%d calls to Next, want %d", iter.count, 100)
+ }
+}
+
+func TestFilterNodes(t *testing.T) {
+ nodes := make([]*Node, 100)
+ for i := range nodes {
+ nodes[i] = testNode(uint64(i), uint64(i))
+ }
+
+ it := Filter(IterNodes(nodes), func(n *Node) bool {
+ return n.Seq() >= 50
+ })
+ for i := 50; i < len(nodes); i++ {
+ if !it.Next() {
+ t.Fatal("Next returned false")
+ }
+ if it.Node() != nodes[i] {
+ t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i])
+ }
+ }
+ if it.Next() {
+ t.Fatal("Next returned true after underlying iterator has ended")
+ }
+}
+
+func checkNodes(t *testing.T, nodes []*Node, wantLen int) {
+ if len(nodes) != wantLen {
+ t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen)
+ return
+ }
+ seen := make(map[ID]bool, len(nodes))
+ for i, e := range nodes {
+ if e == nil {
+ t.Errorf("nil node at index %d", i)
+ return
+ }
+ if seen[e.ID()] {
+ t.Errorf("slice has duplicate node %v", e.ID())
+ return
+ }
+ seen[e.ID()] = true
+ }
+}
+
+// This test checks fairness of FairMix in the happy case where all sources return nodes
+// within the context's deadline.
+func TestFairMix(t *testing.T) {
+ for i := 0; i < 500; i++ {
+ testMixerFairness(t)
+ }
+}
+
+func testMixerFairness(t *testing.T) {
+ mix := NewFairMix(1 * time.Second)
+ mix.AddSource(&genIter{index: 1})
+ mix.AddSource(&genIter{index: 2})
+ mix.AddSource(&genIter{index: 3})
+ defer mix.Close()
+
+ nodes := ReadNodes(mix, 500)
+ checkNodes(t, nodes, 500)
+
+ // Verify that the nodes slice contains an approximately equal number of nodes
+ // from each source.
+ d := idPrefixDistribution(nodes)
+ for _, count := range d {
+ if approxEqual(count, len(nodes)/3, 30) {
+ t.Fatalf("ID distribution is unfair: %v", d)
+ }
+ }
+}
+
+// This test checks that FairMix falls back to an alternative source when
+// the 'fair' choice doesn't return a node within the timeout.
+func TestFairMixNextFromAll(t *testing.T) {
+ mix := NewFairMix(1 * time.Millisecond)
+ mix.AddSource(&genIter{index: 1})
+ mix.AddSource(CycleNodes(nil))
+ defer mix.Close()
+
+ nodes := ReadNodes(mix, 500)
+ checkNodes(t, nodes, 500)
+
+ d := idPrefixDistribution(nodes)
+ if len(d) > 1 || d[1] != len(nodes) {
+ t.Fatalf("wrong ID distribution: %v", d)
+ }
+}
+
+// This test ensures FairMix works for Next with no sources.
+func TestFairMixEmpty(t *testing.T) {
+ var (
+ mix = NewFairMix(1 * time.Second)
+ testN = testNode(1, 1)
+ ch = make(chan *Node)
+ )
+ defer mix.Close()
+
+ go func() {
+ mix.Next()
+ ch <- mix.Node()
+ }()
+
+ mix.AddSource(CycleNodes([]*Node{testN}))
+ if n := <-ch; n != testN {
+ t.Errorf("got wrong node: %v", n)
+ }
+}
+
+// This test checks closing a source while Next runs.
+func TestFairMixRemoveSource(t *testing.T) {
+ mix := NewFairMix(1 * time.Second)
+ source := make(blockingIter)
+ mix.AddSource(source)
+
+ sig := make(chan *Node)
+ go func() {
+ <-sig
+ mix.Next()
+ sig <- mix.Node()
+ }()
+
+ sig <- nil
+ runtime.Gosched()
+ source.Close()
+
+ wantNode := testNode(0, 0)
+ mix.AddSource(CycleNodes([]*Node{wantNode}))
+ n := <-sig
+
+ if len(mix.sources) != 1 {
+ t.Fatalf("have %d sources, want one", len(mix.sources))
+ }
+ if n != wantNode {
+ t.Fatalf("mixer returned wrong node")
+ }
+}
+
+type blockingIter chan struct{}
+
+func (it blockingIter) Next() bool {
+ <-it
+ return false
+}
+
+func (it blockingIter) Node() *Node {
+ return nil
+}
+
+func (it blockingIter) Close() {
+ close(it)
+}
+
+func TestFairMixClose(t *testing.T) {
+ for i := 0; i < 20 && !t.Failed(); i++ {
+ testMixerClose(t)
+ }
+}
+
+func testMixerClose(t *testing.T) {
+ mix := NewFairMix(-1)
+ mix.AddSource(CycleNodes(nil))
+ mix.AddSource(CycleNodes(nil))
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ if mix.Next() {
+ t.Error("Next returned true")
+ }
+ }()
+ // This call is supposed to make it more likely that NextNode is
+ // actually executing by the time we call Close.
+ runtime.Gosched()
+
+ mix.Close()
+ select {
+ case <-done:
+ case <-time.After(3 * time.Second):
+ t.Fatal("Next didn't unblock on Close")
+ }
+
+ mix.Close() // shouldn't crash
+}
+
+func idPrefixDistribution(nodes []*Node) map[uint32]int {
+ d := make(map[uint32]int, len(nodes))
+ for _, node := range nodes {
+ id := node.ID()
+ d[binary.BigEndian.Uint32(id[:4])]++
+ }
+ return d
+}
+
+func approxEqual(x, y, ε int) bool {
+ if y > x {
+ x, y = y, x
+ }
+ return x-y > ε
+}
+
+// genIter creates fake nodes with numbered IDs based on 'index' and 'gen'
+type genIter struct {
+ node *Node
+ index, gen uint32
+}
+
+func (s *genIter) Next() bool {
+ index := atomic.LoadUint32(&s.index)
+ if index == ^uint32(0) {
+ s.node = nil
+ return false
+ }
+ s.node = testNode(uint64(index)<<32|uint64(s.gen), 0)
+ s.gen++
+ return true
+}
+
+func (s *genIter) Node() *Node {
+ return s.node
+}
+
+func (s *genIter) Close() {
+ atomic.StoreUint32(&s.index, ^uint32(0))
+}
+
+func testNode(id, seq uint64) *Node {
+ var nodeID ID
+ binary.BigEndian.PutUint64(nodeID[:], id)
+ r := new(enr.Record)
+ r.SetSeq(seq)
+ return SignNull(r, nodeID)
+}
+
+// callCountIter counts calls to NextNode.
+type callCountIter struct {
+ Iterator
+ count int
+}
+
+func (it *callCountIter) Next() bool {
+ it.count++
+ return it.Iterator.Next()
+}
diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go
new file mode 100644
index 000000000..06f274992
--- /dev/null
+++ b/p2p/enode/localnode.go
@@ -0,0 +1,332 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "crypto/ecdsa"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/p2p/enr"
+ "github.com/tomochain/tomochain/p2p/netutil"
+)
+
+const (
+ // IP tracker configuration
+ iptrackMinStatements = 10
+ iptrackWindow = 5 * time.Minute
+ iptrackContactWindow = 10 * time.Minute
+
+ // time needed to wait between two updates to the local ENR
+ recordUpdateThrottle = time.Millisecond
+)
+
+// LocalNode produces the signed node record of a local node, i.e. a node run in the
+// current process. Setting ENR entries via the Set method updates the record. A new version
+// of the record is signed on demand when the Node method is called.
+type LocalNode struct {
+ cur atomic.Value // holds a non-nil node pointer while the record is up-to-date
+
+ id ID
+ key *ecdsa.PrivateKey
+ db *DB
+
+ // everything below is protected by a lock
+ mu sync.RWMutex
+ seq uint64
+ update time.Time // timestamp when the record was last updated
+ entries map[string]enr.Entry
+ endpoint4 lnEndpoint
+ endpoint6 lnEndpoint
+}
+
+type lnEndpoint struct {
+ track *netutil.IPTracker
+ staticIP, fallbackIP net.IP
+ fallbackUDP uint16 // port
+}
+
+// NewLocalNode creates a local node.
+func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
+ ln := &LocalNode{
+ id: PubkeyToIDV4(&key.PublicKey),
+ db: db,
+ key: key,
+ entries: make(map[string]enr.Entry),
+ endpoint4: lnEndpoint{
+ track: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
+ },
+ endpoint6: lnEndpoint{
+ track: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
+ },
+ }
+ ln.seq = db.localSeq(ln.id)
+ ln.update = time.Now()
+ ln.cur.Store((*Node)(nil))
+ return ln
+}
+
+// Database returns the node database associated with the local node.
+func (ln *LocalNode) Database() *DB {
+ return ln.db
+}
+
+// Node returns the current version of the local node record.
+func (ln *LocalNode) Node() *Node {
+ // If we have a valid record, return that
+ n := ln.cur.Load().(*Node)
+ if n != nil {
+ return n
+ }
+
+ // Record was invalidated, sign a new copy.
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ // Double check the current record, since multiple goroutines might be waiting
+ // on the write mutex.
+ if n = ln.cur.Load().(*Node); n != nil {
+ return n
+ }
+
+ // The initial sequence number is the current timestamp in milliseconds. To ensure
+ // that the initial sequence number will always be higher than any previous sequence
+ // number (assuming the clock is correct), we want to avoid updating the record faster
+ // than once per ms. So we need to sleep here until the next possible update time has
+ // arrived.
+ lastChange := time.Since(ln.update)
+ if lastChange < recordUpdateThrottle {
+ time.Sleep(recordUpdateThrottle - lastChange)
+ }
+
+ ln.sign()
+ ln.update = time.Now()
+ return ln.cur.Load().(*Node)
+}
+
+// Seq returns the current sequence number of the local node record.
+func (ln *LocalNode) Seq() uint64 {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ return ln.seq
+}
+
+// ID returns the local node ID.
+func (ln *LocalNode) ID() ID {
+ return ln.id
+}
+
+// Set puts the given entry into the local record, overwriting any existing value.
+// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll
+// be overwritten by the endpoint predictor.
+//
+// Since node record updates are throttled to one per second, Set is asynchronous.
+// Any update will be queued up and published when at least one second passes from
+// the last change.
+func (ln *LocalNode) Set(e enr.Entry) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.set(e)
+}
+
+func (ln *LocalNode) set(e enr.Entry) {
+ val, exists := ln.entries[e.ENRKey()]
+ if !exists || !reflect.DeepEqual(val, e) {
+ ln.entries[e.ENRKey()] = e
+ ln.invalidate()
+ }
+}
+
+// Delete removes the given entry from the local record.
+func (ln *LocalNode) Delete(e enr.Entry) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.delete(e)
+}
+
+func (ln *LocalNode) delete(e enr.Entry) {
+ _, exists := ln.entries[e.ENRKey()]
+ if exists {
+ delete(ln.entries, e.ENRKey())
+ ln.invalidate()
+ }
+}
+
+func (ln *LocalNode) endpointForIP(ip net.IP) *lnEndpoint {
+ if ip.To4() != nil {
+ return &ln.endpoint4
+ }
+ return &ln.endpoint6
+}
+
+// SetStaticIP sets the local IP to the given one unconditionally.
+// This disables endpoint prediction.
+func (ln *LocalNode) SetStaticIP(ip net.IP) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.endpointForIP(ip).staticIP = ip
+ ln.updateEndpoints()
+}
+
+// SetFallbackIP sets the last-resort IP address. This address is used
+// if no endpoint prediction can be made and no static IP is set.
+func (ln *LocalNode) SetFallbackIP(ip net.IP) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.endpointForIP(ip).fallbackIP = ip
+ ln.updateEndpoints()
+}
+
+// SetFallbackUDP sets the last-resort UDP-on-IPv4 port. This port is used
+// if no endpoint prediction can be made.
+func (ln *LocalNode) SetFallbackUDP(port int) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.endpoint4.fallbackUDP = uint16(port)
+ ln.endpoint6.fallbackUDP = uint16(port)
+ ln.updateEndpoints()
+}
+
+// UDPEndpointStatement should be called whenever a statement about the local node's
+// UDP endpoint is received. It feeds the local endpoint predictor.
+func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.endpointForIP(endpoint.IP).track.AddStatement(fromaddr.String(), endpoint.String())
+ ln.updateEndpoints()
+}
+
+// UDPContact should be called whenever the local node has announced itself to another node
+// via UDP. It feeds the local endpoint predictor.
+func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+
+ ln.endpointForIP(toaddr.IP).track.AddContact(toaddr.String())
+ ln.updateEndpoints()
+}
+
+// updateEndpoints updates the record with predicted endpoints.
+func (ln *LocalNode) updateEndpoints() {
+ ip4, udp4 := ln.endpoint4.get()
+ ip6, udp6 := ln.endpoint6.get()
+
+ if ip4 != nil && !ip4.IsUnspecified() {
+ ln.set(enr.IPv4(ip4))
+ } else {
+ ln.delete(enr.IPv4{})
+ }
+ if ip6 != nil && !ip6.IsUnspecified() {
+ ln.set(enr.IPv6(ip6))
+ } else {
+ ln.delete(enr.IPv6{})
+ }
+ if udp4 != 0 {
+ ln.set(enr.UDP(udp4))
+ } else {
+ ln.delete(enr.UDP(0))
+ }
+ if udp6 != 0 && udp6 != udp4 {
+ ln.set(enr.UDP6(udp6))
+ } else {
+ ln.delete(enr.UDP6(0))
+ }
+}
+
+// get returns the endpoint with highest precedence.
+func (e *lnEndpoint) get() (newIP net.IP, newPort uint16) {
+ newPort = e.fallbackUDP
+ if e.fallbackIP != nil {
+ newIP = e.fallbackIP
+ }
+ if e.staticIP != nil {
+ newIP = e.staticIP
+ } else if ip, port := predictAddr(e.track); ip != nil {
+ newIP = ip
+ newPort = port
+ }
+ return newIP, newPort
+}
+
+// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
+// endpoint representation to IP and port types.
+func predictAddr(t *netutil.IPTracker) (net.IP, uint16) {
+ ep := t.PredictEndpoint()
+ if ep == "" {
+ return nil, 0
+ }
+ ipString, portString, _ := net.SplitHostPort(ep)
+ ip := net.ParseIP(ipString)
+ port, err := strconv.ParseUint(portString, 10, 16)
+ if err != nil {
+ return nil, 0
+ }
+ return ip, uint16(port)
+}
+
+func (ln *LocalNode) invalidate() {
+ ln.cur.Store((*Node)(nil))
+}
+
+func (ln *LocalNode) sign() {
+ if n := ln.cur.Load().(*Node); n != nil {
+ return // no changes
+ }
+
+ var r enr.Record
+ for _, e := range ln.entries {
+ r.Set(e)
+ }
+ ln.bumpSeq()
+ r.SetSeq(ln.seq)
+ if err := SignV4(&r, ln.key); err != nil {
+ panic(fmt.Errorf("enode: can't sign record: %v", err))
+ }
+ n, err := New(ValidSchemes, &r)
+ if err != nil {
+ panic(fmt.Errorf("enode: can't verify local record: %v", err))
+ }
+ ln.cur.Store(n)
+ log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP())
+}
+
+func (ln *LocalNode) bumpSeq() {
+ ln.seq++
+ ln.db.storeLocalSeq(ln.id, ln.seq)
+}
+
+// nowMilliseconds gives the current timestamp at millisecond precision.
+func nowMilliseconds() uint64 {
+ ns := time.Now().UnixNano()
+ if ns < 0 {
+ return 0
+ }
+ return uint64(ns / 1000 / 1000)
+}
diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go
new file mode 100644
index 000000000..c7fd79ef9
--- /dev/null
+++ b/p2p/enode/localnode_test.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "crypto/rand"
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/p2p/enr"
+)
+
+func newLocalNodeForTesting() (*LocalNode, *DB) {
+ db, _ := OpenDB("")
+ key, _ := crypto.GenerateKey()
+ return NewLocalNode(db, key), db
+}
+
+func TestLocalNode(t *testing.T) {
+ ln, db := newLocalNodeForTesting()
+ defer db.Close()
+
+ if ln.Node().ID() != ln.ID() {
+ t.Fatal("inconsistent ID")
+ }
+
+ ln.Set(enr.WithEntry("x", uint(3)))
+ var x uint
+ if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil {
+ t.Fatal("can't load entry 'x':", err)
+ } else if x != 3 {
+ t.Fatal("wrong value for entry 'x':", x)
+ }
+}
+
+// This test checks that the sequence number is persisted between restarts.
+func TestLocalNodeSeqPersist(t *testing.T) {
+ timestamp := nowMilliseconds()
+
+ ln, db := newLocalNodeForTesting()
+ defer db.Close()
+
+ initialSeq := ln.Node().Seq()
+ if initialSeq < timestamp {
+ t.Fatalf("wrong initial seq %d, want at least %d", initialSeq, timestamp)
+ }
+
+ ln.Set(enr.WithEntry("x", uint(1)))
+ if s := ln.Node().Seq(); s != initialSeq+1 {
+ t.Fatalf("wrong seq %d after set, want %d", s, initialSeq+1)
+ }
+
+ // Create a new instance, it should reload the sequence number.
+ // The number increases just after that because a new record is
+ // created without the "x" entry.
+ ln2 := NewLocalNode(db, ln.key)
+ if s := ln2.Node().Seq(); s != initialSeq+2 {
+ t.Fatalf("wrong seq %d on new instance, want %d", s, initialSeq+2)
+ }
+
+ finalSeq := ln2.Node().Seq()
+
+ // Create a new instance with a different node key on the same database.
+ // This should reset the sequence number.
+ key, _ := crypto.GenerateKey()
+ ln3 := NewLocalNode(db, key)
+ if s := ln3.Node().Seq(); s < finalSeq {
+ t.Fatalf("wrong seq %d on instance with changed key, want >= %d", s, finalSeq)
+ }
+}
+
+// This test checks behavior of the endpoint predictor.
+func TestLocalNodeEndpoint(t *testing.T) {
+ var (
+ fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80}
+ predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81}
+ staticIP = net.IP{127, 0, 1, 2}
+ )
+ ln, db := newLocalNodeForTesting()
+ defer db.Close()
+
+ // Nothing is set initially.
+ assert.Equal(t, net.IP(nil), ln.Node().IP())
+ assert.Equal(t, 0, ln.Node().UDP())
+ initialSeq := ln.Node().Seq()
+
+ // Set up fallback address.
+ ln.SetFallbackIP(fallback.IP)
+ ln.SetFallbackUDP(fallback.Port)
+ assert.Equal(t, fallback.IP, ln.Node().IP())
+ assert.Equal(t, fallback.Port, ln.Node().UDP())
+ assert.Equal(t, initialSeq+1, ln.Node().Seq())
+
+ // Add endpoint statements from random hosts.
+ for i := 0; i < iptrackMinStatements; i++ {
+ assert.Equal(t, fallback.IP, ln.Node().IP())
+ assert.Equal(t, fallback.Port, ln.Node().UDP())
+ assert.Equal(t, initialSeq+1, ln.Node().Seq())
+
+ from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
+ rand.Read(from.IP)
+ ln.UDPEndpointStatement(from, predicted)
+ }
+ assert.Equal(t, predicted.IP, ln.Node().IP())
+ assert.Equal(t, predicted.Port, ln.Node().UDP())
+ assert.Equal(t, initialSeq+2, ln.Node().Seq())
+
+ // Static IP overrides prediction.
+ ln.SetStaticIP(staticIP)
+ assert.Equal(t, staticIP, ln.Node().IP())
+ assert.Equal(t, fallback.Port, ln.Node().UDP())
+ assert.Equal(t, initialSeq+3, ln.Node().Seq())
+}
diff --git a/p2p/enode/node.go b/p2p/enode/node.go
new file mode 100644
index 000000000..7606ad40f
--- /dev/null
+++ b/p2p/enode/node.go
@@ -0,0 +1,279 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "crypto/ecdsa"
+ "encoding/base64"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "math/bits"
+ "net"
+ "strings"
+
+ "github.com/tomochain/tomochain/p2p/enr"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+var errMissingPrefix = errors.New("missing 'enr:' prefix for base64-encoded record")
+
+// Node represents a host on the network.
+type Node struct {
+ r enr.Record
+ id ID
+}
+
+// New wraps a node record. The record must be valid according to the given
+// identity scheme.
+func New(validSchemes enr.IdentityScheme, r *enr.Record) (*Node, error) {
+ if err := r.VerifySignature(validSchemes); err != nil {
+ return nil, err
+ }
+ node := &Node{r: *r}
+ if n := copy(node.id[:], validSchemes.NodeAddr(&node.r)); n != len(ID{}) {
+ return nil, fmt.Errorf("invalid node ID length %d, need %d", n, len(ID{}))
+ }
+ return node, nil
+}
+
+// MustParse parses a node record or enode:// URL. It panics if the input is invalid.
+func MustParse(rawurl string) *Node {
+ n, err := Parse(ValidSchemes, rawurl)
+ if err != nil {
+ panic("invalid node: " + err.Error())
+ }
+ return n
+}
+
+// Parse decodes and verifies a base64-encoded node record.
+func Parse(validSchemes enr.IdentityScheme, input string) (*Node, error) {
+ if strings.HasPrefix(input, "enode://") {
+ return ParseV4(input)
+ }
+ if !strings.HasPrefix(input, "enr:") {
+ return nil, errMissingPrefix
+ }
+ bin, err := base64.RawURLEncoding.DecodeString(input[4:])
+ if err != nil {
+ return nil, err
+ }
+ var r enr.Record
+ if err := rlp.DecodeBytes(bin, &r); err != nil {
+ return nil, err
+ }
+ return New(validSchemes, &r)
+}
+
+// ID returns the node identifier.
+func (n *Node) ID() ID {
+ return n.id
+}
+
+// Seq returns the sequence number of the underlying record.
+func (n *Node) Seq() uint64 {
+ return n.r.Seq()
+}
+
+// Incomplete returns true for nodes with no IP address.
+func (n *Node) Incomplete() bool {
+ return n.IP() == nil
+}
+
+// Load retrieves an entry from the underlying record.
+func (n *Node) Load(k enr.Entry) error {
+ return n.r.Load(k)
+}
+
+// IP returns the IP address of the node. This prefers IPv4 addresses.
+func (n *Node) IP() net.IP {
+ var (
+ ip4 enr.IPv4
+ ip6 enr.IPv6
+ )
+ if n.Load(&ip4) == nil {
+ return net.IP(ip4)
+ }
+ if n.Load(&ip6) == nil {
+ return net.IP(ip6)
+ }
+ return nil
+}
+
+// UDP returns the UDP port of the node.
+func (n *Node) UDP() int {
+ var port enr.UDP
+ n.Load(&port)
+ return int(port)
+}
+
+// TCP returns the TCP port of the node.
+func (n *Node) TCP() int {
+ var port enr.TCP
+ n.Load(&port)
+ return int(port)
+}
+
+// Pubkey returns the secp256k1 public key of the node, if present.
+func (n *Node) Pubkey() *ecdsa.PublicKey {
+ var key ecdsa.PublicKey
+ if n.Load((*Secp256k1)(&key)) != nil {
+ return nil
+ }
+ return &key
+}
+
+// Record returns the node's record. The return value is a copy and may
+// be modified by the caller.
+func (n *Node) Record() *enr.Record {
+ cpy := n.r
+ return &cpy
+}
+
+// ValidateComplete checks whether n has a valid IP and UDP port.
+// Deprecated: don't use this method.
+func (n *Node) ValidateComplete() error {
+ if n.Incomplete() {
+ return errors.New("missing IP address")
+ }
+ if n.UDP() == 0 {
+ return errors.New("missing UDP port")
+ }
+ ip := n.IP()
+ if ip.IsMulticast() || ip.IsUnspecified() {
+ return errors.New("invalid IP (multicast/unspecified)")
+ }
+ // Validate the node key (on curve, etc.).
+ var key Secp256k1
+ return n.Load(&key)
+}
+
+// String returns the text representation of the record.
+func (n *Node) String() string {
+ if isNewV4(n) {
+ return n.URLv4() // backwards-compatibility glue for NewV4 nodes
+ }
+ enc, _ := rlp.EncodeToBytes(&n.r) // always succeeds because record is valid
+ b64 := base64.RawURLEncoding.EncodeToString(enc)
+ return "enr:" + b64
+}
+
+// MarshalText implements encoding.TextMarshaler.
+func (n *Node) MarshalText() ([]byte, error) {
+ return []byte(n.String()), nil
+}
+
+// UnmarshalText implements encoding.TextUnmarshaler.
+func (n *Node) UnmarshalText(text []byte) error {
+ dec, err := Parse(ValidSchemes, string(text))
+ if err == nil {
+ *n = *dec
+ }
+ return err
+}
+
+// ID is a unique identifier for each node.
+type ID [32]byte
+
+// Bytes returns a byte slice representation of the ID
+func (n ID) Bytes() []byte {
+ return n[:]
+}
+
+// ID prints as a long hexadecimal number.
+func (n ID) String() string {
+ return fmt.Sprintf("%x", n[:])
+}
+
+// GoString returns the Go syntax representation of a ID is a call to HexID.
+func (n ID) GoString() string {
+ return fmt.Sprintf("enode.HexID(\"%x\")", n[:])
+}
+
+// TerminalString returns a shortened hex string for terminal logging.
+func (n ID) TerminalString() string {
+ return hex.EncodeToString(n[:8])
+}
+
+// MarshalText implements the encoding.TextMarshaler interface.
+func (n ID) MarshalText() ([]byte, error) {
+ return []byte(hex.EncodeToString(n[:])), nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface.
+func (n *ID) UnmarshalText(text []byte) error {
+ id, err := ParseID(string(text))
+ if err != nil {
+ return err
+ }
+ *n = id
+ return nil
+}
+
+// HexID converts a hex string to an ID.
+// The string may be prefixed with 0x.
+// It panics if the string is not a valid ID.
+func HexID(in string) ID {
+ id, err := ParseID(in)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+func ParseID(in string) (ID, error) {
+ var id ID
+ b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
+ if err != nil {
+ return id, err
+ } else if len(b) != len(id) {
+ return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
+ }
+ copy(id[:], b)
+ return id, nil
+}
+
+// DistCmp compares the distances a->target and b->target.
+// Returns -1 if a is closer to target, 1 if b is closer to target
+// and 0 if they are equal.
+func DistCmp(target, a, b ID) int {
+ for i := range target {
+ da := a[i] ^ target[i]
+ db := b[i] ^ target[i]
+ if da > db {
+ return 1
+ } else if da < db {
+ return -1
+ }
+ }
+ return 0
+}
+
+// LogDist returns the logarithmic distance between a and b, log2(a ^ b).
+func LogDist(a, b ID) int {
+ lz := 0
+ for i := range a {
+ x := a[i] ^ b[i]
+ if x == 0 {
+ lz += 8
+ } else {
+ lz += bits.LeadingZeros8(x)
+ break
+ }
+ }
+ return len(a)*8 - lz
+}
diff --git a/p2p/enode/node_test.go b/p2p/enode/node_test.go
new file mode 100644
index 000000000..a2ec52655
--- /dev/null
+++ b/p2p/enode/node_test.go
@@ -0,0 +1,145 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "math/big"
+ "testing"
+ "testing/quick"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tomochain/tomochain/p2p/enr"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+var pyRecord, _ = hex.DecodeString("f884b8407098ad865b00a582051940cb9cf36836572411a47278783077011599ed5cd16b76f2635f4e234738f30813a89eb9137e3e3df5266e3a1f11df72ecf1145ccb9c01826964827634826970847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31388375647082765f")
+
+// TestPythonInterop checks that we can decode and verify a record produced by the Python
+// implementation.
+func TestPythonInterop(t *testing.T) {
+ var r enr.Record
+ if err := rlp.DecodeBytes(pyRecord, &r); err != nil {
+ t.Fatalf("can't decode: %v", err)
+ }
+ n, err := New(ValidSchemes, &r)
+ if err != nil {
+ t.Fatalf("can't verify record: %v", err)
+ }
+
+ var (
+ wantID = HexID("a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7")
+ wantSeq = uint64(1)
+ wantIP = enr.IPv4{127, 0, 0, 1}
+ wantUDP = enr.UDP(30303)
+ )
+ if n.Seq() != wantSeq {
+ t.Errorf("wrong seq: got %d, want %d", n.Seq(), wantSeq)
+ }
+ if n.ID() != wantID {
+ t.Errorf("wrong id: got %x, want %x", n.ID(), wantID)
+ }
+ want := map[enr.Entry]interface{}{new(enr.IPv4): &wantIP, new(enr.UDP): &wantUDP}
+ for k, v := range want {
+ desc := fmt.Sprintf("loading key %q", k.ENRKey())
+ if assert.NoError(t, n.Load(k), desc) {
+ assert.Equal(t, k, v, desc)
+ }
+ }
+}
+
+func TestHexID(t *testing.T) {
+ ref := ID{0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
+ id1 := HexID("0x00000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+ id2 := HexID("00000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+
+ if id1 != ref {
+ t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
+ }
+ if id2 != ref {
+ t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
+ }
+}
+
+func TestID_textEncoding(t *testing.T) {
+ ref := ID{
+ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10,
+ 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x20,
+ 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x30,
+ 0x31, 0x32,
+ }
+ hex := "0102030405060708091011121314151617181920212223242526272829303132"
+
+ text, err := ref.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(text, []byte(hex)) {
+ t.Fatalf("text encoding did not match\nexpected: %s\ngot: %s", hex, text)
+ }
+
+ id := new(ID)
+ if err := id.UnmarshalText(text); err != nil {
+ t.Fatal(err)
+ }
+ if *id != ref {
+ t.Fatalf("text decoding did not match\nexpected: %s\ngot: %s", ref, id)
+ }
+}
+
+func TestID_distcmp(t *testing.T) {
+ distcmpBig := func(target, a, b ID) int {
+ tbig := new(big.Int).SetBytes(target[:])
+ abig := new(big.Int).SetBytes(a[:])
+ bbig := new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
+ }
+ if err := quick.CheckEqual(DistCmp, distcmpBig, nil); err != nil {
+ t.Error(err)
+ }
+}
+
+// The random tests is likely to miss the case where a and b are equal,
+// this test checks it explicitly.
+func TestID_distcmpEqual(t *testing.T) {
+ base := ID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ x := ID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
+ if DistCmp(base, x, x) != 0 {
+ t.Errorf("DistCmp(base, x, x) != 0")
+ }
+}
+
+func TestID_logdist(t *testing.T) {
+ logdistBig := func(a, b ID) int {
+ abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(abig, bbig).BitLen()
+ }
+ if err := quick.CheckEqual(LogDist, logdistBig, nil); err != nil {
+ t.Error(err)
+ }
+}
+
+// The random tests is likely to miss the case where a and b are equal,
+// this test checks it explicitly.
+func TestID_logdistEqual(t *testing.T) {
+ x := ID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ if LogDist(x, x) != 0 {
+ t.Errorf("LogDist(x, x) != 0")
+ }
+}
diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go
new file mode 100644
index 000000000..466de5ce1
--- /dev/null
+++ b/p2p/enode/nodedb.go
@@ -0,0 +1,501 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ "github.com/syndtr/goleveldb/leveldb"
+ "github.com/syndtr/goleveldb/leveldb/errors"
+ "github.com/syndtr/goleveldb/leveldb/iterator"
+ "github.com/syndtr/goleveldb/leveldb/opt"
+ "github.com/syndtr/goleveldb/leveldb/storage"
+ "github.com/syndtr/goleveldb/leveldb/util"
+ "github.com/tomochain/tomochain/rlp"
+)
+
+// Keys in the node database.
+const (
+ dbVersionKey = "version" // Version of the database to flush if changes
+ dbNodePrefix = "n:" // Identifier to prefix node entries with
+ dbLocalPrefix = "local:"
+ dbDiscoverRoot = "v4"
+ dbDiscv5Root = "v5"
+
+ // These fields are stored per ID and IP, the full key is "n::v4::findfail".
+ // Use nodeItemKey to create those keys.
+ dbNodeFindFails = "findfail"
+ dbNodePing = "lastping"
+ dbNodePong = "lastpong"
+ dbNodeSeq = "seq"
+
+ // Local information is keyed by ID only, the full key is "local::seq".
+ // Use localItemKey to create those keys.
+ dbLocalSeq = "seq"
+)
+
+const (
+ dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
+ dbCleanupCycle = time.Hour // Time period for running the expiration task.
+ dbVersion = 9
+)
+
+var (
+ errInvalidIP = errors.New("invalid IP")
+)
+
+var zeroIP = make(net.IP, 16)
+
+// DB is the node database, storing previously seen nodes and any collected metadata about
+// them for QoS purposes.
+type DB struct {
+ lvl *leveldb.DB // Interface to the database itself
+ runner sync.Once // Ensures we can start at most one expirer
+ quit chan struct{} // Channel to signal the expiring thread to stop
+}
+
+// OpenDB opens a node database for storing and retrieving infos about known peers in the
+// network. If no path is given an in-memory, temporary database is constructed.
+func OpenDB(path string) (*DB, error) {
+ if path == "" {
+ return newMemoryDB()
+ }
+ return newPersistentDB(path)
+}
+
+// newMemoryNodeDB creates a new in-memory node database without a persistent backend.
+func newMemoryDB() (*DB, error) {
+ db, err := leveldb.Open(storage.NewMemStorage(), nil)
+ if err != nil {
+ return nil, err
+ }
+ return &DB{lvl: db, quit: make(chan struct{})}, nil
+}
+
+// newPersistentNodeDB creates/opens a leveldb backed persistent node database,
+// also flushing its contents in case of a version mismatch.
+func newPersistentDB(path string) (*DB, error) {
+ opts := &opt.Options{OpenFilesCacheCapacity: 5}
+ db, err := leveldb.OpenFile(path, opts)
+ if _, iscorrupted := err.(*errors.ErrCorrupted); iscorrupted {
+ db, err = leveldb.RecoverFile(path, nil)
+ }
+ if err != nil {
+ return nil, err
+ }
+ // The nodes contained in the cache correspond to a certain protocol version.
+ // Flush all nodes if the version doesn't match.
+ currentVer := make([]byte, binary.MaxVarintLen64)
+ currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))]
+
+ blob, err := db.Get([]byte(dbVersionKey), nil)
+ switch err {
+ case leveldb.ErrNotFound:
+ // Version not found (i.e. empty cache), insert it
+ if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil {
+ db.Close()
+ return nil, err
+ }
+
+ case nil:
+ // Version present, flush if different
+ if !bytes.Equal(blob, currentVer) {
+ db.Close()
+ if err = os.RemoveAll(path); err != nil {
+ return nil, err
+ }
+ return newPersistentDB(path)
+ }
+ }
+ return &DB{lvl: db, quit: make(chan struct{})}, nil
+}
+
+// nodeKey returns the database key for a node record.
+func nodeKey(id ID) []byte {
+ key := append([]byte(dbNodePrefix), id[:]...)
+ key = append(key, ':')
+ key = append(key, dbDiscoverRoot...)
+ return key
+}
+
+// splitNodeKey returns the node ID of a key created by nodeKey.
+func splitNodeKey(key []byte) (id ID, rest []byte) {
+ if !bytes.HasPrefix(key, []byte(dbNodePrefix)) {
+ return ID{}, nil
+ }
+ item := key[len(dbNodePrefix):]
+ copy(id[:], item[:len(id)])
+ return id, item[len(id)+1:]
+}
+
+// nodeItemKey returns the database key for a node metadata field.
+func nodeItemKey(id ID, ip net.IP, field string) []byte {
+ ip16 := ip.To16()
+ if ip16 == nil {
+ panic(fmt.Errorf("invalid IP (length %d)", len(ip)))
+ }
+ return bytes.Join([][]byte{nodeKey(id), ip16, []byte(field)}, []byte{':'})
+}
+
+// splitNodeItemKey returns the components of a key created by nodeItemKey.
+func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) {
+ id, key = splitNodeKey(key)
+ // Skip discover root.
+ if string(key) == dbDiscoverRoot {
+ return id, nil, ""
+ }
+ key = key[len(dbDiscoverRoot)+1:]
+ // Split out the IP.
+ ip = key[:16]
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ key = key[16+1:]
+ // Field is the remainder of key.
+ field = string(key)
+ return id, ip, field
+}
+
+func v5Key(id ID, ip net.IP, field string) []byte {
+ return bytes.Join([][]byte{
+ []byte(dbNodePrefix),
+ id[:],
+ []byte(dbDiscv5Root),
+ ip.To16(),
+ []byte(field),
+ }, []byte{':'})
+}
+
+// localItemKey returns the key of a local node item.
+func localItemKey(id ID, field string) []byte {
+ key := append([]byte(dbLocalPrefix), id[:]...)
+ key = append(key, ':')
+ key = append(key, field...)
+ return key
+}
+
+// fetchInt64 retrieves an integer associated with a particular key.
+func (db *DB) fetchInt64(key []byte) int64 {
+ blob, err := db.lvl.Get(key, nil)
+ if err != nil {
+ return 0
+ }
+ val, read := binary.Varint(blob)
+ if read <= 0 {
+ return 0
+ }
+ return val
+}
+
+// storeInt64 stores an integer in the given key.
+func (db *DB) storeInt64(key []byte, n int64) error {
+ blob := make([]byte, binary.MaxVarintLen64)
+ blob = blob[:binary.PutVarint(blob, n)]
+ return db.lvl.Put(key, blob, nil)
+}
+
+// fetchUint64 retrieves an integer associated with a particular key.
+func (db *DB) fetchUint64(key []byte) uint64 {
+ blob, err := db.lvl.Get(key, nil)
+ if err != nil {
+ return 0
+ }
+ val, _ := binary.Uvarint(blob)
+ return val
+}
+
+// storeUint64 stores an integer in the given key.
+func (db *DB) storeUint64(key []byte, n uint64) error {
+ blob := make([]byte, binary.MaxVarintLen64)
+ blob = blob[:binary.PutUvarint(blob, n)]
+ return db.lvl.Put(key, blob, nil)
+}
+
+// Node retrieves a node with a given id from the database.
+func (db *DB) Node(id ID) *Node {
+ blob, err := db.lvl.Get(nodeKey(id), nil)
+ if err != nil {
+ return nil
+ }
+ return mustDecodeNode(id[:], blob)
+}
+
+func mustDecodeNode(id, data []byte) *Node {
+ node := new(Node)
+ if err := rlp.DecodeBytes(data, &node.r); err != nil {
+ panic(fmt.Errorf("p2p/enode: can't decode node %x in DB: %v", id, err))
+ }
+ // Restore node id cache.
+ copy(node.id[:], id)
+ return node
+}
+
+// UpdateNode inserts - potentially overwriting - a node into the peer database.
+func (db *DB) UpdateNode(node *Node) error {
+ if node.Seq() < db.NodeSeq(node.ID()) {
+ return nil
+ }
+ blob, err := rlp.EncodeToBytes(&node.r)
+ if err != nil {
+ return err
+ }
+ if err := db.lvl.Put(nodeKey(node.ID()), blob, nil); err != nil {
+ return err
+ }
+ return db.storeUint64(nodeItemKey(node.ID(), zeroIP, dbNodeSeq), node.Seq())
+}
+
+// NodeSeq returns the stored record sequence number of the given node.
+func (db *DB) NodeSeq(id ID) uint64 {
+ return db.fetchUint64(nodeItemKey(id, zeroIP, dbNodeSeq))
+}
+
+// Resolve returns the stored record of the node if it has a larger sequence
+// number than n.
+func (db *DB) Resolve(n *Node) *Node {
+ if n.Seq() > db.NodeSeq(n.ID()) {
+ return n
+ }
+ return db.Node(n.ID())
+}
+
+// DeleteNode deletes all information associated with a node.
+func (db *DB) DeleteNode(id ID) {
+ deleteRange(db.lvl, nodeKey(id))
+}
+
+func deleteRange(db *leveldb.DB, prefix []byte) {
+ it := db.NewIterator(util.BytesPrefix(prefix), nil)
+ defer it.Release()
+ for it.Next() {
+ db.Delete(it.Key(), nil)
+ }
+}
+
+// ensureExpirer is a small helper method ensuring that the data expiration
+// mechanism is running. If the expiration goroutine is already running, this
+// method simply returns.
+//
+// The goal is to start the data evacuation only after the network successfully
+// bootstrapped itself (to prevent dumping potentially useful seed nodes). Since
+// it would require significant overhead to exactly trace the first successful
+// convergence, it's simpler to "ensure" the correct state when an appropriate
+// condition occurs (i.e. a successful bonding), and discard further events.
+func (db *DB) ensureExpirer() {
+ db.runner.Do(func() { go db.expirer() })
+}
+
+// expirer should be started in a go routine, and is responsible for looping ad
+// infinitum and dropping stale data from the database.
+func (db *DB) expirer() {
+ tick := time.NewTicker(dbCleanupCycle)
+ defer tick.Stop()
+ for {
+ select {
+ case <-tick.C:
+ db.expireNodes()
+ case <-db.quit:
+ return
+ }
+ }
+}
+
+// expireNodes iterates over the database and deletes all nodes that have not
+// been seen (i.e. received a pong from) for some time.
+func (db *DB) expireNodes() {
+ it := db.lvl.NewIterator(util.BytesPrefix([]byte(dbNodePrefix)), nil)
+ defer it.Release()
+ if !it.Next() {
+ return
+ }
+
+ var (
+ threshold = time.Now().Add(-dbNodeExpiration).Unix()
+ youngestPong int64
+ atEnd = false
+ )
+ for !atEnd {
+ id, ip, field := splitNodeItemKey(it.Key())
+ if field == dbNodePong {
+ time, _ := binary.Varint(it.Value())
+ if time > youngestPong {
+ youngestPong = time
+ }
+ if time < threshold {
+ // Last pong from this IP older than threshold, remove fields belonging to it.
+ deleteRange(db.lvl, nodeItemKey(id, ip, ""))
+ }
+ }
+ atEnd = !it.Next()
+ nextID, _ := splitNodeKey(it.Key())
+ if atEnd || nextID != id {
+ // We've moved beyond the last entry of the current ID.
+ // Remove everything if there was no recent enough pong.
+ if youngestPong > 0 && youngestPong < threshold {
+ deleteRange(db.lvl, nodeKey(id))
+ }
+ youngestPong = 0
+ }
+ }
+}
+
+// LastPingReceived retrieves the time of the last ping packet received from
+// a remote node.
+func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time {
+ if ip = ip.To16(); ip == nil {
+ return time.Time{}
+ }
+ return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0)
+}
+
+// UpdateLastPingReceived updates the last time we tried contacting a remote node.
+func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error {
+ if ip = ip.To16(); ip == nil {
+ return errInvalidIP
+ }
+ return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix())
+}
+
+// LastPongReceived retrieves the time of the last successful pong from remote node.
+func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time {
+ if ip = ip.To16(); ip == nil {
+ return time.Time{}
+ }
+ // Launch expirer
+ db.ensureExpirer()
+ return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePong)), 0)
+}
+
+// UpdateLastPongReceived updates the last pong time of a node.
+func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error {
+ if ip = ip.To16(); ip == nil {
+ return errInvalidIP
+ }
+ return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix())
+}
+
+// FindFails retrieves the number of findnode failures since bonding.
+func (db *DB) FindFails(id ID, ip net.IP) int {
+ if ip = ip.To16(); ip == nil {
+ return 0
+ }
+ return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails)))
+}
+
+// UpdateFindFails updates the number of findnode failures since bonding.
+func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error {
+ if ip = ip.To16(); ip == nil {
+ return errInvalidIP
+ }
+ return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails))
+}
+
+// FindFailsV5 retrieves the discv5 findnode failure counter.
+func (db *DB) FindFailsV5(id ID, ip net.IP) int {
+ if ip = ip.To16(); ip == nil {
+ return 0
+ }
+ return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails)))
+}
+
+// UpdateFindFailsV5 stores the discv5 findnode failure counter.
+func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error {
+ if ip = ip.To16(); ip == nil {
+ return errInvalidIP
+ }
+ return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails))
+}
+
+// localSeq retrieves the local record sequence counter, defaulting to the current
+// timestamp if no previous exists. This ensures that wiping all data associated
+// with a node (apart from its key) will not generate already used sequence nums.
+func (db *DB) localSeq(id ID) uint64 {
+ if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 {
+ return seq
+ }
+ return nowMilliseconds()
+}
+
+// storeLocalSeq stores the local record sequence counter.
+func (db *DB) storeLocalSeq(id ID, n uint64) {
+ db.storeUint64(localItemKey(id, dbLocalSeq), n)
+}
+
+// QuerySeeds retrieves random nodes to be used as potential seed nodes
+// for bootstrapping.
+func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node {
+ var (
+ now = time.Now()
+ nodes = make([]*Node, 0, n)
+ it = db.lvl.NewIterator(nil, nil)
+ id ID
+ )
+ defer it.Release()
+
+seek:
+ for seeks := 0; len(nodes) < n && seeks < n*5; seeks++ {
+ // Seek to a random entry. The first byte is incremented by a
+ // random amount each time in order to increase the likelihood
+ // of hitting all existing nodes in very small databases.
+ ctr := id[0]
+ rand.Read(id[:])
+ id[0] = ctr + id[0]%16
+ it.Seek(nodeKey(id))
+
+ n := nextNode(it)
+ if n == nil {
+ id[0] = 0
+ continue seek // iterator exhausted
+ }
+ if now.Sub(db.LastPongReceived(n.ID(), n.IP())) > maxAge {
+ continue seek
+ }
+ for i := range nodes {
+ if nodes[i].ID() == n.ID() {
+ continue seek // duplicate
+ }
+ }
+ nodes = append(nodes, n)
+ }
+ return nodes
+}
+
+// reads the next node record from the iterator, skipping over other
+// database entries.
+func nextNode(it iterator.Iterator) *Node {
+ for end := false; !end; end = !it.Next() {
+ id, rest := splitNodeKey(it.Key())
+ if string(rest) != dbDiscoverRoot {
+ continue
+ }
+ return mustDecodeNode(id[:], it.Value())
+ }
+ return nil
+}
+
+// Close flushes and closes the database files.
+func (db *DB) Close() {
+ close(db.quit)
+ db.lvl.Close()
+}
diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go
new file mode 100644
index 000000000..38764f31b
--- /dev/null
+++ b/p2p/enode/nodedb_test.go
@@ -0,0 +1,469 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "path/filepath"
+ "reflect"
+ "testing"
+ "time"
+)
+
+var keytestID = HexID("51232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439")
+
+func TestDBNodeKey(t *testing.T) {
+ enc := nodeKey(keytestID)
+ want := []byte{
+ 'n', ':',
+ 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id
+ 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, //
+ 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, //
+ 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, //
+ ':', 'v', '4',
+ }
+ if !bytes.Equal(enc, want) {
+ t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want)
+ }
+ id, _ := splitNodeKey(enc)
+ if id != keytestID {
+ t.Errorf("wrong ID from splitNodeKey")
+ }
+}
+
+func TestDBNodeItemKey(t *testing.T) {
+ wantIP := net.IP{127, 0, 0, 3}
+ wantField := "foobar"
+ enc := nodeItemKey(keytestID, wantIP, wantField)
+ want := []byte{
+ 'n', ':',
+ 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id
+ 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, //
+ 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, //
+ 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, //
+ ':', 'v', '4', ':',
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // IP
+ 0x00, 0x00, 0xff, 0xff, 0x7f, 0x00, 0x00, 0x03, //
+ ':', 'f', 'o', 'o', 'b', 'a', 'r',
+ }
+ if !bytes.Equal(enc, want) {
+ t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want)
+ }
+ id, ip, field := splitNodeItemKey(enc)
+ if id != keytestID {
+ t.Errorf("splitNodeItemKey returned wrong ID: %v", id)
+ }
+ if !ip.Equal(wantIP) {
+ t.Errorf("splitNodeItemKey returned wrong IP: %v", ip)
+ }
+ if field != wantField {
+ t.Errorf("splitNodeItemKey returned wrong field: %q", field)
+ }
+}
+
+var nodeDBInt64Tests = []struct {
+ key []byte
+ value int64
+}{
+ {key: []byte{0x01}, value: 1},
+ {key: []byte{0x02}, value: 2},
+ {key: []byte{0x03}, value: 3},
+}
+
+func TestDBInt64(t *testing.T) {
+ db, _ := OpenDB("")
+ defer db.Close()
+
+ tests := nodeDBInt64Tests
+ for i := 0; i < len(tests); i++ {
+ // Insert the next value
+ if err := db.storeInt64(tests[i].key, tests[i].value); err != nil {
+ t.Errorf("test %d: failed to store value: %v", i, err)
+ }
+ // Check all existing and non existing values
+ for j := 0; j < len(tests); j++ {
+ num := db.fetchInt64(tests[j].key)
+ switch {
+ case j <= i && num != tests[j].value:
+ t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, tests[j].value)
+ case j > i && num != 0:
+ t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, 0)
+ }
+ }
+ }
+}
+
+func TestDBFetchStore(t *testing.T) {
+ node := NewV4(
+ hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ net.IP{192, 168, 0, 1},
+ 30303,
+ 30303,
+ )
+ inst := time.Now()
+ num := 314
+
+ db, _ := OpenDB("")
+ defer db.Close()
+
+ // Check fetch/store operations on a node ping object
+ if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != 0 {
+ t.Errorf("ping: non-existing object: %v", stored)
+ }
+ if err := db.UpdateLastPingReceived(node.ID(), node.IP(), inst); err != nil {
+ t.Errorf("ping: failed to update: %v", err)
+ }
+ if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() {
+ t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
+ }
+ // Check fetch/store operations on a node pong object
+ if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != 0 {
+ t.Errorf("pong: non-existing object: %v", stored)
+ }
+ if err := db.UpdateLastPongReceived(node.ID(), node.IP(), inst); err != nil {
+ t.Errorf("pong: failed to update: %v", err)
+ }
+ if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() {
+ t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
+ }
+ // Check fetch/store operations on a node findnode-failure object
+ if stored := db.FindFails(node.ID(), node.IP()); stored != 0 {
+ t.Errorf("find-node fails: non-existing object: %v", stored)
+ }
+ if err := db.UpdateFindFails(node.ID(), node.IP(), num); err != nil {
+ t.Errorf("find-node fails: failed to update: %v", err)
+ }
+ if stored := db.FindFails(node.ID(), node.IP()); stored != num {
+ t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num)
+ }
+ // Check fetch/store operations on an actual node object
+ if stored := db.Node(node.ID()); stored != nil {
+ t.Errorf("node: non-existing object: %v", stored)
+ }
+ if err := db.UpdateNode(node); err != nil {
+ t.Errorf("node: failed to update: %v", err)
+ }
+ if stored := db.Node(node.ID()); stored == nil {
+ t.Errorf("node: not found")
+ } else if !reflect.DeepEqual(stored, node) {
+ t.Errorf("node: data mismatch: have %v, want %v", stored, node)
+ }
+}
+
+var nodeDBSeedQueryNodes = []struct {
+ node *Node
+ pong time.Time
+}{
+ // This one should not be in the result set because its last
+ // pong time is too far in the past.
+ {
+ node: NewV4(
+ hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ net.IP{127, 0, 0, 3},
+ 30303,
+ 30303,
+ ),
+ pong: time.Now().Add(-3 * time.Hour),
+ },
+ // This one shouldn't be in the result set because its
+ // nodeID is the local node's ID.
+ {
+ node: NewV4(
+ hexPubkey("ff93ff820abacd4351b0f14e47b324bc82ff014c226f3f66a53535734a3c150e7e38ca03ef0964ba55acddc768f5e99cd59dea95ddd4defbab1339c92fa319b2"),
+ net.IP{127, 0, 0, 3},
+ 30303,
+ 30303,
+ ),
+ pong: time.Now().Add(-4 * time.Second),
+ },
+
+ // These should be in the result set.
+ {
+ node: NewV4(
+ hexPubkey("c2b5eb3f5dde05f815b63777809ee3e7e0cbb20035a6b00ce327191e6eaa8f26a8d461c9112b7ab94698e7361fa19fd647e603e73239002946d76085b6f928d6"),
+ net.IP{127, 0, 0, 1},
+ 30303,
+ 30303,
+ ),
+ pong: time.Now().Add(-2 * time.Second),
+ },
+ {
+ node: NewV4(
+ hexPubkey("6ca1d400c8ddf8acc94bcb0dd254911ad71a57bed5e0ae5aa205beed59b28c2339908e97990c493499613cff8ecf6c3dc7112a8ead220cdcd00d8847ca3db755"),
+ net.IP{127, 0, 0, 2},
+ 30303,
+ 30303,
+ ),
+ pong: time.Now().Add(-3 * time.Second),
+ },
+ {
+ node: NewV4(
+ hexPubkey("234dc63fe4d131212b38236c4c3411288d7bec61cbf7b120ff12c43dc60c96182882f4291d209db66f8a38e986c9c010ff59231a67f9515c7d1668b86b221a47"),
+ net.IP{127, 0, 0, 3},
+ 30303,
+ 30303,
+ ),
+ pong: time.Now().Add(-1 * time.Second),
+ },
+ {
+ node: NewV4(
+ hexPubkey("c013a50b4d1ebce5c377d8af8cb7114fd933ffc9627f96ad56d90fef5b7253ec736fd07ef9a81dc2955a997e54b7bf50afd0aa9f110595e2bec5bb7ce1657004"),
+ net.IP{127, 0, 0, 3},
+ 30303,
+ 30303,
+ ),
+ pong: time.Now().Add(-2 * time.Second),
+ },
+ {
+ node: NewV4(
+ hexPubkey("f141087e3e08af1aeec261ff75f48b5b1637f594ea9ad670e50051646b0416daa3b134c28788cbe98af26992a47652889cd8577ccc108ac02c6a664db2dc1283"),
+ net.IP{127, 0, 0, 3},
+ 30303,
+ 30303,
+ ),
+ pong: time.Now().Add(-2 * time.Second),
+ },
+}
+
+func TestDBSeedQuery(t *testing.T) {
+ // Querying seeds uses seeks an might not find all nodes
+ // every time when the database is small. Run the test multiple
+ // times to avoid flakes.
+ const attempts = 15
+ var err error
+ for i := 0; i < attempts; i++ {
+ if err = testSeedQuery(); err == nil {
+ return
+ }
+ }
+ if err != nil {
+ t.Errorf("no successful run in %d attempts: %v", attempts, err)
+ }
+}
+
+func testSeedQuery() error {
+ db, _ := OpenDB("")
+ defer db.Close()
+
+ // Insert a batch of nodes for querying
+ for i, seed := range nodeDBSeedQueryNodes {
+ if err := db.UpdateNode(seed.node); err != nil {
+ return fmt.Errorf("node %d: failed to insert: %v", i, err)
+ }
+ if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil {
+ return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err)
+ }
+ }
+
+ // Retrieve the entire batch and check for duplicates
+ seeds := db.QuerySeeds(len(nodeDBSeedQueryNodes)*2, time.Hour)
+ have := make(map[ID]struct{}, len(seeds))
+ for _, seed := range seeds {
+ have[seed.ID()] = struct{}{}
+ }
+ want := make(map[ID]struct{}, len(nodeDBSeedQueryNodes[1:]))
+ for _, seed := range nodeDBSeedQueryNodes[1:] {
+ want[seed.node.ID()] = struct{}{}
+ }
+ if len(seeds) != len(want) {
+ return fmt.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(want))
+ }
+ for id := range have {
+ if _, ok := want[id]; !ok {
+ return fmt.Errorf("extra seed: %v", id)
+ }
+ }
+ for id := range want {
+ if _, ok := have[id]; !ok {
+ return fmt.Errorf("missing seed: %v", id)
+ }
+ }
+ return nil
+}
+
+func TestDBPersistency(t *testing.T) {
+ root := t.TempDir()
+
+ var (
+ testKey = []byte("somekey")
+ testInt = int64(314)
+ )
+
+ // Create a persistent database and store some values
+ db, err := OpenDB(filepath.Join(root, "database"))
+ if err != nil {
+ t.Fatalf("failed to create persistent database: %v", err)
+ }
+ if err := db.storeInt64(testKey, testInt); err != nil {
+ t.Fatalf("failed to store value: %v.", err)
+ }
+ db.Close()
+
+ // Reopen the database and check the value
+ db, err = OpenDB(filepath.Join(root, "database"))
+ if err != nil {
+ t.Fatalf("failed to open persistent database: %v", err)
+ }
+ if val := db.fetchInt64(testKey); val != testInt {
+ t.Fatalf("value mismatch: have %v, want %v", val, testInt)
+ }
+ db.Close()
+}
+
+var nodeDBExpirationNodes = []struct {
+ node *Node
+ pong time.Time
+ storeNode bool
+ exp bool
+}{
+ // Node has new enough pong time and isn't expired:
+ {
+ node: NewV4(
+ hexPubkey("8d110e2ed4b446d9b5fb50f117e5f37fb7597af455e1dab0e6f045a6eeaa786a6781141659020d38bdc5e698ed3d4d2bafa8b5061810dfa63e8ac038db2e9b67"),
+ net.IP{127, 0, 0, 1},
+ 30303,
+ 30303,
+ ),
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration + time.Minute),
+ exp: false,
+ },
+ // Node with pong time before expiration is removed:
+ {
+ node: NewV4(
+ hexPubkey("913a205579c32425b220dfba999d215066e5bdbf900226b11da1907eae5e93eb40616d47412cf819664e9eacbdfcca6b0c6e07e09847a38472d4be46ab0c3672"),
+ net.IP{127, 0, 0, 2},
+ 30303,
+ 30303,
+ ),
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
+ },
+ // Just pong time, no node stored:
+ {
+ node: NewV4(
+ hexPubkey("b56670e0b6bad2c5dab9f9fe6f061a16cf78d68b6ae2cfda3144262d08d97ce5f46fd8799b6d1f709b1abe718f2863e224488bd7518e5e3b43809ac9bd1138ca"),
+ net.IP{127, 0, 0, 3},
+ 30303,
+ 30303,
+ ),
+ storeNode: false,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
+ },
+ // Node with multiple pong times, all older than expiration.
+ {
+ node: NewV4(
+ hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"),
+ net.IP{127, 0, 0, 4},
+ 30303,
+ 30303,
+ ),
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
+ },
+ {
+ node: NewV4(
+ hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"),
+ net.IP{127, 0, 0, 5},
+ 30303,
+ 30303,
+ ),
+ storeNode: false,
+ pong: time.Now().Add(-dbNodeExpiration - 2*time.Minute),
+ exp: true,
+ },
+ // Node with multiple pong times, one newer, one older than expiration.
+ {
+ node: NewV4(
+ hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"),
+ net.IP{127, 0, 0, 6},
+ 30303,
+ 30303,
+ ),
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration + time.Minute),
+ exp: false,
+ },
+ {
+ node: NewV4(
+ hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"),
+ net.IP{127, 0, 0, 7},
+ 30303,
+ 30303,
+ ),
+ storeNode: false,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
+ },
+}
+
+func TestDBExpiration(t *testing.T) {
+ db, _ := OpenDB("")
+ defer db.Close()
+
+ // Add all the test nodes and set their last pong time.
+ for i, seed := range nodeDBExpirationNodes {
+ if seed.storeNode {
+ if err := db.UpdateNode(seed.node); err != nil {
+ t.Fatalf("node %d: failed to insert: %v", i, err)
+ }
+ }
+ if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil {
+ t.Fatalf("node %d: failed to update bondTime: %v", i, err)
+ }
+ }
+
+ db.expireNodes()
+
+ // Check that expired entries have been removed.
+ unixZeroTime := time.Unix(0, 0)
+ for i, seed := range nodeDBExpirationNodes {
+ node := db.Node(seed.node.ID())
+ pong := db.LastPongReceived(seed.node.ID(), seed.node.IP())
+ if seed.exp {
+ if seed.storeNode && node != nil {
+ t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString())
+ }
+ if !pong.Equal(unixZeroTime) {
+ t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IP())
+ }
+ } else {
+ if seed.storeNode && node == nil {
+ t.Errorf("node %d (%s) should be present after expiration", i, seed.node.ID().TerminalString())
+ }
+ if !pong.Equal(seed.pong.Truncate(1 * time.Second)) {
+ t.Errorf("pong time %d (%s) should be %v after expiration, but is %v", i, seed.node.ID().TerminalString(), seed.pong, pong)
+ }
+ }
+ }
+}
+
+// This test checks that expiration works when discovery v5 data is present
+// in the database.
+func TestDBExpireV5(t *testing.T) {
+ db, _ := OpenDB("")
+ defer db.Close()
+
+ ip := net.IP{127, 0, 0, 1}
+ db.UpdateFindFailsV5(ID{}, ip, 4)
+ db.expireNodes()
+}
diff --git a/p2p/enode/urlv4.go b/p2p/enode/urlv4.go
new file mode 100644
index 000000000..204839dcc
--- /dev/null
+++ b/p2p/enode/urlv4.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "crypto/ecdsa"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "regexp"
+ "strconv"
+
+ "github.com/tomochain/tomochain/common/math"
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/p2p/enr"
+)
+
+var (
+ incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
+ lookupIPFunc = net.LookupIP
+)
+
+// MustParseV4 parses a node URL. It panics if the URL is not valid.
+func MustParseV4(rawurl string) *Node {
+ n, err := ParseV4(rawurl)
+ if err != nil {
+ panic("invalid node URL: " + err.Error())
+ }
+ return n
+}
+
+// ParseV4 parses a node URL.
+//
+// There are two basic forms of node URLs:
+//
+// - incomplete nodes, which only have the public key (node ID)
+// - complete nodes, which contain the public key and IP/Port information
+//
+// For incomplete nodes, the designator must look like one of these
+//
+// enode://
+//
+//
+// For complete nodes, the node ID is encoded in the username portion
+// of the URL, separated from the host by an @ sign. The hostname can
+// only be given as an IP address or using DNS domain name.
+// The port in the host name section is the TCP listening port. If the
+// TCP and UDP (discovery) ports differ, the UDP port is specified as
+// query parameter "discport".
+//
+// In the following example, the node URL describes
+// a node with IP address 10.3.58.6, TCP listening port 30303
+// and UDP discovery port 30301.
+//
+// enode://@10.3.58.6:30303?discport=30301
+func ParseV4(rawurl string) (*Node, error) {
+ if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
+ id, err := parsePubkey(m[1])
+ if err != nil {
+ return nil, fmt.Errorf("invalid public key (%v)", err)
+ }
+ return NewV4(id, nil, 0, 0), nil
+ }
+ return parseComplete(rawurl)
+}
+
+// NewV4 creates a node from discovery v4 node information. The record
+// contained in the node has a zero-length signature.
+func NewV4(pubkey *ecdsa.PublicKey, ip net.IP, tcp, udp int) *Node {
+ var r enr.Record
+ if len(ip) > 0 {
+ r.Set(enr.IP(ip))
+ }
+ if udp != 0 {
+ r.Set(enr.UDP(udp))
+ }
+ if tcp != 0 {
+ r.Set(enr.TCP(tcp))
+ }
+ signV4Compat(&r, pubkey)
+ n, err := New(v4CompatID{}, &r)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+// isNewV4 returns true for nodes created by NewV4.
+func isNewV4(n *Node) bool {
+ var k s256raw
+ return n.r.IdentityScheme() == "" && n.r.Load(&k) == nil && len(n.r.Signature()) == 0
+}
+
+func parseComplete(rawurl string) (*Node, error) {
+ var (
+ id *ecdsa.PublicKey
+ tcpPort, udpPort uint64
+ )
+ u, err := url.Parse(rawurl)
+ if err != nil {
+ return nil, err
+ }
+ if u.Scheme != "enode" {
+ return nil, errors.New("invalid URL scheme, want \"enode\"")
+ }
+ // Parse the Node ID from the user portion.
+ if u.User == nil {
+ return nil, errors.New("does not contain node ID")
+ }
+ if id, err = parsePubkey(u.User.String()); err != nil {
+ return nil, fmt.Errorf("invalid public key (%v)", err)
+ }
+ // Parse the IP address.
+ ip := net.ParseIP(u.Hostname())
+ if ip == nil {
+ ips, err := lookupIPFunc(u.Hostname())
+ if err != nil {
+ return nil, err
+ }
+ ip = ips[0]
+ }
+ // Ensure the IP is 4 bytes long for IPv4 addresses.
+ if ipv4 := ip.To4(); ipv4 != nil {
+ ip = ipv4
+ }
+ // Parse the port numbers.
+ if tcpPort, err = strconv.ParseUint(u.Port(), 10, 16); err != nil {
+ return nil, errors.New("invalid port")
+ }
+ udpPort = tcpPort
+ qv := u.Query()
+ if qv.Get("discport") != "" {
+ udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
+ if err != nil {
+ return nil, errors.New("invalid discport in query")
+ }
+ }
+ return NewV4(id, ip, int(tcpPort), int(udpPort)), nil
+}
+
+// parsePubkey parses a hex-encoded secp256k1 public key.
+func parsePubkey(in string) (*ecdsa.PublicKey, error) {
+ b, err := hex.DecodeString(in)
+ if err != nil {
+ return nil, err
+ } else if len(b) != 64 {
+ return nil, fmt.Errorf("wrong length, want %d hex chars", 128)
+ }
+ b = append([]byte{0x4}, b...)
+ return crypto.UnmarshalPubkey(b)
+}
+
+func (n *Node) URLv4() string {
+ var (
+ scheme enr.ID
+ nodeid string
+ key ecdsa.PublicKey
+ )
+ n.Load(&scheme)
+ n.Load((*Secp256k1)(&key))
+ switch {
+ case scheme == "v4" || key != ecdsa.PublicKey{}:
+ nodeid = fmt.Sprintf("%x", crypto.FromECDSAPub(&key)[1:])
+ default:
+ nodeid = fmt.Sprintf("%s.%x", scheme, n.id[:])
+ }
+ u := url.URL{Scheme: "enode"}
+ if n.Incomplete() {
+ u.Host = nodeid
+ } else {
+ addr := net.TCPAddr{IP: n.IP(), Port: n.TCP()}
+ u.User = url.User(nodeid)
+ u.Host = addr.String()
+ if n.UDP() != n.TCP() {
+ u.RawQuery = "discport=" + strconv.Itoa(n.UDP())
+ }
+ }
+ return u.String()
+}
+
+// PubkeyToIDV4 derives the v4 node address from the given public key.
+func PubkeyToIDV4(key *ecdsa.PublicKey) ID {
+ e := make([]byte, 64)
+ math.ReadBits(key.X, e[:len(e)/2])
+ math.ReadBits(key.Y, e[len(e)/2:])
+ return ID(crypto.Keccak256Hash(e))
+}
diff --git a/p2p/enode/urlv4_test.go b/p2p/enode/urlv4_test.go
new file mode 100644
index 000000000..f56d28632
--- /dev/null
+++ b/p2p/enode/urlv4_test.go
@@ -0,0 +1,200 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "crypto/ecdsa"
+ "errors"
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/p2p/enr"
+)
+
+func init() {
+ lookupIPFunc = func(name string) ([]net.IP, error) {
+ if name == "node.example.org" {
+ return []net.IP{{33, 44, 55, 66}}, nil
+ }
+ return nil, errors.New("no such host")
+ }
+}
+
+var parseNodeTests = []struct {
+ input string
+ wantError string
+ wantResult *Node
+}{
+ // Records
+ {
+ input: "enr:-IS4QGrdq0ugARp5T2BZ41TrZOqLc_oKvZoPuZP5--anqWE_J-Tucc1xgkOL7qXl0puJgT7qc2KSvcupc4NCb0nr4tdjgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQM6UUF2Rm-oFe1IH_rQkRCi00T2ybeMHRSvw1HDpRvjPYN1ZHCCdl8",
+ wantResult: func() *Node {
+ testKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8")
+ var r enr.Record
+ r.Set(enr.IP{127, 0, 0, 1})
+ r.Set(enr.UDP(30303))
+ r.SetSeq(99)
+ SignV4(&r, testKey)
+ n, _ := New(ValidSchemes, &r)
+ return n
+ }(),
+ },
+ // Invalid Records
+ {
+ input: "enr:",
+ wantError: "EOF", // could be nicer
+ },
+ {
+ input: "enr:x",
+ wantError: "illegal base64 data at input byte 0",
+ },
+ {
+ input: "enr:-EmGZm9vYmFyY4JpZIJ2NIJpcIR_AAABiXNlY3AyNTZrMaEDOlFBdkZvqBXtSB_60JEQotNE9sm3jB0Ur8NRw6Ub4z2DdWRwgnZf",
+ wantError: enr.ErrInvalidSig.Error(),
+ },
+ // Complete node URLs with IP address and ports
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@invalid.:3",
+ wantError: `no such host`,
+ },
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
+ wantError: `invalid port`,
+ },
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
+ wantError: `invalid discport in query`,
+ },
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
+ wantResult: NewV4(
+ hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ net.IP{127, 0, 0, 1},
+ 52150,
+ 52150,
+ ),
+ },
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
+ wantResult: NewV4(
+ hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ net.ParseIP("::"),
+ 52150,
+ 52150,
+ ),
+ },
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[2001:db8:3c4d:15::abcd:ef12]:52150",
+ wantResult: NewV4(
+ hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ net.ParseIP("2001:db8:3c4d:15::abcd:ef12"),
+ 52150,
+ 52150,
+ ),
+ },
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=22334",
+ wantResult: NewV4(
+ hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ net.IP{0x7f, 0x0, 0x0, 0x1},
+ 52150,
+ 22334,
+ ),
+ },
+ // Incomplete node URLs with no address
+ {
+ input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439",
+ wantResult: NewV4(
+ hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ nil, 0, 0,
+ ),
+ },
+ // Invalid URLs
+ {
+ input: "",
+ wantError: errMissingPrefix.Error(),
+ },
+ {
+ input: "1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439",
+ wantError: errMissingPrefix.Error(),
+ },
+ {
+ input: "01010101",
+ wantError: errMissingPrefix.Error(),
+ },
+ {
+ input: "enode://01010101@123.124.125.126:3",
+ wantError: `invalid public key (wrong length, want 128 hex chars)`,
+ },
+ {
+ input: "enode://01010101",
+ wantError: `invalid public key (wrong length, want 128 hex chars)`,
+ },
+ {
+ input: "http://foobar",
+ wantError: errMissingPrefix.Error(),
+ },
+ {
+ input: "://foo",
+ wantError: errMissingPrefix.Error(),
+ },
+}
+
+func hexPubkey(h string) *ecdsa.PublicKey {
+ k, err := parsePubkey(h)
+ if err != nil {
+ panic(err)
+ }
+ return k
+}
+
+func TestParseNode(t *testing.T) {
+ for _, test := range parseNodeTests {
+ n, err := Parse(ValidSchemes, test.input)
+ if test.wantError != "" {
+ if err == nil {
+ t.Errorf("test %q:\n got nil error, expected %#q", test.input, test.wantError)
+ continue
+ } else if !strings.Contains(err.Error(), test.wantError) {
+ t.Errorf("test %q:\n got error %#q, expected %#q", test.input, err.Error(), test.wantError)
+ continue
+ }
+ } else {
+ if err != nil {
+ t.Errorf("test %q:\n unexpected error: %v", test.input, err)
+ continue
+ }
+ if !reflect.DeepEqual(n, test.wantResult) {
+ t.Errorf("test %q:\n result mismatch:\ngot: %#v\nwant: %#v", test.input, n, test.wantResult)
+ }
+ }
+ }
+}
+
+func TestNodeString(t *testing.T) {
+ for i, test := range parseNodeTests {
+ if test.wantError == "" && strings.HasPrefix(test.input, "enode://") {
+ str := test.wantResult.String()
+ if str != test.input {
+ t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.input)
+ }
+ }
+ }
+}
diff --git a/p2p/enr/enr.go b/p2p/enr/enr.go
index fabd08ae2..81cdffdf1 100644
--- a/p2p/enr/enr.go
+++ b/p2p/enr/enr.go
@@ -29,33 +29,53 @@ package enr
import (
"bytes"
- "crypto/ecdsa"
"errors"
"fmt"
"io"
"sort"
- "github.com/tomochain/tomochain/crypto"
- "github.com/tomochain/tomochain/crypto/sha3"
"github.com/tomochain/tomochain/rlp"
)
const SizeLimit = 300 // maximum encoded size of a node record in bytes
-const ID_SECP256k1_KECCAK = ID("secp256k1-keccak") // the default identity scheme
-
var (
- errNoID = errors.New("unknown or unspecified identity scheme")
- errInvalidSigsize = errors.New("invalid signature size")
- errInvalidSig = errors.New("invalid signature")
+ ErrInvalidSig = errors.New("invalid signature on node record")
errNotSorted = errors.New("record key/value pairs are not sorted by key")
errDuplicateKey = errors.New("record contains duplicate key")
errIncompletePair = errors.New("record contains incomplete k/v pair")
+ errIncompleteList = errors.New("record contains less than two list elements")
errTooBig = fmt.Errorf("record bigger than %d bytes", SizeLimit)
errEncodeUnsigned = errors.New("can't encode unsigned record")
errNotFound = errors.New("no such key in record")
)
+// An IdentityScheme is capable of verifying record signatures and
+// deriving node addresses.
+type IdentityScheme interface {
+ Verify(r *Record, sig []byte) error
+ NodeAddr(r *Record) []byte
+}
+
+// SchemeMap is a registry of named identity schemes.
+type SchemeMap map[string]IdentityScheme
+
+func (m SchemeMap) Verify(r *Record, sig []byte) error {
+ s := m[r.IdentityScheme()]
+ if s == nil {
+ return ErrInvalidSig
+ }
+ return s.Verify(r, sig)
+}
+
+func (m SchemeMap) NodeAddr(r *Record) []byte {
+ s := m[r.IdentityScheme()]
+ if s == nil {
+ return nil
+ }
+ return s.NodeAddr(r)
+}
+
// Record represents a node record. The zero value is an empty record.
type Record struct {
seq uint64 // sequence number
@@ -70,9 +90,22 @@ type pair struct {
v rlp.RawValue
}
-// Signed reports whether the record has a valid signature.
-func (r *Record) Signed() bool {
- return r.signature != nil
+// Size returns the encoded size of the record.
+func (r *Record) Size() uint64 {
+ if r.raw != nil {
+ return uint64(len(r.raw))
+ }
+ return computeSize(r)
+}
+
+func computeSize(r *Record) uint64 {
+ size := uint64(rlp.IntSize(r.seq))
+ size += rlp.BytesSize(r.signature)
+ for _, p := range r.pairs {
+ size += rlp.StringSize(p.k)
+ size += uint64(len(p.v))
+ }
+ return rlp.ListSize(size)
}
// Seq returns the sequence number.
@@ -81,8 +114,8 @@ func (r *Record) Seq() uint64 {
}
// SetSeq updates the record sequence number. This invalidates any signature on the record.
-// Calling SetSeq is usually not required because signing the redord increments the
-// sequence number.
+// Calling SetSeq is usually not required because setting any key in a signed record
+// increments the sequence number.
func (r *Record) SetSeq(s uint64) {
r.signature = nil
r.raw = nil
@@ -105,66 +138,100 @@ func (r *Record) Load(e Entry) error {
return &KeyError{Key: e.ENRKey(), Err: errNotFound}
}
-// Set adds or updates the given entry in the record.
-// It panics if the value can't be encoded.
+// Set adds or updates the given entry in the record. It panics if the value can't be
+// encoded. If the record is signed, Set increments the sequence number and invalidates
+// the sequence number.
func (r *Record) Set(e Entry) {
- r.signature = nil
- r.raw = nil
blob, err := rlp.EncodeToBytes(e)
if err != nil {
panic(fmt.Errorf("enr: can't encode %s: %v", e.ENRKey(), err))
}
+ r.invalidate()
- i := sort.Search(len(r.pairs), func(i int) bool { return r.pairs[i].k >= e.ENRKey() })
-
- if i < len(r.pairs) && r.pairs[i].k == e.ENRKey() {
+ pairs := make([]pair, len(r.pairs))
+ copy(pairs, r.pairs)
+ i := sort.Search(len(pairs), func(i int) bool { return pairs[i].k >= e.ENRKey() })
+ switch {
+ case i < len(pairs) && pairs[i].k == e.ENRKey():
// element is present at r.pairs[i]
- r.pairs[i].v = blob
- return
- } else if i < len(r.pairs) {
+ pairs[i].v = blob
+ case i < len(r.pairs):
// insert pair before i-th elem
el := pair{e.ENRKey(), blob}
- r.pairs = append(r.pairs, pair{})
- copy(r.pairs[i+1:], r.pairs[i:])
- r.pairs[i] = el
- return
+ pairs = append(pairs, pair{})
+ copy(pairs[i+1:], pairs[i:])
+ pairs[i] = el
+ default:
+ // element should be placed at the end of r.pairs
+ pairs = append(pairs, pair{e.ENRKey(), blob})
}
+ r.pairs = pairs
+}
+
+func (r *Record) invalidate() {
+ if r.signature != nil {
+ r.seq++
+ }
+ r.signature = nil
+ r.raw = nil
+}
- // element should be placed at the end of r.pairs
- r.pairs = append(r.pairs, pair{e.ENRKey(), blob})
+// Signature returns the signature of the record.
+func (r *Record) Signature() []byte {
+ if r.signature == nil {
+ return nil
+ }
+ cpy := make([]byte, len(r.signature))
+ copy(cpy, r.signature)
+ return cpy
}
// EncodeRLP implements rlp.Encoder. Encoding fails if
// the record is unsigned.
func (r Record) EncodeRLP(w io.Writer) error {
- if !r.Signed() {
+ if r.signature == nil {
return errEncodeUnsigned
}
_, err := w.Write(r.raw)
return err
}
-// DecodeRLP implements rlp.Decoder. Decoding verifies the signature.
+// DecodeRLP implements rlp.Decoder. Decoding doesn't verify the signature.
func (r *Record) DecodeRLP(s *rlp.Stream) error {
- raw, err := s.Raw()
+ dec, raw, err := decodeRecord(s)
if err != nil {
return err
}
+ *r = dec
+ r.raw = raw
+ return nil
+}
+
+func decodeRecord(s *rlp.Stream) (dec Record, raw []byte, err error) {
+ raw, err = s.Raw()
+ if err != nil {
+ return dec, raw, err
+ }
if len(raw) > SizeLimit {
- return errTooBig
+ return dec, raw, errTooBig
}
// Decode the RLP container.
- dec := Record{raw: raw}
s = rlp.NewStream(bytes.NewReader(raw), 0)
if _, err := s.List(); err != nil {
- return err
+ return dec, raw, err
}
if err = s.Decode(&dec.signature); err != nil {
- return err
+ if err == rlp.EOL {
+ err = errIncompleteList
+ }
+ return dec, raw, err
}
if err = s.Decode(&dec.seq); err != nil {
- return err
+ if err == rlp.EOL {
+ err = errIncompleteList
+ }
+ return dec, raw, err
}
// The rest of the record contains sorted k/v pairs.
var prevkey string
@@ -174,62 +241,73 @@ func (r *Record) DecodeRLP(s *rlp.Stream) error {
if err == rlp.EOL {
break
}
- return err
+ return dec, raw, err
}
if err := s.Decode(&kv.v); err != nil {
if err == rlp.EOL {
- return errIncompletePair
+ return dec, raw, errIncompletePair
}
- return err
+ return dec, raw, err
}
if i > 0 {
if kv.k == prevkey {
- return errDuplicateKey
+ return dec, raw, errDuplicateKey
}
if kv.k < prevkey {
- return errNotSorted
+ return dec, raw, errNotSorted
}
}
dec.pairs = append(dec.pairs, kv)
prevkey = kv.k
}
- if err := s.ListEnd(); err != nil {
- return err
- }
-
- // Verify signature.
- if err = dec.verifySignature(); err != nil {
- return err
- }
- *r = dec
- return nil
+ return dec, raw, s.ListEnd()
}
-type s256raw []byte
-
-func (s256raw) ENRKey() string { return "secp256k1" }
+// IdentityScheme returns the name of the identity scheme in the record.
+func (r *Record) IdentityScheme() string {
+ var id ID
+ r.Load(&id)
+ return string(id)
+}
-// NodeAddr returns the node address. The return value will be nil if the record is
-// unsigned.
-func (r *Record) NodeAddr() []byte {
- var entry s256raw
- if r.Load(&entry) != nil {
- return nil
- }
- return crypto.Keccak256(entry)
+// VerifySignature checks whether the record is signed using the given identity scheme.
+func (r *Record) VerifySignature(s IdentityScheme) error {
+ return s.Verify(r, r.signature)
}
-// Sign signs the record with the given private key. It updates the record's identity
-// scheme, public key and increments the sequence number. Sign returns an error if the
-// encoded record is larger than the size limit.
-func (r *Record) Sign(privkey *ecdsa.PrivateKey) error {
- r.seq = r.seq + 1
- r.Set(ID_SECP256k1_KECCAK)
- r.Set(Secp256k1(privkey.PublicKey))
- return r.signAndEncode(privkey)
+// SetSig sets the record signature. It returns an error if the encoded record is larger
+// than the size limit or if the signature is invalid according to the passed scheme.
+//
+// You can also use SetSig to remove the signature explicitly by passing a nil scheme
+// and signature.
+//
+// SetSig panics when either the scheme or the signature (but not both) are nil.
+func (r *Record) SetSig(s IdentityScheme, sig []byte) error {
+ switch {
+ // Prevent storing invalid data.
+ case s == nil && sig != nil:
+ panic("enr: invalid call to SetSig with non-nil signature but nil scheme")
+ case s != nil && sig == nil:
+ panic("enr: invalid call to SetSig with nil signature but non-nil scheme")
+ // Verify if we have a scheme.
+ case s != nil:
+ if err := s.Verify(r, sig); err != nil {
+ return err
+ }
+ raw, err := r.encode(sig)
+ if err != nil {
+ return err
+ }
+ r.signature, r.raw = sig, raw
+ // Reset otherwise.
+ default:
+ r.signature, r.raw = nil, nil
+ }
+ return nil
}
-func (r *Record) appendPairs(list []interface{}) []interface{} {
+// AppendElements appends the sequence number and entries to the given slice.
+func (r *Record) AppendElements(list []interface{}) []interface{} {
list = append(list, r.seq)
for _, p := range r.pairs {
list = append(list, p.k, p.v)
@@ -237,54 +315,15 @@ func (r *Record) appendPairs(list []interface{}) []interface{} {
return list
}
-func (r *Record) signAndEncode(privkey *ecdsa.PrivateKey) error {
- // Put record elements into a flat list. Leave room for the signature.
- list := make([]interface{}, 1, len(r.pairs)*2+2)
- list = r.appendPairs(list)
-
- // Sign the tail of the list.
- h := sha3.NewKeccak256()
- rlp.Encode(h, list[1:])
- sig, err := crypto.Sign(h.Sum(nil), privkey)
- if err != nil {
- return err
- }
- sig = sig[:len(sig)-1] // remove v
-
- // Put signature in front.
- r.signature, list[0] = sig, sig
- r.raw, err = rlp.EncodeToBytes(list)
- if err != nil {
- return err
- }
- if len(r.raw) > SizeLimit {
- return errTooBig
- }
- return nil
-}
-
-func (r *Record) verifySignature() error {
- // Get identity scheme, public key, signature.
- var id ID
- var entry s256raw
- if err := r.Load(&id); err != nil {
- return err
- } else if id != ID_SECP256k1_KECCAK {
- return errNoID
+func (r *Record) encode(sig []byte) (raw []byte, err error) {
+ list := make([]interface{}, 1, 2*len(r.pairs)+2)
+ list[0] = sig
+ list = r.AppendElements(list)
+ if raw, err = rlp.EncodeToBytes(list); err != nil {
+ return nil, err
}
- if err := r.Load(&entry); err != nil {
- return err
- } else if len(entry) != 33 {
- return fmt.Errorf("invalid public key")
- }
-
- // Verify the signature.
- list := make([]interface{}, 0, len(r.pairs)*2+1)
- list = r.appendPairs(list)
- h := sha3.NewKeccak256()
- rlp.Encode(h, list)
- if !crypto.VerifySignature(entry, h.Sum(nil), r.signature) {
- return errInvalidSig
+ if len(raw) > SizeLimit {
+ return nil, errTooBig
}
- return nil
+ return raw, nil
}
diff --git a/p2p/enr/enr_test.go b/p2p/enr/enr_test.go
index bba1738bc..8ea78fd9f 100644
--- a/p2p/enr/enr_test.go
+++ b/p2p/enr/enr_test.go
@@ -18,7 +18,7 @@ package enr
import (
"bytes"
- "encoding/hex"
+ "encoding/binary"
"fmt"
"math/rand"
"testing"
@@ -26,13 +26,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tomochain/tomochain/crypto"
- "github.com/tomochain/tomochain/rlp"
-)
-var (
- privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
- pubkey = &privkey.PublicKey
+ "github.com/tomochain/tomochain/rlp"
)
var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
@@ -54,63 +49,51 @@ func TestGetSetID(t *testing.T) {
assert.Equal(t, id, id2)
}
-// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP4 key.
-func TestGetSetIP4(t *testing.T) {
- ip := IP4{192, 168, 0, 3}
+// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key.
+func TestGetSetIPv4(t *testing.T) {
+ ip := IPv4{192, 168, 0, 3}
var r Record
r.Set(ip)
- var ip2 IP4
+ var ip2 IPv4
require.NoError(t, r.Load(&ip2))
assert.Equal(t, ip, ip2)
}
// TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key.
-func TestGetSetIP6(t *testing.T) {
- ip := IP6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
+func TestGetSetIPv6(t *testing.T) {
+ ip := IPv6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
var r Record
r.Set(ip)
- var ip2 IP6
+ var ip2 IPv6
require.NoError(t, r.Load(&ip2))
assert.Equal(t, ip, ip2)
}
-// TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key.
-func TestGetSetDiscPort(t *testing.T) {
- port := DiscPort(30309)
+// TestGetSetUDP tests encoding/decoding and setting/getting of the UDP key.
+func TestGetSetUDP(t *testing.T) {
+ port := UDP(30309)
var r Record
r.Set(port)
- var port2 DiscPort
+ var port2 UDP
require.NoError(t, r.Load(&port2))
assert.Equal(t, port, port2)
}
-// TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key.
-func TestGetSetSecp256k1(t *testing.T) {
- var r Record
- if err := r.Sign(privkey); err != nil {
- t.Fatal(err)
- }
-
- var pk Secp256k1
- require.NoError(t, r.Load(&pk))
- assert.EqualValues(t, pubkey, &pk)
-}
-
func TestLoadErrors(t *testing.T) {
var r Record
- ip4 := IP4{127, 0, 0, 1}
+ ip4 := IPv4{127, 0, 0, 1}
r.Set(ip4)
// Check error for missing keys.
- var ip6 IP6
- err := r.Load(&ip6)
+ var udp UDP
+ err := r.Load(&udp)
if !IsNotFound(err) {
t.Error("IsNotFound should return true for missing key")
}
- assert.Equal(t, &KeyError{Key: ip6.ENRKey(), Err: errNotFound}, err)
+ assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err)
// Check error for invalid keys.
var list []uint
@@ -167,40 +150,75 @@ func TestSortedGetAndSet(t *testing.T) {
func TestDirty(t *testing.T) {
var r Record
- if r.Signed() {
- t.Error("Signed returned true for zero record")
- }
if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
t.Errorf("expected errEncodeUnsigned, got %#v", err)
}
- require.NoError(t, r.Sign(privkey))
- if !r.Signed() {
- t.Error("Signed return false for signed record")
+ require.NoError(t, signTest([]byte{5}, &r))
+ if len(r.signature) == 0 {
+ t.Error("record is not signed")
}
_, err := rlp.EncodeToBytes(r)
assert.NoError(t, err)
r.SetSeq(3)
- if r.Signed() {
- t.Error("Signed returned true for modified record")
+ if len(r.signature) != 0 {
+ t.Error("signature still set after modification")
}
if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
t.Errorf("expected errEncodeUnsigned, got %#v", err)
}
}
+func TestSize(t *testing.T) {
+ var r Record
+
+ // Empty record size is 3 bytes.
+ // Unsigned records cannot be encoded, but they could, the encoding
+ // would be [ 0, 0 ] -> 0xC28080.
+ assert.Equal(t, uint64(3), r.Size())
+
+ // Add one attribute. The size increases to 5, the encoding
+ // would be [ 0, 0, "k", "v" ] -> 0xC58080C26B76.
+ r.Set(WithEntry("k", "v"))
+ assert.Equal(t, uint64(5), r.Size())
+
+ // Now add a signature.
+ nodeid := []byte{1, 2, 3, 4, 5, 6, 7, 8}
+ signTest(nodeid, &r)
+ assert.Equal(t, uint64(45), r.Size())
+ enc, _ := rlp.EncodeToBytes(&r)
+ if r.Size() != uint64(len(enc)) {
+ t.Error("Size() not equal encoded length", len(enc))
+ }
+ if r.Size() != computeSize(&r) {
+ t.Error("Size() not equal computed size", computeSize(&r))
+ }
+}
+
+func TestSeq(t *testing.T) {
+ var r Record
+
+ assert.Equal(t, uint64(0), r.Seq())
+ r.Set(UDP(1))
+ assert.Equal(t, uint64(0), r.Seq())
+ signTest([]byte{5}, &r)
+ assert.Equal(t, uint64(0), r.Seq())
+ r.Set(UDP(2))
+ assert.Equal(t, uint64(1), r.Seq())
+}
+
// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
func TestGetSetOverwrite(t *testing.T) {
var r Record
- ip := IP4{192, 168, 0, 3}
+ ip := IPv4{192, 168, 0, 3}
r.Set(ip)
- ip2 := IP4{192, 168, 0, 4}
+ ip2 := IPv4{192, 168, 0, 4}
r.Set(ip2)
- var ip3 IP4
+ var ip3 IPv4
require.NoError(t, r.Load(&ip3))
assert.Equal(t, ip2, ip3)
}
@@ -208,9 +226,9 @@ func TestGetSetOverwrite(t *testing.T) {
// TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record.
func TestSignEncodeAndDecode(t *testing.T) {
var r Record
- r.Set(DiscPort(30303))
- r.Set(IP4{127, 0, 0, 1})
- require.NoError(t, r.Sign(privkey))
+ r.Set(UDP(30303))
+ r.Set(IPv4{127, 0, 0, 1})
+ require.NoError(t, signTest([]byte{5}, &r))
blob, err := rlp.EncodeToBytes(r)
require.NoError(t, err)
@@ -224,62 +242,43 @@ func TestSignEncodeAndDecode(t *testing.T) {
assert.Equal(t, blob, blob2)
}
-func TestNodeAddr(t *testing.T) {
- var r Record
- if addr := r.NodeAddr(); addr != nil {
- t.Errorf("wrong address on empty record: got %v, want %v", addr, nil)
- }
-
- require.NoError(t, r.Sign(privkey))
- expected := "caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726"
- assert.Equal(t, expected, hex.EncodeToString(r.NodeAddr()))
-}
-
-var pyRecord, _ = hex.DecodeString("f896b840954dc36583c1f4b69ab59b1375f362f06ee99f3723cd77e64b6de6d211c27d7870642a79d4516997f94091325d2a7ca6215376971455fb221d34f35b277149a1018664697363763582765f82696490736563703235366b312d6b656363616b83697034847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138")
-
-// TestPythonInterop checks that we can decode and verify a record produced by the Python
-// implementation.
-func TestPythonInterop(t *testing.T) {
- var r Record
- if err := rlp.DecodeBytes(pyRecord, &r); err != nil {
- t.Fatalf("can't decode: %v", err)
- }
-
- var (
- wantAddr, _ = hex.DecodeString("caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726")
- wantSeq = uint64(1)
- wantIP = IP4{127, 0, 0, 1}
- wantDiscport = DiscPort(30303)
- )
- if r.Seq() != wantSeq {
- t.Errorf("wrong seq: got %d, want %d", r.Seq(), wantSeq)
- }
- if addr := r.NodeAddr(); !bytes.Equal(addr, wantAddr) {
- t.Errorf("wrong addr: got %x, want %x", addr, wantAddr)
- }
- want := map[Entry]interface{}{new(IP4): &wantIP, new(DiscPort): &wantDiscport}
- for k, v := range want {
- desc := fmt.Sprintf("loading key %q", k.ENRKey())
- if assert.NoError(t, r.Load(k), desc) {
- assert.Equal(t, k, v, desc)
- }
- }
-}
-
// TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed.
func TestRecordTooBig(t *testing.T) {
var r Record
key := randomString(10)
// set a big value for random key, expect error
- r.Set(WithEntry(key, randomString(300)))
- if err := r.Sign(privkey); err != errTooBig {
+ r.Set(WithEntry(key, randomString(SizeLimit)))
+ if err := signTest([]byte{5}, &r); err != errTooBig {
t.Fatalf("expected to get errTooBig, got %#v", err)
}
// set an acceptable value for random key, expect no error
r.Set(WithEntry(key, randomString(100)))
- require.NoError(t, r.Sign(privkey))
+ require.NoError(t, signTest([]byte{5}, &r))
+}
+
+// This checks that incomplete RLP inputs are handled correctly.
+func TestDecodeIncomplete(t *testing.T) {
+ type decTest struct {
+ input []byte
+ err error
+ }
+ tests := []decTest{
+ {[]byte{0xC0}, errIncompleteList},
+ {[]byte{0xC1, 0x1}, errIncompleteList},
+ {[]byte{0xC2, 0x1, 0x2}, nil},
+ {[]byte{0xC3, 0x1, 0x2, 0x3}, errIncompletePair},
+ {[]byte{0xC4, 0x1, 0x2, 0x3, 0x4}, nil},
+ {[]byte{0xC5, 0x1, 0x2, 0x3, 0x4, 0x5}, errIncompletePair},
+ }
+ for _, test := range tests {
+ var r Record
+ err := rlp.DecodeBytes(test.input, &r)
+ if err != test.err {
+ t.Errorf("wrong error for %X: %v", test.input, err)
+ }
+ }
}
// TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs.
@@ -295,9 +294,12 @@ func TestSignEncodeAndDecodeRandom(t *testing.T) {
r.Set(WithEntry(key, &value))
}
- require.NoError(t, r.Sign(privkey))
- _, err := rlp.EncodeToBytes(r)
+ require.NoError(t, signTest([]byte{5}, &r))
+
+ enc, err := rlp.EncodeToBytes(r)
require.NoError(t, err)
+ require.Equal(t, uint64(len(enc)), r.Size())
+ require.Equal(t, uint64(len(enc)), computeSize(&r))
for k, v := range pairs {
desc := fmt.Sprintf("key %q", k)
@@ -308,11 +310,40 @@ func TestSignEncodeAndDecodeRandom(t *testing.T) {
}
}
-func BenchmarkDecode(b *testing.B) {
- var r Record
- for i := 0; i < b.N; i++ {
- rlp.DecodeBytes(pyRecord, &r)
+type testSig struct{}
+
+type testID []byte
+
+func (id testID) ENRKey() string { return "testid" }
+
+func signTest(id []byte, r *Record) error {
+ r.Set(ID("test"))
+ r.Set(testID(id))
+ return r.SetSig(testSig{}, makeTestSig(id, r.Seq()))
+}
+
+func makeTestSig(id []byte, seq uint64) []byte {
+ sig := make([]byte, 8, len(id)+8)
+ binary.BigEndian.PutUint64(sig[:8], seq)
+ sig = append(sig, id...)
+ return sig
+}
+
+func (testSig) Verify(r *Record, sig []byte) error {
+ var id []byte
+ if err := r.Load((*testID)(&id)); err != nil {
+ return err
+ }
+ if !bytes.Equal(sig, makeTestSig(id, r.Seq())) {
+ return ErrInvalidSig
+ }
+ return nil
+}
+
+func (testSig) NodeAddr(r *Record) []byte {
+ var id []byte
+ if err := r.Load((*testID)(&id)); err != nil {
+ return nil
}
- b.StopTimer()
- r.NodeAddr()
+ return id
}
diff --git a/p2p/enr/entries.go b/p2p/enr/entries.go
index e31a4901a..f68e7725d 100644
--- a/p2p/enr/entries.go
+++ b/p2p/enr/entries.go
@@ -62,27 +62,83 @@ type DiscPort uint16
func (v DiscPort) ENRKey() string { return "discv5" }
+// TCP is the "tcp" key, which holds the TCP port of the node.
+type TCP uint16
+
+func (v TCP) ENRKey() string { return "tcp" }
+
+// TCP6 is the "tcp6" key, which holds the IPv6-specific tcp6 port of the node.
+type TCP6 uint16
+
+func (v TCP6) ENRKey() string { return "tcp6" }
+
+// UDP is the "udp" key, which holds the UDP port of the node.
+type UDP uint16
+
+func (v UDP) ENRKey() string { return "udp" }
+
+// UDP6 is the "udp6" key, which holds the IPv6-specific UDP port of the node.
+type UDP6 uint16
+
+func (v UDP6) ENRKey() string { return "udp6" }
+
// ID is the "id" key, which holds the name of the identity scheme.
type ID string
+const IDv4 = ID("v4") // the default identity scheme
+
func (v ID) ENRKey() string { return "id" }
-// IP4 is the "ip4" key, which holds a 4-byte IPv4 address.
-type IP4 net.IP
+// IP is either the "ip" or "ip6" key, depending on the value.
+// Use this value to encode IP addresses that can be either v4 or v6.
+// To load an address from a record use the IPv4 or IPv6 types.
+type IP net.IP
-func (v IP4) ENRKey() string { return "ip4" }
+func (v IP) ENRKey() string {
+ if net.IP(v).To4() == nil {
+ return "ip6"
+ }
+ return "ip"
+}
// EncodeRLP implements rlp.Encoder.
-func (v IP4) EncodeRLP(w io.Writer) error {
+func (v IP) EncodeRLP(w io.Writer) error {
+ if ip4 := net.IP(v).To4(); ip4 != nil {
+ return rlp.Encode(w, ip4)
+ }
+ if ip6 := net.IP(v).To16(); ip6 != nil {
+ return rlp.Encode(w, ip6)
+ }
+ return fmt.Errorf("invalid IP address: %v", net.IP(v))
+}
+
+// DecodeRLP implements rlp.Decoder.
+func (v *IP) DecodeRLP(s *rlp.Stream) error {
+ if err := s.Decode((*net.IP)(v)); err != nil {
+ return err
+ }
+ if len(*v) != 4 && len(*v) != 16 {
+ return fmt.Errorf("invalid IP address, want 4 or 16 bytes: %v", *v)
+ }
+ return nil
+}
+
+// IPv4 is the "ip" key, which holds the IP address of the node.
+type IPv4 net.IP
+
+func (v IPv4) ENRKey() string { return "ip" }
+
+// EncodeRLP implements rlp.Encoder.
+func (v IPv4) EncodeRLP(w io.Writer) error {
ip4 := net.IP(v).To4()
if ip4 == nil {
- return fmt.Errorf("invalid IPv4 address: %v", v)
+ return fmt.Errorf("invalid IPv4 address: %v", net.IP(v))
}
return rlp.Encode(w, ip4)
}
// DecodeRLP implements rlp.Decoder.
-func (v *IP4) DecodeRLP(s *rlp.Stream) error {
+func (v *IPv4) DecodeRLP(s *rlp.Stream) error {
if err := s.Decode((*net.IP)(v)); err != nil {
return err
}
@@ -92,19 +148,22 @@ func (v *IP4) DecodeRLP(s *rlp.Stream) error {
return nil
}
-// IP6 is the "ip6" key, which holds a 16-byte IPv6 address.
-type IP6 net.IP
+// IPv6 is the "ip6" key, which holds the IP address of the node.
+type IPv6 net.IP
-func (v IP6) ENRKey() string { return "ip6" }
+func (v IPv6) ENRKey() string { return "ip6" }
// EncodeRLP implements rlp.Encoder.
-func (v IP6) EncodeRLP(w io.Writer) error {
- ip6 := net.IP(v)
+func (v IPv6) EncodeRLP(w io.Writer) error {
+ ip6 := net.IP(v).To16()
+ if ip6 == nil {
+ return fmt.Errorf("invalid IPv6 address: %v", net.IP(v))
+ }
return rlp.Encode(w, ip6)
}
// DecodeRLP implements rlp.Decoder.
-func (v *IP6) DecodeRLP(s *rlp.Stream) error {
+func (v *IPv6) DecodeRLP(s *rlp.Stream) error {
if err := s.Decode((*net.IP)(v)); err != nil {
return err
}
diff --git a/p2p/message.go b/p2p/message.go
index c92b31581..b29a82c7b 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -26,7 +26,7 @@ import (
"time"
"github.com/tomochain/tomochain/event"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rlp"
)
@@ -100,12 +100,11 @@ func Send(w MsgWriter, msgcode uint64, data interface{}) error {
// SendItems writes an RLP with the given code and data elements.
// For a call such as:
//
-// SendItems(w, code, e1, e2, e3)
+// SendItems(w, code, e1, e2, e3)
//
// the message payload will be an RLP list containing the items:
//
-// [e1, e2, e3]
-//
+// [e1, e2, e3]
func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error {
return Send(w, msgcode, elems)
}
@@ -254,13 +253,13 @@ type msgEventer struct {
MsgReadWriter
feed *event.Feed
- peerID discover.NodeID
+ peerID enode.ID
Protocol string
}
// newMsgEventer returns a msgEventer which sends message events to the given
// feed
-func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID discover.NodeID, proto string) *msgEventer {
+func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID enode.ID, proto string) *msgEventer {
return &msgEventer{
MsgReadWriter: rw,
feed: feed,
diff --git a/p2p/netutil/iptrack.go b/p2p/netutil/iptrack.go
new file mode 100644
index 000000000..a8660a4d7
--- /dev/null
+++ b/p2p/netutil/iptrack.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package netutil
+
+import (
+ "time"
+
+ "github.com/tomochain/tomochain/common/mclock"
+)
+
+// IPTracker predicts the external endpoint, i.e. IP address and port, of the local host
+// based on statements made by other hosts.
+type IPTracker struct {
+ window time.Duration
+ contactWindow time.Duration
+ minStatements int
+ clock mclock.Clock
+ statements map[string]ipStatement
+ contact map[string]mclock.AbsTime
+ lastStatementGC mclock.AbsTime
+ lastContactGC mclock.AbsTime
+}
+
+type ipStatement struct {
+ endpoint string
+ time mclock.AbsTime
+}
+
+// NewIPTracker creates an IP tracker.
+//
+// The window parameters configure the amount of past network events which are kept. The
+// minStatements parameter enforces a minimum number of statements which must be recorded
+// before any prediction is made. Higher values for these parameters decrease 'flapping' of
+// predictions as network conditions change. Window duration values should typically be in
+// the range of minutes.
+func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTracker {
+ return &IPTracker{
+ window: window,
+ contactWindow: contactWindow,
+ statements: make(map[string]ipStatement),
+ minStatements: minStatements,
+ contact: make(map[string]mclock.AbsTime),
+ clock: mclock.System{},
+ }
+}
+
+// PredictFullConeNAT checks whether the local host is behind full cone NAT. It predicts by
+// checking whether any statement has been received from a node we didn't contact before
+// the statement was made.
+func (it *IPTracker) PredictFullConeNAT() bool {
+ now := it.clock.Now()
+ it.gcContact(now)
+ it.gcStatements(now)
+ for host, st := range it.statements {
+ if c, ok := it.contact[host]; !ok || c > st.time {
+ return true
+ }
+ }
+ return false
+}
+
+// PredictEndpoint returns the current prediction of the external endpoint.
+func (it *IPTracker) PredictEndpoint() string {
+ it.gcStatements(it.clock.Now())
+
+ // The current strategy is simple: find the endpoint with most statements.
+ counts := make(map[string]int, len(it.statements))
+ maxcount, max := 0, ""
+ for _, s := range it.statements {
+ c := counts[s.endpoint] + 1
+ counts[s.endpoint] = c
+ if c > maxcount && c >= it.minStatements {
+ maxcount, max = c, s.endpoint
+ }
+ }
+ return max
+}
+
+// AddStatement records that a certain host thinks our external endpoint is the one given.
+func (it *IPTracker) AddStatement(host, endpoint string) {
+ now := it.clock.Now()
+ it.statements[host] = ipStatement{endpoint, now}
+ if time.Duration(now-it.lastStatementGC) >= it.window {
+ it.gcStatements(now)
+ }
+}
+
+// AddContact records that a packet containing our endpoint information has been sent to a
+// certain host.
+func (it *IPTracker) AddContact(host string) {
+ now := it.clock.Now()
+ it.contact[host] = now
+ if time.Duration(now-it.lastContactGC) >= it.contactWindow {
+ it.gcContact(now)
+ }
+}
+
+func (it *IPTracker) gcStatements(now mclock.AbsTime) {
+ it.lastStatementGC = now
+ cutoff := now.Add(-it.window)
+ for host, s := range it.statements {
+ if s.time < cutoff {
+ delete(it.statements, host)
+ }
+ }
+}
+
+func (it *IPTracker) gcContact(now mclock.AbsTime) {
+ it.lastContactGC = now
+ cutoff := now.Add(-it.contactWindow)
+ for host, ct := range it.contact {
+ if ct < cutoff {
+ delete(it.contact, host)
+ }
+ }
+}
diff --git a/p2p/netutil/iptrack_test.go b/p2p/netutil/iptrack_test.go
new file mode 100644
index 000000000..711e588d6
--- /dev/null
+++ b/p2p/netutil/iptrack_test.go
@@ -0,0 +1,138 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package netutil
+
+import (
+ crand "crypto/rand"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/tomochain/tomochain/common/mclock"
+)
+
+const (
+ opStatement = iota
+ opContact
+ opPredict
+ opCheckFullCone
+)
+
+type iptrackTestEvent struct {
+ op int
+ time int // absolute, in milliseconds
+ ip, from string
+}
+
+func TestIPTracker(t *testing.T) {
+ tests := map[string][]iptrackTestEvent{
+ "minStatements": {
+ {opPredict, 0, "", ""},
+ {opStatement, 0, "127.0.0.1", "127.0.0.2"},
+ {opPredict, 1000, "", ""},
+ {opStatement, 1000, "127.0.0.1", "127.0.0.3"},
+ {opPredict, 1000, "", ""},
+ {opStatement, 1000, "127.0.0.1", "127.0.0.4"},
+ {opPredict, 1000, "127.0.0.1", ""},
+ },
+ "window": {
+ {opStatement, 0, "127.0.0.1", "127.0.0.2"},
+ {opStatement, 2000, "127.0.0.1", "127.0.0.3"},
+ {opStatement, 3000, "127.0.0.1", "127.0.0.4"},
+ {opPredict, 10000, "127.0.0.1", ""},
+ {opPredict, 10001, "", ""}, // first statement expired
+ {opStatement, 10100, "127.0.0.1", "127.0.0.2"},
+ {opPredict, 10200, "127.0.0.1", ""},
+ },
+ "fullcone": {
+ {opContact, 0, "", "127.0.0.2"},
+ {opStatement, 10, "127.0.0.1", "127.0.0.2"},
+ {opContact, 2000, "", "127.0.0.3"},
+ {opStatement, 2010, "127.0.0.1", "127.0.0.3"},
+ {opContact, 3000, "", "127.0.0.4"},
+ {opStatement, 3010, "127.0.0.1", "127.0.0.4"},
+ {opCheckFullCone, 3500, "false", ""},
+ },
+ "fullcone_2": {
+ {opContact, 0, "", "127.0.0.2"},
+ {opStatement, 10, "127.0.0.1", "127.0.0.2"},
+ {opContact, 2000, "", "127.0.0.3"},
+ {opStatement, 2010, "127.0.0.1", "127.0.0.3"},
+ {opStatement, 3000, "127.0.0.1", "127.0.0.4"},
+ {opContact, 3010, "", "127.0.0.4"},
+ {opCheckFullCone, 3500, "true", ""},
+ },
+ }
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) { runIPTrackerTest(t, test) })
+ }
+}
+
+func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) {
+ var (
+ clock mclock.Simulated
+ it = NewIPTracker(10*time.Second, 10*time.Second, 3)
+ )
+ it.clock = &clock
+ for i, ev := range evs {
+ evtime := time.Duration(ev.time) * time.Millisecond
+ clock.Run(evtime - time.Duration(clock.Now()))
+ switch ev.op {
+ case opStatement:
+ it.AddStatement(ev.from, ev.ip)
+ case opContact:
+ it.AddContact(ev.from)
+ case opPredict:
+ if pred := it.PredictEndpoint(); pred != ev.ip {
+ t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip)
+ }
+ case opCheckFullCone:
+ pred := fmt.Sprintf("%t", it.PredictFullConeNAT())
+ if pred != ev.ip {
+ t.Errorf("op %d: wrong prediction %s, want %s", i, pred, ev.ip)
+ }
+ }
+ }
+}
+
+// This checks that old statements and contacts are GCed even if Predict* isn't called.
+func TestIPTrackerForceGC(t *testing.T) {
+ var (
+ clock mclock.Simulated
+ window = 10 * time.Second
+ rate = 50 * time.Millisecond
+ max = int(window/rate) + 1
+ it = NewIPTracker(window, window, 3)
+ )
+ it.clock = &clock
+
+ for i := 0; i < 5*max; i++ {
+ e1 := make([]byte, 4)
+ e2 := make([]byte, 4)
+ crand.Read(e1)
+ crand.Read(e2)
+ it.AddStatement(string(e1), string(e2))
+ it.AddContact(string(e1))
+ clock.Run(rate)
+ }
+ if len(it.contact) > 2*max {
+ t.Errorf("contacts not GCed, have %d", len(it.contact))
+ }
+ if len(it.statements) > 2*max {
+ t.Errorf("statements not GCed, have %d", len(it.statements))
+ }
+}
diff --git a/p2p/peer.go b/p2p/peer.go
index 1852d59bc..04b12d893 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -17,6 +17,7 @@
package p2p
import (
+ "errors"
"fmt"
"io"
"net"
@@ -27,10 +28,15 @@ import (
"github.com/tomochain/tomochain/common/mclock"
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
+ "github.com/tomochain/tomochain/p2p/enr"
"github.com/tomochain/tomochain/rlp"
)
+var (
+ ErrShuttingDown = errors.New("shutting down")
+)
+
const (
baseProtocolVersion = 5
baseProtocolLength = uint64(16)
@@ -47,8 +53,6 @@ const (
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
- getPeersMsg = 0x04
- peersMsg = 0x05
)
// protoHandshake is the RLP structure of the protocol handshake.
@@ -57,7 +61,7 @@ type protoHandshake struct {
Name string
Caps []Cap
ListenPort uint64
- ID discover.NodeID
+ ID []byte // secp256k1 public key
// Ignore additional fields (for forward compatibility).
Rest []rlp.RawValue `rlp:"tail"`
@@ -87,12 +91,12 @@ const (
// PeerEvent is an event emitted when peers are either added or dropped from
// a p2p.Server or when a message is sent or received on a peer connection
type PeerEvent struct {
- Type PeerEventType `json:"type"`
- Peer discover.NodeID `json:"peer"`
- Error string `json:"error,omitempty"`
- Protocol string `json:"protocol,omitempty"`
- MsgCode *uint64 `json:"msg_code,omitempty"`
- MsgSize *uint32 `json:"msg_size,omitempty"`
+ Type PeerEventType `json:"type"`
+ Peer enode.ID `json:"peer"`
+ Error string `json:"error,omitempty"`
+ Protocol string `json:"protocol,omitempty"`
+ MsgCode *uint64 `json:"msg_code,omitempty"`
+ MsgSize *uint32 `json:"msg_size,omitempty"`
}
// Peer represents a connected remote node.
@@ -108,22 +112,27 @@ type Peer struct {
disc chan DiscReason
// events receives message send / receive events if set
- events *event.Feed
- PairPeer *Peer
+ events *event.Feed
}
// NewPeer returns a peer for testing purposes.
-func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
+func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe()
- conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
+ node := enode.SignNull(new(enr.Record), id)
+ conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
peer := newPeer(conn, nil)
close(peer.closed) // ensures Disconnect doesn't block
return peer
}
// ID returns the node's public key.
-func (p *Peer) ID() discover.NodeID {
- return p.rw.id
+func (p *Peer) ID() enode.ID {
+ return p.rw.node.ID()
+}
+
+// Node returns the peer's node descriptor.
+func (p *Peer) Node() *enode.Node {
+ return p.rw.node
}
// Name returns the node name that the remote node advertised.
@@ -158,12 +167,13 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer.
func (p *Peer) String() string {
- return fmt.Sprintf("Peer %x %v ", p.rw.id[:8], p.RemoteAddr())
+ id := p.ID()
+ return fmt.Sprintf("Peer %x %v", id[:8], p.RemoteAddr())
}
// Inbound returns true if the peer is an inbound connection
func (p *Peer) Inbound() bool {
- return p.rw.flags&inboundConn != 0
+ return p.rw.is(inboundConn)
}
func newPeer(conn *conn, protocols []Protocol) *Peer {
@@ -175,7 +185,7 @@ func newPeer(conn *conn, protocols []Protocol) *Peer {
disc: make(chan DiscReason),
protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop
closed: make(chan struct{}),
- log: log.New("id", conn.id, "conn", conn.flags),
+ log: log.New("id", conn.node.ID(), "conn", conn.flags),
}
return p
}
@@ -223,15 +233,14 @@ loop:
reason = discReasonForError(err)
break loop
case err = <-p.disc:
+ reason = discReasonForError(err)
break loop
}
}
+
close(p.closed)
p.rw.close(reason)
p.wg.Wait()
- if p.PairPeer != nil {
- go func() { p.PairPeer.Disconnect(DiscPairPeerStop) }()
- }
return remoteRequested, err
}
@@ -348,7 +357,6 @@ func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error)
rw = newMsgEventer(rw, p.events, p.ID(), proto.Name)
}
p.log.Trace(fmt.Sprintf("Starting protocol %s/%d", proto.Name, proto.Version))
-
go func() {
err := proto.Run(p, rw)
if err == nil {
@@ -376,7 +384,7 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
type protoRW struct {
Protocol
- in chan Msg // receices read messages
+ in chan Msg // receives read messages
closed <-chan struct{} // receives when peer is shutting down
wstart <-chan struct{} // receives when write may start
werr chan<- error // for write results
@@ -398,7 +406,7 @@ func (rw *protoRW) WriteMsg(msg Msg) (err error) {
// as well but we don't want to rely on that.
rw.werr <- err
case <-rw.closed:
- err = fmt.Errorf("shutting down")
+ err = ErrShuttingDown
}
return err
}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index a3e1c74fd..1c795f9f8 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -44,9 +44,14 @@ var discard = Protocol{
}
func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
- fd1, fd2 := net.Pipe()
- c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)}
- c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)}
+ var (
+ fd1, fd2 = net.Pipe()
+ key1, key2 = newkey(), newkey()
+ t1 = newTestTransport(&key2.PublicKey, fd1)
+ t2 = newTestTransport(&key1.PublicKey, fd2)
+ )
+ c1 := &conn{fd: fd1, node: newNode(randomID(), nil), transport: t1}
+ c2 := &conn{fd: fd2, node: newNode(randomID(), nil), transport: t2}
for _, p := range protos {
c1.caps = append(c1.caps, p.cap())
c2.caps = append(c2.caps, p.cap())
diff --git a/p2p/protocol.go b/p2p/protocol.go
index dbdb19701..fafc044d8 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -19,7 +19,7 @@ package p2p
import (
"fmt"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
// Protocol represents a P2P subprotocol implementation.
@@ -51,7 +51,7 @@ type Protocol struct {
// PeerInfo is an optional helper method to retrieve protocol specific metadata
// about a certain peer in the network. If an info retrieval function is set,
// but returns nil, it is assumed that the protocol handshake is still running.
- PeerInfo func(id discover.NodeID) interface{}
+ PeerInfo func(id enode.ID) interface{}
}
func (p Protocol) cap() Cap {
diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go
index 286bbf97f..0e4523e40 100644
--- a/p2p/protocols/protocol_test.go
+++ b/p2p/protocols/protocol_test.go
@@ -24,7 +24,7 @@ import (
"time"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
p2ptest "github.com/tomochain/tomochain/p2p/testing"
)
@@ -36,14 +36,14 @@ type hs0 struct {
// message to kill/drop the peer with nodeID
type kill struct {
- C discover.NodeID
+ C enode.ID
}
// message to drop connection
type drop struct {
}
-/// protoHandshake represents module-independent aspects of the protocol and is
+// / protoHandshake represents module-independent aspects of the protocol and is
// the first message peers send and receive as part the initial exchange
type protoHandshake struct {
Version uint // local and remote peer should have identical version
@@ -144,7 +144,7 @@ func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTes
return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp))
}
-func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange {
+func protoHandshakeExchange(id enode.ID, proto *protoHandshake) []p2ptest.Exchange {
return []p2ptest.Exchange{
{
@@ -197,7 +197,7 @@ func TestProtoHandshakeSuccess(t *testing.T) {
runProtoHandshake(t, &protoHandshake{42, "420"})
}
-func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange {
+func moduleHandshakeExchange(id enode.ID, resp uint) []p2ptest.Exchange {
return []p2ptest.Exchange{
{
@@ -249,7 +249,7 @@ func TestModuleHandshakeSuccess(t *testing.T) {
}
// testing complex interactions over multiple peers, relaying, dropping
-func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange {
+func testMultiPeerSetup(a, b enode.ID) []p2ptest.Exchange {
return []p2ptest.Exchange{
{
diff --git a/p2p/rlpx.go b/p2p/rlpx.go
index 2cc4d42d3..2fcfc1c18 100644
--- a/p2p/rlpx.go
+++ b/p2p/rlpx.go
@@ -36,12 +36,12 @@ import (
"time"
"github.com/golang/snappy"
- "github.com/tomochain/tomochain/crypto"
+ "github.com/tomochain/tomochain/common/bitutil"
+ crypto "github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/crypto/ecies"
"github.com/tomochain/tomochain/crypto/secp256k1"
"github.com/tomochain/tomochain/crypto/sha3"
- "github.com/tomochain/tomochain/p2p/discover"
- "github.com/tomochain/tomochain/rlp"
+ rlp "github.com/tomochain/tomochain/rlp"
)
const (
@@ -122,7 +122,6 @@ func (t *rlpx) close(err error) {
}
func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) {
-
// Writing our handshake happens concurrently, we prefer
// returning the handshake read error. If the remote side
// disconnects us early with a valid reason, we should return it
@@ -166,7 +165,7 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake,
if err := msg.Decode(&hs); err != nil {
return nil, err
}
- if (hs.ID == discover.NodeID{}) {
+ if len(hs.ID) != 64 || !bitutil.TestBytes(hs.ID) {
return nil, DiscInvalidIdentity
}
return &hs, nil
@@ -176,31 +175,29 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake,
// messages. the protocol handshake is the first authenticated message
// and also verifies whether the encryption handshake 'worked' and the
// remote side actually provided the right public key.
-func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) {
+func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *ecdsa.PublicKey) (*ecdsa.PublicKey, error) {
var (
sec secrets
err error
)
if dial == nil {
- sec, err = receiverEncHandshake(t.fd, prv, nil)
+ sec, err = receiverEncHandshake(t.fd, prv)
} else {
- sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil)
+ sec, err = initiatorEncHandshake(t.fd, prv, dial)
}
if err != nil {
- return discover.NodeID{}, err
+ return nil, err
}
t.wmu.Lock()
t.rw = newRLPXFrameRW(t.fd, sec)
t.wmu.Unlock()
- return sec.RemoteID, nil
+ return sec.Remote.ExportECDSA(), nil
}
// encHandshake contains the state of the encryption handshake.
type encHandshake struct {
- initiator bool
- remoteID discover.NodeID
-
- remotePub *ecies.PublicKey // remote-pubk
+ initiator bool
+ remote *ecies.PublicKey // remote-pubk
initNonce, respNonce []byte // nonce
randomPrivKey *ecies.PrivateKey // ecdhe-random
remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
@@ -209,7 +206,7 @@ type encHandshake struct {
// secrets represents the connection secrets
// which are negotiated during the encryption handshake.
type secrets struct {
- RemoteID discover.NodeID
+ Remote *ecies.PublicKey
AES, MAC []byte
EgressMAC, IngressMAC hash.Hash
Token []byte
@@ -250,9 +247,9 @@ func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
sharedSecret := crypto.Keccak256(ecdheSecret, crypto.Keccak256(h.respNonce, h.initNonce))
aesSecret := crypto.Keccak256(ecdheSecret, sharedSecret)
s := secrets{
- RemoteID: h.remoteID,
- AES: aesSecret,
- MAC: crypto.Keccak256(ecdheSecret, aesSecret),
+ Remote: h.remote,
+ AES: aesSecret,
+ MAC: crypto.Keccak256(ecdheSecret, aesSecret),
}
// setup sha3 instances for the MACs
@@ -274,16 +271,16 @@ func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
// staticSharedSecret returns the static shared secret, the result
// of key agreement between the local and remote static node key.
func (h *encHandshake) staticSharedSecret(prv *ecdsa.PrivateKey) ([]byte, error) {
- return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
+ return ecies.ImportECDSA(prv).GenerateShared(h.remote, sskLen, sskLen)
}
// initiatorEncHandshake negotiates a session token on conn.
// it should be called on the dialing side of the connection.
//
// prv is the local client's private key.
-func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
- h := &encHandshake{initiator: true, remoteID: remoteID}
- authMsg, err := h.makeAuthMsg(prv, token)
+func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remote *ecdsa.PublicKey) (s secrets, err error) {
+ h := &encHandshake{initiator: true, remote: ecies.ImportECDSAPublic(remote)}
+ authMsg, err := h.makeAuthMsg(prv)
if err != nil {
return s, err
}
@@ -307,15 +304,11 @@ func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID d
}
// makeAuthMsg creates the initiator handshake message.
-func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMsgV4, error) {
- rpub, err := h.remoteID.Pubkey()
- if err != nil {
- return nil, fmt.Errorf("bad remoteID: %v", err)
- }
- h.remotePub = ecies.ImportECDSAPublic(rpub)
+func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey) (*authMsgV4, error) {
// Generate random initiator nonce.
h.initNonce = make([]byte, shaLen)
- if _, err := rand.Read(h.initNonce); err != nil {
+ _, err := rand.Read(h.initNonce)
+ if err != nil {
return nil, err
}
// Generate random keypair to for ECDH.
@@ -325,7 +318,7 @@ func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMs
}
// Sign known message: static-shared-secret ^ nonce
- token, err = h.staticSharedSecret(prv)
+ token, err := h.staticSharedSecret(prv)
if err != nil {
return nil, err
}
@@ -353,8 +346,7 @@ func (h *encHandshake) handleAuthResp(msg *authRespV4) (err error) {
// it should be called on the listening side of the connection.
//
// prv is the local client's private key.
-// token is the token from a previous session with this node.
-func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
+func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey) (s secrets, err error) {
authMsg := new(authMsgV4)
authPacket, err := readHandshakeMsg(authMsg, encAuthMsgLen, prv, conn)
if err != nil {
@@ -386,13 +378,12 @@ func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byt
func (h *encHandshake) handleAuthMsg(msg *authMsgV4, prv *ecdsa.PrivateKey) error {
// Import the remote identity.
- h.initNonce = msg.Nonce[:]
- h.remoteID = msg.InitiatorPubkey
- rpub, err := h.remoteID.Pubkey()
+ rpub, err := importPublicKey(msg.InitiatorPubkey[:])
if err != nil {
- return fmt.Errorf("bad remoteID: %#v", err)
+ return err
}
- h.remotePub = ecies.ImportECDSAPublic(rpub)
+ h.initNonce = msg.Nonce[:]
+ h.remote = rpub
// Generate random keypair for ECDH.
// If a private key is already set, use it instead of generating one (for testing).
@@ -438,7 +429,7 @@ func (msg *authMsgV4) sealPlain(h *encHandshake) ([]byte, error) {
n += copy(buf[n:], msg.InitiatorPubkey[:])
n += copy(buf[n:], msg.Nonce[:])
buf[n] = 0 // token-flag
- return ecies.Encrypt(rand.Reader, h.remotePub, buf, nil, nil)
+ return ecies.Encrypt(rand.Reader, h.remote, buf, nil, nil)
}
func (msg *authMsgV4) decodePlain(input []byte) {
@@ -454,7 +445,7 @@ func (msg *authRespV4) sealPlain(hs *encHandshake) ([]byte, error) {
buf := make([]byte, authRespLen)
n := copy(buf, msg.RandomPubkey[:])
copy(buf[n:], msg.Nonce[:])
- return ecies.Encrypt(rand.Reader, hs.remotePub, buf, nil, nil)
+ return ecies.Encrypt(rand.Reader, hs.remote, buf, nil, nil)
}
func (msg *authRespV4) decodePlain(input []byte) {
@@ -477,7 +468,7 @@ func sealEIP8(msg interface{}, h *encHandshake) ([]byte, error) {
prefix := make([]byte, 2)
binary.BigEndian.PutUint16(prefix, uint16(buf.Len()+eciesOverhead))
- enc, err := ecies.Encrypt(rand.Reader, h.remotePub, buf.Bytes(), nil, prefix)
+ enc, err := ecies.Encrypt(rand.Reader, h.remote, buf.Bytes(), nil, prefix)
return append(prefix, enc...), err
}
@@ -529,9 +520,9 @@ func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
// TODO: fewer pointless conversions
- pub := crypto.ToECDSAPub(pubKey65)
- if pub.X == nil {
- return nil, fmt.Errorf("invalid public key")
+ pub, err := crypto.UnmarshalPubkey(pubKey65)
+ if err != nil {
+ return nil, err
}
return ecies.ImportECDSAPublic(pub), nil
}
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go
index e86a1fb17..207dbf057 100644
--- a/p2p/rlpx_test.go
+++ b/p2p/rlpx_test.go
@@ -18,6 +18,7 @@ package p2p
import (
"bytes"
+ "crypto/ecdsa"
"crypto/rand"
"errors"
"fmt"
@@ -31,11 +32,11 @@ import (
"time"
"github.com/davecgh/go-spew/spew"
- "github.com/tomochain/tomochain/crypto"
+ crypto "github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/crypto/ecies"
"github.com/tomochain/tomochain/crypto/sha3"
- "github.com/tomochain/tomochain/p2p/discover"
- "github.com/tomochain/tomochain/rlp"
+ "github.com/tomochain/tomochain/p2p/simulations/pipes"
+ rlp "github.com/tomochain/tomochain/rlp"
)
func TestSharedSecret(t *testing.T) {
@@ -79,9 +80,9 @@ func TestEncHandshake(t *testing.T) {
func testEncHandshake(token []byte) error {
type result struct {
- side string
- id discover.NodeID
- err error
+ side string
+ pubkey *ecdsa.PublicKey
+ err error
}
var (
prv0, _ = crypto.GenerateKey()
@@ -96,14 +97,12 @@ func testEncHandshake(token []byte) error {
defer func() { output <- r }()
defer fd0.Close()
- dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
- r.id, r.err = c0.doEncHandshake(prv0, dest)
+ r.pubkey, r.err = c0.doEncHandshake(prv0, &prv1.PublicKey)
if r.err != nil {
return
}
- id1 := discover.PubkeyID(&prv1.PublicKey)
- if r.id != id1 {
- r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
+ if !reflect.DeepEqual(r.pubkey, &prv1.PublicKey) {
+ r.err = fmt.Errorf("remote pubkey mismatch: got %v, want: %v", r.pubkey, &prv1.PublicKey)
}
}()
go func() {
@@ -111,13 +110,12 @@ func testEncHandshake(token []byte) error {
defer func() { output <- r }()
defer fd1.Close()
- r.id, r.err = c1.doEncHandshake(prv1, nil)
+ r.pubkey, r.err = c1.doEncHandshake(prv1, nil)
if r.err != nil {
return
}
- id0 := discover.PubkeyID(&prv0.PublicKey)
- if r.id != id0 {
- r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
+ if !reflect.DeepEqual(r.pubkey, &prv0.PublicKey) {
+ r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.pubkey, &prv0.PublicKey)
}
}()
@@ -149,17 +147,17 @@ func testEncHandshake(token []byte) error {
func TestProtocolHandshake(t *testing.T) {
var (
prv0, _ = crypto.GenerateKey()
- node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
- hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
+ pub0 = crypto.FromECDSAPub(&prv0.PublicKey)[1:]
+ hs0 = &protoHandshake{Version: 3, ID: pub0, Caps: []Cap{{"a", 0}, {"b", 2}}}
prv1, _ = crypto.GenerateKey()
- node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
- hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
+ pub1 = crypto.FromECDSAPub(&prv1.PublicKey)[1:]
+ hs1 = &protoHandshake{Version: 3, ID: pub1, Caps: []Cap{{"c", 1}, {"d", 3}}}
wg sync.WaitGroup
)
- fd0, fd1, err := tcpPipe()
+ fd0, fd1, err := pipes.TCPPipe()
if err != nil {
t.Fatal(err)
}
@@ -169,13 +167,13 @@ func TestProtocolHandshake(t *testing.T) {
defer wg.Done()
defer fd0.Close()
rlpx := newRLPX(fd0)
- remid, err := rlpx.doEncHandshake(prv0, node1)
+ rpubkey, err := rlpx.doEncHandshake(prv0, &prv1.PublicKey)
if err != nil {
t.Errorf("dial side enc handshake failed: %v", err)
return
}
- if remid != node1.ID {
- t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
+ if !reflect.DeepEqual(rpubkey, &prv1.PublicKey) {
+ t.Errorf("dial side remote pubkey mismatch: got %v, want %v", rpubkey, &prv1.PublicKey)
return
}
@@ -195,13 +193,13 @@ func TestProtocolHandshake(t *testing.T) {
defer wg.Done()
defer fd1.Close()
rlpx := newRLPX(fd1)
- remid, err := rlpx.doEncHandshake(prv1, nil)
+ rpubkey, err := rlpx.doEncHandshake(prv1, nil)
if err != nil {
t.Errorf("listen side enc handshake failed: %v", err)
return
}
- if remid != node0.ID {
- t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
+ if !reflect.DeepEqual(rpubkey, &prv0.PublicKey) {
+ t.Errorf("listen side remote pubkey mismatch: got %v, want %v", rpubkey, &prv0.PublicKey)
return
}
@@ -601,31 +599,3 @@ func TestHandshakeForwardCompatibility(t *testing.T) {
t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash)
}
}
-
-// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket
-func tcpPipe() (net.Conn, net.Conn, error) {
- l, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return nil, nil, err
- }
- defer l.Close()
-
- var aconn net.Conn
- aerr := make(chan error, 1)
- go func() {
- var err error
- aconn, err = l.Accept()
- aerr <- err
- }()
-
- dconn, err := net.Dial("tcp", l.Addr().String())
- if err != nil {
- <-aerr
- return nil, nil, err
- }
- if err := <-aerr; err != nil {
- dconn.Close()
- return nil, nil, err
- }
- return aconn, dconn, nil
-}
diff --git a/p2p/server.go b/p2p/server.go
index 6a5ea9e61..7875426b6 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -18,19 +18,23 @@
package p2p
import (
+ "bytes"
"crypto/ecdsa"
"errors"
"fmt"
"net"
"sync"
+ "sync/atomic"
"time"
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/mclock"
+ "github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/p2p/discv5"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
"github.com/tomochain/tomochain/p2p/netutil"
)
@@ -76,7 +80,7 @@ type Config struct {
// Disabling is useful for protocol debugging (manual topology).
NoDiscovery bool
- // DiscoveryV5 specifies whether the the new topic-discovery based V5 discovery
+ // DiscoveryV5 specifies whether the new topic-discovery based V5 discovery
// protocol should be started or not.
DiscoveryV5 bool `toml:",omitempty"`
@@ -86,7 +90,7 @@ type Config struct {
// BootstrapNodes are used to establish connectivity
// with the rest of the network.
- BootstrapNodes []*discover.Node
+ BootstrapNodes []*enode.Node
// BootstrapNodesV5 are used to establish connectivity
// with the rest of the network using the V5 discovery
@@ -95,11 +99,11 @@ type Config struct {
// Static nodes are used as pre-configured connections which are always
// maintained and re-connected on disconnects.
- StaticNodes []*discover.Node
+ StaticNodes []*enode.Node
// Trusted nodes are used as pre-configured connections which are always
// allowed to connect, even above the peer limit.
- TrustedNodes []*discover.Node
+ TrustedNodes []*enode.Node
// Connectivity can be restricted to certain IP networks.
// If this option is set to a non-nil value, only hosts which match one of the
@@ -167,8 +171,10 @@ type Server struct {
peerOpDone chan struct{}
quit chan struct{}
- addstatic chan *discover.Node
- removestatic chan *discover.Node
+ addstatic chan *enode.Node
+ removestatic chan *enode.Node
+ addtrusted chan *enode.Node
+ removetrusted chan *enode.Node
posthandshake chan *conn
addpeer chan *conn
delpeer chan peerDrop
@@ -177,7 +183,7 @@ type Server struct {
log log.Logger
}
-type peerOpFunc func(map[discover.NodeID]*Peer)
+type peerOpFunc func(map[enode.ID]*Peer)
type peerDrop struct {
*Peer
@@ -185,7 +191,7 @@ type peerDrop struct {
requested bool // true if signaled by the peer
}
-type connFlag int
+type connFlag int32
const (
dynDialedConn connFlag = 1 << iota
@@ -199,16 +205,16 @@ const (
type conn struct {
fd net.Conn
transport
+ node *enode.Node
flags connFlag
- cont chan error // The run loop uses cont to signal errors to SetupConn.
- id discover.NodeID // valid after the encryption handshake
- caps []Cap // valid after the protocol handshake
- name string // valid after the protocol handshake
+ cont chan error // The run loop uses cont to signal errors to SetupConn.
+ caps []Cap // valid after the protocol handshake
+ name string // valid after the protocol handshake
}
type transport interface {
// The two handshakes.
- doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error)
+ doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error)
doProtoHandshake(our *protoHandshake) (*protoHandshake, error)
// The MsgReadWriter can only be used after the encryption
// handshake has completed. The code uses conn.id to track this
@@ -222,8 +228,8 @@ type transport interface {
func (c *conn) String() string {
s := c.flags.String()
- if (c.id != discover.NodeID{}) {
- s += " " + c.id.String()
+ if (c.node.ID() != enode.ID{}) {
+ s += " " + c.node.ID().String()
}
s += " " + c.fd.RemoteAddr().String()
return s
@@ -250,7 +256,23 @@ func (f connFlag) String() string {
}
func (c *conn) is(f connFlag) bool {
- return c.flags&f != 0
+ flags := connFlag(atomic.LoadInt32((*int32)(&c.flags)))
+ return flags&f != 0
+}
+
+func (c *conn) set(f connFlag, val bool) {
+ for {
+ oldFlags := connFlag(atomic.LoadInt32((*int32)(&c.flags)))
+ flags := oldFlags
+ if val {
+ flags |= f
+ } else {
+ flags &= ^f
+ }
+ if atomic.CompareAndSwapInt32((*int32)(&c.flags), int32(oldFlags), int32(flags)) {
+ return
+ }
+ }
}
// Peers returns all connected peers.
@@ -260,7 +282,7 @@ func (srv *Server) Peers() []*Peer {
// Note: We'd love to put this function into a variable but
// that seems to cause a weird compiler error in some
// environments.
- case srv.peerOp <- func(peers map[discover.NodeID]*Peer) {
+ case srv.peerOp <- func(peers map[enode.ID]*Peer) {
for _, p := range peers {
ps = append(ps, p)
}
@@ -275,7 +297,7 @@ func (srv *Server) Peers() []*Peer {
func (srv *Server) PeerCount() int {
var count int
select {
- case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }:
+ case srv.peerOp <- func(ps map[enode.ID]*Peer) { count = len(ps) }:
<-srv.peerOpDone
case <-srv.quit:
}
@@ -285,8 +307,7 @@ func (srv *Server) PeerCount() int {
// AddPeer connects to the given node and maintains the connection until the
// server is shut down. If the connection fails for any reason, the server will
// attempt to reconnect the peer.
-func (srv *Server) AddPeer(node *discover.Node) {
-
+func (srv *Server) AddPeer(node *enode.Node) {
select {
case srv.addstatic <- node:
case <-srv.quit:
@@ -294,55 +315,83 @@ func (srv *Server) AddPeer(node *discover.Node) {
}
// RemovePeer disconnects from the given node
-func (srv *Server) RemovePeer(node *discover.Node) {
+func (srv *Server) RemovePeer(node *enode.Node) {
select {
case srv.removestatic <- node:
case <-srv.quit:
}
}
+// AddTrustedPeer adds the given node to a reserved whitelist which allows the
+// node to always connect, even if the slot are full.
+func (srv *Server) AddTrustedPeer(node *enode.Node) {
+ select {
+ case srv.addtrusted <- node:
+ case <-srv.quit:
+ }
+}
+
+// RemoveTrustedPeer removes the given node from the trusted peer set.
+func (srv *Server) RemoveTrustedPeer(node *enode.Node) {
+ select {
+ case srv.removetrusted <- node:
+ case <-srv.quit:
+ }
+}
+
// SubscribePeers subscribes the given channel to peer events
func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription {
return srv.peerFeed.Subscribe(ch)
}
// Self returns the local node's endpoint information.
-func (srv *Server) Self() *discover.Node {
+func (srv *Server) Self() *enode.Node {
srv.lock.Lock()
- defer srv.lock.Unlock()
+ running, listener, ntab := srv.running, srv.listener, srv.ntab
+ srv.lock.Unlock()
- if !srv.running {
- return &discover.Node{IP: net.ParseIP("0.0.0.0")}
+ if !running {
+ return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0)
}
- return srv.makeSelf(srv.listener, srv.ntab)
+ return srv.makeSelf(listener, ntab)
}
-func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *discover.Node {
- // If the server's not running, return an empty node.
+func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *enode.Node {
// If the node is running but discovery is off, manually assemble the node infos.
if ntab == nil {
- // Inbound connections disabled, use zero address.
- if listener == nil {
- return &discover.Node{IP: net.ParseIP("0.0.0.0"), ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)}
- }
- // Otherwise inject the listener address too
- addr := listener.Addr().(*net.TCPAddr)
- return &discover.Node{
- ID: discover.PubkeyID(&srv.PrivateKey.PublicKey),
- IP: addr.IP,
- TCP: uint16(addr.Port),
- }
+ addr := srv.tcpAddr(listener)
+ return enode.NewV4(&srv.PrivateKey.PublicKey, addr.IP, addr.Port, 0)
}
// Otherwise return the discovery node.
return ntab.Self()
}
+func (srv *Server) tcpAddr(listener net.Listener) net.TCPAddr {
+ addr := net.TCPAddr{IP: net.IP{0, 0, 0, 0}}
+ if listener == nil {
+ return addr // Inbound connections disabled, use zero address.
+ }
+ // Otherwise inject the listener address too.
+ if a, ok := listener.Addr().(*net.TCPAddr); ok {
+ addr = *a
+ }
+ if srv.NAT != nil {
+ if ip, err := srv.NAT.ExternalIP(); err == nil {
+ addr.IP = ip
+ }
+ }
+ if addr.IP.IsUnspecified() {
+ addr.IP = net.IP{127, 0, 0, 1}
+ }
+ return addr
+}
+
// Stop terminates the server and all active peer connections.
// It blocks until all active connections have been closed.
func (srv *Server) Stop() {
srv.lock.Lock()
- defer srv.lock.Unlock()
if !srv.running {
+ srv.lock.Unlock()
return
}
srv.running = false
@@ -351,6 +400,7 @@ func (srv *Server) Stop() {
srv.listener.Close()
}
close(srv.quit)
+ srv.lock.Unlock()
srv.loopWG.Wait()
}
@@ -409,8 +459,10 @@ func (srv *Server) Start() (err error) {
srv.addpeer = make(chan *conn)
srv.delpeer = make(chan peerDrop)
srv.posthandshake = make(chan *conn)
- srv.addstatic = make(chan *discover.Node)
- srv.removestatic = make(chan *discover.Node)
+ srv.addstatic = make(chan *enode.Node)
+ srv.removestatic = make(chan *enode.Node)
+ srv.addtrusted = make(chan *enode.Node)
+ srv.removetrusted = make(chan *enode.Node)
srv.peerOp = make(chan peerOpFunc)
srv.peerOpDone = make(chan struct{})
@@ -487,7 +539,8 @@ func (srv *Server) Start() (err error) {
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake
- srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)}
+ pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
+ srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
for _, p := range srv.Protocols {
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
}
@@ -503,7 +556,6 @@ func (srv *Server) Start() (err error) {
srv.loopWG.Add(1)
go srv.run(dialer)
- srv.running = true
return nil
}
@@ -530,27 +582,26 @@ func (srv *Server) startListening() error {
}
type dialer interface {
- newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task
+ newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task
taskDone(task, time.Time)
- addStatic(*discover.Node)
- removeStatic(*discover.Node)
+ addStatic(*enode.Node)
+ removeStatic(*enode.Node)
}
func (srv *Server) run(dialstate dialer) {
defer srv.loopWG.Done()
var (
- peers = make(map[discover.NodeID]*Peer)
+ peers = make(map[enode.ID]*Peer)
inboundCount = 0
- trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
+ trusted = make(map[enode.ID]bool, len(srv.TrustedNodes))
taskdone = make(chan task, maxActiveDialTasks)
runningTasks []task
queuedTasks []task // tasks that can't run yet
)
// Put trusted nodes into a map to speed up checks.
- // Trusted peers are loaded on startup and cannot be
- // modified while the server is running.
+ // Trusted peers are loaded on startup or added via AddTrustedPeer RPC.
for _, n := range srv.TrustedNodes {
- trusted[n.ID] = true
+ trusted[n.ID()] = true
}
// removes t from runningTasks
@@ -595,17 +646,37 @@ running:
// This channel is used by AddPeer to add to the
// ephemeral static peer list. Add it to the dialer,
// it will keep the node connected.
- srv.log.Debug("Adding static node", "node", n)
+ srv.log.Trace("Adding static node", "node", n)
dialstate.addStatic(n)
case n := <-srv.removestatic:
// This channel is used by RemovePeer to send a
// disconnect request to a peer and begin the
- // stop keeping the node connected
- srv.log.Debug("Removing static node", "node", n)
+ // stop keeping the node connected.
+ srv.log.Trace("Removing static node", "node", n)
dialstate.removeStatic(n)
- if p, ok := peers[n.ID]; ok {
+ if p, ok := peers[n.ID()]; ok {
p.Disconnect(DiscRequested)
}
+ case n := <-srv.addtrusted:
+ // This channel is used by AddTrustedPeer to add an enode
+ // to the trusted node set.
+ srv.log.Trace("Adding trusted node", "node", n)
+ trusted[n.ID()] = true
+ // Mark any already-connected peer as trusted
+ if p, ok := peers[n.ID()]; ok {
+ p.rw.set(trustedConn, true)
+ }
+ case n := <-srv.removetrusted:
+ // This channel is used by RemoveTrustedPeer to remove an enode
+ // from the trusted node set.
+ srv.log.Trace("Removing trusted node", "node", n)
+ if _, ok := trusted[n.ID()]; ok {
+ delete(trusted, n.ID())
+ }
+ // Unmark any already-connected peer as trusted
+ if p, ok := peers[n.ID()]; ok {
+ p.rw.set(trustedConn, false)
+ }
case op := <-srv.peerOp:
// This channel is used by Peers and PeerCount.
op(peers)
@@ -620,7 +691,7 @@ running:
case c := <-srv.posthandshake:
// A connection has passed the encryption handshake so
// the remote identity is known (but hasn't been verified yet).
- if trusted[c.id] {
+ if trusted[c.node.ID()] {
// Ensure that the trusted flag is set before checking against MaxPeers.
c.flags |= trustedConn
}
@@ -643,15 +714,9 @@ running:
p.events = &srv.peerFeed
}
name := truncateName(c.name)
-
+ srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
go srv.runPeer(p)
- if peers[c.id] != nil {
- peers[c.id].PairPeer = p
- srv.log.Debug("Adding p2p pair peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
- } else {
- peers[c.id] = p
- srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
- }
+ peers[c.node.ID()] = p
if p.Inbound() {
inboundCount++
}
@@ -698,7 +763,7 @@ running:
}
}
-func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
+func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
@@ -708,19 +773,15 @@ func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inbound
return srv.encHandshakeChecks(peers, inboundCount, c)
}
-func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
+func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers
case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
return DiscTooManyPeers
- case peers[c.id] != nil:
- exitPeer := peers[c.id]
- if exitPeer.PairPeer != nil {
- return DiscAlreadyConnected
- }
- return nil
- case c.id == srv.Self().ID:
+ case peers[c.node.ID()] != nil:
+ return DiscAlreadyConnected
+ case c.node.ID() == srv.Self().ID():
return DiscSelf
default:
return nil
@@ -730,7 +791,6 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCo
func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns()
}
-
func (srv *Server) maxDialedConns() int {
if srv.NoDiscovery || srv.NoDial {
return 0
@@ -750,7 +810,7 @@ type tempError interface {
// inbound connections.
func (srv *Server) listenLoop() {
defer srv.loopWG.Done()
- srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab))
+ srv.log.Info("RLPx listener up", "self", srv.Self())
tokens := defaultMaxPendingPeers
if srv.MaxPendingPeers > 0 {
@@ -803,7 +863,7 @@ func (srv *Server) listenLoop() {
// SetupConn runs the handshakes and attempts to add the connection
// as a peer. It returns when the connection has been added as a peer
// or the handshakes have failed.
-func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) error {
+func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error {
self := srv.Self()
if self == nil {
return errors.New("shutdown")
@@ -812,12 +872,12 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Nod
err := srv.setupConn(c, flags, dialDest)
if err != nil {
c.close(err)
- srv.log.Trace("Setting up connection failed", "id", c.id, "err", err)
+ srv.log.Trace("Setting up connection failed", "addr", fd.RemoteAddr(), "err", err)
}
return err
}
-func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) error {
+func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) error {
// Prevent leftover pending conns from entering the handshake.
srv.lock.Lock()
running := srv.running
@@ -825,18 +885,30 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e
if !running {
return errServerStopped
}
+ // If dialing, figure out the remote public key.
+ var dialPubkey *ecdsa.PublicKey
+ if dialDest != nil {
+ dialPubkey = new(ecdsa.PublicKey)
+ if err := dialDest.Load((*enode.Secp256k1)(dialPubkey)); err != nil {
+ return fmt.Errorf("dial destination doesn't have a secp256k1 public key")
+ }
+ }
// Run the encryption handshake.
- var err error
- if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil {
+ remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
+ if err != nil {
srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
return err
}
- clog := srv.log.New("id", c.id, "addr", c.fd.RemoteAddr(), "conn", c.flags)
- // For dialed connections, check that the remote public key matches.
- if dialDest != nil && c.id != dialDest.ID {
- clog.Trace("Dialed identity mismatch", "want", c, dialDest.ID)
- return DiscUnexpectedIdentity
+ if dialDest != nil {
+ // For dialed connections, check that the remote public key matches.
+ if dialPubkey.X.Cmp(remotePubkey.X) != 0 || dialPubkey.Y.Cmp(remotePubkey.Y) != 0 {
+ return DiscUnexpectedIdentity
+ }
+ c.node = dialDest
+ } else {
+ c.node = nodeFromConn(remotePubkey, c.fd)
}
+ clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
err = srv.checkpoint(c, srv.posthandshake)
if err != nil {
clog.Trace("Rejected peer before protocol handshake", "err", err)
@@ -848,8 +920,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e
clog.Trace("Failed proto handshake", "err", err)
return err
}
- if phs.ID != c.id {
- clog.Trace("Wrong devp2p handshake identity", "err", phs.ID)
+ if id := c.node.ID(); !bytes.Equal(crypto.Keccak256(phs.ID), id[:]) {
+ clog.Trace("Wrong devp2p handshake identity", "phsid", fmt.Sprintf("%x", phs.ID))
return DiscUnexpectedIdentity
}
c.caps, c.name = phs.Caps, phs.Name
@@ -864,6 +936,16 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e
return nil
}
+func nodeFromConn(pubkey *ecdsa.PublicKey, conn net.Conn) *enode.Node {
+ var ip net.IP
+ var port int
+ if tcp, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
+ ip = tcp.IP
+ port = tcp.Port
+ }
+ return enode.NewV4(pubkey, ip, port, port)
+}
+
func truncateName(s string) string {
if len(s) > 20 {
return s[:20] + "..."
@@ -938,13 +1020,13 @@ func (srv *Server) NodeInfo() *NodeInfo {
info := &NodeInfo{
Name: srv.Name,
Enode: node.String(),
- ID: node.ID.String(),
- IP: node.IP.String(),
+ ID: node.ID().String(),
+ IP: node.IP().String(),
ListenAddr: srv.ListenAddr,
Protocols: make(map[string]interface{}),
}
- info.Ports.Discovery = int(node.UDP)
- info.Ports.Listener = int(node.TCP)
+ info.Ports.Discovery = node.UDP()
+ info.Ports.Listener = node.TCP()
// Gather all the running protocol infos (only once per protocol type)
for _, proto := range srv.Protocols {
diff --git a/p2p/server_test.go b/p2p/server_test.go
index b014bd9c3..da86ef63d 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -28,21 +28,22 @@ import (
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/crypto/sha3"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
+ "github.com/tomochain/tomochain/p2p/enr"
)
-func init() {
- // log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
-}
+// func init() {
+// log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
+// }
type testTransport struct {
- id discover.NodeID
+ rpub *ecdsa.PublicKey
*rlpx
closeErr error
}
-func newTestTransport(id discover.NodeID, fd net.Conn) transport {
+func newTestTransport(rpub *ecdsa.PublicKey, fd net.Conn) transport {
wrapped := newRLPX(fd).(*rlpx)
wrapped.rw = newRLPXFrameRW(fd, secrets{
MAC: zero16,
@@ -50,15 +51,16 @@ func newTestTransport(id discover.NodeID, fd net.Conn) transport {
IngressMAC: sha3.NewKeccak256(),
EgressMAC: sha3.NewKeccak256(),
})
- return &testTransport{id: id, rlpx: wrapped}
+ return &testTransport{rpub: rpub, rlpx: wrapped}
}
-func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
- return c.id, nil
+func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) {
+ return c.rpub, nil
}
func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
- return &protoHandshake{ID: c.id, Name: "test"}, nil
+ pubkey := crypto.FromECDSAPub(c.rpub)[1:]
+ return &protoHandshake{ID: pubkey, Name: "test"}, nil
}
func (c *testTransport) close(err error) {
@@ -66,7 +68,7 @@ func (c *testTransport) close(err error) {
c.closeErr = err
}
-func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
+func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server {
config := Config{
Name: "test",
MaxPeers: 10,
@@ -76,7 +78,7 @@ func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
server := &Server{
Config: config,
newPeerHook: pf,
- newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
+ newTransport: func(fd net.Conn) transport { return newTestTransport(remoteKey, fd) },
}
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
@@ -87,14 +89,11 @@ func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
func TestServerListen(t *testing.T) {
// start the test server
connected := make(chan *Peer)
- remid := randomID()
+ remid := &newkey().PublicKey
srv := startTestServer(t, remid, func(p *Peer) {
- if p.ID() != remid {
+ if p.ID() != enode.PubkeyToIDV4(remid) {
t.Error("peer func called with wrong node id")
}
- if p == nil {
- t.Error("peer func called with nil conn")
- }
connected <- p
})
defer close(connected)
@@ -141,21 +140,23 @@ func TestServerDial(t *testing.T) {
// start the server
connected := make(chan *Peer)
- remid := randomID()
+ remid := &newkey().PublicKey
srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
defer close(connected)
defer srv.Stop()
// tell the server to connect
tcpAddr := listener.Addr().(*net.TCPAddr)
- srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
+ node := enode.NewV4(remid, tcpAddr.IP, tcpAddr.Port, 0)
+ srv.AddPeer(node)
select {
case conn := <-accepted:
defer conn.Close()
+
select {
case peer := <-connected:
- if peer.ID() != remid {
+ if peer.ID() != enode.PubkeyToIDV4(remid) {
t.Errorf("peer has wrong id")
}
if peer.Name() != "test" {
@@ -169,25 +170,33 @@ func TestServerDial(t *testing.T) {
if !reflect.DeepEqual(peers, []*Peer{peer}) {
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
}
- case <-time.After(1 * time.Second):
- t.Error("server did not launch peer within one second")
- }
- select {
- case peer := <-connected:
- if peer.ID() != remid {
- t.Errorf("peer has wrong id")
- }
- if peer.Name() != "test" {
- t.Errorf("peer has wrong name")
- }
- if peer.RemoteAddr().String() != conn.LocalAddr().String() {
- t.Errorf("peer started with wrong conn: got %v, want %v",
- peer.RemoteAddr(), conn.LocalAddr())
+ // Test AddTrustedPeer/RemoveTrustedPeer and changing Trusted flags
+ // Particularly for race conditions on changing the flag state.
+ if peer := srv.Peers()[0]; peer.Info().Network.Trusted {
+ t.Errorf("peer is trusted prematurely: %v", peer)
}
+ done := make(chan bool)
+ go func() {
+ srv.AddTrustedPeer(node)
+ if peer := srv.Peers()[0]; !peer.Info().Network.Trusted {
+ t.Errorf("peer is not trusted after AddTrustedPeer: %v", peer)
+ }
+ srv.RemoveTrustedPeer(node)
+ if peer := srv.Peers()[0]; peer.Info().Network.Trusted {
+ t.Errorf("peer is trusted after RemoveTrustedPeer: %v", peer)
+ }
+ done <- true
+ }()
+ // Trigger potential race conditions
+ peer = srv.Peers()[0]
+ _ = peer.Inbound()
+ _ = peer.Info()
+ <-done
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
}
+
case <-time.After(1 * time.Second):
t.Error("server did not connect within one second")
}
@@ -201,7 +210,7 @@ func TestServerTaskScheduling(t *testing.T) {
quit, returned = make(chan struct{}), make(chan struct{})
tc = 0
tg = taskgen{
- newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
+ newFunc: func(running int, peers map[enode.ID]*Peer) []task {
tc++
return []task{&testTask{index: tc - 1}}
},
@@ -274,7 +283,7 @@ func TestServerManyTasks(t *testing.T) {
defer srv.Stop()
srv.loopWG.Add(1)
go srv.run(taskgen{
- newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
+ newFunc: func(running int, peers map[enode.ID]*Peer) []task {
start, end = end, end+maxActiveDialTasks+10
if end > len(alltasks) {
end = len(alltasks)
@@ -309,19 +318,19 @@ func TestServerManyTasks(t *testing.T) {
}
type taskgen struct {
- newFunc func(running int, peers map[discover.NodeID]*Peer) []task
+ newFunc func(running int, peers map[enode.ID]*Peer) []task
doneFunc func(task)
}
-func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
+func (tg taskgen) newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task {
return tg.newFunc(running, peers)
}
func (tg taskgen) taskDone(t task, now time.Time) {
tg.doneFunc(t)
}
-func (tg taskgen) addStatic(*discover.Node) {
+func (tg taskgen) addStatic(*enode.Node) {
}
-func (tg taskgen) removeStatic(*discover.Node) {
+func (tg taskgen) removeStatic(*enode.Node) {
}
type testTask struct {
@@ -337,13 +346,14 @@ func (t *testTask) Do(srv *Server) {
// just after the encryption handshake when the server is
// at capacity. Trusted connections should still be accepted.
func TestServerAtCap(t *testing.T) {
- trustedID := randomID()
+ trustedNode := newkey()
+ trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey)
srv := &Server{
Config: Config{
PrivateKey: newkey(),
MaxPeers: 10,
NoDial: true,
- TrustedNodes: []*discover.Node{{ID: trustedID}},
+ TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
},
}
if err := srv.Start(); err != nil {
@@ -351,10 +361,11 @@ func TestServerAtCap(t *testing.T) {
}
defer srv.Stop()
- newconn := func(id discover.NodeID) *conn {
+ newconn := func(id enode.ID) *conn {
fd, _ := net.Pipe()
- tx := newTestTransport(id, fd)
- return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
+ tx := newTestTransport(&trustedNode.PublicKey, fd)
+ node := enode.SignNull(new(enr.Record), id)
+ return &conn{fd: fd, transport: tx, flags: inboundConn, node: node, cont: make(chan error)}
}
// Inject a few connections to fill up the peer set.
@@ -365,7 +376,8 @@ func TestServerAtCap(t *testing.T) {
}
}
// Try inserting a non-trusted connection.
- c := newconn(randomID())
+ anotherID := randomID()
+ c := newconn(anotherID)
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err)
}
@@ -378,62 +390,144 @@ func TestServerAtCap(t *testing.T) {
t.Error("Server did not set trusted flag")
}
+ // Remove from trusted set and try again
+ srv.RemoveTrustedPeer(newNode(trustedID, nil))
+ c = newconn(trustedID)
+ if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
+ t.Error("wrong error for insert:", err)
+ }
+
+ // Add anotherID to trusted set and try again
+ srv.AddTrustedPeer(newNode(anotherID, nil))
+ c = newconn(anotherID)
+ if err := srv.checkpoint(c, srv.posthandshake); err != nil {
+ t.Error("unexpected error for trusted conn @posthandshake:", err)
+ }
+ if !c.is(trustedConn) {
+ t.Error("Server did not set trusted flag")
+ }
+}
+
+func TestServerPeerLimits(t *testing.T) {
+ srvkey := newkey()
+ clientkey := newkey()
+ clientnode := enode.NewV4(&clientkey.PublicKey, nil, 0, 0)
+
+ var tp = &setupTransport{
+ pubkey: &clientkey.PublicKey,
+ phs: protoHandshake{
+ ID: crypto.FromECDSAPub(&clientkey.PublicKey)[1:],
+ // Force "DiscUselessPeer" due to unmatching caps
+ // Caps: []Cap{discard.cap()},
+ },
+ }
+
+ srv := &Server{
+ Config: Config{
+ PrivateKey: srvkey,
+ MaxPeers: 0,
+ NoDial: true,
+ Protocols: []Protocol{discard},
+ },
+ newTransport: func(fd net.Conn) transport { return tp },
+ log: log.New(),
+ }
+ if err := srv.Start(); err != nil {
+ t.Fatalf("couldn't start server: %v", err)
+ }
+ defer srv.Stop()
+
+ // Check that server is full (MaxPeers=0)
+ flags := dynDialedConn
+ dialDest := clientnode
+ conn, _ := net.Pipe()
+ srv.SetupConn(conn, flags, dialDest)
+ if tp.closeErr != DiscTooManyPeers {
+ t.Errorf("unexpected close error: %q", tp.closeErr)
+ }
+ conn.Close()
+
+ srv.AddTrustedPeer(clientnode)
+
+ // Check that server allows a trusted peer despite being full.
+ conn, _ = net.Pipe()
+ srv.SetupConn(conn, flags, dialDest)
+ if tp.closeErr == DiscTooManyPeers {
+ t.Errorf("failed to bypass MaxPeers with trusted node: %q", tp.closeErr)
+ }
+
+ if tp.closeErr != DiscUselessPeer {
+ t.Errorf("unexpected close error: %q", tp.closeErr)
+ }
+ conn.Close()
+
+ srv.RemoveTrustedPeer(clientnode)
+
+ // Check that server is full again.
+ conn, _ = net.Pipe()
+ srv.SetupConn(conn, flags, dialDest)
+ if tp.closeErr != DiscTooManyPeers {
+ t.Errorf("unexpected close error: %q", tp.closeErr)
+ }
+ conn.Close()
}
func TestServerSetupConn(t *testing.T) {
- id := randomID()
- srvkey := newkey()
- srvid := discover.PubkeyID(&srvkey.PublicKey)
+ var (
+ clientkey, srvkey = newkey(), newkey()
+ clientpub = &clientkey.PublicKey
+ srvpub = &srvkey.PublicKey
+ )
tests := []struct {
dontstart bool
tt *setupTransport
flags connFlag
- dialDest *discover.Node
+ dialDest *enode.Node
wantCloseErr error
wantCalls string
}{
{
dontstart: true,
- tt: &setupTransport{id: id},
+ tt: &setupTransport{pubkey: clientpub},
wantCalls: "close,",
wantCloseErr: errServerStopped,
},
{
- tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
+ tt: &setupTransport{pubkey: clientpub, encHandshakeErr: errors.New("read error")},
flags: inboundConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: errors.New("read error"),
},
{
- tt: &setupTransport{id: id},
- dialDest: &discover.Node{ID: randomID()},
+ tt: &setupTransport{pubkey: clientpub},
+ dialDest: enode.NewV4(&newkey().PublicKey, nil, 0, 0),
flags: dynDialedConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: DiscUnexpectedIdentity,
},
{
- tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
- dialDest: &discover.Node{ID: id},
+ tt: &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: randomID().Bytes()}},
+ dialDest: enode.NewV4(clientpub, nil, 0, 0),
flags: dynDialedConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: DiscUnexpectedIdentity,
},
{
- tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
- dialDest: &discover.Node{ID: id},
+ tt: &setupTransport{pubkey: clientpub, protoHandshakeErr: errors.New("foo")},
+ dialDest: enode.NewV4(clientpub, nil, 0, 0),
flags: dynDialedConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: errors.New("foo"),
},
{
- tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
+ tt: &setupTransport{pubkey: srvpub, phs: protoHandshake{ID: crypto.FromECDSAPub(srvpub)[1:]}},
flags: inboundConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: DiscSelf,
},
{
- tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
+ tt: &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: crypto.FromECDSAPub(clientpub)[1:]}},
flags: inboundConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: DiscUselessPeer,
@@ -468,26 +562,26 @@ func TestServerSetupConn(t *testing.T) {
}
type setupTransport struct {
- id discover.NodeID
- encHandshakeErr error
-
- phs *protoHandshake
+ pubkey *ecdsa.PublicKey
+ encHandshakeErr error
+ phs protoHandshake
protoHandshakeErr error
calls string
closeErr error
}
-func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
+func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) {
c.calls += "doEncHandshake,"
- return c.id, c.encHandshakeErr
+ return c.pubkey, c.encHandshakeErr
}
+
func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
c.calls += "doProtoHandshake,"
if c.protoHandshakeErr != nil {
return nil, c.protoHandshakeErr
}
- return c.phs, nil
+ return &c.phs, nil
}
func (c *setupTransport) close(err error) {
c.calls += "close,"
@@ -510,7 +604,7 @@ func newkey() *ecdsa.PrivateKey {
return key
}
-func randomID() (id discover.NodeID) {
+func randomID() (id enode.ID) {
for i := range id {
id[i] = byte(rand.Intn(255))
}
diff --git a/p2p/simulations/adapters/docker.go b/p2p/simulations/adapters/docker.go
index 51469a4a5..10fd73232 100644
--- a/p2p/simulations/adapters/docker.go
+++ b/p2p/simulations/adapters/docker.go
@@ -30,7 +30,7 @@ import (
"github.com/docker/docker/pkg/reexec"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/node"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
// DockerAdapter is a NodeAdapter which runs simulation nodes inside Docker
@@ -61,7 +61,7 @@ func NewDockerAdapter() (*DockerAdapter, error) {
return &DockerAdapter{
ExecAdapter{
- nodes: make(map[discover.NodeID]*ExecNode),
+ nodes: make(map[enode.ID]*ExecNode),
},
}, nil
}
diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go
index 31a7dbe3f..58e261312 100644
--- a/p2p/simulations/adapters/exec.go
+++ b/p2p/simulations/adapters/exec.go
@@ -39,7 +39,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rpc"
"golang.org/x/net/websocket"
)
@@ -55,7 +55,7 @@ type ExecAdapter struct {
// simulation node are created.
BaseDir string
- nodes map[discover.NodeID]*ExecNode
+ nodes map[enode.ID]*ExecNode
}
// NewExecAdapter returns an ExecAdapter which stores node data in
@@ -63,7 +63,7 @@ type ExecAdapter struct {
func NewExecAdapter(baseDir string) *ExecAdapter {
return &ExecAdapter{
BaseDir: baseDir,
- nodes: make(map[discover.NodeID]*ExecNode),
+ nodes: make(map[enode.ID]*ExecNode),
}
}
@@ -123,7 +123,7 @@ func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) {
// ExecNode starts a simulation node by exec'ing the current binary and
// running the configured services
type ExecNode struct {
- ID discover.NodeID
+ ID enode.ID
Dir string
Config *execNodeConfig
Cmd *exec.Cmd
@@ -498,7 +498,7 @@ type wsRPCDialer struct {
// DialRPC implements the RPCDialer interface by creating a WebSocket RPC
// client of the given node
-func (w *wsRPCDialer) DialRPC(id discover.NodeID) (*rpc.Client, error) {
+func (w *wsRPCDialer) DialRPC(id enode.ID) (*rpc.Client, error) {
addr, ok := w.addrs[id.String()]
if !ok {
return nil, fmt.Errorf("unknown node: %s", id)
diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go
index 5ebfb9109..3a21dd279 100644
--- a/p2p/simulations/adapters/inproc.go
+++ b/p2p/simulations/adapters/inproc.go
@@ -27,7 +27,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rpc"
)
@@ -35,7 +35,7 @@ import (
// connects them using in-memory net.Pipe connections
type SimAdapter struct {
mtx sync.RWMutex
- nodes map[discover.NodeID]*SimNode
+ nodes map[enode.ID]*SimNode
services map[string]ServiceFunc
}
@@ -44,7 +44,7 @@ type SimAdapter struct {
// particular node are passed to the NewNode function in the NodeConfig)
func NewSimAdapter(services map[string]ServiceFunc) *SimAdapter {
return &SimAdapter{
- nodes: make(map[discover.NodeID]*SimNode),
+ nodes: make(map[enode.ID]*SimNode),
services: services,
}
}
@@ -96,7 +96,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) {
node: n,
adapter: s,
running: make(map[string]node.Service),
- connected: make(map[discover.NodeID]bool),
+ connected: make(map[enode.ID]bool),
}
s.nodes[id] = simNode
return simNode, nil
@@ -104,27 +104,27 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) {
// Dial implements the p2p.NodeDialer interface by connecting to the node using
// an in-memory net.Pipe connection
-func (s *SimAdapter) Dial(dest *discover.Node) (conn net.Conn, err error) {
- node, ok := s.GetNode(dest.ID)
+func (s *SimAdapter) Dial(dest *enode.Node) (conn net.Conn, err error) {
+ node, ok := s.GetNode(dest.ID())
if !ok {
- return nil, fmt.Errorf("unknown node: %s", dest.ID)
+ return nil, fmt.Errorf("unknown node: %s", dest.ID())
}
- if node.connected[dest.ID] {
- return nil, fmt.Errorf("dialed node: %s", dest.ID)
+ if node.connected[dest.ID()] {
+ return nil, fmt.Errorf("dialed node: %s", dest.ID())
}
srv := node.Server()
if srv == nil {
- return nil, fmt.Errorf("node not running: %s", dest.ID)
+ return nil, fmt.Errorf("node not running: %s", dest.ID())
}
pipe1, pipe2 := net.Pipe()
go srv.SetupConn(pipe1, 0, nil)
- node.connected[dest.ID] = true
+ node.connected[dest.ID()] = true
return pipe2, nil
}
// DialRPC implements the RPCDialer interface by creating an in-memory RPC
// client of the given node
-func (s *SimAdapter) DialRPC(id discover.NodeID) (*rpc.Client, error) {
+func (s *SimAdapter) DialRPC(id enode.ID) (*rpc.Client, error) {
node, ok := s.GetNode(id)
if !ok {
return nil, fmt.Errorf("unknown node: %s", id)
@@ -137,7 +137,7 @@ func (s *SimAdapter) DialRPC(id discover.NodeID) (*rpc.Client, error) {
}
// GetNode returns the node with the given ID if it exists
-func (s *SimAdapter) GetNode(id discover.NodeID) (*SimNode, bool) {
+func (s *SimAdapter) GetNode(id enode.ID) (*SimNode, bool) {
s.mtx.RLock()
defer s.mtx.RUnlock()
node, ok := s.nodes[id]
@@ -149,14 +149,14 @@ func (s *SimAdapter) GetNode(id discover.NodeID) (*SimNode, bool) {
// protocols directly over that pipe
type SimNode struct {
lock sync.RWMutex
- ID discover.NodeID
+ ID enode.ID
config *NodeConfig
adapter *SimAdapter
node *node.Node
running map[string]node.Service
client *rpc.Client
registerOnce sync.Once
- connected map[discover.NodeID]bool
+ connected map[enode.ID]bool
}
// Addr returns the node's discovery address
@@ -164,9 +164,9 @@ func (self *SimNode) Addr() []byte {
return []byte(self.Node().String())
}
-// Node returns a discover.Node representing the SimNode
-func (self *SimNode) Node() *discover.Node {
- return discover.NewNode(self.ID, net.IP{127, 0, 0, 1}, 30303, 30303)
+// Node returns a node descriptor representing the SimNode
+func (sn *SimNode) Node() *enode.Node {
+ return sn.config.Node()
}
// Client returns an rpc.Client which can be used to communicate with the
diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go
index f03bbe75b..089f50ea2 100644
--- a/p2p/simulations/adapters/types.go
+++ b/p2p/simulations/adapters/types.go
@@ -23,12 +23,14 @@ import (
"fmt"
"net"
"os"
+ "strconv"
"github.com/docker/docker/pkg/reexec"
+
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rpc"
)
@@ -38,7 +40,6 @@ import (
// * SimNode - An in-memory node
// * ExecNode - A child process node
// * DockerNode - A Docker container node
-//
type Node interface {
// Addr returns the node's address (e.g. an Enode URL)
Addr() []byte
@@ -77,7 +78,7 @@ type NodeAdapter interface {
type NodeConfig struct {
// ID is the node's ID which is used to identify the node in the
// simulation network
- ID discover.NodeID
+ ID enode.ID
// PrivateKey is the node's private key which is used by the devp2p
// stack to encrypt communications
@@ -96,25 +97,31 @@ type NodeConfig struct {
Services []string
// function to sanction or prevent suggesting a peer
- Reachable func(id discover.NodeID) bool
+ Reachable func(id enode.ID) bool
+
+ Port uint16
}
// nodeConfigJSON is used to encode and decode NodeConfig as JSON by encoding
// all fields as strings
type nodeConfigJSON struct {
- ID string `json:"id"`
- PrivateKey string `json:"private_key"`
- Name string `json:"name"`
- Services []string `json:"services"`
+ ID string `json:"id"`
+ PrivateKey string `json:"private_key"`
+ Name string `json:"name"`
+ Services []string `json:"services"`
+ EnableMsgEvents bool `json:"enable_msg_events"`
+ Port uint16 `json:"port"`
}
// MarshalJSON implements the json.Marshaler interface by encoding the config
// fields as strings
func (n *NodeConfig) MarshalJSON() ([]byte, error) {
confJSON := nodeConfigJSON{
- ID: n.ID.String(),
- Name: n.Name,
- Services: n.Services,
+ ID: n.ID.String(),
+ Name: n.Name,
+ Services: n.Services,
+ Port: n.Port,
+ EnableMsgEvents: n.EnableMsgEvents,
}
if n.PrivateKey != nil {
confJSON.PrivateKey = hex.EncodeToString(crypto.FromECDSA(n.PrivateKey))
@@ -131,11 +138,9 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error {
}
if confJSON.ID != "" {
- nodeID, err := discover.HexID(confJSON.ID)
- if err != nil {
+ if err := n.ID.UnmarshalText([]byte(confJSON.ID)); err != nil {
return err
}
- n.ID = nodeID
}
if confJSON.PrivateKey != "" {
@@ -152,10 +157,17 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error {
n.Name = confJSON.Name
n.Services = confJSON.Services
+ n.Port = confJSON.Port
+ n.EnableMsgEvents = confJSON.EnableMsgEvents
return nil
}
+// Node returns the node descriptor represented by the config.
+func (n *NodeConfig) Node() *enode.Node {
+ return enode.NewV4(&n.PrivateKey.PublicKey, net.IP{127, 0, 0, 1}, int(n.Port), int(n.Port))
+}
+
// RandomNodeConfig returns node configuration with a randomly generated ID and
// PrivateKey
func RandomNodeConfig() *NodeConfig {
@@ -163,13 +175,36 @@ func RandomNodeConfig() *NodeConfig {
if err != nil {
panic("unable to generate key")
}
- var id discover.NodeID
- pubkey := crypto.FromECDSAPub(&key.PublicKey)
- copy(id[:], pubkey[1:])
+
+ id := enode.PubkeyToIDV4(&key.PublicKey)
+ port, err := assignTCPPort()
+ if err != nil {
+ panic("unable to assign tcp port")
+ }
return &NodeConfig{
- ID: id,
- PrivateKey: key,
+ ID: id,
+ Name: fmt.Sprintf("node_%s", id.String()),
+ PrivateKey: key,
+ Port: port,
+ EnableMsgEvents: true,
+ }
+}
+
+func assignTCPPort() (uint16, error) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return 0, err
+ }
+ l.Close()
+ _, port, err := net.SplitHostPort(l.Addr().String())
+ if err != nil {
+ return 0, err
+ }
+ p, err := strconv.ParseInt(port, 10, 32)
+ if err != nil {
+ return 0, err
}
+ return uint16(p), nil
}
// ServiceContext is a collection of options and methods which can be utilised
@@ -186,7 +221,7 @@ type ServiceContext struct {
// other nodes in the network (for example a simulated Swarm node which needs
// to connect to a Geth node to resolve ENS names)
type RPCDialer interface {
- DialRPC(id discover.NodeID) (*rpc.Client, error)
+ DialRPC(id enode.ID) (*rpc.Client, error)
}
// Services is a collection of services which can be run in a simulation
diff --git a/p2p/simulations/examples/ping-pong.go b/p2p/simulations/examples/ping-pong.go
index dae524d05..de7a9e6b5 100644
--- a/p2p/simulations/examples/ping-pong.go
+++ b/p2p/simulations/examples/ping-pong.go
@@ -28,7 +28,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
"github.com/tomochain/tomochain/rpc"
@@ -96,12 +96,12 @@ func main() {
// sends a ping to all its connected peers every 10s and receives a pong in
// return
type pingPongService struct {
- id discover.NodeID
+ id enode.ID
log log.Logger
received int64
}
-func newPingPongService(id discover.NodeID) *pingPongService {
+func newPingPongService(id enode.ID) *pingPongService {
return &pingPongService{
id: id,
log: log.New("node.id", id),
diff --git a/p2p/simulations/http.go b/p2p/simulations/http.go
index 29159b6fc..d7ed380a4 100644
--- a/p2p/simulations/http.go
+++ b/p2p/simulations/http.go
@@ -29,10 +29,11 @@ import (
"strings"
"sync"
+ "github.com/tomochain/tomochain/p2p/enode"
+
"github.com/julienschmidt/httprouter"
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
"github.com/tomochain/tomochain/rpc"
"golang.org/x/net/websocket"
@@ -698,18 +699,19 @@ func (s *Server) JSON(w http.ResponseWriter, status int, data interface{}) {
json.NewEncoder(w).Encode(data)
}
-// wrapHandler returns a httprouter.Handle which wraps a http.HandlerFunc by
+// wrapHandler returns an httprouter.Handle which wraps an http.HandlerFunc by
// populating request.Context with any objects from the URL params
func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle {
return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
- ctx := context.Background()
+ ctx := req.Context()
if id := params.ByName("nodeid"); id != "" {
+ var nodeID enode.ID
var node *Node
- if nodeID, err := discover.HexID(id); err == nil {
+ if nodeID.UnmarshalText([]byte(id)) == nil {
node = s.network.GetNode(nodeID)
} else {
node = s.network.GetNodeByName(id)
@@ -722,8 +724,9 @@ func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle {
}
if id := params.ByName("peerid"); id != "" {
+ var peerID enode.ID
var peer *Node
- if peerID, err := discover.HexID(id); err == nil {
+ if peerID.UnmarshalText([]byte(id)) == nil {
peer = s.network.GetNode(peerID)
} else {
peer = s.network.GetNodeByName(id)
diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go
index e00b8057c..d557e5f99 100644
--- a/p2p/simulations/http_test.go
+++ b/p2p/simulations/http_test.go
@@ -30,7 +30,7 @@ import (
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
"github.com/tomochain/tomochain/rpc"
)
@@ -38,12 +38,12 @@ import (
// testService implements the node.Service interface and provides protocols
// and APIs which are useful for testing nodes in a simulation network
type testService struct {
- id discover.NodeID
+ id enode.ID
// peerCount is incremented once a peer handshake has been performed
peerCount int64
- peers map[discover.NodeID]*testPeer
+ peers map[enode.ID]*testPeer
peersMtx sync.Mutex
// state stores []byte which is used to test creating and loading
@@ -54,7 +54,7 @@ type testService struct {
func newTestService(ctx *adapters.ServiceContext) (node.Service, error) {
svc := &testService{
id: ctx.Config.ID,
- peers: make(map[discover.NodeID]*testPeer),
+ peers: make(map[enode.ID]*testPeer),
}
svc.state.Store(ctx.Snapshot)
return svc, nil
@@ -65,7 +65,7 @@ type testPeer struct {
dumReady chan struct{}
}
-func (t *testService) peer(id discover.NodeID) *testPeer {
+func (t *testService) peer(id enode.ID) *testPeer {
t.peersMtx.Lock()
defer t.peersMtx.Unlock()
if peer, ok := t.peers[id]; ok {
@@ -411,7 +411,7 @@ func (t *expectEvents) nodeEvent(id string, up bool) *Event {
Type: EventTypeNode,
Node: &Node{
Config: &adapters.NodeConfig{
- ID: discover.MustHexID(id),
+ ID: enode.HexID(id),
},
Up: up,
},
@@ -422,8 +422,8 @@ func (t *expectEvents) connEvent(one, other string, up bool) *Event {
return &Event{
Type: EventTypeConn,
Conn: &Conn{
- One: discover.MustHexID(one),
- Other: discover.MustHexID(other),
+ One: enode.HexID(one),
+ Other: enode.HexID(other),
Up: up,
},
}
diff --git a/p2p/simulations/mocker.go b/p2p/simulations/mocker.go
index daff17e29..d052d9e26 100644
--- a/p2p/simulations/mocker.go
+++ b/p2p/simulations/mocker.go
@@ -25,23 +25,23 @@ import (
"time"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
-//a map of mocker names to its function
+// a map of mocker names to its function
var mockerList = map[string]func(net *Network, quit chan struct{}, nodeCount int){
"startStop": startStop,
"probabilistic": probabilistic,
"boot": boot,
}
-//Lookup a mocker by its name, returns the mockerFn
+// Lookup a mocker by its name, returns the mockerFn
func LookupMocker(mockerType string) func(net *Network, quit chan struct{}, nodeCount int) {
return mockerList[mockerType]
}
-//Get a list of mockers (keys of the map)
-//Useful for frontend to build available mocker selection
+// Get a list of mockers (keys of the map)
+// Useful for frontend to build available mocker selection
func GetMockerList() []string {
list := make([]string, 0, len(mockerList))
for k := range mockerList {
@@ -50,7 +50,7 @@ func GetMockerList() []string {
return list
}
-//The boot mockerFn only connects the node in a ring and doesn't do anything else
+// The boot mockerFn only connects the node in a ring and doesn't do anything else
func boot(net *Network, quit chan struct{}, nodeCount int) {
_, err := connectNodesInRing(net, nodeCount)
if err != nil {
@@ -58,7 +58,7 @@ func boot(net *Network, quit chan struct{}, nodeCount int) {
}
}
-//The startStop mockerFn stops and starts nodes in a defined period (ticker)
+// The startStop mockerFn stops and starts nodes in a defined period (ticker)
func startStop(net *Network, quit chan struct{}, nodeCount int) {
nodes, err := connectNodesInRing(net, nodeCount)
if err != nil {
@@ -95,10 +95,10 @@ func startStop(net *Network, quit chan struct{}, nodeCount int) {
}
}
-//The probabilistic mocker func has a more probabilistic pattern
-//(the implementation could probably be improved):
-//nodes are connected in a ring, then a varying number of random nodes is selected,
-//mocker then stops and starts them in random intervals, and continues the loop
+// The probabilistic mocker func has a more probabilistic pattern
+// (the implementation could probably be improved):
+// nodes are connected in a ring, then a varying number of random nodes is selected,
+// mocker then stops and starts them in random intervals, and continues the loop
func probabilistic(net *Network, quit chan struct{}, nodeCount int) {
nodes, err := connectNodesInRing(net, nodeCount)
if err != nil {
@@ -147,7 +147,7 @@ func probabilistic(net *Network, quit chan struct{}, nodeCount int) {
wg.Done()
continue
}
- go func(id discover.NodeID) {
+ go func(id enode.ID) {
time.Sleep(randWait)
err := net.Start(id)
if err != nil {
@@ -161,9 +161,9 @@ func probabilistic(net *Network, quit chan struct{}, nodeCount int) {
}
-//connect nodeCount number of nodes in a ring
-func connectNodesInRing(net *Network, nodeCount int) ([]discover.NodeID, error) {
- ids := make([]discover.NodeID, nodeCount)
+// connect nodeCount number of nodes in a ring
+func connectNodesInRing(net *Network, nodeCount int) ([]enode.ID, error) {
+ ids := make([]enode.ID, nodeCount)
for i := 0; i < nodeCount; i++ {
node, err := net.NewNode()
if err != nil {
diff --git a/p2p/simulations/mocker_test.go b/p2p/simulations/mocker_test.go
index f9a23bfe1..397c4c5dc 100644
--- a/p2p/simulations/mocker_test.go
+++ b/p2p/simulations/mocker_test.go
@@ -27,7 +27,7 @@ import (
"testing"
"time"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
func TestMocker(t *testing.T) {
@@ -82,7 +82,7 @@ func TestMocker(t *testing.T) {
defer sub.Unsubscribe()
//wait until all nodes are started and connected
//store every node up event in a map (value is irrelevant, mimic Set datatype)
- nodemap := make(map[discover.NodeID]bool)
+ nodemap := make(map[enode.ID]bool)
wg.Add(1)
nodesComplete := false
connCount := 0
diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go
index 08643f7d8..1c50b8f35 100644
--- a/p2p/simulations/network.go
+++ b/p2p/simulations/network.go
@@ -27,7 +27,7 @@ import (
"github.com/tomochain/tomochain/event"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
)
@@ -51,7 +51,7 @@ type Network struct {
NetworkConfig
Nodes []*Node `json:"nodes"`
- nodeMap map[discover.NodeID]int
+ nodeMap map[enode.ID]int
Conns []*Conn `json:"conns"`
connMap map[string]int
@@ -67,7 +67,7 @@ func NewNetwork(nodeAdapter adapters.NodeAdapter, conf *NetworkConfig) *Network
return &Network{
NetworkConfig: *conf,
nodeAdapter: nodeAdapter,
- nodeMap: make(map[discover.NodeID]int),
+ nodeMap: make(map[enode.ID]int),
connMap: make(map[string]int),
quitc: make(chan struct{}),
}
@@ -92,14 +92,14 @@ func (self *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error)
defer self.lock.Unlock()
// create a random ID and PrivateKey if not set
- if conf.ID == (discover.NodeID{}) {
+ if conf.ID == (enode.ID{}) {
c := adapters.RandomNodeConfig()
conf.ID = c.ID
conf.PrivateKey = c.PrivateKey
}
id := conf.ID
if conf.Reachable == nil {
- conf.Reachable = func(otherID discover.NodeID) bool {
+ conf.Reachable = func(otherID enode.ID) bool {
_, err := self.InitConn(conf.ID, otherID)
return err == nil
}
@@ -174,13 +174,13 @@ func (self *Network) StopAll() error {
}
// Start starts the node with the given ID
-func (self *Network) Start(id discover.NodeID) error {
+func (self *Network) Start(id enode.ID) error {
return self.startWithSnapshots(id, nil)
}
// startWithSnapshots starts the node with the given ID using the give
// snapshots
-func (self *Network) startWithSnapshots(id discover.NodeID, snapshots map[string][]byte) error {
+func (self *Network) startWithSnapshots(id enode.ID, snapshots map[string][]byte) error {
node := self.GetNode(id)
if node == nil {
return fmt.Errorf("node %v does not exist", id)
@@ -214,7 +214,7 @@ func (self *Network) startWithSnapshots(id discover.NodeID, snapshots map[string
// watchPeerEvents reads peer events from the given channel and emits
// corresponding network events
-func (self *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEvent, sub event.Subscription) {
+func (self *Network) watchPeerEvents(id enode.ID, events chan *p2p.PeerEvent, sub event.Subscription) {
defer func() {
sub.Unsubscribe()
@@ -258,7 +258,7 @@ func (self *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEv
}
// Stop stops the node with the given ID
-func (self *Network) Stop(id discover.NodeID) error {
+func (self *Network) Stop(id enode.ID) error {
node := self.GetNode(id)
if node == nil {
return fmt.Errorf("node %v does not exist", id)
@@ -278,7 +278,7 @@ func (self *Network) Stop(id discover.NodeID) error {
// Connect connects two nodes together by calling the "admin_addPeer" RPC
// method on the "one" node so that it connects to the "other" node
-func (self *Network) Connect(oneID, otherID discover.NodeID) error {
+func (self *Network) Connect(oneID, otherID enode.ID) error {
log.Debug(fmt.Sprintf("connecting %s to %s", oneID, otherID))
conn, err := self.InitConn(oneID, otherID)
if err != nil {
@@ -294,7 +294,7 @@ func (self *Network) Connect(oneID, otherID discover.NodeID) error {
// Disconnect disconnects two nodes by calling the "admin_removePeer" RPC
// method on the "one" node so that it disconnects from the "other" node
-func (self *Network) Disconnect(oneID, otherID discover.NodeID) error {
+func (self *Network) Disconnect(oneID, otherID enode.ID) error {
conn := self.GetConn(oneID, otherID)
if conn == nil {
return fmt.Errorf("connection between %v and %v does not exist", oneID, otherID)
@@ -311,7 +311,7 @@ func (self *Network) Disconnect(oneID, otherID discover.NodeID) error {
}
// DidConnect tracks the fact that the "one" node connected to the "other" node
-func (self *Network) DidConnect(one, other discover.NodeID) error {
+func (self *Network) DidConnect(one, other enode.ID) error {
conn, err := self.GetOrCreateConn(one, other)
if err != nil {
return fmt.Errorf("connection between %v and %v does not exist", one, other)
@@ -326,7 +326,7 @@ func (self *Network) DidConnect(one, other discover.NodeID) error {
// DidDisconnect tracks the fact that the "one" node disconnected from the
// "other" node
-func (self *Network) DidDisconnect(one, other discover.NodeID) error {
+func (self *Network) DidDisconnect(one, other enode.ID) error {
conn := self.GetConn(one, other)
if conn == nil {
return fmt.Errorf("connection between %v and %v does not exist", one, other)
@@ -341,7 +341,7 @@ func (self *Network) DidDisconnect(one, other discover.NodeID) error {
}
// DidSend tracks the fact that "sender" sent a message to "receiver"
-func (self *Network) DidSend(sender, receiver discover.NodeID, proto string, code uint64) error {
+func (self *Network) DidSend(sender, receiver enode.ID, proto string, code uint64) error {
msg := &Msg{
One: sender,
Other: receiver,
@@ -354,7 +354,7 @@ func (self *Network) DidSend(sender, receiver discover.NodeID, proto string, cod
}
// DidReceive tracks the fact that "receiver" received a message from "sender"
-func (self *Network) DidReceive(sender, receiver discover.NodeID, proto string, code uint64) error {
+func (self *Network) DidReceive(sender, receiver enode.ID, proto string, code uint64) error {
msg := &Msg{
One: sender,
Other: receiver,
@@ -368,7 +368,7 @@ func (self *Network) DidReceive(sender, receiver discover.NodeID, proto string,
// GetNode gets the node with the given ID, returning nil if the node does not
// exist
-func (self *Network) GetNode(id discover.NodeID) *Node {
+func (self *Network) GetNode(id enode.ID) *Node {
self.lock.Lock()
defer self.lock.Unlock()
return self.getNode(id)
@@ -382,7 +382,7 @@ func (self *Network) GetNodeByName(name string) *Node {
return self.getNodeByName(name)
}
-func (self *Network) getNode(id discover.NodeID) *Node {
+func (self *Network) getNode(id enode.ID) *Node {
i, found := self.nodeMap[id]
if !found {
return nil
@@ -410,7 +410,7 @@ func (self *Network) GetNodes() (nodes []*Node) {
// GetConn returns the connection which exists between "one" and "other"
// regardless of which node initiated the connection
-func (self *Network) GetConn(oneID, otherID discover.NodeID) *Conn {
+func (self *Network) GetConn(oneID, otherID enode.ID) *Conn {
self.lock.Lock()
defer self.lock.Unlock()
return self.getConn(oneID, otherID)
@@ -418,13 +418,13 @@ func (self *Network) GetConn(oneID, otherID discover.NodeID) *Conn {
// GetOrCreateConn is like GetConn but creates the connection if it doesn't
// already exist
-func (self *Network) GetOrCreateConn(oneID, otherID discover.NodeID) (*Conn, error) {
+func (self *Network) GetOrCreateConn(oneID, otherID enode.ID) (*Conn, error) {
self.lock.Lock()
defer self.lock.Unlock()
return self.getOrCreateConn(oneID, otherID)
}
-func (self *Network) getOrCreateConn(oneID, otherID discover.NodeID) (*Conn, error) {
+func (self *Network) getOrCreateConn(oneID, otherID enode.ID) (*Conn, error) {
if conn := self.getConn(oneID, otherID); conn != nil {
return conn, nil
}
@@ -449,7 +449,7 @@ func (self *Network) getOrCreateConn(oneID, otherID discover.NodeID) (*Conn, err
return conn, nil
}
-func (self *Network) getConn(oneID, otherID discover.NodeID) *Conn {
+func (self *Network) getConn(oneID, otherID enode.ID) *Conn {
label := ConnLabel(oneID, otherID)
i, found := self.connMap[label]
if !found {
@@ -466,7 +466,7 @@ func (self *Network) getConn(oneID, otherID discover.NodeID) *Conn {
// it also checks whether there has been recent attempt to connect the peers
// this is cheating as the simulation is used as an oracle and know about
// remote peers attempt to connect to a node which will then not initiate the connection
-func (self *Network) InitConn(oneID, otherID discover.NodeID) (*Conn, error) {
+func (self *Network) InitConn(oneID, otherID enode.ID) (*Conn, error) {
self.lock.Lock()
defer self.lock.Unlock()
if oneID == otherID {
@@ -501,15 +501,15 @@ func (self *Network) Shutdown() {
close(self.quitc)
}
-//Reset resets all network properties:
-//emtpies the nodes and the connection list
+// Reset resets all network properties:
+// emtpies the nodes and the connection list
func (self *Network) Reset() {
self.lock.Lock()
defer self.lock.Unlock()
//re-initialize the maps
self.connMap = make(map[string]int)
- self.nodeMap = make(map[discover.NodeID]int)
+ self.nodeMap = make(map[enode.ID]int)
self.Nodes = nil
self.Conns = nil
@@ -528,7 +528,7 @@ type Node struct {
}
// ID returns the ID of the node
-func (self *Node) ID() discover.NodeID {
+func (self *Node) ID() enode.ID {
return self.Config.ID
}
@@ -565,10 +565,10 @@ func (self *Node) MarshalJSON() ([]byte, error) {
// Conn represents a connection between two nodes in the network
type Conn struct {
// One is the node which initiated the connection
- One discover.NodeID `json:"one"`
+ One enode.ID `json:"one"`
// Other is the node which the connection was made to
- Other discover.NodeID `json:"other"`
+ Other enode.ID `json:"other"`
// Up tracks whether or not the connection is active
Up bool `json:"up"`
@@ -597,11 +597,11 @@ func (self *Conn) String() string {
// Msg represents a p2p message sent between two nodes in the network
type Msg struct {
- One discover.NodeID `json:"one"`
- Other discover.NodeID `json:"other"`
- Protocol string `json:"protocol"`
- Code uint64 `json:"code"`
- Received bool `json:"received"`
+ One enode.ID `json:"one"`
+ Other enode.ID `json:"other"`
+ Protocol string `json:"protocol"`
+ Code uint64 `json:"code"`
+ Received bool `json:"received"`
}
// String returns a log-friendly string
@@ -612,8 +612,8 @@ func (self *Msg) String() string {
// ConnLabel generates a deterministic string which represents a connection
// between two nodes, used to compare if two connections are between the same
// nodes
-func ConnLabel(source, target discover.NodeID) string {
- var first, second discover.NodeID
+func ConnLabel(source, target enode.ID) string {
+ var first, second enode.ID
if bytes.Compare(source.Bytes(), target.Bytes()) > 0 {
first = target
second = source
diff --git a/p2p/simulations/network_test.go b/p2p/simulations/network_test.go
index da97428cf..a89083b62 100644
--- a/p2p/simulations/network_test.go
+++ b/p2p/simulations/network_test.go
@@ -22,7 +22,7 @@ import (
"testing"
"time"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
)
@@ -39,7 +39,7 @@ func TestNetworkSimulation(t *testing.T) {
})
defer network.Shutdown()
nodeCount := 20
- ids := make([]discover.NodeID, nodeCount)
+ ids := make([]enode.ID, nodeCount)
for i := 0; i < nodeCount; i++ {
node, err := network.NewNode()
if err != nil {
@@ -63,7 +63,7 @@ func TestNetworkSimulation(t *testing.T) {
}
return nil
}
- check := func(ctx context.Context, id discover.NodeID) (bool, error) {
+ check := func(ctx context.Context, id enode.ID) (bool, error) {
// check we haven't run out of time
select {
case <-ctx.Done():
@@ -101,7 +101,7 @@ func TestNetworkSimulation(t *testing.T) {
defer cancel()
// trigger a check every 100ms
- trigger := make(chan discover.NodeID)
+ trigger := make(chan enode.ID)
go triggerChecks(ctx, ids, trigger, 100*time.Millisecond)
result := NewSimulation(network).Run(ctx, &Step{
@@ -139,7 +139,7 @@ func TestNetworkSimulation(t *testing.T) {
}
}
-func triggerChecks(ctx context.Context, ids []discover.NodeID, trigger chan discover.NodeID, interval time.Duration) {
+func triggerChecks(ctx context.Context, ids []enode.ID, trigger chan enode.ID, interval time.Duration) {
tick := time.NewTicker(interval)
defer tick.Stop()
for {
diff --git a/p2p/simulations/pipes/pipes.go b/p2p/simulations/pipes/pipes.go
new file mode 100644
index 000000000..ec277c0d1
--- /dev/null
+++ b/p2p/simulations/pipes/pipes.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pipes
+
+import (
+ "net"
+)
+
+// NetPipe wraps net.Pipe in a signature returning an error
+func NetPipe() (net.Conn, net.Conn, error) {
+ p1, p2 := net.Pipe()
+ return p1, p2, nil
+}
+
+// TCPPipe creates an in process full duplex pipe based on a localhost TCP socket
+func TCPPipe() (net.Conn, net.Conn, error) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, nil, err
+ }
+ defer l.Close()
+
+ var aconn net.Conn
+ aerr := make(chan error, 1)
+ go func() {
+ var err error
+ aconn, err = l.Accept()
+ aerr <- err
+ }()
+
+ dconn, err := net.Dial("tcp", l.Addr().String())
+ if err != nil {
+ <-aerr
+ return nil, nil, err
+ }
+ if err := <-aerr; err != nil {
+ dconn.Close()
+ return nil, nil, err
+ }
+ return aconn, dconn, nil
+}
diff --git a/p2p/simulations/simulation.go b/p2p/simulations/simulation.go
index 6fc879ed1..0879cd912 100644
--- a/p2p/simulations/simulation.go
+++ b/p2p/simulations/simulation.go
@@ -20,7 +20,7 @@ import (
"context"
"time"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
// Simulation provides a framework for running actions in a simulated network
@@ -55,7 +55,7 @@ func (s *Simulation) Run(ctx context.Context, step *Step) (result *StepResult) {
}
// wait for all node expectations to either pass, error or timeout
- nodes := make(map[discover.NodeID]struct{}, len(step.Expect.Nodes))
+ nodes := make(map[enode.ID]struct{}, len(step.Expect.Nodes))
for _, id := range step.Expect.Nodes {
nodes[id] = struct{}{}
}
@@ -119,7 +119,7 @@ type Step struct {
// Trigger is a channel which receives node ids and triggers an
// expectation check for that node
- Trigger chan discover.NodeID
+ Trigger chan enode.ID
// Expect is the expectation to wait for when performing this step
Expect *Expectation
@@ -127,15 +127,15 @@ type Step struct {
type Expectation struct {
// Nodes is a list of nodes to check
- Nodes []discover.NodeID
+ Nodes []enode.ID
// Check checks whether a given node meets the expectation
- Check func(context.Context, discover.NodeID) (bool, error)
+ Check func(context.Context, enode.ID) (bool, error)
}
func newStepResult() *StepResult {
return &StepResult{
- Passes: make(map[discover.NodeID]time.Time),
+ Passes: make(map[enode.ID]time.Time),
}
}
@@ -150,7 +150,7 @@ type StepResult struct {
FinishedAt time.Time
// Passes are the timestamps of the successful node expectations
- Passes map[discover.NodeID]time.Time
+ Passes map[enode.ID]time.Time
// NetworkEvents are the network events which occurred during the step
NetworkEvents []*Event
diff --git a/p2p/testing/peerpool.go b/p2p/testing/peerpool.go
index 0934cfbdb..9b80e8e05 100644
--- a/p2p/testing/peerpool.go
+++ b/p2p/testing/peerpool.go
@@ -21,22 +21,22 @@ import (
"sync"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
)
type TestPeer interface {
- ID() discover.NodeID
+ ID() enode.ID
Drop(error)
}
// TestPeerPool is an example peerPool to demonstrate registration of peer connections
type TestPeerPool struct {
lock sync.Mutex
- peers map[discover.NodeID]TestPeer
+ peers map[enode.ID]TestPeer
}
func NewTestPeerPool() *TestPeerPool {
- return &TestPeerPool{peers: make(map[discover.NodeID]TestPeer)}
+ return &TestPeerPool{peers: make(map[enode.ID]TestPeer)}
}
func (self *TestPeerPool) Add(p TestPeer) {
@@ -53,14 +53,14 @@ func (self *TestPeerPool) Remove(p TestPeer) {
delete(self.peers, p.ID())
}
-func (self *TestPeerPool) Has(id discover.NodeID) bool {
+func (self *TestPeerPool) Has(id enode.ID) bool {
self.lock.Lock()
defer self.lock.Unlock()
_, ok := self.peers[id]
return ok
}
-func (self *TestPeerPool) Get(id discover.NodeID) TestPeer {
+func (self *TestPeerPool) Get(id enode.ID) TestPeer {
self.lock.Lock()
defer self.lock.Unlock()
return self.peers[id]
diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go
index 6f4d4c499..783af99bf 100644
--- a/p2p/testing/protocolsession.go
+++ b/p2p/testing/protocolsession.go
@@ -24,7 +24,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
)
@@ -35,7 +35,7 @@ var errTimedOut = errors.New("timed out")
// receive (expect) messages
type ProtocolSession struct {
Server *p2p.Server
- IDs []discover.NodeID
+ IDs []enode.ID
adapter *adapters.SimAdapter
events chan *p2p.PeerEvent
}
@@ -56,25 +56,25 @@ type Exchange struct {
// Trigger is part of the exchange, incoming message for the pivot node
// sent by a peer
type Trigger struct {
- Msg interface{} // type of message to be sent
- Code uint64 // code of message is given
- Peer discover.NodeID // the peer to send the message to
- Timeout time.Duration // timeout duration for the sending
+ Msg interface{} // type of message to be sent
+ Code uint64 // code of message is given
+ Peer enode.ID // the peer to send the message to
+ Timeout time.Duration // timeout duration for the sending
}
// Expect is part of an exchange, outgoing message from the pivot node
// received by a peer
type Expect struct {
- Msg interface{} // type of message to expect
- Code uint64 // code of message is now given
- Peer discover.NodeID // the peer that expects the message
- Timeout time.Duration // timeout duration for receiving
+ Msg interface{} // type of message to expect
+ Code uint64 // code of message is now given
+ Peer enode.ID // the peer that expects the message
+ Timeout time.Duration // timeout duration for receiving
}
// Disconnect represents a disconnect event, used and checked by TestDisconnected
type Disconnect struct {
- Peer discover.NodeID // discconnected peer
- Error error // disconnect reason
+ Peer enode.ID // discconnected peer
+ Error error // disconnect reason
}
// trigger sends messages from peers
@@ -109,7 +109,7 @@ func (self *ProtocolSession) trigger(trig Trigger) error {
// expect checks an expectation of a message sent out by the pivot node
func (self *ProtocolSession) expect(exps []Expect) error {
// construct a map of expectations for each node
- peerExpects := make(map[discover.NodeID][]Expect)
+ peerExpects := make(map[enode.ID][]Expect)
for _, exp := range exps {
if exp.Msg == nil {
return errors.New("no message to expect")
@@ -118,7 +118,7 @@ func (self *ProtocolSession) expect(exps []Expect) error {
}
// construct a map of mockNodes for each node
- mockNodes := make(map[discover.NodeID]*mockNode)
+ mockNodes := make(map[enode.ID]*mockNode)
for nodeID := range peerExpects {
simNode, ok := self.adapter.GetNode(nodeID)
if !ok {
@@ -251,7 +251,7 @@ func (self *ProtocolSession) testExchange(e Exchange) error {
// TestDisconnected tests the disconnections given as arguments
// the disconnect structs describe what disconnect error is expected on which peer
func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error {
- expects := make(map[discover.NodeID]error)
+ expects := make(map[enode.ID]error)
for _, disconnect := range disconnects {
expects[disconnect.Peer] = disconnect.Error
}
diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go
index 0ac8b05f3..f16acbac9 100644
--- a/p2p/testing/protocoltester.go
+++ b/p2p/testing/protocoltester.go
@@ -35,7 +35,7 @@ import (
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/simulations"
"github.com/tomochain/tomochain/p2p/simulations/adapters"
"github.com/tomochain/tomochain/rlp"
@@ -52,7 +52,7 @@ type ProtocolTester struct {
// NewProtocolTester constructs a new ProtocolTester
// it takes as argument the pivot node id, the number of dummy peers and the
// protocol run function called on a peer connection by the p2p server
-func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester {
+func NewProtocolTester(t *testing.T, id enode.ID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester {
services := adapters.Services{
"test": func(ctx *adapters.ServiceContext) (node.Service, error) {
return &testNode{run}, nil
@@ -76,7 +76,7 @@ func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Pe
node := net.GetNode(id).Node.(*adapters.SimNode)
peers := make([]*adapters.NodeConfig, n)
- peerIDs := make([]discover.NodeID, n)
+ peerIDs := make([]enode.ID, n)
for i := 0; i < n; i++ {
peers[i] = adapters.RandomNodeConfig()
peers[i].Services = []string{"mock"}
@@ -108,7 +108,7 @@ func (self *ProtocolTester) Stop() error {
// Connect brings up the remote peer node and connects it using the
// p2p/simulations network connection with the in memory network adapter
-func (self *ProtocolTester) Connect(selfID discover.NodeID, peers ...*adapters.NodeConfig) {
+func (self *ProtocolTester) Connect(selfID enode.ID, peers ...*adapters.NodeConfig) {
for _, peer := range peers {
log.Trace(fmt.Sprintf("start node %v", peer.ID))
if _, err := self.network.NewNodeWithConfig(peer); err != nil {
diff --git a/params/version.go b/params/version.go
index af4d16e53..c220c1273 100644
--- a/params/version.go
+++ b/params/version.go
@@ -21,10 +21,10 @@ import (
)
const (
- VersionMajor = 2 // Major version component of the current release
- VersionMinor = 3 // Minor version component of the current release
- VersionPatch = 2 // Patch version component of the current release
- VersionMeta = "stable" // Version metadata to append to the version string
+ VersionMajor = 2 // Major version component of the current release
+ VersionMinor = 4 // Minor version component of the current release
+ VersionPatch = 0 // Patch version component of the current release
+ VersionMeta = "dev" // Version metadata to append to the version string
)
// Version holds the textual version string.
diff --git a/rlp/decode.go b/rlp/decode.go
index 60d9dab2b..ac93c139a 100644
--- a/rlp/decode.go
+++ b/rlp/decode.go
@@ -26,100 +26,77 @@ import (
"math/big"
"reflect"
"strings"
+ "sync"
+
+ "github.com/holiman/uint256"
+ "github.com/tomochain/tomochain/rlp/internal/rlpstruct"
)
+//lint:ignore ST1012 EOL is not an error.
+
+// EOL is returned when the end of the current list
+// has been reached during streaming.
+var EOL = errors.New("rlp: end of list")
+
var (
+ ErrExpectedString = errors.New("rlp: expected String or Byte")
+ ErrExpectedList = errors.New("rlp: expected List")
+ ErrCanonInt = errors.New("rlp: non-canonical integer format")
+ ErrCanonSize = errors.New("rlp: non-canonical size information")
+ ErrElemTooLarge = errors.New("rlp: element is larger than containing list")
+ ErrValueTooLarge = errors.New("rlp: value size exceeds available input length")
+ ErrMoreThanOneValue = errors.New("rlp: input contains more than one value")
+
+ // internal errors
+ errNotInList = errors.New("rlp: call of ListEnd outside of any list")
+ errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL")
+ errUintOverflow = errors.New("rlp: uint overflow")
errNoPointer = errors.New("rlp: interface given to Decode must be a pointer")
errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil")
+ errUint256Large = errors.New("rlp: value too large for uint256")
+
+ streamPool = sync.Pool{
+ New: func() interface{} { return new(Stream) },
+ }
)
-// Decoder is implemented by types that require custom RLP
-// decoding rules or need to decode into private fields.
+// Decoder is implemented by types that require custom RLP decoding rules or need to decode
+// into private fields.
//
-// The DecodeRLP method should read one value from the given
-// Stream. It is not forbidden to read less or more, but it might
-// be confusing.
+// The DecodeRLP method should read one value from the given Stream. It is not forbidden to
+// read less or more, but it might be confusing.
type Decoder interface {
DecodeRLP(*Stream) error
}
-// Decode parses RLP-encoded data from r and stores the result in the
-// value pointed to by val. Val must be a non-nil pointer. If r does
-// not implement ByteReader, Decode will do its own buffering.
-//
-// Decode uses the following type-dependent decoding rules:
-//
-// If the type implements the Decoder interface, decode calls
-// DecodeRLP.
-//
-// To decode into a pointer, Decode will decode into the value pointed
-// to. If the pointer is nil, a new value of the pointer's element
-// type is allocated. If the pointer is non-nil, the existing value
-// will be reused.
-//
-// To decode into a struct, Decode expects the input to be an RLP
-// list. The decoded elements of the list are assigned to each public
-// field in the order given by the struct's definition. The input list
-// must contain an element for each decoded field. Decode returns an
-// error if there are too few or too many elements.
+// Decode parses RLP-encoded data from r and stores the result in the value pointed to by
+// val. Please see package-level documentation for the decoding rules. Val must be a
+// non-nil pointer.
//
-// The decoding of struct fields honours certain struct tags, "tail",
-// "nil" and "-".
+// If r does not implement ByteReader, Decode will do its own buffering.
//
-// The "-" tag ignores fields.
+// Note that Decode does not set an input limit for all readers and may be vulnerable to
+// panics cause by huge value sizes. If you need an input limit, use
//
-// For an explanation of "tail", see the example.
-//
-// The "nil" tag applies to pointer-typed fields and changes the decoding
-// rules for the field such that input values of size zero decode as a nil
-// pointer. This tag can be useful when decoding recursive types.
-//
-// type StructWithEmptyOK struct {
-// Foo *[20]byte `rlp:"nil"`
-// }
-//
-// To decode into a slice, the input must be a list and the resulting
-// slice will contain the input elements in order. For byte slices,
-// the input must be an RLP string. Array types decode similarly, with
-// the additional restriction that the number of input elements (or
-// bytes) must match the array's length.
-//
-// To decode into a Go string, the input must be an RLP string. The
-// input bytes are taken as-is and will not necessarily be valid UTF-8.
-//
-// To decode into an unsigned integer type, the input must also be an RLP
-// string. The bytes are interpreted as a big endian representation of
-// the integer. If the RLP string is larger than the bit size of the
-// type, Decode will return an error. Decode also supports *big.Int.
-// There is no size limit for big integers.
-//
-// To decode into an interface value, Decode stores one of these
-// in the value:
-//
-// []interface{}, for RLP lists
-// []byte, for RLP strings
-//
-// Non-empty interface types are not supported, nor are booleans,
-// signed integers, floating point numbers, maps, channels and
-// functions.
-//
-// Note that Decode does not set an input limit for all readers
-// and may be vulnerable to panics cause by huge value sizes. If
-// you need an input limit, use
-//
-// NewStream(r, limit).Decode(val)
+// NewStream(r, limit).Decode(val)
func Decode(r io.Reader, val interface{}) error {
- // TODO: this could use a Stream from a pool.
- return NewStream(r, 0).Decode(val)
+ stream := streamPool.Get().(*Stream)
+ defer streamPool.Put(stream)
+
+ stream.Reset(r, 0)
+ return stream.Decode(val)
}
-// DecodeBytes parses RLP data from b into val.
-// Please see the documentation of Decode for the decoding rules.
-// The input must contain exactly one value and no trailing data.
+// DecodeBytes parses RLP data from b into val. Please see package-level documentation for
+// the decoding rules. The input must contain exactly one value and no trailing data.
func DecodeBytes(b []byte, val interface{}) error {
- // TODO: this could use a Stream from a pool.
r := bytes.NewReader(b)
- if err := NewStream(r, uint64(len(b))).Decode(val); err != nil {
+
+ stream := streamPool.Get().(*Stream)
+ defer streamPool.Put(stream)
+
+ stream.Reset(r, uint64(len(b)))
+ if err := stream.Decode(val); err != nil {
return err
}
if r.Len() > 0 {
@@ -173,21 +150,26 @@ func addErrorContext(err error, ctx string) error {
var (
decoderInterface = reflect.TypeOf(new(Decoder)).Elem()
bigInt = reflect.TypeOf(big.Int{})
+ u256Int = reflect.TypeOf(uint256.Int{})
)
-func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
+func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) {
kind := typ.Kind()
switch {
case typ == rawValueType:
return decodeRawValue, nil
- case typ.Implements(decoderInterface):
- return decodeDecoder, nil
- case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface):
- return decodeDecoderNoPtr, nil
case typ.AssignableTo(reflect.PtrTo(bigInt)):
return decodeBigInt, nil
case typ.AssignableTo(bigInt):
return decodeBigIntNoPtr, nil
+ case typ == reflect.PtrTo(u256Int):
+ return decodeU256, nil
+ case typ == u256Int:
+ return decodeU256NoPtr, nil
+ case kind == reflect.Ptr:
+ return makePtrDecoder(typ, tags)
+ case reflect.PtrTo(typ).Implements(decoderInterface):
+ return decodeDecoder, nil
case isUint(kind):
return decodeUint, nil
case kind == reflect.Bool:
@@ -198,11 +180,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
return makeListDecoder(typ, tags)
case kind == reflect.Struct:
return makeStructDecoder(typ)
- case kind == reflect.Ptr:
- if tags.nilOK {
- return makeOptionalPtrDecoder(typ)
- }
- return makePtrDecoder(typ)
case kind == reflect.Interface:
return decodeInterface, nil
default:
@@ -252,35 +229,48 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error {
}
func decodeBigInt(s *Stream, val reflect.Value) error {
- b, err := s.Bytes()
+ i := val.Interface().(*big.Int)
+ if i == nil {
+ i = new(big.Int)
+ val.Set(reflect.ValueOf(i))
+ }
+
+ err := s.decodeBigInt(i)
if err != nil {
return wrapStreamError(err, val.Type())
}
- i := val.Interface().(*big.Int)
+ return nil
+}
+
+func decodeU256NoPtr(s *Stream, val reflect.Value) error {
+ return decodeU256(s, val.Addr())
+}
+
+func decodeU256(s *Stream, val reflect.Value) error {
+ i := val.Interface().(*uint256.Int)
if i == nil {
- i = new(big.Int)
+ i = new(uint256.Int)
val.Set(reflect.ValueOf(i))
}
- // Reject leading zero bytes
- if len(b) > 0 && b[0] == 0 {
- return wrapStreamError(ErrCanonInt, val.Type())
+
+ err := s.ReadUint256(i)
+ if err != nil {
+ return wrapStreamError(err, val.Type())
}
- i.SetBytes(b)
return nil
}
-func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) {
+func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) {
etype := typ.Elem()
if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) {
if typ.Kind() == reflect.Array {
return decodeByteArray, nil
- } else {
- return decodeByteSlice, nil
}
+ return decodeByteSlice, nil
}
- etypeinfo, err := cachedTypeInfo1(etype, tags{})
- if err != nil {
- return nil, err
+ etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{})
+ if etypeinfo.decoderErr != nil {
+ return nil, etypeinfo.decoderErr
}
var dec decoder
switch {
@@ -288,7 +278,7 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) {
dec = func(s *Stream, val reflect.Value) error {
return decodeListArray(s, val, etypeinfo.decoder)
}
- case tag.tail:
+ case tag.Tail:
// A slice with "tail" tag can occur as the last field
// of a struct and is supposed to swallow all remaining
// list elements. The struct decoder already called s.List,
@@ -381,25 +371,23 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
if err != nil {
return err
}
- vlen := val.Len()
+ slice := byteArrayBytes(val, val.Len())
switch kind {
case Byte:
- if vlen == 0 {
+ if len(slice) == 0 {
return &decodeError{msg: "input string too long", typ: val.Type()}
- }
- if vlen > 1 {
+ } else if len(slice) > 1 {
return &decodeError{msg: "input string too short", typ: val.Type()}
}
- bv, _ := s.Uint()
- val.Index(0).SetUint(bv)
+ slice[0] = s.byteval
+ s.kind = -1
case String:
- if uint64(vlen) < size {
+ if uint64(len(slice)) < size {
return &decodeError{msg: "input string too long", typ: val.Type()}
}
- if uint64(vlen) > size {
+ if uint64(len(slice)) > size {
return &decodeError{msg: "input string too short", typ: val.Type()}
}
- slice := val.Slice(0, vlen).Interface().([]byte)
if err := s.readFull(slice); err != nil {
return err
}
@@ -418,13 +406,25 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
if err != nil {
return nil, err
}
+ for _, f := range fields {
+ if f.info.decoderErr != nil {
+ return nil, structFieldError{typ, f.index, f.info.decoderErr}
+ }
+ }
dec := func(s *Stream, val reflect.Value) (err error) {
if _, err := s.List(); err != nil {
return wrapStreamError(err, typ)
}
- for _, f := range fields {
+ for i, f := range fields {
err := f.info.decoder(s, val.Field(f.index))
if err == EOL {
+ if f.optional {
+ // The field is optional, so reaching the end of the list before
+ // reaching the last field is acceptable. All remaining undecoded
+ // fields are zeroed.
+ zeroFields(val, fields[i:])
+ break
+ }
return &decodeError{msg: "too few elements", typ: typ}
} else if err != nil {
return addErrorContext(err, "."+typ.Field(f.index).Name)
@@ -435,15 +435,29 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
return dec, nil
}
-// makePtrDecoder creates a decoder that decodes into
-// the pointer's element type.
-func makePtrDecoder(typ reflect.Type) (decoder, error) {
+func zeroFields(structval reflect.Value, fields []field) {
+ for _, f := range fields {
+ fv := structval.Field(f.index)
+ fv.Set(reflect.Zero(fv.Type()))
+ }
+}
+
+// makePtrDecoder creates a decoder that decodes into the pointer's element type.
+func makePtrDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) {
etype := typ.Elem()
- etypeinfo, err := cachedTypeInfo1(etype, tags{})
- if err != nil {
- return nil, err
+ etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{})
+ switch {
+ case etypeinfo.decoderErr != nil:
+ return nil, etypeinfo.decoderErr
+ case !tag.NilOK:
+ return makeSimplePtrDecoder(etype, etypeinfo), nil
+ default:
+ return makeNilPtrDecoder(etype, etypeinfo, tag), nil
}
- dec := func(s *Stream, val reflect.Value) (err error) {
+}
+
+func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder {
+ return func(s *Stream, val reflect.Value) (err error) {
newval := val
if val.IsNil() {
newval = reflect.New(etype)
@@ -453,30 +467,39 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) {
}
return err
}
- return dec, nil
}
-// makeOptionalPtrDecoder creates a decoder that decodes empty values
-// as nil. Non-empty values are decoded into a value of the element type,
-// just like makePtrDecoder does.
+// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty
+// values are decoded into a value of the element type, just like makePtrDecoder does.
//
// This decoder is used for pointer-typed struct fields with struct tag "nil".
-func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
- etype := typ.Elem()
- etypeinfo, err := cachedTypeInfo1(etype, tags{})
- if err != nil {
- return nil, err
- }
- dec := func(s *Stream, val reflect.Value) (err error) {
+func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, ts rlpstruct.Tags) decoder {
+ typ := reflect.PtrTo(etype)
+ nilPtr := reflect.Zero(typ)
+
+ // Determine the value kind that results in nil pointer.
+ nilKind := typeNilKind(etype, ts)
+
+ return func(s *Stream, val reflect.Value) (err error) {
kind, size, err := s.Kind()
- if err != nil || size == 0 && kind != Byte {
+ if err != nil {
+ val.Set(nilPtr)
+ return wrapStreamError(err, typ)
+ }
+ // Handle empty values as a nil pointer.
+ if kind != Byte && size == 0 {
+ if kind != nilKind {
+ return &decodeError{
+ msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind),
+ typ: typ,
+ }
+ }
// rearm s.Kind. This is important because the input
// position must advance to the next value even though
// we don't read anything.
s.kind = -1
- // set the pointer to nil.
- val.Set(reflect.Zero(typ))
- return err
+ val.Set(nilPtr)
+ return nil
}
newval := val
if val.IsNil() {
@@ -487,7 +510,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
}
return err
}
- return dec, nil
}
var ifsliceType = reflect.TypeOf([]interface{}{})
@@ -516,25 +538,12 @@ func decodeInterface(s *Stream, val reflect.Value) error {
return nil
}
-// This decoder is used for non-pointer values of types
-// that implement the Decoder interface using a pointer receiver.
-func decodeDecoderNoPtr(s *Stream, val reflect.Value) error {
- return val.Addr().Interface().(Decoder).DecodeRLP(s)
-}
-
func decodeDecoder(s *Stream, val reflect.Value) error {
- // Decoder instances are not handled using the pointer rule if the type
- // implements Decoder with pointer receiver (i.e. always)
- // because it might handle empty values specially.
- // We need to allocate one here in this case, like makePtrDecoder does.
- if val.Kind() == reflect.Ptr && val.IsNil() {
- val.Set(reflect.New(val.Type().Elem()))
- }
- return val.Interface().(Decoder).DecodeRLP(s)
+ return val.Addr().Interface().(Decoder).DecodeRLP(s)
}
// Kind represents the kind of value contained in an RLP stream.
-type Kind int
+type Kind int8
const (
Byte Kind = iota
@@ -555,29 +564,6 @@ func (k Kind) String() string {
}
}
-var (
- // EOL is returned when the end of the current list
- // has been reached during streaming.
- EOL = errors.New("rlp: end of list")
-
- // Actual Errors
- ErrExpectedString = errors.New("rlp: expected String or Byte")
- ErrExpectedList = errors.New("rlp: expected List")
- ErrCanonInt = errors.New("rlp: non-canonical integer format")
- ErrCanonSize = errors.New("rlp: non-canonical size information")
- ErrElemTooLarge = errors.New("rlp: element is larger than containing list")
- ErrValueTooLarge = errors.New("rlp: value size exceeds available input length")
-
- // This error is reported by DecodeBytes if the slice contains
- // additional data after the first RLP value.
- ErrMoreThanOneValue = errors.New("rlp: input contains more than one value")
-
- // internal errors
- errNotInList = errors.New("rlp: call of ListEnd outside of any list")
- errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL")
- errUintOverflow = errors.New("rlp: uint overflow")
-)
-
// ByteReader must be implemented by any input reader for a Stream. It
// is implemented by e.g. bufio.Reader and bytes.Reader.
type ByteReader interface {
@@ -600,22 +586,16 @@ type ByteReader interface {
type Stream struct {
r ByteReader
- // number of bytes remaining to be read from r.
- remaining uint64
- limited bool
-
- // auxiliary buffer for integer decoding
- uintbuf []byte
-
- kind Kind // kind of value ahead
- size uint64 // size of value ahead
- byteval byte // value of single byte in type tag
- kinderr error // error from last readKind
- stack []listpos
+ remaining uint64 // number of bytes remaining to be read from r
+ size uint64 // size of value ahead
+ kinderr error // error from last readKind
+ stack []uint64 // list sizes
+ uintbuf [32]byte // auxiliary buffer for integer decoding
+ kind Kind // kind of value ahead
+ byteval byte // value of single byte in type tag
+ limited bool // true if input limit is in effect
}
-type listpos struct{ pos, size uint64 }
-
// NewStream creates a new decoding stream reading from r.
//
// If r implements the ByteReader interface, Stream will
@@ -675,6 +655,37 @@ func (s *Stream) Bytes() ([]byte, error) {
}
}
+// ReadBytes decodes the next RLP value and stores the result in b.
+// The value size must match len(b) exactly.
+func (s *Stream) ReadBytes(b []byte) error {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return err
+ }
+ switch kind {
+ case Byte:
+ if len(b) != 1 {
+ return fmt.Errorf("input value has wrong size 1, want %d", len(b))
+ }
+ b[0] = s.byteval
+ s.kind = -1 // rearm Kind
+ return nil
+ case String:
+ if uint64(len(b)) != size {
+ return fmt.Errorf("input value has wrong size %d, want %d", size, len(b))
+ }
+ if err = s.readFull(b); err != nil {
+ return err
+ }
+ if size == 1 && b[0] < 128 {
+ return ErrCanonSize
+ }
+ return nil
+ default:
+ return ErrExpectedString
+ }
+}
+
// Raw reads a raw encoded value including RLP type information.
func (s *Stream) Raw() ([]byte, error) {
kind, size, err := s.Kind()
@@ -685,8 +696,8 @@ func (s *Stream) Raw() ([]byte, error) {
s.kind = -1 // rearm Kind
return []byte{s.byteval}, nil
}
- // the original header has already been read and is no longer
- // available. read content and put a new header in front of it.
+ // The original header has already been read and is no longer
+ // available. Read content and put a new header in front of it.
start := headsize(size)
buf := make([]byte, uint64(start)+size)
if err := s.readFull(buf[start:]); err != nil {
@@ -703,10 +714,31 @@ func (s *Stream) Raw() ([]byte, error) {
// Uint reads an RLP string of up to 8 bytes and returns its contents
// as an unsigned integer. If the input does not contain an RLP string, the
// returned error will be ErrExpectedString.
+//
+// Deprecated: use s.Uint64 instead.
func (s *Stream) Uint() (uint64, error) {
return s.uint(64)
}
+func (s *Stream) Uint64() (uint64, error) {
+ return s.uint(64)
+}
+
+func (s *Stream) Uint32() (uint32, error) {
+ i, err := s.uint(32)
+ return uint32(i), err
+}
+
+func (s *Stream) Uint16() (uint16, error) {
+ i, err := s.uint(16)
+ return uint16(i), err
+}
+
+func (s *Stream) Uint8() (uint8, error) {
+ i, err := s.uint(8)
+ return uint8(i), err
+}
+
func (s *Stream) uint(maxbits int) (uint64, error) {
kind, size, err := s.Kind()
if err != nil {
@@ -769,7 +801,14 @@ func (s *Stream) List() (size uint64, err error) {
if kind != List {
return 0, ErrExpectedList
}
- s.stack = append(s.stack, listpos{0, size})
+
+ // Remove size of inner list from outer list before pushing the new size
+ // onto the stack. This ensures that the remaining outer list size will
+ // be correct after the matching call to ListEnd.
+ if inList, limit := s.listLimit(); inList {
+ s.stack[len(s.stack)-1] = limit - size
+ }
+ s.stack = append(s.stack, size)
s.kind = -1
s.size = 0
return size, nil
@@ -778,22 +817,116 @@ func (s *Stream) List() (size uint64, err error) {
// ListEnd returns to the enclosing list.
// The input reader must be positioned at the end of a list.
func (s *Stream) ListEnd() error {
- if len(s.stack) == 0 {
+ // Ensure that no more data is remaining in the current list.
+ if inList, listLimit := s.listLimit(); !inList {
return errNotInList
- }
- tos := s.stack[len(s.stack)-1]
- if tos.pos != tos.size {
+ } else if listLimit > 0 {
return errNotAtEOL
}
s.stack = s.stack[:len(s.stack)-1] // pop
- if len(s.stack) > 0 {
- s.stack[len(s.stack)-1].pos += tos.size
- }
s.kind = -1
s.size = 0
return nil
}
+// MoreDataInList reports whether the current list context contains
+// more data to be read.
+func (s *Stream) MoreDataInList() bool {
+ _, listLimit := s.listLimit()
+ return listLimit > 0
+}
+
+// BigInt decodes an arbitrary-size integer value.
+func (s *Stream) BigInt() (*big.Int, error) {
+ i := new(big.Int)
+ if err := s.decodeBigInt(i); err != nil {
+ return nil, err
+ }
+ return i, nil
+}
+
+func (s *Stream) decodeBigInt(dst *big.Int) error {
+ var buffer []byte
+ kind, size, err := s.Kind()
+ switch {
+ case err != nil:
+ return err
+ case kind == List:
+ return ErrExpectedString
+ case kind == Byte:
+ buffer = s.uintbuf[:1]
+ buffer[0] = s.byteval
+ s.kind = -1 // re-arm Kind
+ case size == 0:
+ // Avoid zero-length read.
+ s.kind = -1
+ case size <= uint64(len(s.uintbuf)):
+ // For integers smaller than s.uintbuf, allocating a buffer
+ // can be avoided.
+ buffer = s.uintbuf[:size]
+ if err := s.readFull(buffer); err != nil {
+ return err
+ }
+ // Reject inputs where single byte encoding should have been used.
+ if size == 1 && buffer[0] < 128 {
+ return ErrCanonSize
+ }
+ default:
+ // For large integers, a temporary buffer is needed.
+ buffer = make([]byte, size)
+ if err := s.readFull(buffer); err != nil {
+ return err
+ }
+ }
+
+ // Reject leading zero bytes.
+ if len(buffer) > 0 && buffer[0] == 0 {
+ return ErrCanonInt
+ }
+ // Set the integer bytes.
+ dst.SetBytes(buffer)
+ return nil
+}
+
+// ReadUint256 decodes the next value as a uint256.
+func (s *Stream) ReadUint256(dst *uint256.Int) error {
+ var buffer []byte
+ kind, size, err := s.Kind()
+ switch {
+ case err != nil:
+ return err
+ case kind == List:
+ return ErrExpectedString
+ case kind == Byte:
+ buffer = s.uintbuf[:1]
+ buffer[0] = s.byteval
+ s.kind = -1 // re-arm Kind
+ case size == 0:
+ // Avoid zero-length read.
+ s.kind = -1
+ case size <= uint64(len(s.uintbuf)):
+ // All possible uint256 values fit into s.uintbuf.
+ buffer = s.uintbuf[:size]
+ if err := s.readFull(buffer); err != nil {
+ return err
+ }
+ // Reject inputs where single byte encoding should have been used.
+ if size == 1 && buffer[0] < 128 {
+ return ErrCanonSize
+ }
+ default:
+ return errUint256Large
+ }
+
+ // Reject leading zero bytes.
+ if len(buffer) > 0 && buffer[0] == 0 {
+ return ErrCanonInt
+ }
+ // Set the integer bytes.
+ dst.SetBytes(buffer)
+ return nil
+}
+
// Decode decodes a value and stores the result in the value pointed
// to by val. Please see the documentation for the Decode function
// to learn about the decoding rules.
@@ -809,14 +942,14 @@ func (s *Stream) Decode(val interface{}) error {
if rval.IsNil() {
return errDecodeIntoNil
}
- info, err := cachedTypeInfo(rtyp.Elem(), tags{})
+ decoder, err := cachedDecoder(rtyp.Elem())
if err != nil {
return err
}
- err = info.decoder(s, rval.Elem())
+ err = decoder(s, rval.Elem())
if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 {
- // add decode target type to error so context has more meaning
+ // Add decode target type to error so context has more meaning.
decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))
}
return err
@@ -839,6 +972,9 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
case *bytes.Reader:
s.remaining = uint64(br.Len())
s.limited = true
+ case *bytes.Buffer:
+ s.remaining = uint64(br.Len())
+ s.limited = true
case *strings.Reader:
s.remaining = uint64(br.Len())
s.limited = true
@@ -857,9 +993,8 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
s.size = 0
s.kind = -1
s.kinderr = nil
- if s.uintbuf == nil {
- s.uintbuf = make([]byte, 8)
- }
+ s.byteval = 0
+ s.uintbuf = [32]byte{}
}
// Kind returns the kind and size of the next value in the
@@ -874,35 +1009,29 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
// the value. Subsequent calls to Kind (until the value is decoded)
// will not advance the input reader and return cached information.
func (s *Stream) Kind() (kind Kind, size uint64, err error) {
- var tos *listpos
- if len(s.stack) > 0 {
- tos = &s.stack[len(s.stack)-1]
- }
- if s.kind < 0 {
- s.kinderr = nil
- // Don't read further if we're at the end of the
- // innermost list.
- if tos != nil && tos.pos == tos.size {
- return 0, 0, EOL
- }
- s.kind, s.size, s.kinderr = s.readKind()
- if s.kinderr == nil {
- if tos == nil {
- // At toplevel, check that the value is smaller
- // than the remaining input length.
- if s.limited && s.size > s.remaining {
- s.kinderr = ErrValueTooLarge
- }
- } else {
- // Inside a list, check that the value doesn't overflow the list.
- if s.size > tos.size-tos.pos {
- s.kinderr = ErrElemTooLarge
- }
- }
+ if s.kind >= 0 {
+ return s.kind, s.size, s.kinderr
+ }
+
+ // Check for end of list. This needs to be done here because readKind
+ // checks against the list size, and would return the wrong error.
+ inList, listLimit := s.listLimit()
+ if inList && listLimit == 0 {
+ return 0, 0, EOL
+ }
+ // Read the actual size tag.
+ s.kind, s.size, s.kinderr = s.readKind()
+ if s.kinderr == nil {
+ // Check the data size of the value ahead against input limits. This
+ // is done here because many decoders require allocating an input
+ // buffer matching the value size. Checking it here protects those
+ // decoders from inputs declaring very large value size.
+ if inList && s.size > listLimit {
+ s.kinderr = ErrElemTooLarge
+ } else if s.limited && s.size > s.remaining {
+ s.kinderr = ErrValueTooLarge
}
}
- // Note: this might return a sticky error generated
- // by an earlier call to readKind.
return s.kind, s.size, s.kinderr
}
@@ -929,37 +1058,35 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) {
s.byteval = b
return Byte, 0, nil
case b < 0xB8:
- // Otherwise, if a string is 0-55 bytes long,
- // the RLP encoding consists of a single byte with value 0x80 plus the
- // length of the string followed by the string. The range of the first
- // byte is thus [0x80, 0xB7].
+ // Otherwise, if a string is 0-55 bytes long, the RLP encoding consists
+ // of a single byte with value 0x80 plus the length of the string
+ // followed by the string. The range of the first byte is thus [0x80, 0xB7].
return String, uint64(b - 0x80), nil
case b < 0xC0:
- // If a string is more than 55 bytes long, the
- // RLP encoding consists of a single byte with value 0xB7 plus the length
- // of the length of the string in binary form, followed by the length of
- // the string, followed by the string. For example, a length-1024 string
- // would be encoded as 0xB90400 followed by the string. The range of
- // the first byte is thus [0xB8, 0xBF].
+ // If a string is more than 55 bytes long, the RLP encoding consists of a
+ // single byte with value 0xB7 plus the length of the length of the
+ // string in binary form, followed by the length of the string, followed
+ // by the string. For example, a length-1024 string would be encoded as
+ // 0xB90400 followed by the string. The range of the first byte is thus
+ // [0xB8, 0xBF].
size, err = s.readUint(b - 0xB7)
if err == nil && size < 56 {
err = ErrCanonSize
}
return String, size, err
case b < 0xF8:
- // If the total payload of a list
- // (i.e. the combined length of all its items) is 0-55 bytes long, the
- // RLP encoding consists of a single byte with value 0xC0 plus the length
- // of the list followed by the concatenation of the RLP encodings of the
- // items. The range of the first byte is thus [0xC0, 0xF7].
+ // If the total payload of a list (i.e. the combined length of all its
+ // items) is 0-55 bytes long, the RLP encoding consists of a single byte
+ // with value 0xC0 plus the length of the list followed by the
+ // concatenation of the RLP encodings of the items. The range of the
+ // first byte is thus [0xC0, 0xF7].
return List, uint64(b - 0xC0), nil
default:
- // If the total payload of a list is more than 55 bytes long,
- // the RLP encoding consists of a single byte with value 0xF7
- // plus the length of the length of the payload in binary
- // form, followed by the length of the payload, followed by
- // the concatenation of the RLP encodings of the items. The
- // range of the first byte is thus [0xF8, 0xFF].
+ // If the total payload of a list is more than 55 bytes long, the RLP
+ // encoding consists of a single byte with value 0xF7 plus the length of
+ // the length of the payload in binary form, followed by the length of
+ // the payload, followed by the concatenation of the RLP encodings of
+ // the items. The range of the first byte is thus [0xF8, 0xFF].
size, err = s.readUint(b - 0xF7)
if err == nil && size < 56 {
err = ErrCanonSize
@@ -977,23 +1104,24 @@ func (s *Stream) readUint(size byte) (uint64, error) {
b, err := s.readByte()
return uint64(b), err
default:
- start := int(8 - size)
- for i := 0; i < start; i++ {
- s.uintbuf[i] = 0
+ buffer := s.uintbuf[:8]
+ for i := range buffer {
+ buffer[i] = 0
}
- if err := s.readFull(s.uintbuf[start:]); err != nil {
+ start := int(8 - size)
+ if err := s.readFull(buffer[start:]); err != nil {
return 0, err
}
- if s.uintbuf[start] == 0 {
- // Note: readUint is also used to decode integer
- // values. The error needs to be adjusted to become
- // ErrCanonInt in this case.
+ if buffer[start] == 0 {
+ // Note: readUint is also used to decode integer values.
+ // The error needs to be adjusted to become ErrCanonInt in this case.
return 0, ErrCanonSize
}
- return binary.BigEndian.Uint64(s.uintbuf), nil
+ return binary.BigEndian.Uint64(buffer[:]), nil
}
}
+// readFull reads into buf from the underlying stream.
func (s *Stream) readFull(buf []byte) (err error) {
if err := s.willRead(uint64(len(buf))); err != nil {
return err
@@ -1004,11 +1132,18 @@ func (s *Stream) readFull(buf []byte) (err error) {
n += nn
}
if err == io.EOF {
- err = io.ErrUnexpectedEOF
+ if n < len(buf) {
+ err = io.ErrUnexpectedEOF
+ } else {
+ // Readers are allowed to give EOF even though the read succeeded.
+ // In such cases, we discard the EOF, like io.ReadFull() does.
+ err = nil
+ }
}
return err
}
+// readByte reads a single byte from the underlying stream.
func (s *Stream) readByte() (byte, error) {
if err := s.willRead(1); err != nil {
return 0, err
@@ -1020,16 +1155,16 @@ func (s *Stream) readByte() (byte, error) {
return b, err
}
+// willRead is called before any read from the underlying stream. It checks
+// n against size limits, and updates the limits if n doesn't overflow them.
func (s *Stream) willRead(n uint64) error {
s.kind = -1 // rearm Kind
- if len(s.stack) > 0 {
- // check list overflow
- tos := s.stack[len(s.stack)-1]
- if n > tos.size-tos.pos {
+ if inList, limit := s.listLimit(); inList {
+ if n > limit {
return ErrElemTooLarge
}
- s.stack[len(s.stack)-1].pos += n
+ s.stack[len(s.stack)-1] = limit - n
}
if s.limited {
if n > s.remaining {
@@ -1039,3 +1174,11 @@ func (s *Stream) willRead(n uint64) error {
}
return nil
}
+
+// listLimit returns the amount of data remaining in the innermost list.
+func (s *Stream) listLimit() (inList bool, limit uint64) {
+ if len(s.stack) == 0 {
+ return false, 0
+ }
+ return true, s.stack[len(s.stack)-1]
+}
diff --git a/rlp/decode_test.go b/rlp/decode_test.go
index 4d8abd001..3ee237fb0 100644
--- a/rlp/decode_test.go
+++ b/rlp/decode_test.go
@@ -26,6 +26,10 @@ import (
"reflect"
"strings"
"testing"
+
+ "github.com/tomochain/tomochain/common/math"
+
+ "github.com/holiman/uint256"
)
func TestStreamKind(t *testing.T) {
@@ -284,6 +288,47 @@ func TestStreamRaw(t *testing.T) {
}
}
+func TestStreamReadBytes(t *testing.T) {
+ tests := []struct {
+ input string
+ size int
+ err string
+ }{
+ // kind List
+ {input: "C0", size: 1, err: "rlp: expected String or Byte"},
+ // kind Byte
+ {input: "04", size: 0, err: "input value has wrong size 1, want 0"},
+ {input: "04", size: 1},
+ {input: "04", size: 2, err: "input value has wrong size 1, want 2"},
+ // kind String
+ {input: "820102", size: 0, err: "input value has wrong size 2, want 0"},
+ {input: "820102", size: 1, err: "input value has wrong size 2, want 1"},
+ {input: "820102", size: 2},
+ {input: "820102", size: 3, err: "input value has wrong size 2, want 3"},
+ }
+
+ for _, test := range tests {
+ test := test
+ name := fmt.Sprintf("input_%s/size_%d", test.input, test.size)
+ t.Run(name, func(t *testing.T) {
+ s := NewStream(bytes.NewReader(unhex(test.input)), 0)
+ b := make([]byte, test.size)
+ err := s.ReadBytes(b)
+ if test.err == "" {
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ }
+ } else {
+ if err == nil {
+ t.Errorf("expected error, got nil")
+ } else if err.Error() != test.err {
+ t.Errorf("wrong error %q", err)
+ }
+ }
+ })
+ }
+}
+
func TestDecodeErrors(t *testing.T) {
r := bytes.NewReader(nil)
@@ -327,6 +372,15 @@ type recstruct struct {
Child *recstruct `rlp:"nil"`
}
+type bigIntStruct struct {
+ I *big.Int
+ B string
+}
+
+type invalidNilTag struct {
+ X []byte `rlp:"nil"`
+}
+
type invalidTail1 struct {
A uint `rlp:"tail"`
B string
@@ -347,19 +401,79 @@ type tailUint struct {
Tail []uint `rlp:"tail"`
}
-var (
- veryBigInt = big.NewInt(0).Add(
- big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
- big.NewInt(0xFFFF),
- )
-)
+type tailPrivateFields struct {
+ A uint
+ Tail []uint `rlp:"tail"`
+ x, y bool //lint:ignore U1000 unused fields required for testing purposes.
+}
+
+type nilListUint struct {
+ X *uint `rlp:"nilList"`
+}
+
+type nilStringSlice struct {
+ X *[]uint `rlp:"nilString"`
+}
+
+type intField struct {
+ X int
+}
+
+type optionalFields struct {
+ A uint
+ B uint `rlp:"optional"`
+ C uint `rlp:"optional"`
+}
+
+type optionalAndTailField struct {
+ A uint
+ B uint `rlp:"optional"`
+ Tail []uint `rlp:"tail"`
+}
+
+type optionalBigIntField struct {
+ A uint
+ B *big.Int `rlp:"optional"`
+}
+
+type optionalPtrField struct {
+ A uint
+ B *[3]byte `rlp:"optional"`
+}
+
+type nonOptionalPtrField struct {
+ A uint
+ B *[3]byte
+}
-type hasIgnoredField struct {
+type multipleOptionalFields struct {
+ A *[3]byte `rlp:"optional"`
+ B *[3]byte `rlp:"optional"`
+}
+
+type optionalPtrFieldNil struct {
+ A uint
+ B *[3]byte `rlp:"optional,nil"`
+}
+
+type ignoredField struct {
A uint
B uint `rlp:"-"`
C uint
}
+var (
+ veryBigInt = new(big.Int).Add(
+ new(big.Int).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
+ big.NewInt(0xFFFF),
+ )
+ veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil)
+)
+
+var (
+ veryBigInt256, _ = uint256.FromBig(veryBigInt)
+)
+
var decodeTests = []decodeTest{
// booleans
{input: "01", ptr: new(bool), value: true},
@@ -428,12 +542,31 @@ var decodeTests = []decodeTest{
{input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"},
// big ints
+ {input: "80", ptr: new(*big.Int), value: big.NewInt(0)},
{input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
+ {input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt},
{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
+
+ // big int errors
{input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"},
- {input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
- {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"},
+ {input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
+ {input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
+ {input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"},
+
+ // uint256
+ {input: "80", ptr: new(*uint256.Int), value: uint256.NewInt(0)},
+ {input: "01", ptr: new(*uint256.Int), value: uint256.NewInt(1)},
+ {input: "88FFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: uint256.NewInt(math.MaxUint64)},
+ {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: veryBigInt256},
+ {input: "10", ptr: new(uint256.Int), value: *uint256.NewInt(16)}, // non-pointer also works
+
+ // uint256 errors
+ {input: "C0", ptr: new(*uint256.Int), error: "rlp: expected input string or byte for *uint256.Int"},
+ {input: "00", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"},
+ {input: "820001", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"},
+ {input: "8105", ptr: new(*uint256.Int), error: "rlp: non-canonical size information for *uint256.Int"},
+ {input: "A1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00", ptr: new(*uint256.Int), error: "rlp: value too large for uint256"},
// structs
{
@@ -446,6 +579,13 @@ var decodeTests = []decodeTest{
ptr: new(recstruct),
value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
},
+ {
+ // This checks that empty big.Int works correctly in struct context. It's easy to
+ // miss the update of s.kind for this case, so it needs its own test.
+ input: "C58083343434",
+ ptr: new(bigIntStruct),
+ value: bigIntStruct{new(big.Int), "444"},
+ },
// struct errors
{
@@ -479,20 +619,20 @@ var decodeTests = []decodeTest{
error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
},
{
- input: "C0",
- ptr: new(invalidTail1),
- error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)",
- },
- {
- input: "C0",
- ptr: new(invalidTail2),
- error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)",
+ input: "C103",
+ ptr: new(intField),
+ error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)",
},
{
input: "C50102C20102",
ptr: new(tailUint),
error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]",
},
+ {
+ input: "C0",
+ ptr: new(invalidNilTag),
+ error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`,
+ },
// struct tag "tail"
{
@@ -510,12 +650,192 @@ var decodeTests = []decodeTest{
ptr: new(tailRaw),
value: tailRaw{A: 1, Tail: []RawValue{}},
},
+ {
+ input: "C3010203",
+ ptr: new(tailPrivateFields),
+ value: tailPrivateFields{A: 1, Tail: []uint{2, 3}},
+ },
+ {
+ input: "C0",
+ ptr: new(invalidTail1),
+ error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`,
+ },
+ {
+ input: "C0",
+ ptr: new(invalidTail2),
+ error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`,
+ },
// struct tag "-"
{
input: "C20102",
- ptr: new(hasIgnoredField),
- value: hasIgnoredField{A: 1, C: 2},
+ ptr: new(ignoredField),
+ value: ignoredField{A: 1, C: 2},
+ },
+
+ // struct tag "nilList"
+ {
+ input: "C180",
+ ptr: new(nilListUint),
+ error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X",
+ },
+ {
+ input: "C1C0",
+ ptr: new(nilListUint),
+ value: nilListUint{},
+ },
+ {
+ input: "C103",
+ ptr: new(nilListUint),
+ value: func() interface{} {
+ v := uint(3)
+ return nilListUint{X: &v}
+ }(),
+ },
+
+ // struct tag "nilString"
+ {
+ input: "C1C0",
+ ptr: new(nilStringSlice),
+ error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X",
+ },
+ {
+ input: "C180",
+ ptr: new(nilStringSlice),
+ value: nilStringSlice{},
+ },
+ {
+ input: "C2C103",
+ ptr: new(nilStringSlice),
+ value: nilStringSlice{X: &[]uint{3}},
+ },
+
+ // struct tag "optional"
+ {
+ input: "C101",
+ ptr: new(optionalFields),
+ value: optionalFields{1, 0, 0},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalFields),
+ value: optionalFields{1, 2, 0},
+ },
+ {
+ input: "C3010203",
+ ptr: new(optionalFields),
+ value: optionalFields{1, 2, 3},
+ },
+ {
+ input: "C401020304",
+ ptr: new(optionalFields),
+ error: "rlp: input list has too many elements for rlp.optionalFields",
+ },
+ {
+ input: "C101",
+ ptr: new(optionalAndTailField),
+ value: optionalAndTailField{A: 1},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalAndTailField),
+ value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
+ },
+ {
+ input: "C401020304",
+ ptr: new(optionalAndTailField),
+ value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}},
+ },
+ {
+ input: "C101",
+ ptr: new(optionalBigIntField),
+ value: optionalBigIntField{A: 1, B: nil},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalBigIntField),
+ value: optionalBigIntField{A: 1, B: big.NewInt(2)},
+ },
+ {
+ input: "C101",
+ ptr: new(optionalPtrField),
+ value: optionalPtrField{A: 1},
+ },
+ {
+ input: "C20180", // not accepted because "optional" doesn't enable "nil"
+ ptr: new(optionalPtrField),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalPtrField),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
+ },
+ {
+ input: "C50183010203",
+ ptr: new(optionalPtrField),
+ value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}},
+ },
+ {
+ // all optional fields nil
+ input: "C0",
+ ptr: new(multipleOptionalFields),
+ value: multipleOptionalFields{A: nil, B: nil},
+ },
+ {
+ // all optional fields set
+ input: "C88301020383010203",
+ ptr: new(multipleOptionalFields),
+ value: multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}},
+ },
+ {
+ // nil optional field appears before a non-nil one
+ input: "C58083010203",
+ ptr: new(multipleOptionalFields),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.multipleOptionalFields).A",
+ },
+ {
+ // decode a nil ptr into a ptr that is not nil or not optional
+ input: "C20180",
+ ptr: new(nonOptionalPtrField),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.nonOptionalPtrField).B",
+ },
+ {
+ input: "C101",
+ ptr: new(optionalPtrFieldNil),
+ value: optionalPtrFieldNil{A: 1},
+ },
+ {
+ input: "C20180", // accepted because "nil" tag allows empty input
+ ptr: new(optionalPtrFieldNil),
+ value: optionalPtrFieldNil{A: 1},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalPtrFieldNil),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B",
+ },
+
+ // struct tag "optional" field clearing
+ {
+ input: "C101",
+ ptr: &optionalFields{A: 9, B: 8, C: 7},
+ value: optionalFields{A: 1, B: 0, C: 0},
+ },
+ {
+ input: "C20102",
+ ptr: &optionalFields{A: 9, B: 8, C: 7},
+ value: optionalFields{A: 1, B: 2, C: 0},
+ },
+ {
+ input: "C20102",
+ ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}},
+ value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
+ },
+ {
+ input: "C101",
+ ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}},
+ value: optionalPtrField{A: 1},
},
// RawValue
@@ -591,6 +911,26 @@ func TestDecodeWithByteReader(t *testing.T) {
})
}
+func testDecodeWithEncReader(t *testing.T, n int) {
+ s := strings.Repeat("0", n)
+ _, r, _ := EncodeToReader(s)
+ var decoded string
+ err := Decode(r, &decoded)
+ if err != nil {
+ t.Errorf("Unexpected decode error with n=%v: %v", n, err)
+ }
+ if decoded != s {
+ t.Errorf("Decode mismatch with n=%v", n)
+ }
+}
+
+// This is a regression test checking that decoding from encReader
+// works for RLP values of size 8192 bytes or more.
+func TestDecodeWithEncReader(t *testing.T) {
+ testDecodeWithEncReader(t, 8188) // length with header is 8191
+ testDecodeWithEncReader(t, 8189) // length with header is 8192
+}
+
// plainReader reads from a byte slice but does not
// implement ReadByte. It is also not recognized by the
// size validation. This is useful to test how the decoder
@@ -661,6 +1001,22 @@ func TestDecodeDecoder(t *testing.T) {
}
}
+func TestDecodeDecoderNilPointer(t *testing.T) {
+ var s struct {
+ T1 *testDecoder `rlp:"nil"`
+ T2 *testDecoder
+ }
+ if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil {
+ t.Fatalf("Decode error: %v", err)
+ }
+ if s.T1 != nil {
+ t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called)
+ }
+ if s.T2 == nil || !s.T2.called {
+ t.Errorf("decoder T2 not allocated/called")
+ }
+}
+
type byteDecoder byte
func (bd *byteDecoder) DecodeRLP(s *Stream) error {
@@ -691,13 +1047,66 @@ func TestDecoderInByteSlice(t *testing.T) {
}
}
+type unencodableDecoder func()
+
+func (f *unencodableDecoder) DecodeRLP(s *Stream) error {
+ if _, err := s.List(); err != nil {
+ return err
+ }
+ if err := s.ListEnd(); err != nil {
+ return err
+ }
+ *f = func() {}
+ return nil
+}
+
+func TestDecoderFunc(t *testing.T) {
+ var x func()
+ if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil {
+ t.Fatal(err)
+ }
+ x()
+}
+
+// This tests the validity checks for fields with struct tag "optional".
+func TestInvalidOptionalField(t *testing.T) {
+ type (
+ invalid1 struct {
+ A uint `rlp:"optional"`
+ B uint
+ }
+ invalid2 struct {
+ T []uint `rlp:"tail,optional"`
+ }
+ invalid3 struct {
+ T []uint `rlp:"optional,tail"`
+ }
+ )
+
+ tests := []struct {
+ v interface{}
+ err string
+ }{
+ {v: new(invalid1), err: `rlp: invalid struct tag "" for rlp.invalid1.B (must be optional because preceding field "A" is optional)`},
+ {v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (also has "tail" tag)`},
+ {v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (also has "optional" tag)`},
+ }
+ for _, test := range tests {
+ err := DecodeBytes(unhex("C20102"), test.v)
+ if err == nil {
+ t.Errorf("no error for %T", test.v)
+ } else if err.Error() != test.err {
+ t.Errorf("wrong error for %T: %v", test.v, err.Error())
+ }
+ }
+}
+
func ExampleDecode() {
input, _ := hex.DecodeString("C90A1486666F6F626172")
type example struct {
- A, B uint
- private uint // private fields are ignored
- String string
+ A, B uint
+ String string
}
var s example
@@ -708,7 +1117,7 @@ func ExampleDecode() {
fmt.Printf("Decoded value: %#v\n", s)
}
// Output:
- // Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"}
+ // Decoded value: rlp.example{A:0xa, B:0x14, String:"foobar"}
}
func ExampleDecode_structTagNil() {
@@ -768,7 +1177,7 @@ func ExampleStream() {
// [102 111 111 98 97 114]
}
-func BenchmarkDecode(b *testing.B) {
+func BenchmarkDecodeUints(b *testing.B) {
enc := encodeTestSlice(90000)
b.SetBytes(int64(len(enc)))
b.ReportAllocs()
@@ -783,7 +1192,7 @@ func BenchmarkDecode(b *testing.B) {
}
}
-func BenchmarkDecodeIntSliceReuse(b *testing.B) {
+func BenchmarkDecodeUintsReused(b *testing.B) {
enc := encodeTestSlice(100000)
b.SetBytes(int64(len(enc)))
b.ReportAllocs()
@@ -798,6 +1207,65 @@ func BenchmarkDecodeIntSliceReuse(b *testing.B) {
}
}
+func BenchmarkDecodeByteArrayStruct(b *testing.B) {
+ enc, err := EncodeToBytes(&byteArrayStruct{})
+ if err != nil {
+ b.Fatal(err)
+ }
+ b.SetBytes(int64(len(enc)))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ var out byteArrayStruct
+ for i := 0; i < b.N; i++ {
+ if err := DecodeBytes(enc, &out); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkDecodeBigInts(b *testing.B) {
+ ints := make([]*big.Int, 200)
+ for i := range ints {
+ ints[i] = math.BigPow(2, int64(i))
+ }
+ enc, err := EncodeToBytes(ints)
+ if err != nil {
+ b.Fatal(err)
+ }
+ b.SetBytes(int64(len(enc)))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ var out []*big.Int
+ for i := 0; i < b.N; i++ {
+ if err := DecodeBytes(enc, &out); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkDecodeU256Ints(b *testing.B) {
+ ints := make([]*uint256.Int, 200)
+ for i := range ints {
+ ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i)))
+ }
+ enc, err := EncodeToBytes(ints)
+ if err != nil {
+ b.Fatal(err)
+ }
+ b.SetBytes(int64(len(enc)))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ var out []*uint256.Int
+ for i := 0; i < b.N; i++ {
+ if err := DecodeBytes(enc, &out); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
func encodeTestSlice(n uint) []byte {
s := make([]uint, n)
for i := uint(0); i < n; i++ {
@@ -811,7 +1279,7 @@ func encodeTestSlice(n uint) []byte {
}
func unhex(str string) []byte {
- b, err := hex.DecodeString(strings.Replace(str, " ", "", -1))
+ b, err := hex.DecodeString(strings.ReplaceAll(str, " ", ""))
if err != nil {
panic(fmt.Sprintf("invalid hex string: %q", str))
}
diff --git a/rlp/doc.go b/rlp/doc.go
index b3a81fe23..eeeee9a43 100644
--- a/rlp/doc.go
+++ b/rlp/doc.go
@@ -17,17 +17,142 @@
/*
Package rlp implements the RLP serialization format.
-The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily
-nested arrays of binary data, and RLP is the main encoding method used
-to serialize objects in Ethereum. The only purpose of RLP is to encode
-structure; encoding specific atomic data types (eg. strings, ints,
-floats) is left up to higher-order protocols; in Ethereum integers
-must be represented in big endian binary form with no leading zeroes
-(thus making the integer value zero equivalent to the empty byte
-array).
-
-RLP values are distinguished by a type tag. The type tag precedes the
-value in the input stream and defines the size and kind of the bytes
-that follow.
+The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily nested arrays of
+binary data, and RLP is the main encoding method used to serialize objects in Ethereum.
+The only purpose of RLP is to encode structure; encoding specific atomic data types (eg.
+strings, ints, floats) is left up to higher-order protocols. In Ethereum integers must be
+represented in big endian binary form with no leading zeroes (thus making the integer
+value zero equivalent to the empty string).
+
+RLP values are distinguished by a type tag. The type tag precedes the value in the input
+stream and defines the size and kind of the bytes that follow.
+
+# Encoding Rules
+
+Package rlp uses reflection and encodes RLP based on the Go type of the value.
+
+If the type implements the Encoder interface, Encode calls EncodeRLP. It does not
+call EncodeRLP on nil pointer values.
+
+To encode a pointer, the value being pointed to is encoded. A nil pointer to a struct
+type, slice or array always encodes as an empty RLP list unless the slice or array has
+element type byte. A nil pointer to any other value encodes as the empty string.
+
+Struct values are encoded as an RLP list of all their encoded public fields. Recursive
+struct types are supported.
+
+To encode slices and arrays, the elements are encoded as an RLP list of the value's
+elements. Note that arrays and slices with element type uint8 or byte are always encoded
+as an RLP string.
+
+A Go string is encoded as an RLP string.
+
+An unsigned integer value is encoded as an RLP string. Zero always encodes as an empty RLP
+string. big.Int values are treated as integers. Signed integers (int, int8, int16, ...)
+are not supported and will return an error when encoding.
+
+Boolean values are encoded as the unsigned integers zero (false) and one (true).
+
+An interface value encodes as the value contained in the interface.
+
+Floating point numbers, maps, channels and functions are not supported.
+
+# Decoding Rules
+
+Decoding uses the following type-dependent rules:
+
+If the type implements the Decoder interface, DecodeRLP is called.
+
+To decode into a pointer, the value will be decoded as the element type of the pointer. If
+the pointer is nil, a new value of the pointer's element type is allocated. If the pointer
+is non-nil, the existing value will be reused. Note that package rlp never leaves a
+pointer-type struct field as nil unless one of the "nil" struct tags is present.
+
+To decode into a struct, decoding expects the input to be an RLP list. The decoded
+elements of the list are assigned to each public field in the order given by the struct's
+definition. The input list must contain an element for each decoded field. Decoding
+returns an error if there are too few or too many elements for the struct.
+
+To decode into a slice, the input must be a list and the resulting slice will contain the
+input elements in order. For byte slices, the input must be an RLP string. Array types
+decode similarly, with the additional restriction that the number of input elements (or
+bytes) must match the array's defined length.
+
+To decode into a Go string, the input must be an RLP string. The input bytes are taken
+as-is and will not necessarily be valid UTF-8.
+
+To decode into an unsigned integer type, the input must also be an RLP string. The bytes
+are interpreted as a big endian representation of the integer. If the RLP string is larger
+than the bit size of the type, decoding will return an error. Decode also supports
+*big.Int. There is no size limit for big integers.
+
+To decode into a boolean, the input must contain an unsigned integer of value zero (false)
+or one (true).
+
+To decode into an interface value, one of these types is stored in the value:
+
+ []interface{}, for RLP lists
+ []byte, for RLP strings
+
+Non-empty interface types are not supported when decoding.
+Signed integers, floating point numbers, maps, channels and functions cannot be decoded into.
+
+# Struct Tags
+
+As with other encoding packages, the "-" tag ignores fields.
+
+ type StructWithIgnoredField struct{
+ Ignored uint `rlp:"-"`
+ Field uint
+ }
+
+Go struct values encode/decode as RLP lists. There are two ways of influencing the mapping
+of fields to list elements. The "tail" tag, which may only be used on the last exported
+struct field, allows slurping up any excess list elements into a slice.
+
+ type StructWithTail struct{
+ Field uint
+ Tail []string `rlp:"tail"`
+ }
+
+The "optional" tag says that the field may be omitted if it is zero-valued. If this tag is
+used on a struct field, all subsequent public fields must also be declared optional.
+
+When encoding a struct with optional fields, the output RLP list contains all values up to
+the last non-zero optional field.
+
+When decoding into a struct, optional fields may be omitted from the end of the input
+list. For the example below, this means input lists of one, two, or three elements are
+accepted.
+
+ type StructWithOptionalFields struct{
+ Required uint
+ Optional1 uint `rlp:"optional"`
+ Optional2 uint `rlp:"optional"`
+ }
+
+The "nil", "nilList" and "nilString" tags apply to pointer-typed fields only, and change
+the decoding rules for the field type. For regular pointer fields without the "nil" tag,
+input values must always match the required input length exactly and the decoder does not
+produce nil values. When the "nil" tag is set, input values of size zero decode as a nil
+pointer. This is especially useful for recursive types.
+
+ type StructWithNilField struct {
+ Field *[3]byte `rlp:"nil"`
+ }
+
+In the example above, Field allows two possible input sizes. For input 0xC180 (a list
+containing an empty string) Field is set to nil after decoding. For input 0xC483000000 (a
+list containing a 3-byte string), Field is set to a non-nil array pointer.
+
+RLP supports two kinds of empty values: empty lists and empty strings. When using the
+"nil" tag, the kind of empty value allowed for a type is chosen automatically. A field
+whose Go type is a pointer to an unsigned integer, string, boolean or byte array/slice
+expects an empty RLP string. Any other pointer field type encodes/decodes as an empty RLP
+list.
+
+The choice of null value can be made explicit with the "nilList" and "nilString" struct
+tags. Using these tags encodes/decodes a Go nil pointer value as the empty RLP value kind
+defined by the tag.
*/
package rlp
diff --git a/rlp/encbuffer.go b/rlp/encbuffer.go
new file mode 100644
index 000000000..8d3a3b229
--- /dev/null
+++ b/rlp/encbuffer.go
@@ -0,0 +1,423 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "encoding/binary"
+ "io"
+ "math/big"
+ "reflect"
+ "sync"
+
+ "github.com/holiman/uint256"
+)
+
+type encBuffer struct {
+ str []byte // string data, contains everything except list headers
+ lheads []listhead // all list headers
+ lhsize int // sum of sizes of all encoded list headers
+ sizebuf [9]byte // auxiliary buffer for uint encoding
+}
+
+// The global encBuffer pool.
+var encBufferPool = sync.Pool{
+ New: func() interface{} { return new(encBuffer) },
+}
+
+func getEncBuffer() *encBuffer {
+ buf := encBufferPool.Get().(*encBuffer)
+ buf.reset()
+ return buf
+}
+
+func (buf *encBuffer) reset() {
+ buf.lhsize = 0
+ buf.str = buf.str[:0]
+ buf.lheads = buf.lheads[:0]
+}
+
+// size returns the length of the encoded data.
+func (buf *encBuffer) size() int {
+ return len(buf.str) + buf.lhsize
+}
+
+// makeBytes creates the encoder output.
+func (buf *encBuffer) makeBytes() []byte {
+ out := make([]byte, buf.size())
+ buf.copyTo(out)
+ return out
+}
+
+func (buf *encBuffer) copyTo(dst []byte) {
+ strpos := 0
+ pos := 0
+ for _, head := range buf.lheads {
+ // write string data before header
+ n := copy(dst[pos:], buf.str[strpos:head.offset])
+ pos += n
+ strpos += n
+ // write the header
+ enc := head.encode(dst[pos:])
+ pos += len(enc)
+ }
+ // copy string data after the last list header
+ copy(dst[pos:], buf.str[strpos:])
+}
+
+// writeTo writes the encoder output to w.
+func (buf *encBuffer) writeTo(w io.Writer) (err error) {
+ strpos := 0
+ for _, head := range buf.lheads {
+ // write string data before header
+ if head.offset-strpos > 0 {
+ n, err := w.Write(buf.str[strpos:head.offset])
+ strpos += n
+ if err != nil {
+ return err
+ }
+ }
+ // write the header
+ enc := head.encode(buf.sizebuf[:])
+ if _, err = w.Write(enc); err != nil {
+ return err
+ }
+ }
+ if strpos < len(buf.str) {
+ // write string data after the last list header
+ _, err = w.Write(buf.str[strpos:])
+ }
+ return err
+}
+
+// Write implements io.Writer and appends b directly to the output.
+func (buf *encBuffer) Write(b []byte) (int, error) {
+ buf.str = append(buf.str, b...)
+ return len(b), nil
+}
+
+// writeBool writes b as the integer 0 (false) or 1 (true).
+func (buf *encBuffer) writeBool(b bool) {
+ if b {
+ buf.str = append(buf.str, 0x01)
+ } else {
+ buf.str = append(buf.str, 0x80)
+ }
+}
+
+func (buf *encBuffer) writeUint64(i uint64) {
+ if i == 0 {
+ buf.str = append(buf.str, 0x80)
+ } else if i < 128 {
+ // fits single byte
+ buf.str = append(buf.str, byte(i))
+ } else {
+ s := putint(buf.sizebuf[1:], i)
+ buf.sizebuf[0] = 0x80 + byte(s)
+ buf.str = append(buf.str, buf.sizebuf[:s+1]...)
+ }
+}
+
+func (buf *encBuffer) writeBytes(b []byte) {
+ if len(b) == 1 && b[0] <= 0x7F {
+ // fits single byte, no string header
+ buf.str = append(buf.str, b[0])
+ } else {
+ buf.encodeStringHeader(len(b))
+ buf.str = append(buf.str, b...)
+ }
+}
+
+func (buf *encBuffer) writeString(s string) {
+ buf.writeBytes([]byte(s))
+}
+
+// wordBytes is the number of bytes in a big.Word
+const wordBytes = (32 << (uint64(^big.Word(0)) >> 63)) / 8
+
+// writeBigInt writes i as an integer.
+func (buf *encBuffer) writeBigInt(i *big.Int) {
+ bitlen := i.BitLen()
+ if bitlen <= 64 {
+ buf.writeUint64(i.Uint64())
+ return
+ }
+ // Integer is larger than 64 bits, encode from i.Bits().
+ // The minimal byte length is bitlen rounded up to the next
+ // multiple of 8, divided by 8.
+ length := ((bitlen + 7) & -8) >> 3
+ buf.encodeStringHeader(length)
+ buf.str = append(buf.str, make([]byte, length)...)
+ index := length
+ bytesBuf := buf.str[len(buf.str)-length:]
+ for _, d := range i.Bits() {
+ for j := 0; j < wordBytes && index > 0; j++ {
+ index--
+ bytesBuf[index] = byte(d)
+ d >>= 8
+ }
+ }
+}
+
+// writeUint256 writes z as an integer.
+func (buf *encBuffer) writeUint256(z *uint256.Int) {
+ bitlen := z.BitLen()
+ if bitlen <= 64 {
+ buf.writeUint64(z.Uint64())
+ return
+ }
+ nBytes := byte((bitlen + 7) / 8)
+ var b [33]byte
+ binary.BigEndian.PutUint64(b[1:9], z[3])
+ binary.BigEndian.PutUint64(b[9:17], z[2])
+ binary.BigEndian.PutUint64(b[17:25], z[1])
+ binary.BigEndian.PutUint64(b[25:33], z[0])
+ b[32-nBytes] = 0x80 + nBytes
+ buf.str = append(buf.str, b[32-nBytes:]...)
+}
+
+// list adds a new list header to the header stack. It returns the index of the header.
+// Call listEnd with this index after encoding the content of the list.
+func (buf *encBuffer) list() int {
+ buf.lheads = append(buf.lheads, listhead{offset: len(buf.str), size: buf.lhsize})
+ return len(buf.lheads) - 1
+}
+
+func (buf *encBuffer) listEnd(index int) {
+ lh := &buf.lheads[index]
+ lh.size = buf.size() - lh.offset - lh.size
+ if lh.size < 56 {
+ buf.lhsize++ // length encoded into kind tag
+ } else {
+ buf.lhsize += 1 + intsize(uint64(lh.size))
+ }
+}
+
+func (buf *encBuffer) encode(val interface{}) error {
+ rval := reflect.ValueOf(val)
+ writer, err := cachedWriter(rval.Type())
+ if err != nil {
+ return err
+ }
+ return writer(rval, buf)
+}
+
+func (buf *encBuffer) encodeStringHeader(size int) {
+ if size < 56 {
+ buf.str = append(buf.str, 0x80+byte(size))
+ } else {
+ sizesize := putint(buf.sizebuf[1:], uint64(size))
+ buf.sizebuf[0] = 0xB7 + byte(sizesize)
+ buf.str = append(buf.str, buf.sizebuf[:sizesize+1]...)
+ }
+}
+
+// encReader is the io.Reader returned by EncodeToReader.
+// It releases its encbuf at EOF.
+type encReader struct {
+ buf *encBuffer // the buffer we're reading from. this is nil when we're at EOF.
+ lhpos int // index of list header that we're reading
+ strpos int // current position in string buffer
+ piece []byte // next piece to be read
+}
+
+func (r *encReader) Read(b []byte) (n int, err error) {
+ for {
+ if r.piece = r.next(); r.piece == nil {
+ // Put the encode buffer back into the pool at EOF when it
+ // is first encountered. Subsequent calls still return EOF
+ // as the error but the buffer is no longer valid.
+ if r.buf != nil {
+ encBufferPool.Put(r.buf)
+ r.buf = nil
+ }
+ return n, io.EOF
+ }
+ nn := copy(b[n:], r.piece)
+ n += nn
+ if nn < len(r.piece) {
+ // piece didn't fit, see you next time.
+ r.piece = r.piece[nn:]
+ return n, nil
+ }
+ r.piece = nil
+ }
+}
+
+// next returns the next piece of data to be read.
+// it returns nil at EOF.
+func (r *encReader) next() []byte {
+ switch {
+ case r.buf == nil:
+ return nil
+
+ case r.piece != nil:
+ // There is still data available for reading.
+ return r.piece
+
+ case r.lhpos < len(r.buf.lheads):
+ // We're before the last list header.
+ head := r.buf.lheads[r.lhpos]
+ sizebefore := head.offset - r.strpos
+ if sizebefore > 0 {
+ // String data before header.
+ p := r.buf.str[r.strpos:head.offset]
+ r.strpos += sizebefore
+ return p
+ }
+ r.lhpos++
+ return head.encode(r.buf.sizebuf[:])
+
+ case r.strpos < len(r.buf.str):
+ // String data at the end, after all list headers.
+ p := r.buf.str[r.strpos:]
+ r.strpos = len(r.buf.str)
+ return p
+
+ default:
+ return nil
+ }
+}
+
+func encBufferFromWriter(w io.Writer) *encBuffer {
+ switch w := w.(type) {
+ case EncoderBuffer:
+ return w.buf
+ case *EncoderBuffer:
+ return w.buf
+ case *encBuffer:
+ return w
+ default:
+ return nil
+ }
+}
+
+// EncoderBuffer is a buffer for incremental encoding.
+//
+// The zero value is NOT ready for use. To get a usable buffer,
+// create it using NewEncoderBuffer or call Reset.
+type EncoderBuffer struct {
+ buf *encBuffer
+ dst io.Writer
+
+ ownBuffer bool
+}
+
+// NewEncoderBuffer creates an encoder buffer.
+func NewEncoderBuffer(dst io.Writer) EncoderBuffer {
+ var w EncoderBuffer
+ w.Reset(dst)
+ return w
+}
+
+// Reset truncates the buffer and sets the output destination.
+func (w *EncoderBuffer) Reset(dst io.Writer) {
+ if w.buf != nil && !w.ownBuffer {
+ panic("can't Reset derived EncoderBuffer")
+ }
+
+ // If the destination writer has an *encBuffer, use it.
+ // Note that w.ownBuffer is left false here.
+ if dst != nil {
+ if outer := encBufferFromWriter(dst); outer != nil {
+ *w = EncoderBuffer{outer, nil, false}
+ return
+ }
+ }
+
+ // Get a fresh buffer.
+ if w.buf == nil {
+ w.buf = encBufferPool.Get().(*encBuffer)
+ w.ownBuffer = true
+ }
+ w.buf.reset()
+ w.dst = dst
+}
+
+// Flush writes encoded RLP data to the output writer. This can only be called once.
+// If you want to re-use the buffer after Flush, you must call Reset.
+func (w *EncoderBuffer) Flush() error {
+ var err error
+ if w.dst != nil {
+ err = w.buf.writeTo(w.dst)
+ }
+ // Release the internal buffer.
+ if w.ownBuffer {
+ encBufferPool.Put(w.buf)
+ }
+ *w = EncoderBuffer{}
+ return err
+}
+
+// ToBytes returns the encoded bytes.
+func (w *EncoderBuffer) ToBytes() []byte {
+ return w.buf.makeBytes()
+}
+
+// AppendToBytes appends the encoded bytes to dst.
+func (w *EncoderBuffer) AppendToBytes(dst []byte) []byte {
+ size := w.buf.size()
+ out := append(dst, make([]byte, size)...)
+ w.buf.copyTo(out[len(dst):])
+ return out
+}
+
+// Write appends b directly to the encoder output.
+func (w EncoderBuffer) Write(b []byte) (int, error) {
+ return w.buf.Write(b)
+}
+
+// WriteBool writes b as the integer 0 (false) or 1 (true).
+func (w EncoderBuffer) WriteBool(b bool) {
+ w.buf.writeBool(b)
+}
+
+// WriteUint64 encodes an unsigned integer.
+func (w EncoderBuffer) WriteUint64(i uint64) {
+ w.buf.writeUint64(i)
+}
+
+// WriteBigInt encodes a big.Int as an RLP string.
+// Note: Unlike with Encode, the sign of i is ignored.
+func (w EncoderBuffer) WriteBigInt(i *big.Int) {
+ w.buf.writeBigInt(i)
+}
+
+// WriteUint256 encodes uint256.Int as an RLP string.
+func (w EncoderBuffer) WriteUint256(i *uint256.Int) {
+ w.buf.writeUint256(i)
+}
+
+// WriteBytes encodes b as an RLP string.
+func (w EncoderBuffer) WriteBytes(b []byte) {
+ w.buf.writeBytes(b)
+}
+
+// WriteString encodes s as an RLP string.
+func (w EncoderBuffer) WriteString(s string) {
+ w.buf.writeString(s)
+}
+
+// List starts a list. It returns an internal index. Call EndList with
+// this index after encoding the content to finish the list.
+func (w EncoderBuffer) List() int {
+ return w.buf.list()
+}
+
+// ListEnd finishes the given list.
+func (w EncoderBuffer) ListEnd(index int) {
+ w.buf.listEnd(index)
+}
diff --git a/rlp/encbuffer_example_test.go b/rlp/encbuffer_example_test.go
new file mode 100644
index 000000000..c41de60f0
--- /dev/null
+++ b/rlp/encbuffer_example_test.go
@@ -0,0 +1,45 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp_test
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/tomochain/tomochain/rlp"
+)
+
+func ExampleEncoderBuffer() {
+ var w bytes.Buffer
+
+ // Encode [4, [5, 6]] to w.
+ buf := rlp.NewEncoderBuffer(&w)
+ l1 := buf.List()
+ buf.WriteUint64(4)
+ l2 := buf.List()
+ buf.WriteUint64(5)
+ buf.WriteUint64(6)
+ buf.ListEnd(l2)
+ buf.ListEnd(l1)
+
+ if err := buf.Flush(); err != nil {
+ panic(err)
+ }
+ fmt.Printf("%X\n", w.Bytes())
+ // Output:
+ // C404C20506
+}
diff --git a/rlp/encode.go b/rlp/encode.go
index 44592c2f5..2ca283c0a 100644
--- a/rlp/encode.go
+++ b/rlp/encode.go
@@ -17,20 +17,28 @@
package rlp
import (
+ "errors"
"fmt"
"io"
"math/big"
"reflect"
- "sync"
+
+ "github.com/holiman/uint256"
+ "github.com/tomochain/tomochain/rlp/internal/rlpstruct"
)
var (
// Common encoded values.
// These are useful when implementing EncodeRLP.
+
+ // EmptyString is the encoding of an empty string.
EmptyString = []byte{0x80}
- EmptyList = []byte{0xC0}
+ // EmptyList is the encoding of an empty list.
+ EmptyList = []byte{0xC0}
)
+var ErrNegativeBigInt = errors.New("rlp: cannot encode negative big.Int")
+
// Encoder is implemented by types that require custom
// encoding rules or want to encode private fields.
type Encoder interface {
@@ -49,80 +57,48 @@ type Encoder interface {
// perform many small writes in some cases. Consider making w
// buffered.
//
-// Encode uses the following type-dependent encoding rules:
-//
-// If the type implements the Encoder interface, Encode calls
-// EncodeRLP. This is true even for nil pointers, please see the
-// documentation for Encoder.
-//
-// To encode a pointer, the value being pointed to is encoded. For nil
-// pointers, Encode will encode the zero value of the type. A nil
-// pointer to a struct type always encodes as an empty RLP list.
-// A nil pointer to an array encodes as an empty list (or empty string
-// if the array has element type byte).
-//
-// Struct values are encoded as an RLP list of all their encoded
-// public fields. Recursive struct types are supported.
-//
-// To encode slices and arrays, the elements are encoded as an RLP
-// list of the value's elements. Note that arrays and slices with
-// element type uint8 or byte are always encoded as an RLP string.
-//
-// A Go string is encoded as an RLP string.
-//
-// An unsigned integer value is encoded as an RLP string. Zero always
-// encodes as an empty RLP string. Encode also supports *big.Int.
-//
-// An interface value encodes as the value contained in the interface.
-//
-// Boolean values are not supported, nor are signed integers, floating
-// point numbers, maps, channels and functions.
+// Please see package-level documentation of encoding rules.
func Encode(w io.Writer, val interface{}) error {
- if outer, ok := w.(*encbuf); ok {
- // Encode was called by some type's EncodeRLP.
- // Avoid copying by writing to the outer encbuf directly.
- return outer.encode(val)
+ // Optimization: reuse *encBuffer when called by EncodeRLP.
+ if buf := encBufferFromWriter(w); buf != nil {
+ return buf.encode(val)
}
- eb := encbufPool.Get().(*encbuf)
- defer encbufPool.Put(eb)
- eb.reset()
- if err := eb.encode(val); err != nil {
+
+ buf := getEncBuffer()
+ defer encBufferPool.Put(buf)
+ if err := buf.encode(val); err != nil {
return err
}
- return eb.toWriter(w)
+ return buf.writeTo(w)
}
-// EncodeBytes returns the RLP encoding of val.
-// Please see the documentation of Encode for the encoding rules.
+// EncodeToBytes returns the RLP encoding of val.
+// Please see package-level documentation for the encoding rules.
func EncodeToBytes(val interface{}) ([]byte, error) {
- eb := encbufPool.Get().(*encbuf)
- defer encbufPool.Put(eb)
- eb.reset()
- if err := eb.encode(val); err != nil {
+ buf := getEncBuffer()
+ defer encBufferPool.Put(buf)
+
+ if err := buf.encode(val); err != nil {
return nil, err
}
- return eb.toBytes(), nil
+ return buf.makeBytes(), nil
}
-// EncodeReader returns a reader from which the RLP encoding of val
+// EncodeToReader returns a reader from which the RLP encoding of val
// can be read. The returned size is the total size of the encoded
// data.
//
// Please see the documentation of Encode for the encoding rules.
func EncodeToReader(val interface{}) (size int, r io.Reader, err error) {
- eb := encbufPool.Get().(*encbuf)
- eb.reset()
- if err := eb.encode(val); err != nil {
+ buf := getEncBuffer()
+ if err := buf.encode(val); err != nil {
+ encBufferPool.Put(buf)
return 0, nil, err
}
- return eb.size(), &encReader{buf: eb}, nil
-}
-
-type encbuf struct {
- str []byte // string data, contains everything except list headers
- lheads []*listhead // all list headers
- lhsize int // sum of sizes of all encoded list headers
- sizebuf []byte // 9-byte auxiliary buffer for uint encoding
+ // Note: can't put the reader back into the pool here
+ // because it is held by encReader. The reader puts it
+ // back when it has been fully consumed.
+ return buf.size(), &encReader{buf: buf}, nil
}
type listhead struct {
@@ -151,214 +127,32 @@ func puthead(buf []byte, smalltag, largetag byte, size uint64) int {
if size < 56 {
buf[0] = smalltag + byte(size)
return 1
- } else {
- sizesize := putint(buf[1:], size)
- buf[0] = largetag + byte(sizesize)
- return sizesize + 1
- }
-}
-
-// encbufs are pooled.
-var encbufPool = sync.Pool{
- New: func() interface{} { return &encbuf{sizebuf: make([]byte, 9)} },
-}
-
-func (w *encbuf) reset() {
- w.lhsize = 0
- if w.str != nil {
- w.str = w.str[:0]
- }
- if w.lheads != nil {
- w.lheads = w.lheads[:0]
- }
-}
-
-// encbuf implements io.Writer so it can be passed it into EncodeRLP.
-func (w *encbuf) Write(b []byte) (int, error) {
- w.str = append(w.str, b...)
- return len(b), nil
-}
-
-func (w *encbuf) encode(val interface{}) error {
- rval := reflect.ValueOf(val)
- ti, err := cachedTypeInfo(rval.Type(), tags{})
- if err != nil {
- return err
- }
- return ti.writer(rval, w)
-}
-
-func (w *encbuf) encodeStringHeader(size int) {
- if size < 56 {
- w.str = append(w.str, 0x80+byte(size))
- } else {
- // TODO: encode to w.str directly
- sizesize := putint(w.sizebuf[1:], uint64(size))
- w.sizebuf[0] = 0xB7 + byte(sizesize)
- w.str = append(w.str, w.sizebuf[:sizesize+1]...)
- }
-}
-
-func (w *encbuf) encodeString(b []byte) {
- if len(b) == 1 && b[0] <= 0x7F {
- // fits single byte, no string header
- w.str = append(w.str, b[0])
- } else {
- w.encodeStringHeader(len(b))
- w.str = append(w.str, b...)
- }
-}
-
-func (w *encbuf) list() *listhead {
- lh := &listhead{offset: len(w.str), size: w.lhsize}
- w.lheads = append(w.lheads, lh)
- return lh
-}
-
-func (w *encbuf) listEnd(lh *listhead) {
- lh.size = w.size() - lh.offset - lh.size
- if lh.size < 56 {
- w.lhsize += 1 // length encoded into kind tag
- } else {
- w.lhsize += 1 + intsize(uint64(lh.size))
- }
-}
-
-func (w *encbuf) size() int {
- return len(w.str) + w.lhsize
-}
-
-func (w *encbuf) toBytes() []byte {
- out := make([]byte, w.size())
- strpos := 0
- pos := 0
- for _, head := range w.lheads {
- // write string data before header
- n := copy(out[pos:], w.str[strpos:head.offset])
- pos += n
- strpos += n
- // write the header
- enc := head.encode(out[pos:])
- pos += len(enc)
- }
- // copy string data after the last list header
- copy(out[pos:], w.str[strpos:])
- return out
-}
-
-func (w *encbuf) toWriter(out io.Writer) (err error) {
- strpos := 0
- for _, head := range w.lheads {
- // write string data before header
- if head.offset-strpos > 0 {
- n, err := out.Write(w.str[strpos:head.offset])
- strpos += n
- if err != nil {
- return err
- }
- }
- // write the header
- enc := head.encode(w.sizebuf)
- if _, err = out.Write(enc); err != nil {
- return err
- }
- }
- if strpos < len(w.str) {
- // write string data after the last list header
- _, err = out.Write(w.str[strpos:])
- }
- return err
-}
-
-// encReader is the io.Reader returned by EncodeToReader.
-// It releases its encbuf at EOF.
-type encReader struct {
- buf *encbuf // the buffer we're reading from. this is nil when we're at EOF.
- lhpos int // index of list header that we're reading
- strpos int // current position in string buffer
- piece []byte // next piece to be read
-}
-
-func (r *encReader) Read(b []byte) (n int, err error) {
- for {
- if r.piece = r.next(); r.piece == nil {
- // Put the encode buffer back into the pool at EOF when it
- // is first encountered. Subsequent calls still return EOF
- // as the error but the buffer is no longer valid.
- if r.buf != nil {
- encbufPool.Put(r.buf)
- r.buf = nil
- }
- return n, io.EOF
- }
- nn := copy(b[n:], r.piece)
- n += nn
- if nn < len(r.piece) {
- // piece didn't fit, see you next time.
- r.piece = r.piece[nn:]
- return n, nil
- }
- r.piece = nil
- }
-}
-
-// next returns the next piece of data to be read.
-// it returns nil at EOF.
-func (r *encReader) next() []byte {
- switch {
- case r.buf == nil:
- return nil
-
- case r.piece != nil:
- // There is still data available for reading.
- return r.piece
-
- case r.lhpos < len(r.buf.lheads):
- // We're before the last list header.
- head := r.buf.lheads[r.lhpos]
- sizebefore := head.offset - r.strpos
- if sizebefore > 0 {
- // String data before header.
- p := r.buf.str[r.strpos:head.offset]
- r.strpos += sizebefore
- return p
- } else {
- r.lhpos++
- return head.encode(r.buf.sizebuf)
- }
-
- case r.strpos < len(r.buf.str):
- // String data at the end, after all list headers.
- p := r.buf.str[r.strpos:]
- r.strpos = len(r.buf.str)
- return p
-
- default:
- return nil
}
+ sizesize := putint(buf[1:], size)
+ buf[0] = largetag + byte(sizesize)
+ return sizesize + 1
}
-var (
- encoderInterface = reflect.TypeOf(new(Encoder)).Elem()
- big0 = big.NewInt(0)
-)
+var encoderInterface = reflect.TypeOf(new(Encoder)).Elem()
// makeWriter creates a writer function for the given type.
-func makeWriter(typ reflect.Type, ts tags) (writer, error) {
+func makeWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) {
kind := typ.Kind()
switch {
case typ == rawValueType:
return writeRawValue, nil
- case typ.Implements(encoderInterface):
- return writeEncoder, nil
- case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(encoderInterface):
- return writeEncoderNoPtr, nil
- case kind == reflect.Interface:
- return writeInterface, nil
case typ.AssignableTo(reflect.PtrTo(bigInt)):
return writeBigIntPtr, nil
case typ.AssignableTo(bigInt):
return writeBigIntNoPtr, nil
+ case typ == reflect.PtrTo(u256Int):
+ return writeU256IntPtr, nil
+ case typ == u256Int:
+ return writeU256IntNoPtr, nil
+ case kind == reflect.Ptr:
+ return makePtrWriter(typ, ts)
+ case reflect.PtrTo(typ).Implements(encoderInterface):
+ return makeEncoderWriter(typ), nil
case isUint(kind):
return writeUint, nil
case kind == reflect.Bool:
@@ -368,97 +162,116 @@ func makeWriter(typ reflect.Type, ts tags) (writer, error) {
case kind == reflect.Slice && isByte(typ.Elem()):
return writeBytes, nil
case kind == reflect.Array && isByte(typ.Elem()):
- return writeByteArray, nil
+ return makeByteArrayWriter(typ), nil
case kind == reflect.Slice || kind == reflect.Array:
return makeSliceWriter(typ, ts)
case kind == reflect.Struct:
return makeStructWriter(typ)
- case kind == reflect.Ptr:
- return makePtrWriter(typ)
+ case kind == reflect.Interface:
+ return writeInterface, nil
default:
return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ)
}
}
-func isByte(typ reflect.Type) bool {
- return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
-}
-
-func writeRawValue(val reflect.Value, w *encbuf) error {
+func writeRawValue(val reflect.Value, w *encBuffer) error {
w.str = append(w.str, val.Bytes()...)
return nil
}
-func writeUint(val reflect.Value, w *encbuf) error {
- i := val.Uint()
- if i == 0 {
- w.str = append(w.str, 0x80)
- } else if i < 128 {
- // fits single byte
- w.str = append(w.str, byte(i))
- } else {
- // TODO: encode int to w.str directly
- s := putint(w.sizebuf[1:], i)
- w.sizebuf[0] = 0x80 + byte(s)
- w.str = append(w.str, w.sizebuf[:s+1]...)
- }
+func writeUint(val reflect.Value, w *encBuffer) error {
+ w.writeUint64(val.Uint())
return nil
}
-func writeBool(val reflect.Value, w *encbuf) error {
- if val.Bool() {
- w.str = append(w.str, 0x01)
- } else {
- w.str = append(w.str, 0x80)
- }
+func writeBool(val reflect.Value, w *encBuffer) error {
+ w.writeBool(val.Bool())
return nil
}
-func writeBigIntPtr(val reflect.Value, w *encbuf) error {
+func writeBigIntPtr(val reflect.Value, w *encBuffer) error {
ptr := val.Interface().(*big.Int)
if ptr == nil {
w.str = append(w.str, 0x80)
return nil
}
- return writeBigInt(ptr, w)
+ if ptr.Sign() == -1 {
+ return ErrNegativeBigInt
+ }
+ w.writeBigInt(ptr)
+ return nil
}
-func writeBigIntNoPtr(val reflect.Value, w *encbuf) error {
+func writeBigIntNoPtr(val reflect.Value, w *encBuffer) error {
i := val.Interface().(big.Int)
- return writeBigInt(&i, w)
+ if i.Sign() == -1 {
+ return ErrNegativeBigInt
+ }
+ w.writeBigInt(&i)
+ return nil
}
-func writeBigInt(i *big.Int, w *encbuf) error {
- if cmp := i.Cmp(big0); cmp == -1 {
- return fmt.Errorf("rlp: cannot encode negative *big.Int")
- } else if cmp == 0 {
+func writeU256IntPtr(val reflect.Value, w *encBuffer) error {
+ ptr := val.Interface().(*uint256.Int)
+ if ptr == nil {
w.str = append(w.str, 0x80)
- } else {
- w.encodeString(i.Bytes())
+ return nil
}
+ w.writeUint256(ptr)
+ return nil
+}
+
+func writeU256IntNoPtr(val reflect.Value, w *encBuffer) error {
+ i := val.Interface().(uint256.Int)
+ w.writeUint256(&i)
return nil
}
-func writeBytes(val reflect.Value, w *encbuf) error {
- w.encodeString(val.Bytes())
+func writeBytes(val reflect.Value, w *encBuffer) error {
+ w.writeBytes(val.Bytes())
return nil
}
-func writeByteArray(val reflect.Value, w *encbuf) error {
- if !val.CanAddr() {
- // Slice requires the value to be addressable.
- // Make it addressable by copying.
- copy := reflect.New(val.Type()).Elem()
- copy.Set(val)
- val = copy
+func makeByteArrayWriter(typ reflect.Type) writer {
+ switch typ.Len() {
+ case 0:
+ return writeLengthZeroByteArray
+ case 1:
+ return writeLengthOneByteArray
+ default:
+ length := typ.Len()
+ return func(val reflect.Value, w *encBuffer) error {
+ if !val.CanAddr() {
+ // Getting the byte slice of val requires it to be addressable. Make it
+ // addressable by copying.
+ copy := reflect.New(val.Type()).Elem()
+ copy.Set(val)
+ val = copy
+ }
+ slice := byteArrayBytes(val, length)
+ w.encodeStringHeader(len(slice))
+ w.str = append(w.str, slice...)
+ return nil
+ }
}
- size := val.Len()
- slice := val.Slice(0, size).Bytes()
- w.encodeString(slice)
+}
+
+func writeLengthZeroByteArray(val reflect.Value, w *encBuffer) error {
+ w.str = append(w.str, 0x80)
return nil
}
-func writeString(val reflect.Value, w *encbuf) error {
+func writeLengthOneByteArray(val reflect.Value, w *encBuffer) error {
+ b := byte(val.Index(0).Uint())
+ if b <= 0x7f {
+ w.str = append(w.str, b)
+ } else {
+ w.str = append(w.str, 0x81, b)
+ }
+ return nil
+}
+
+func writeString(val reflect.Value, w *encBuffer) error {
s := val.String()
if len(s) == 1 && s[0] <= 0x7f {
// fits single byte, no string header
@@ -470,27 +283,7 @@ func writeString(val reflect.Value, w *encbuf) error {
return nil
}
-func writeEncoder(val reflect.Value, w *encbuf) error {
- return val.Interface().(Encoder).EncodeRLP(w)
-}
-
-// writeEncoderNoPtr handles non-pointer values that implement Encoder
-// with a pointer receiver.
-func writeEncoderNoPtr(val reflect.Value, w *encbuf) error {
- if !val.CanAddr() {
- // We can't get the address. It would be possible to make the
- // value addressable by creating a shallow copy, but this
- // creates other problems so we're not doing it (yet).
- //
- // package json simply doesn't call MarshalJSON for cases like
- // this, but encodes the value as if it didn't implement the
- // interface. We don't want to handle it that way.
- return fmt.Errorf("rlp: game over: unadressable value of type %v, EncodeRLP is pointer method", val.Type())
- }
- return val.Addr().Interface().(Encoder).EncodeRLP(w)
-}
-
-func writeInterface(val reflect.Value, w *encbuf) error {
+func writeInterface(val reflect.Value, w *encBuffer) error {
if val.IsNil() {
// Write empty list. This is consistent with the previous RLP
// encoder that we had and should therefore avoid any
@@ -499,31 +292,51 @@ func writeInterface(val reflect.Value, w *encbuf) error {
return nil
}
eval := val.Elem()
- ti, err := cachedTypeInfo(eval.Type(), tags{})
+ writer, err := cachedWriter(eval.Type())
if err != nil {
return err
}
- return ti.writer(eval, w)
+ return writer(eval, w)
}
-func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) {
- etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
- if err != nil {
- return nil, err
+func makeSliceWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) {
+ etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{})
+ if etypeinfo.writerErr != nil {
+ return nil, etypeinfo.writerErr
}
- writer := func(val reflect.Value, w *encbuf) error {
- if !ts.tail {
- defer w.listEnd(w.list())
+
+ var wfn writer
+ if ts.Tail {
+ // This is for struct tail slices.
+ // w.list is not called for them.
+ wfn = func(val reflect.Value, w *encBuffer) error {
+ vlen := val.Len()
+ for i := 0; i < vlen; i++ {
+ if err := etypeinfo.writer(val.Index(i), w); err != nil {
+ return err
+ }
+ }
+ return nil
}
- vlen := val.Len()
- for i := 0; i < vlen; i++ {
- if err := etypeinfo.writer(val.Index(i), w); err != nil {
- return err
+ } else {
+ // This is for regular slices and arrays.
+ wfn = func(val reflect.Value, w *encBuffer) error {
+ vlen := val.Len()
+ if vlen == 0 {
+ w.str = append(w.str, 0xC0)
+ return nil
+ }
+ listOffset := w.list()
+ for i := 0; i < vlen; i++ {
+ if err := etypeinfo.writer(val.Index(i), w); err != nil {
+ return err
+ }
}
+ w.listEnd(listOffset)
+ return nil
}
- return nil
}
- return writer, nil
+ return wfn, nil
}
func makeStructWriter(typ reflect.Type) (writer, error) {
@@ -531,56 +344,86 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
if err != nil {
return nil, err
}
- writer := func(val reflect.Value, w *encbuf) error {
- lh := w.list()
- for _, f := range fields {
- if err := f.info.writer(val.Field(f.index), w); err != nil {
- return err
+ for _, f := range fields {
+ if f.info.writerErr != nil {
+ return nil, structFieldError{typ, f.index, f.info.writerErr}
+ }
+ }
+
+ var writer writer
+ firstOptionalField := firstOptionalField(fields)
+ if firstOptionalField == len(fields) {
+ // This is the writer function for structs without any optional fields.
+ writer = func(val reflect.Value, w *encBuffer) error {
+ lh := w.list()
+ for _, f := range fields {
+ if err := f.info.writer(val.Field(f.index), w); err != nil {
+ return err
+ }
}
+ w.listEnd(lh)
+ return nil
+ }
+ } else {
+ // If there are any "optional" fields, the writer needs to perform additional
+ // checks to determine the output list length.
+ writer = func(val reflect.Value, w *encBuffer) error {
+ lastField := len(fields) - 1
+ for ; lastField >= firstOptionalField; lastField-- {
+ if !val.Field(fields[lastField].index).IsZero() {
+ break
+ }
+ }
+ lh := w.list()
+ for i := 0; i <= lastField; i++ {
+ if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil {
+ return err
+ }
+ }
+ w.listEnd(lh)
+ return nil
}
- w.listEnd(lh)
- return nil
}
return writer, nil
}
-func makePtrWriter(typ reflect.Type) (writer, error) {
- etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
- if err != nil {
- return nil, err
+func makePtrWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) {
+ nilEncoding := byte(0xC0)
+ if typeNilKind(typ.Elem(), ts) == String {
+ nilEncoding = 0x80
}
- // determine nil pointer handler
- var nilfunc func(*encbuf) error
- kind := typ.Elem().Kind()
- switch {
- case kind == reflect.Array && isByte(typ.Elem().Elem()):
- nilfunc = func(w *encbuf) error {
- w.str = append(w.str, 0x80)
- return nil
- }
- case kind == reflect.Struct || kind == reflect.Array:
- nilfunc = func(w *encbuf) error {
- // encoding the zero value of a struct/array could trigger
- // infinite recursion, avoid that.
- w.listEnd(w.list())
- return nil
- }
- default:
- zero := reflect.Zero(typ.Elem())
- nilfunc = func(w *encbuf) error {
- return etypeinfo.writer(zero, w)
+ etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{})
+ if etypeinfo.writerErr != nil {
+ return nil, etypeinfo.writerErr
+ }
+
+ writer := func(val reflect.Value, w *encBuffer) error {
+ if ev := val.Elem(); ev.IsValid() {
+ return etypeinfo.writer(ev, w)
}
+ w.str = append(w.str, nilEncoding)
+ return nil
}
+ return writer, nil
+}
- writer := func(val reflect.Value, w *encbuf) error {
- if val.IsNil() {
- return nilfunc(w)
- } else {
- return etypeinfo.writer(val.Elem(), w)
+func makeEncoderWriter(typ reflect.Type) writer {
+ if typ.Implements(encoderInterface) {
+ return func(val reflect.Value, w *encBuffer) error {
+ return val.Interface().(Encoder).EncodeRLP(w)
+ }
+ }
+ w := func(val reflect.Value, w *encBuffer) error {
+ if !val.CanAddr() {
+ // package json simply doesn't call MarshalJSON for this case, but encodes the
+ // value as if it didn't implement the interface. We don't want to handle it that
+ // way.
+ return fmt.Errorf("rlp: unadressable value of type %v, EncodeRLP is pointer method", val.Type())
}
+ return val.Addr().Interface().(Encoder).EncodeRLP(w)
}
- return writer, err
+ return w
}
// putint writes i to the beginning of b in big endian byte
diff --git a/rlp/encode_test.go b/rlp/encode_test.go
index 827960f7c..9f2e6c38f 100644
--- a/rlp/encode_test.go
+++ b/rlp/encode_test.go
@@ -21,10 +21,13 @@ import (
"errors"
"fmt"
"io"
- "io/ioutil"
"math/big"
+ "runtime"
"sync"
"testing"
+
+ "github.com/holiman/uint256"
+ "github.com/tomochain/tomochain/common/math"
)
type testEncoder struct {
@@ -33,12 +36,19 @@ type testEncoder struct {
func (e *testEncoder) EncodeRLP(w io.Writer) error {
if e == nil {
- w.Write([]byte{0, 0, 0, 0})
- } else if e.err != nil {
+ panic("EncodeRLP called on nil value")
+ }
+ if e.err != nil {
return e.err
- } else {
- w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1})
}
+ w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1})
+ return nil
+}
+
+type testEncoderValueMethod struct{}
+
+func (e testEncoderValueMethod) EncodeRLP(w io.Writer) error {
+ w.Write([]byte{0xFA, 0xFE, 0xF0})
return nil
}
@@ -49,6 +59,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error {
return nil
}
+type undecodableEncoder func()
+
+func (f undecodableEncoder) EncodeRLP(w io.Writer) error {
+ w.Write([]byte{0xF5, 0xF5, 0xF5})
+ return nil
+}
+
type encodableReader struct {
A, B uint
}
@@ -103,35 +120,95 @@ var encTests = []encTest{
{val: big.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"},
{val: big.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"},
{
- val: big.NewInt(0).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")),
+ val: new(big.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")),
output: "8F102030405060708090A0B0C0D0E0F2",
},
{
- val: big.NewInt(0).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")),
+ val: new(big.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")),
output: "9C0100020003000400050006000700080009000A000B000C000D000E01",
},
{
- val: big.NewInt(0).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")),
+ val: new(big.Int).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")),
output: "A1010000000000000000000000000000000000000000000000000000000000000000",
},
+ {
+ val: veryBigInt,
+ output: "89FFFFFFFFFFFFFFFFFF",
+ },
+ {
+ val: veryVeryBigInt,
+ output: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001",
+ },
// non-pointer big.Int
{val: *big.NewInt(0), output: "80"},
{val: *big.NewInt(0xFFFFFF), output: "83FFFFFF"},
// negative ints are not supported
- {val: big.NewInt(-1), error: "rlp: cannot encode negative *big.Int"},
-
- // byte slices, strings
+ {val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"},
+ {val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"},
+
+ // uint256
+ {val: uint256.NewInt(0), output: "80"},
+ {val: uint256.NewInt(1), output: "01"},
+ {val: uint256.NewInt(127), output: "7F"},
+ {val: uint256.NewInt(128), output: "8180"},
+ {val: uint256.NewInt(256), output: "820100"},
+ {val: uint256.NewInt(1024), output: "820400"},
+ {val: uint256.NewInt(0xFFFFFF), output: "83FFFFFF"},
+ {val: uint256.NewInt(0xFFFFFFFF), output: "84FFFFFFFF"},
+ {val: uint256.NewInt(0xFFFFFFFFFF), output: "85FFFFFFFFFF"},
+ {val: uint256.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"},
+ {val: uint256.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"},
+ {
+ val: new(uint256.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")),
+ output: "8F102030405060708090A0B0C0D0E0F2",
+ },
+ {
+ val: new(uint256.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")),
+ output: "9C0100020003000400050006000700080009000A000B000C000D000E01",
+ },
+ // non-pointer uint256.Int
+ {val: *uint256.NewInt(0), output: "80"},
+ {val: *uint256.NewInt(0xFFFFFF), output: "83FFFFFF"},
+
+ // byte arrays
+ {val: [0]byte{}, output: "80"},
+ {val: [1]byte{0}, output: "00"},
+ {val: [1]byte{1}, output: "01"},
+ {val: [1]byte{0x7F}, output: "7F"},
+ {val: [1]byte{0x80}, output: "8180"},
+ {val: [1]byte{0xFF}, output: "81FF"},
+ {val: [3]byte{1, 2, 3}, output: "83010203"},
+ {val: [57]byte{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
+
+ // named byte type arrays
+ {val: [0]namedByteType{}, output: "80"},
+ {val: [1]namedByteType{0}, output: "00"},
+ {val: [1]namedByteType{1}, output: "01"},
+ {val: [1]namedByteType{0x7F}, output: "7F"},
+ {val: [1]namedByteType{0x80}, output: "8180"},
+ {val: [1]namedByteType{0xFF}, output: "81FF"},
+ {val: [3]namedByteType{1, 2, 3}, output: "83010203"},
+ {val: [57]namedByteType{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
+
+ // byte slices
{val: []byte{}, output: "80"},
+ {val: []byte{0}, output: "00"},
{val: []byte{0x7E}, output: "7E"},
{val: []byte{0x7F}, output: "7F"},
{val: []byte{0x80}, output: "8180"},
{val: []byte{1, 2, 3}, output: "83010203"},
+ // named byte type slices
+ {val: []namedByteType{}, output: "80"},
+ {val: []namedByteType{0}, output: "00"},
+ {val: []namedByteType{0x7E}, output: "7E"},
+ {val: []namedByteType{0x7F}, output: "7F"},
+ {val: []namedByteType{0x80}, output: "8180"},
{val: []namedByteType{1, 2, 3}, output: "83010203"},
- {val: [...]namedByteType{1, 2, 3}, output: "83010203"},
+ // strings
{val: "", output: "80"},
{val: "\x7E", output: "7E"},
{val: "\x7F", output: "7F"},
@@ -204,6 +281,12 @@ var encTests = []encTest{
output: "F90200CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376",
},
+ // Non-byte arrays are encoded as lists.
+ // Note that it is important to test [4]uint64 specifically,
+ // because that's the underlying type of uint256.Int.
+ {val: [4]uint32{1, 2, 3, 4}, output: "C401020304"},
+ {val: [4]uint64{1, 2, 3, 4}, output: "C401020304"},
+
// RawValue
{val: RawValue(unhex("01")), output: "01"},
{val: RawValue(unhex("82FFFF")), output: "82FFFF"},
@@ -214,11 +297,34 @@ var encTests = []encTest{
{val: simplestruct{A: 3, B: "foo"}, output: "C50383666F6F"},
{val: &recstruct{5, nil}, output: "C205C0"},
{val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"},
+ {val: &intField{X: 3}, error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)"},
+
+ // struct tag "-"
+ {val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"},
+
+ // struct tag "tail"
{val: &tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}}, output: "C3010203"},
{val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"},
{val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"},
{val: &tailRaw{A: 1, Tail: nil}, output: "C101"},
- {val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"},
+
+ // struct tag "optional"
+ {val: &optionalFields{}, output: "C180"},
+ {val: &optionalFields{A: 1}, output: "C101"},
+ {val: &optionalFields{A: 1, B: 2}, output: "C20102"},
+ {val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"},
+ {val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"},
+ {val: &optionalAndTailField{A: 1}, output: "C101"},
+ {val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"},
+ {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
+ {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
+ {val: &optionalBigIntField{A: 1}, output: "C101"},
+ {val: &optionalPtrField{A: 1}, output: "C101"},
+ {val: &optionalPtrFieldNil{A: 1}, output: "C101"},
+ {val: &multipleOptionalFields{A: nil, B: nil}, output: "C0"},
+ {val: &multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, output: "C88301020383010203"},
+ {val: &multipleOptionalFields{A: nil, B: &[3]byte{1, 2, 3}}, output: "C58083010203"}, // encodes without error but decode will fail
+ {val: &nonOptionalPtrField{A: 1}, output: "C20180"}, // encodes without error but decode will fail
// nil
{val: (*uint)(nil), output: "80"},
@@ -226,26 +332,73 @@ var encTests = []encTest{
{val: (*[]byte)(nil), output: "80"},
{val: (*[10]byte)(nil), output: "80"},
{val: (*big.Int)(nil), output: "80"},
+ {val: (*uint256.Int)(nil), output: "80"},
{val: (*[]string)(nil), output: "C0"},
{val: (*[10]string)(nil), output: "C0"},
{val: (*[]interface{})(nil), output: "C0"},
{val: (*[]struct{ uint })(nil), output: "C0"},
{val: (*interface{})(nil), output: "C0"},
+ // nil struct fields
+ {
+ val: struct {
+ X *[]byte
+ }{},
+ output: "C180",
+ },
+ {
+ val: struct {
+ X *[2]byte
+ }{},
+ output: "C180",
+ },
+ {
+ val: struct {
+ X *uint64
+ }{},
+ output: "C180",
+ },
+ {
+ val: struct {
+ X *uint64 `rlp:"nilList"`
+ }{},
+ output: "C1C0",
+ },
+ {
+ val: struct {
+ X *[]uint64
+ }{},
+ output: "C1C0",
+ },
+ {
+ val: struct {
+ X *[]uint64 `rlp:"nilString"`
+ }{},
+ output: "C180",
+ },
+
// interfaces
{val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct
// Encoder
- {val: (*testEncoder)(nil), output: "00000000"},
+ {val: (*testEncoder)(nil), output: "C0"},
{val: &testEncoder{}, output: "00010001000100010001"},
{val: &testEncoder{errors.New("test error")}, error: "test error"},
- // verify that pointer method testEncoder.EncodeRLP is called for
+ {val: struct{ E testEncoderValueMethod }{}, output: "C3FAFEF0"},
+ {val: struct{ E *testEncoderValueMethod }{}, output: "C1C0"},
+
+ // Verify that the Encoder interface works for unsupported types like func().
+ {val: undecodableEncoder(func() {}), output: "F5F5F5"},
+
+ // Verify that pointer method testEncoder.EncodeRLP is called for
// addressable non-pointer values.
{val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"},
{val: &struct{ TE testEncoder }{testEncoder{errors.New("test error")}}, error: "test error"},
- // verify the error for non-addressable non-pointer Encoder
- {val: testEncoder{}, error: "rlp: game over: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"},
- // verify the special case for []byte
+
+ // Verify the error for non-addressable non-pointer Encoder.
+ {val: testEncoder{}, error: "rlp: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"},
+
+ // Verify Encoder takes precedence over []byte.
{val: []byteEncoder{0, 1, 2, 3, 4}, output: "C5C0C0C0C0C0"},
}
@@ -281,13 +434,28 @@ func TestEncodeToBytes(t *testing.T) {
runEncTests(t, EncodeToBytes)
}
+func TestEncodeAppendToBytes(t *testing.T) {
+ buffer := make([]byte, 20)
+ runEncTests(t, func(val interface{}) ([]byte, error) {
+ w := NewEncoderBuffer(nil)
+ defer w.Flush()
+
+ err := Encode(w, val)
+ if err != nil {
+ return nil, err
+ }
+ output := w.AppendToBytes(buffer[:0])
+ return output, nil
+ })
+}
+
func TestEncodeToReader(t *testing.T) {
runEncTests(t, func(val interface{}) ([]byte, error) {
_, r, err := EncodeToReader(val)
if err != nil {
return nil, err
}
- return ioutil.ReadAll(r)
+ return io.ReadAll(r)
})
}
@@ -328,7 +496,7 @@ func TestEncodeToReaderReturnToPool(t *testing.T) {
go func() {
for i := 0; i < 1000; i++ {
_, r, _ := EncodeToReader("foo")
- ioutil.ReadAll(r)
+ io.ReadAll(r)
r.Read(buf)
r.Read(buf)
r.Read(buf)
@@ -339,3 +507,132 @@ func TestEncodeToReaderReturnToPool(t *testing.T) {
}
wg.Wait()
}
+
+var sink interface{}
+
+func BenchmarkIntsize(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ sink = intsize(0x12345678)
+ }
+}
+
+func BenchmarkPutint(b *testing.B) {
+ buf := make([]byte, 8)
+ for i := 0; i < b.N; i++ {
+ putint(buf, 0x12345678)
+ sink = buf
+ }
+}
+
+func BenchmarkEncodeBigInts(b *testing.B) {
+ ints := make([]*big.Int, 200)
+ for i := range ints {
+ ints[i] = math.BigPow(2, int64(i))
+ }
+ out := bytes.NewBuffer(make([]byte, 0, 4096))
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ out.Reset()
+ if err := Encode(out, ints); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkEncodeU256Ints(b *testing.B) {
+ ints := make([]*uint256.Int, 200)
+ for i := range ints {
+ ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i)))
+ }
+ out := bytes.NewBuffer(make([]byte, 0, 4096))
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ out.Reset()
+ if err := Encode(out, ints); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkEncodeConcurrentInterface(b *testing.B) {
+ type struct1 struct {
+ A string
+ B *big.Int
+ C [20]byte
+ }
+ value := []interface{}{
+ uint(999),
+ &struct1{A: "hello", B: big.NewInt(0xFFFFFFFF)},
+ [10]byte{1, 2, 3, 4, 5, 6},
+ []string{"yeah", "yeah", "yeah"},
+ }
+
+ var wg sync.WaitGroup
+ for cpu := 0; cpu < runtime.NumCPU(); cpu++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ var buffer bytes.Buffer
+ for i := 0; i < b.N; i++ {
+ buffer.Reset()
+ err := Encode(&buffer, value)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+type byteArrayStruct struct {
+ A [20]byte
+ B [32]byte
+ C [32]byte
+}
+
+func BenchmarkEncodeByteArrayStruct(b *testing.B) {
+ var out bytes.Buffer
+ var value byteArrayStruct
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ out.Reset()
+ if err := Encode(&out, &value); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+type structSliceElem struct {
+ X uint64
+ Y uint64
+ Z uint64
+}
+
+type structPtrSlice []*structSliceElem
+
+func BenchmarkEncodeStructPtrSlice(b *testing.B) {
+ var out bytes.Buffer
+ var value = structPtrSlice{
+ &structSliceElem{1, 1, 1},
+ &structSliceElem{2, 2, 2},
+ &structSliceElem{3, 3, 3},
+ &structSliceElem{5, 5, 5},
+ &structSliceElem{6, 6, 6},
+ &structSliceElem{7, 7, 7},
+ }
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ out.Reset()
+ if err := Encode(&out, &value); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
diff --git a/rlp/encoder_example_test.go b/rlp/encoder_example_test.go
index 1cffa241c..6291bfafe 100644
--- a/rlp/encoder_example_test.go
+++ b/rlp/encoder_example_test.go
@@ -14,11 +14,13 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package rlp
+package rlp_test
import (
"fmt"
"io"
+
+ "github.com/tomochain/tomochain/rlp"
)
type MyCoolType struct {
@@ -28,27 +30,19 @@ type MyCoolType struct {
// EncodeRLP writes x as RLP list [a, b] that omits the Name field.
func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) {
- // Note: the receiver can be a nil pointer. This allows you to
- // control the encoding of nil, but it also means that you have to
- // check for a nil receiver.
- if x == nil {
- err = Encode(w, []uint{0, 0})
- } else {
- err = Encode(w, []uint{x.a, x.b})
- }
- return err
+ return rlp.Encode(w, []uint{x.a, x.b})
}
func ExampleEncoder() {
var t *MyCoolType // t is nil pointer to MyCoolType
- bytes, _ := EncodeToBytes(t)
+ bytes, _ := rlp.EncodeToBytes(t)
fmt.Printf("%v → %X\n", t, bytes)
t = &MyCoolType{Name: "foobar", a: 5, b: 6}
- bytes, _ = EncodeToBytes(t)
+ bytes, _ = rlp.EncodeToBytes(t)
fmt.Printf("%v → %X\n", t, bytes)
// Output:
- // → C28080
+ // → C0
// &{foobar 5 6} → C20506
}
diff --git a/rlp/internal/rlpstruct/rlpstruct.go b/rlp/internal/rlpstruct/rlpstruct.go
new file mode 100644
index 000000000..2e3eeb688
--- /dev/null
+++ b/rlp/internal/rlpstruct/rlpstruct.go
@@ -0,0 +1,213 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package rlpstruct implements struct processing for RLP encoding/decoding.
+//
+// In particular, this package handles all rules around field filtering,
+// struct tags and nil value determination.
+package rlpstruct
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// Field represents a struct field.
+type Field struct {
+ Name string
+ Index int
+ Exported bool
+ Type Type
+ Tag string
+}
+
+// Type represents the attributes of a Go type.
+type Type struct {
+ Name string
+ Kind reflect.Kind
+ IsEncoder bool // whether type implements rlp.Encoder
+ IsDecoder bool // whether type implements rlp.Decoder
+ Elem *Type // non-nil for Kind values of Ptr, Slice, Array
+}
+
+// DefaultNilValue determines whether a nil pointer to t encodes/decodes
+// as an empty string or empty list.
+func (t Type) DefaultNilValue() NilKind {
+ k := t.Kind
+ if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(t) {
+ return NilKindString
+ }
+ return NilKindList
+}
+
+// NilKind is the RLP value encoded in place of nil pointers.
+type NilKind uint8
+
+const (
+ NilKindString NilKind = 0x80
+ NilKindList NilKind = 0xC0
+)
+
+// Tags represents struct tags.
+type Tags struct {
+ // rlp:"nil" controls whether empty input results in a nil pointer.
+ // nilKind is the kind of empty value allowed for the field.
+ NilKind NilKind
+ NilOK bool
+
+ // rlp:"optional" allows for a field to be missing in the input list.
+ // If this is set, all subsequent fields must also be optional.
+ Optional bool
+
+ // rlp:"tail" controls whether this field swallows additional list elements. It can
+ // only be set for the last field, which must be of slice type.
+ Tail bool
+
+ // rlp:"-" ignores fields.
+ Ignored bool
+}
+
+// TagError is raised for invalid struct tags.
+type TagError struct {
+ StructType string
+
+ // These are set by this package.
+ Field string
+ Tag string
+ Err string
+}
+
+func (e TagError) Error() string {
+ field := "field " + e.Field
+ if e.StructType != "" {
+ field = e.StructType + "." + e.Field
+ }
+ return fmt.Sprintf("rlp: invalid struct tag %q for %s (%s)", e.Tag, field, e.Err)
+}
+
+// ProcessFields filters the given struct fields, returning only fields
+// that should be considered for encoding/decoding.
+func ProcessFields(allFields []Field) ([]Field, []Tags, error) {
+ lastPublic := lastPublicField(allFields)
+
+ // Gather all exported fields and their tags.
+ var fields []Field
+ var tags []Tags
+ for _, field := range allFields {
+ if !field.Exported {
+ continue
+ }
+ ts, err := parseTag(field, lastPublic)
+ if err != nil {
+ return nil, nil, err
+ }
+ if ts.Ignored {
+ continue
+ }
+ fields = append(fields, field)
+ tags = append(tags, ts)
+ }
+
+ // Verify optional field consistency. If any optional field exists,
+ // all fields after it must also be optional. Note: optional + tail
+ // is supported.
+ var anyOptional bool
+ var firstOptionalName string
+ for i, ts := range tags {
+ name := fields[i].Name
+ if ts.Optional || ts.Tail {
+ if !anyOptional {
+ firstOptionalName = name
+ }
+ anyOptional = true
+ } else {
+ if anyOptional {
+ msg := fmt.Sprintf("must be optional because preceding field %q is optional", firstOptionalName)
+ return nil, nil, TagError{Field: name, Err: msg}
+ }
+ }
+ }
+ return fields, tags, nil
+}
+
+func parseTag(field Field, lastPublic int) (Tags, error) {
+ name := field.Name
+ tag := reflect.StructTag(field.Tag)
+ var ts Tags
+ for _, t := range strings.Split(tag.Get("rlp"), ",") {
+ switch t = strings.TrimSpace(t); t {
+ case "":
+ // empty tag is allowed for some reason
+ case "-":
+ ts.Ignored = true
+ case "nil", "nilString", "nilList":
+ ts.NilOK = true
+ if field.Type.Kind != reflect.Ptr {
+ return ts, TagError{Field: name, Tag: t, Err: "field is not a pointer"}
+ }
+ switch t {
+ case "nil":
+ ts.NilKind = field.Type.Elem.DefaultNilValue()
+ case "nilString":
+ ts.NilKind = NilKindString
+ case "nilList":
+ ts.NilKind = NilKindList
+ }
+ case "optional":
+ ts.Optional = true
+ if ts.Tail {
+ return ts, TagError{Field: name, Tag: t, Err: `also has "tail" tag`}
+ }
+ case "tail":
+ ts.Tail = true
+ if field.Index != lastPublic {
+ return ts, TagError{Field: name, Tag: t, Err: "must be on last field"}
+ }
+ if ts.Optional {
+ return ts, TagError{Field: name, Tag: t, Err: `also has "optional" tag`}
+ }
+ if field.Type.Kind != reflect.Slice {
+ return ts, TagError{Field: name, Tag: t, Err: "field type is not slice"}
+ }
+ default:
+ return ts, TagError{Field: name, Tag: t, Err: "unknown tag"}
+ }
+ }
+ return ts, nil
+}
+
+func lastPublicField(fields []Field) int {
+ last := 0
+ for _, f := range fields {
+ if f.Exported {
+ last = f.Index
+ }
+ }
+ return last
+}
+
+func isUint(k reflect.Kind) bool {
+ return k >= reflect.Uint && k <= reflect.Uintptr
+}
+
+func isByte(typ Type) bool {
+ return typ.Kind == reflect.Uint8 && !typ.IsEncoder
+}
+
+func isByteArray(typ Type) bool {
+ return (typ.Kind == reflect.Slice || typ.Kind == reflect.Array) && isByte(*typ.Elem)
+}
diff --git a/rlp/iterator.go b/rlp/iterator.go
new file mode 100644
index 000000000..6be574572
--- /dev/null
+++ b/rlp/iterator.go
@@ -0,0 +1,60 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+type listIterator struct {
+ data []byte
+ next []byte
+ err error
+}
+
+// NewListIterator creates an iterator for the (list) represented by data
+// TODO: Consider removing this implementation, as it is no longer used.
+func NewListIterator(data RawValue) (*listIterator, error) {
+ k, t, c, err := readKind(data)
+ if err != nil {
+ return nil, err
+ }
+ if k != List {
+ return nil, ErrExpectedList
+ }
+ it := &listIterator{
+ data: data[t : t+c],
+ }
+ return it, nil
+}
+
+// Next forwards the iterator one step, returns true if it was not at end yet
+func (it *listIterator) Next() bool {
+ if len(it.data) == 0 {
+ return false
+ }
+ _, t, c, err := readKind(it.data)
+ it.next = it.data[:t+c]
+ it.data = it.data[t+c:]
+ it.err = err
+ return true
+}
+
+// Value returns the current value
+func (it *listIterator) Value() []byte {
+ return it.next
+}
+
+func (it *listIterator) Err() error {
+ return it.err
+}
diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go
new file mode 100644
index 000000000..87c11bdba
--- /dev/null
+++ b/rlp/iterator_test.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "testing"
+
+ "github.com/tomochain/tomochain/common/hexutil"
+)
+
+// TestIterator tests some basic things about the ListIterator. A more
+// comprehensive test can be found in core/rlp_test.go, where we can
+// use both types and rlp without dependency cycles
+func TestIterator(t *testing.T) {
+ bodyRlpHex := "0xf902cbf8d6f869800182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ba01025c66fad28b4ce3370222624d952c35529e602af7cbe04f667371f61b0e3b3a00ab8813514d1217059748fd903288ace1b4001a4bc5fbde2790debdc8167de2ff869010182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ca05ac4cf1d19be06f3742c21df6c49a7e929ceb3dbaf6a09f3cfb56ff6828bd9a7a06875970133a35e63ac06d360aa166d228cc013e9b96e0a2cae7f55b22e1ee2e8f901f0f901eda0c75448377c0e426b8017b23c5f77379ecf69abc1d5c224284ad3ba1c46c59adaa00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000808080808080a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"
+ bodyRlp := hexutil.MustDecode(bodyRlpHex)
+
+ it, err := NewListIterator(bodyRlp)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Check that txs exist
+ if !it.Next() {
+ t.Fatal("expected two elems, got zero")
+ }
+ txs := it.Value()
+ // Check that uncles exist
+ if !it.Next() {
+ t.Fatal("expected two elems, got one")
+ }
+ txit, err := NewListIterator(txs)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var i = 0
+ for txit.Next() {
+ if txit.err != nil {
+ t.Fatal(txit.err)
+ }
+ i++
+ }
+ if exp := 2; i != exp {
+ t.Errorf("count wrong, expected %d got %d", i, exp)
+ }
+}
diff --git a/rlp/raw.go b/rlp/raw.go
index 2b3f328f6..773aa7e61 100644
--- a/rlp/raw.go
+++ b/rlp/raw.go
@@ -28,12 +28,53 @@ type RawValue []byte
var rawValueType = reflect.TypeOf(RawValue{})
+// StringSize returns the encoded size of a string.
+func StringSize(s string) uint64 {
+ switch {
+ case len(s) == 0:
+ return 1
+ case len(s) == 1:
+ if s[0] <= 0x7f {
+ return 1
+ } else {
+ return 2
+ }
+ default:
+ return uint64(headsize(uint64(len(s))) + len(s))
+ }
+}
+
+// BytesSize returns the encoded size of a byte slice.
+func BytesSize(b []byte) uint64 {
+ switch {
+ case len(b) == 0:
+ return 1
+ case len(b) == 1:
+ if b[0] <= 0x7f {
+ return 1
+ } else {
+ return 2
+ }
+ default:
+ return uint64(headsize(uint64(len(b))) + len(b))
+ }
+}
+
// ListSize returns the encoded size of an RLP list with the given
// content size.
func ListSize(contentSize uint64) uint64 {
return uint64(headsize(contentSize)) + contentSize
}
+// IntSize returns the encoded size of the integer x. Note: The return type of this
+// function is 'int' for backwards-compatibility reasons. The result is always positive.
+func IntSize(x uint64) int {
+ if x < 0x80 {
+ return 1
+ }
+ return 1 + intsize(x)
+}
+
// Split returns the content of first RLP value and any
// bytes after the value as subslices of b.
func Split(b []byte) (k Kind, content, rest []byte, err error) {
@@ -57,6 +98,32 @@ func SplitString(b []byte) (content, rest []byte, err error) {
return content, rest, nil
}
+// SplitUint64 decodes an integer at the beginning of b.
+// It also returns the remaining data after the integer in 'rest'.
+func SplitUint64(b []byte) (x uint64, rest []byte, err error) {
+ content, rest, err := SplitString(b)
+ if err != nil {
+ return 0, b, err
+ }
+ switch {
+ case len(content) == 0:
+ return 0, rest, nil
+ case len(content) == 1:
+ if content[0] == 0 {
+ return 0, b, ErrCanonInt
+ }
+ return uint64(content[0]), rest, nil
+ case len(content) > 8:
+ return 0, b, errUintOverflow
+ default:
+ x, err = readSize(content, byte(len(content)))
+ if err != nil {
+ return 0, b, ErrCanonInt
+ }
+ return x, rest, nil
+ }
+}
+
// SplitList splits b into the content of a list and any remaining
// bytes after the list.
func SplitList(b []byte) (content, rest []byte, err error) {
@@ -154,3 +221,74 @@ func readSize(b []byte, slen byte) (uint64, error) {
}
return s, nil
}
+
+// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice.
+func AppendUint64(b []byte, i uint64) []byte {
+ if i == 0 {
+ return append(b, 0x80)
+ } else if i < 128 {
+ return append(b, byte(i))
+ }
+ switch {
+ case i < (1 << 8):
+ return append(b, 0x81, byte(i))
+ case i < (1 << 16):
+ return append(b, 0x82,
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 24):
+ return append(b, 0x83,
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 32):
+ return append(b, 0x84,
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 40):
+ return append(b, 0x85,
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+
+ case i < (1 << 48):
+ return append(b, 0x86,
+ byte(i>>40),
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 56):
+ return append(b, 0x87,
+ byte(i>>48),
+ byte(i>>40),
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+
+ default:
+ return append(b, 0x88,
+ byte(i>>56),
+ byte(i>>48),
+ byte(i>>40),
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ }
+}
diff --git a/rlp/raw_test.go b/rlp/raw_test.go
index 2aad04210..7b3255eca 100644
--- a/rlp/raw_test.go
+++ b/rlp/raw_test.go
@@ -18,9 +18,10 @@ package rlp
import (
"bytes"
+ "errors"
"io"
- "reflect"
"testing"
+ "testing/quick"
)
func TestCountValues(t *testing.T) {
@@ -53,21 +54,84 @@ func TestCountValues(t *testing.T) {
if count != test.count {
t.Errorf("test %d: count mismatch, got %d want %d\ninput: %s", i, count, test.count, test.input)
}
- if !reflect.DeepEqual(err, test.err) {
+ if !errors.Is(err, test.err) {
t.Errorf("test %d: err mismatch, got %q want %q\ninput: %s", i, err, test.err, test.input)
}
}
}
-func TestSplitTypes(t *testing.T) {
- if _, _, err := SplitString(unhex("C100")); err != ErrExpectedString {
- t.Errorf("SplitString returned %q, want %q", err, ErrExpectedString)
+func TestSplitString(t *testing.T) {
+ for i, test := range []string{
+ "C0",
+ "C100",
+ "C3010203",
+ "C88363617483646F67",
+ "F8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974",
+ } {
+ if _, _, err := SplitString(unhex(test)); !errors.Is(err, ErrExpectedString) {
+ t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedString)
+ }
+ }
+}
+
+func TestSplitList(t *testing.T) {
+ for i, test := range []string{
+ "80",
+ "00",
+ "01",
+ "8180",
+ "81FF",
+ "820400",
+ "83636174",
+ "83646F67",
+ "B8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974",
+ } {
+ if _, _, err := SplitList(unhex(test)); !errors.Is(err, ErrExpectedList) {
+ t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedList)
+ }
}
- if _, _, err := SplitList(unhex("01")); err != ErrExpectedList {
- t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList)
+}
+
+func TestSplitUint64(t *testing.T) {
+ tests := []struct {
+ input string
+ val uint64
+ rest string
+ err error
+ }{
+ {"01", 1, "", nil},
+ {"7FFF", 0x7F, "FF", nil},
+ {"80FF", 0, "FF", nil},
+ {"81FAFF", 0xFA, "FF", nil},
+ {"82FAFAFF", 0xFAFA, "FF", nil},
+ {"83FAFAFAFF", 0xFAFAFA, "FF", nil},
+ {"84FAFAFAFAFF", 0xFAFAFAFA, "FF", nil},
+ {"85FAFAFAFAFAFF", 0xFAFAFAFAFA, "FF", nil},
+ {"86FAFAFAFAFAFAFF", 0xFAFAFAFAFAFA, "FF", nil},
+ {"87FAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFA, "FF", nil},
+ {"88FAFAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFAFA, "FF", nil},
+
+ // errors
+ {"", 0, "", io.ErrUnexpectedEOF},
+ {"00", 0, "00", ErrCanonInt},
+ {"81", 0, "81", ErrValueTooLarge},
+ {"8100", 0, "8100", ErrCanonSize},
+ {"8200FF", 0, "8200FF", ErrCanonInt},
+ {"8103FF", 0, "8103FF", ErrCanonSize},
+ {"89FAFAFAFAFAFAFAFAFAFF", 0, "89FAFAFAFAFAFAFAFAFAFF", errUintOverflow},
}
- if _, _, err := SplitList(unhex("81FF")); err != ErrExpectedList {
- t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList)
+
+ for i, test := range tests {
+ val, rest, err := SplitUint64(unhex(test.input))
+ if val != test.val {
+ t.Errorf("test %d: val mismatch: got %x, want %x (input %q)", i, val, test.val, test.input)
+ }
+ if !bytes.Equal(rest, unhex(test.rest)) {
+ t.Errorf("test %d: rest mismatch: got %x, want %s (input %q)", i, rest, test.rest, test.input)
+ }
+ if err != test.err {
+ t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
+ }
}
}
@@ -78,7 +142,9 @@ func TestSplit(t *testing.T) {
val, rest string
err error
}{
+ {input: "00FFFF", kind: Byte, val: "00", rest: "FFFF"},
{input: "01FFFF", kind: Byte, val: "01", rest: "FFFF"},
+ {input: "7FFFFF", kind: Byte, val: "7F", rest: "FFFF"},
{input: "80FFFF", kind: String, val: "", rest: "FFFF"},
{input: "C3010203", kind: List, val: "010203"},
@@ -194,3 +260,79 @@ func TestReadSize(t *testing.T) {
}
}
}
+
+func TestAppendUint64(t *testing.T) {
+ tests := []struct {
+ input uint64
+ slice []byte
+ output string
+ }{
+ {0, nil, "80"},
+ {1, nil, "01"},
+ {2, nil, "02"},
+ {127, nil, "7F"},
+ {128, nil, "8180"},
+ {129, nil, "8181"},
+ {0xFFFFFF, nil, "83FFFFFF"},
+ {127, []byte{1, 2, 3}, "0102037F"},
+ {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"},
+ }
+
+ for _, test := range tests {
+ x := AppendUint64(test.slice, test.input)
+ if !bytes.Equal(x, unhex(test.output)) {
+ t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output)
+ }
+
+ // Check that IntSize returns the appended size.
+ length := len(x) - len(test.slice)
+ if s := IntSize(test.input); s != length {
+ t.Errorf("IntSize(%d): got %d, want %d", test.input, s, length)
+ }
+ }
+}
+
+func TestAppendUint64Random(t *testing.T) {
+ fn := func(i uint64) bool {
+ enc, _ := EncodeToBytes(i)
+ encAppend := AppendUint64(nil, i)
+ return bytes.Equal(enc, encAppend)
+ }
+ config := quick.Config{MaxCountScale: 50}
+ if err := quick.Check(fn, &config); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestBytesSize(t *testing.T) {
+ tests := []struct {
+ v []byte
+ size uint64
+ }{
+ {v: []byte{}, size: 1},
+ {v: []byte{0x1}, size: 1},
+ {v: []byte{0x7E}, size: 1},
+ {v: []byte{0x7F}, size: 1},
+ {v: []byte{0x80}, size: 2},
+ {v: []byte{0xFF}, size: 2},
+ {v: []byte{0xFF, 0xF0}, size: 3},
+ {v: make([]byte, 55), size: 56},
+ {v: make([]byte, 56), size: 58},
+ }
+
+ for _, test := range tests {
+ s := BytesSize(test.v)
+ if s != test.size {
+ t.Errorf("BytesSize(%#x) -> %d, want %d", test.v, s, test.size)
+ }
+ s = StringSize(string(test.v))
+ if s != test.size {
+ t.Errorf("StringSize(%#x) -> %d, want %d", test.v, s, test.size)
+ }
+ // Sanity check:
+ enc, _ := EncodeToBytes(test.v)
+ if uint64(len(enc)) != test.size {
+ t.Errorf("len(EncodeToBytes(%#x)) -> %d, test says %d", test.v, len(enc), test.size)
+ }
+ }
+}
diff --git a/rlp/rlpgen/gen.go b/rlp/rlpgen/gen.go
new file mode 100644
index 000000000..26ccdc574
--- /dev/null
+++ b/rlp/rlpgen/gen.go
@@ -0,0 +1,800 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "go/format"
+ "go/types"
+ "sort"
+
+ "github.com/tomochain/tomochain/rlp/internal/rlpstruct"
+)
+
+// buildContext keeps the data needed for make*Op.
+type buildContext struct {
+ topType *types.Named // the type we're creating methods for
+
+ encoderIface *types.Interface
+ decoderIface *types.Interface
+ rawValueType *types.Named
+
+ typeToStructCache map[types.Type]*rlpstruct.Type
+}
+
+func newBuildContext(packageRLP *types.Package) *buildContext {
+ enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying()
+ dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying()
+ rawv := packageRLP.Scope().Lookup("RawValue").Type()
+ return &buildContext{
+ typeToStructCache: make(map[types.Type]*rlpstruct.Type),
+ encoderIface: enc.(*types.Interface),
+ decoderIface: dec.(*types.Interface),
+ rawValueType: rawv.(*types.Named),
+ }
+}
+
+func (bctx *buildContext) isEncoder(typ types.Type) bool {
+ return types.Implements(typ, bctx.encoderIface)
+}
+
+func (bctx *buildContext) isDecoder(typ types.Type) bool {
+ return types.Implements(typ, bctx.decoderIface)
+}
+
+// typeToStructType converts typ to rlpstruct.Type.
+func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
+ if prev := bctx.typeToStructCache[typ]; prev != nil {
+ return prev // short-circuit for recursive types.
+ }
+
+ // Resolve named types to their underlying type, but keep the name.
+ name := types.TypeString(typ, nil)
+ for {
+ utype := typ.Underlying()
+ if utype == typ {
+ break
+ }
+ typ = utype
+ }
+
+ // Create the type and store it in cache.
+ t := &rlpstruct.Type{
+ Name: name,
+ Kind: typeReflectKind(typ),
+ IsEncoder: bctx.isEncoder(typ),
+ IsDecoder: bctx.isDecoder(typ),
+ }
+ bctx.typeToStructCache[typ] = t
+
+ // Assign element type.
+ switch typ.(type) {
+ case *types.Array, *types.Slice, *types.Pointer:
+ etype := typ.(interface{ Elem() types.Type }).Elem()
+ t.Elem = bctx.typeToStructType(etype)
+ }
+ return t
+}
+
+// genContext is passed to the gen* methods of op when generating
+// the output code. It tracks packages to be imported by the output
+// file and assigns unique names of temporary variables.
+type genContext struct {
+ inPackage *types.Package
+ imports map[string]struct{}
+ tempCounter int
+}
+
+func newGenContext(inPackage *types.Package) *genContext {
+ return &genContext{
+ inPackage: inPackage,
+ imports: make(map[string]struct{}),
+ }
+}
+
+func (ctx *genContext) temp() string {
+ v := fmt.Sprintf("_tmp%d", ctx.tempCounter)
+ ctx.tempCounter++
+ return v
+}
+
+func (ctx *genContext) resetTemp() {
+ ctx.tempCounter = 0
+}
+
+func (ctx *genContext) addImport(path string) {
+ if path == ctx.inPackage.Path() {
+ return // avoid importing the package that we're generating in.
+ }
+ // TODO: renaming?
+ ctx.imports[path] = struct{}{}
+}
+
+// importsList returns all packages that need to be imported.
+func (ctx *genContext) importsList() []string {
+ imp := make([]string, 0, len(ctx.imports))
+ for k := range ctx.imports {
+ imp = append(imp, k)
+ }
+ sort.Strings(imp)
+ return imp
+}
+
+// qualify is the types.Qualifier used for printing types.
+func (ctx *genContext) qualify(pkg *types.Package) string {
+ if pkg.Path() == ctx.inPackage.Path() {
+ return ""
+ }
+ ctx.addImport(pkg.Path())
+ // TODO: renaming?
+ return pkg.Name()
+}
+
+type op interface {
+ // genWrite creates the encoder. The generated code should write v,
+ // which is any Go expression, to the rlp.EncoderBuffer 'w'.
+ genWrite(ctx *genContext, v string) string
+
+ // genDecode creates the decoder. The generated code should read
+ // a value from the rlp.Stream 'dec' and store it to dst.
+ genDecode(ctx *genContext) (string, string)
+}
+
+// basicOp handles basic types bool, uint*, string.
+type basicOp struct {
+ typ types.Type
+ writeMethod string // calle write the value
+ writeArgType types.Type // parameter type of writeMethod
+ decMethod string
+ decResultType types.Type // return type of decMethod
+ decUseBitSize bool // if true, result bit size is appended to decMethod
+}
+
+func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) {
+ op := basicOp{typ: typ}
+ kind := typ.Kind()
+ switch {
+ case kind == types.Bool:
+ op.writeMethod = "WriteBool"
+ op.writeArgType = types.Typ[types.Bool]
+ op.decMethod = "Bool"
+ op.decResultType = types.Typ[types.Bool]
+ case kind >= types.Uint8 && kind <= types.Uint64:
+ op.writeMethod = "WriteUint64"
+ op.writeArgType = types.Typ[types.Uint64]
+ op.decMethod = "Uint"
+ op.decResultType = typ
+ op.decUseBitSize = true
+ case kind == types.String:
+ op.writeMethod = "WriteString"
+ op.writeArgType = types.Typ[types.String]
+ op.decMethod = "String"
+ op.decResultType = types.Typ[types.String]
+ default:
+ return nil, fmt.Errorf("unhandled basic type: %v", typ)
+ }
+ return op, nil
+}
+
+func (*buildContext) makeByteSliceOp(typ *types.Slice) op {
+ if !isByte(typ.Elem()) {
+ panic("non-byte slice type in makeByteSliceOp")
+ }
+ bslice := types.NewSlice(types.Typ[types.Uint8])
+ return basicOp{
+ typ: typ,
+ writeMethod: "WriteBytes",
+ writeArgType: bslice,
+ decMethod: "Bytes",
+ decResultType: bslice,
+ }
+}
+
+func (bctx *buildContext) makeRawValueOp() op {
+ bslice := types.NewSlice(types.Typ[types.Uint8])
+ return basicOp{
+ typ: bctx.rawValueType,
+ writeMethod: "Write",
+ writeArgType: bslice,
+ decMethod: "Raw",
+ decResultType: bslice,
+ }
+}
+
+func (op basicOp) writeNeedsConversion() bool {
+ return !types.AssignableTo(op.typ, op.writeArgType)
+}
+
+func (op basicOp) decodeNeedsConversion() bool {
+ return !types.AssignableTo(op.decResultType, op.typ)
+}
+
+func (op basicOp) genWrite(ctx *genContext, v string) string {
+ if op.writeNeedsConversion() {
+ v = fmt.Sprintf("%s(%s)", op.writeArgType, v)
+ }
+ return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v)
+}
+
+func (op basicOp) genDecode(ctx *genContext) (string, string) {
+ var (
+ resultV = ctx.temp()
+ result = resultV
+ method = op.decMethod
+ )
+ if op.decUseBitSize {
+ // Note: For now, this only works for platform-independent integer
+ // sizes. makeBasicOp forbids the platform-dependent types.
+ var sizes types.StdSizes
+ method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8)
+ }
+
+ // Call the decoder method.
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method)
+ fmt.Fprintf(&b, "if err != nil { return err }\n")
+ if op.decodeNeedsConversion() {
+ conv := ctx.temp()
+ fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV)
+ result = conv
+ }
+ return result, b.String()
+}
+
+// byteArrayOp handles [...]byte.
+type byteArrayOp struct {
+ typ types.Type
+ name types.Type // name != typ for named byte array types (e.g. common.Address)
+}
+
+func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp {
+ nt := types.Type(name)
+ if name == nil {
+ nt = typ
+ }
+ return byteArrayOp{typ, nt}
+}
+
+func (op byteArrayOp) genWrite(ctx *genContext, v string) string {
+ return fmt.Sprintf("w.WriteBytes(%s[:])\n", v)
+}
+
+func (op byteArrayOp) genDecode(ctx *genContext) (string, string) {
+ var resultV = ctx.temp()
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify))
+ fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV)
+ return resultV, b.String()
+}
+
+// bigIntOp handles big.Int.
+// This exists because big.Int has it's own decoder operation on rlp.Stream,
+// but the decode method returns *big.Int, so it needs to be dereferenced.
+type bigIntOp struct {
+ pointer bool
+}
+
+func (op bigIntOp) genWrite(ctx *genContext, v string) string {
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v)
+ fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n")
+ fmt.Fprintf(&b, "}\n")
+ dst := v
+ if !op.pointer {
+ dst = "&" + v
+ }
+ fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst)
+
+ // Wrap with nil check.
+ if op.pointer {
+ code := b.String()
+ b.Reset()
+ fmt.Fprintf(&b, "if %s == nil {\n", v)
+ fmt.Fprintf(&b, " w.Write(rlp.EmptyString)")
+ fmt.Fprintf(&b, "} else {\n")
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, "}\n")
+ }
+
+ return b.String()
+}
+
+func (op bigIntOp) genDecode(ctx *genContext) (string, string) {
+ var resultV = ctx.temp()
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV)
+ fmt.Fprintf(&b, "if err != nil { return err }\n")
+
+ result := resultV
+ if !op.pointer {
+ result = "(*" + resultV + ")"
+ }
+ return result, b.String()
+}
+
+// uint256Op handles "github.com/holiman/uint256".Int
+type uint256Op struct {
+ pointer bool
+}
+
+func (op uint256Op) genWrite(ctx *genContext, v string) string {
+ var b bytes.Buffer
+
+ dst := v
+ if !op.pointer {
+ dst = "&" + v
+ }
+ fmt.Fprintf(&b, "w.WriteUint256(%s)\n", dst)
+
+ // Wrap with nil check.
+ if op.pointer {
+ code := b.String()
+ b.Reset()
+ fmt.Fprintf(&b, "if %s == nil {\n", v)
+ fmt.Fprintf(&b, " w.Write(rlp.EmptyString)")
+ fmt.Fprintf(&b, "} else {\n")
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, "}\n")
+ }
+
+ return b.String()
+}
+
+func (op uint256Op) genDecode(ctx *genContext) (string, string) {
+ ctx.addImport("github.com/holiman/uint256")
+
+ var b bytes.Buffer
+ resultV := ctx.temp()
+ fmt.Fprintf(&b, "var %s uint256.Int\n", resultV)
+ fmt.Fprintf(&b, "if err := dec.ReadUint256(&%s); err != nil { return err }\n", resultV)
+
+ result := resultV
+ if op.pointer {
+ result = "&" + resultV
+ }
+ return result, b.String()
+}
+
+// encoderDecoderOp handles rlp.Encoder and rlp.Decoder.
+// In order to be used with this, the type must implement both interfaces.
+// This restriction may be lifted in the future by creating separate ops for
+// encoding and decoding.
+type encoderDecoderOp struct {
+ typ types.Type
+}
+
+func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string {
+ return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v)
+}
+
+func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) {
+ // DecodeRLP must have pointer receiver, and this is verified in makeOp.
+ etyp := op.typ.(*types.Pointer).Elem()
+ var resultV = ctx.temp()
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify))
+ fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV)
+ return resultV, b.String()
+}
+
+// ptrOp handles pointer types.
+type ptrOp struct {
+ elemTyp types.Type
+ elem op
+ nilOK bool
+ nilValue rlpstruct.NilKind
+}
+
+func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) {
+ elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{})
+ if err != nil {
+ return nil, err
+ }
+ op := ptrOp{elemTyp: elemTyp, elem: elemOp}
+
+ // Determine nil value.
+ if tags.NilOK {
+ op.nilOK = true
+ op.nilValue = tags.NilKind
+ } else {
+ styp := bctx.typeToStructType(elemTyp)
+ op.nilValue = styp.DefaultNilValue()
+ }
+ return op, nil
+}
+
+func (op ptrOp) genWrite(ctx *genContext, v string) string {
+ // Note: in writer functions, accesses to v are read-only, i.e. v is any Go
+ // expression. To make all accesses work through the pointer, we substitute
+ // v with (*v). This is required for most accesses including `v`, `call(v)`,
+ // and `v[index]` on slices.
+ //
+ // For `v.field` and `v[:]` on arrays, the dereference operation is not required.
+ var vv string
+ _, isStruct := op.elem.(structOp)
+ _, isByteArray := op.elem.(byteArrayOp)
+ if isStruct || isByteArray {
+ vv = v
+ } else {
+ vv = fmt.Sprintf("(*%s)", v)
+ }
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "if %s == nil {\n", v)
+ fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue)
+ fmt.Fprintf(&b, "} else {\n")
+ fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv))
+ fmt.Fprintf(&b, "}\n")
+ return b.String()
+}
+
+func (op ptrOp) genDecode(ctx *genContext) (string, string) {
+ result, code := op.elem.genDecode(ctx)
+ if !op.nilOK {
+ // If nil pointers are not allowed, we can just decode the element.
+ return "&" + result, code
+ }
+
+ // nil is allowed, so check the kind and size first.
+ // If size is zero and kind matches the nilKind of the type,
+ // the value decodes as a nil pointer.
+ var (
+ resultV = ctx.temp()
+ kindV = ctx.temp()
+ sizeV = ctx.temp()
+ wantKind string
+ )
+ if op.nilValue == rlpstruct.NilKindList {
+ wantKind = "rlp.List"
+ } else {
+ wantKind = "rlp.String"
+ }
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify))
+ fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV)
+ fmt.Fprintf(&b, " return err\n")
+ fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind)
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, " %s = &%s\n", resultV, result)
+ fmt.Fprintf(&b, "}\n")
+ return resultV, b.String()
+}
+
+// structOp handles struct types.
+type structOp struct {
+ named *types.Named
+ typ *types.Struct
+ fields []*structField
+ optionalFields []*structField
+}
+
+type structField struct {
+ name string
+ typ types.Type
+ elem op
+}
+
+func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) {
+ // Convert fields to []rlpstruct.Field.
+ var allStructFields []rlpstruct.Field
+ for i := 0; i < typ.NumFields(); i++ {
+ f := typ.Field(i)
+ allStructFields = append(allStructFields, rlpstruct.Field{
+ Name: f.Name(),
+ Exported: f.Exported(),
+ Index: i,
+ Tag: typ.Tag(i),
+ Type: *bctx.typeToStructType(f.Type()),
+ })
+ }
+
+ // Filter/validate fields.
+ fields, tags, err := rlpstruct.ProcessFields(allStructFields)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create field ops.
+ var op = structOp{named: named, typ: typ}
+ for i, field := range fields {
+ // Advanced struct tags are not supported yet.
+ tag := tags[i]
+ if err := checkUnsupportedTags(field.Name, tag); err != nil {
+ return nil, err
+ }
+ typ := typ.Field(field.Index).Type()
+ elem, err := bctx.makeOp(nil, typ, tags[i])
+ if err != nil {
+ return nil, fmt.Errorf("field %s: %v", field.Name, err)
+ }
+ f := &structField{name: field.Name, typ: typ, elem: elem}
+ if tag.Optional {
+ op.optionalFields = append(op.optionalFields, f)
+ } else {
+ op.fields = append(op.fields, f)
+ }
+ }
+ return op, nil
+}
+
+func checkUnsupportedTags(field string, tag rlpstruct.Tags) error {
+ if tag.Tail {
+ return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field)
+ }
+ return nil
+}
+
+func (op structOp) genWrite(ctx *genContext, v string) string {
+ var b bytes.Buffer
+ var listMarker = ctx.temp()
+ fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
+ for _, field := range op.fields {
+ selector := v + "." + field.name
+ fmt.Fprint(&b, field.elem.genWrite(ctx, selector))
+ }
+ op.writeOptionalFields(&b, ctx, v)
+ fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
+ return b.String()
+}
+
+func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) {
+ if len(op.optionalFields) == 0 {
+ return
+ }
+ // First check zero-ness of all optional fields.
+ var zeroV = make([]string, len(op.optionalFields))
+ for i, field := range op.optionalFields {
+ selector := v + "." + field.name
+ zeroV[i] = ctx.temp()
+ fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify))
+ }
+ // Now write the fields.
+ for i, field := range op.optionalFields {
+ selector := v + "." + field.name
+ cond := ""
+ for j := i; j < len(op.optionalFields); j++ {
+ if j > i {
+ cond += " || "
+ }
+ cond += zeroV[j]
+ }
+ fmt.Fprintf(b, "if %s {\n", cond)
+ fmt.Fprint(b, field.elem.genWrite(ctx, selector))
+ fmt.Fprintf(b, "}\n")
+ }
+}
+
+func (op structOp) genDecode(ctx *genContext) (string, string) {
+ // Get the string representation of the type.
+ // Here, named types are handled separately because the output
+ // would contain a copy of the struct definition otherwise.
+ var typeName string
+ if op.named != nil {
+ typeName = types.TypeString(op.named, ctx.qualify)
+ } else {
+ typeName = types.TypeString(op.typ, ctx.qualify)
+ }
+
+ // Create struct object.
+ var resultV = ctx.temp()
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "var %s %s\n", resultV, typeName)
+
+ // Decode fields.
+ fmt.Fprintf(&b, "{\n")
+ fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
+ for _, field := range op.fields {
+ result, code := field.elem.genDecode(ctx)
+ fmt.Fprintf(&b, "// %s:\n", field.name)
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result)
+ }
+ op.decodeOptionalFields(&b, ctx, resultV)
+ fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
+ fmt.Fprintf(&b, "}\n")
+ return resultV, b.String()
+}
+
+func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) {
+ var suffix bytes.Buffer
+ for _, field := range op.optionalFields {
+ result, code := field.elem.genDecode(ctx)
+ fmt.Fprintf(b, "// %s:\n", field.name)
+ fmt.Fprintf(b, "if dec.MoreDataInList() {\n")
+ fmt.Fprint(b, code)
+ fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result)
+ fmt.Fprintf(&suffix, "}\n")
+ }
+ suffix.WriteTo(b)
+}
+
+// sliceOp handles slice types.
+type sliceOp struct {
+ typ *types.Slice
+ elemOp op
+}
+
+func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) {
+ elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{})
+ if err != nil {
+ return nil, err
+ }
+ return sliceOp{typ: typ, elemOp: elemOp}, nil
+}
+
+func (op sliceOp) genWrite(ctx *genContext, v string) string {
+ var (
+ listMarker = ctx.temp() // holds return value of w.List()
+ iterElemV = ctx.temp() // iteration variable
+ elemCode = op.elemOp.genWrite(ctx, iterElemV)
+ )
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
+ fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v)
+ fmt.Fprint(&b, elemCode)
+ fmt.Fprintf(&b, "}\n")
+ fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
+ return b.String()
+}
+
+func (op sliceOp) genDecode(ctx *genContext) (string, string) {
+ var sliceV = ctx.temp() // holds the output slice
+ elemResult, elemCode := op.elemOp.genDecode(ctx)
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify))
+ fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
+ fmt.Fprintf(&b, "for dec.MoreDataInList() {\n")
+ fmt.Fprintf(&b, " %s", elemCode)
+ fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult)
+ fmt.Fprintf(&b, "}\n")
+ fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
+ return sliceV, b.String()
+}
+
+func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) {
+ switch typ := typ.(type) {
+ case *types.Named:
+ if isBigInt(typ) {
+ return bigIntOp{}, nil
+ }
+ if isUint256(typ) {
+ return uint256Op{}, nil
+ }
+ if typ == bctx.rawValueType {
+ return bctx.makeRawValueOp(), nil
+ }
+ if bctx.isDecoder(typ) {
+ return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ)
+ }
+ // TODO: same check for encoder?
+ return bctx.makeOp(typ, typ.Underlying(), tags)
+ case *types.Pointer:
+ if isBigInt(typ.Elem()) {
+ return bigIntOp{pointer: true}, nil
+ }
+ if isUint256(typ.Elem()) {
+ return uint256Op{pointer: true}, nil
+ }
+ // Encoder/Decoder interfaces.
+ if bctx.isEncoder(typ) {
+ if bctx.isDecoder(typ) {
+ return encoderDecoderOp{typ}, nil
+ }
+ return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ)
+ }
+ if bctx.isDecoder(typ) {
+ return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ)
+ }
+ // Default pointer handling.
+ return bctx.makePtrOp(typ.Elem(), tags)
+ case *types.Basic:
+ return bctx.makeBasicOp(typ)
+ case *types.Struct:
+ return bctx.makeStructOp(name, typ)
+ case *types.Slice:
+ etyp := typ.Elem()
+ if isByte(etyp) && !bctx.isEncoder(etyp) {
+ return bctx.makeByteSliceOp(typ), nil
+ }
+ return bctx.makeSliceOp(typ)
+ case *types.Array:
+ etyp := typ.Elem()
+ if isByte(etyp) && !bctx.isEncoder(etyp) {
+ return bctx.makeByteArrayOp(name, typ), nil
+ }
+ return nil, fmt.Errorf("unhandled array type: %v", typ)
+ default:
+ return nil, fmt.Errorf("unhandled type: %v", typ)
+ }
+}
+
+// generateDecoder generates the DecodeRLP method on 'typ'.
+func generateDecoder(ctx *genContext, typ string, op op) []byte {
+ ctx.resetTemp()
+ ctx.addImport(pathOfPackageRLP)
+
+ result, code := op.genDecode(ctx)
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ)
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, " *obj = %s\n", result)
+ fmt.Fprintf(&b, " return nil\n")
+ fmt.Fprintf(&b, "}\n")
+ return b.Bytes()
+}
+
+// generateEncoder generates the EncodeRLP method on 'typ'.
+func generateEncoder(ctx *genContext, typ string, op op) []byte {
+ ctx.resetTemp()
+ ctx.addImport("io")
+ ctx.addImport(pathOfPackageRLP)
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
+ fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n")
+ fmt.Fprint(&b, op.genWrite(ctx, "obj"))
+ fmt.Fprintf(&b, " return w.Flush()\n")
+ fmt.Fprintf(&b, "}\n")
+ return b.Bytes()
+}
+
+func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) {
+ bctx.topType = typ
+
+ pkg := typ.Obj().Pkg()
+ op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{})
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ ctx = newGenContext(pkg)
+ encSource []byte
+ decSource []byte
+ )
+ if encoder {
+ encSource = generateEncoder(ctx, typ.Obj().Name(), op)
+ }
+ if decoder {
+ decSource = generateDecoder(ctx, typ.Obj().Name(), op)
+ }
+
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
+ for _, imp := range ctx.importsList() {
+ fmt.Fprintf(&b, "import %q\n", imp)
+ }
+ if encoder {
+ fmt.Fprintln(&b)
+ b.Write(encSource)
+ }
+ if decoder {
+ fmt.Fprintln(&b)
+ b.Write(decSource)
+ }
+
+ source := b.Bytes()
+ // fmt.Println(string(source))
+ return format.Source(source)
+}
diff --git a/rlp/rlpgen/gen_test.go b/rlp/rlpgen/gen_test.go
new file mode 100644
index 000000000..3b4f5df28
--- /dev/null
+++ b/rlp/rlpgen/gen_test.go
@@ -0,0 +1,107 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/importer"
+ "go/parser"
+ "go/token"
+ "go/types"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+// Package RLP is loaded only once and reused for all tests.
+var (
+ testFset = token.NewFileSet()
+ testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom)
+ testPackageRLP *types.Package
+)
+
+func init() {
+ cwd, err := os.Getwd()
+ if err != nil {
+ panic(err)
+ }
+ testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0)
+ if err != nil {
+ panic(fmt.Errorf("can't load package RLP: %v", err))
+ }
+}
+
+var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"}
+
+func TestOutput(t *testing.T) {
+ for _, test := range tests {
+ test := test
+ t.Run(test, func(t *testing.T) {
+ inputFile := filepath.Join("testdata", test+".in.txt")
+ outputFile := filepath.Join("testdata", test+".out.txt")
+ bctx, typ, err := loadTestSource(inputFile, "Test")
+ if err != nil {
+ t.Fatal("error loading test source:", err)
+ }
+ output, err := bctx.generate(typ, true, true)
+ if err != nil {
+ t.Fatal("error in generate:", err)
+ }
+
+ // Set this environment variable to regenerate the test outputs.
+ if os.Getenv("WRITE_TEST_FILES") != "" {
+ os.WriteFile(outputFile, output, 0644)
+ }
+
+ // Check if output matches.
+ wantOutput, err := os.ReadFile(outputFile)
+ if err != nil {
+ t.Fatal("error loading expected test output:", err)
+ }
+ if !bytes.Equal(output, wantOutput) {
+ t.Fatalf("output mismatch, want: %v got %v", string(wantOutput), string(output))
+ }
+ })
+ }
+}
+
+func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) {
+ // Load the test input.
+ content, err := os.ReadFile(file)
+ if err != nil {
+ return nil, nil, err
+ }
+ f, err := parser.ParseFile(testFset, file, content, 0)
+ if err != nil {
+ return nil, nil, err
+ }
+ conf := types.Config{Importer: testImporter}
+ pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Find the test struct.
+ bctx := newBuildContext(testPackageRLP)
+ typ, err := lookupStructType(pkg.Scope(), typeName)
+ if err != nil {
+ return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err)
+ }
+ return bctx, typ, nil
+}
diff --git a/rlp/rlpgen/main.go b/rlp/rlpgen/main.go
new file mode 100644
index 000000000..87aebbc47
--- /dev/null
+++ b/rlp/rlpgen/main.go
@@ -0,0 +1,147 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "bytes"
+ "errors"
+ "flag"
+ "fmt"
+ "go/types"
+ "os"
+
+ "golang.org/x/tools/go/packages"
+)
+
+const pathOfPackageRLP = "github.com/tomochain/tomochain/rlp"
+
+func main() {
+ var (
+ pkgdir = flag.String("dir", ".", "input package")
+ output = flag.String("out", "-", "output file (default is stdout)")
+ genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?")
+ genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?")
+ typename = flag.String("type", "", "type to generate methods for")
+ )
+ flag.Parse()
+
+ cfg := Config{
+ Dir: *pkgdir,
+ Type: *typename,
+ GenerateEncoder: *genEncoder,
+ GenerateDecoder: *genDecoder,
+ }
+ code, err := cfg.process()
+ if err != nil {
+ fatal(err)
+ }
+ if *output == "-" {
+ os.Stdout.Write(code)
+ } else if err := os.WriteFile(*output, code, 0600); err != nil {
+ fatal(err)
+ }
+}
+
+func fatal(args ...interface{}) {
+ fmt.Fprintln(os.Stderr, args...)
+ os.Exit(1)
+}
+
+type Config struct {
+ Dir string // input package directory
+ Type string
+
+ GenerateEncoder bool
+ GenerateDecoder bool
+}
+
+// process generates the Go code.
+func (cfg *Config) process() (code []byte, err error) {
+ // Load packages.
+ pcfg := &packages.Config{
+ Mode: packages.NeedName | packages.NeedTypes | packages.NeedImports | packages.NeedDeps,
+ Dir: cfg.Dir,
+ BuildFlags: []string{"-tags", "norlpgen"},
+ }
+ ps, err := packages.Load(pcfg, pathOfPackageRLP, ".")
+ if err != nil {
+ return nil, err
+ }
+ if len(ps) == 0 {
+ return nil, fmt.Errorf("no Go package found in %s", cfg.Dir)
+ }
+ packages.PrintErrors(ps)
+
+ // Find the packages that were loaded.
+ var (
+ pkg *types.Package
+ packageRLP *types.Package
+ )
+ for _, p := range ps {
+ if len(p.Errors) > 0 {
+ return nil, fmt.Errorf("package %s has errors", p.PkgPath)
+ }
+ if p.PkgPath == pathOfPackageRLP {
+ packageRLP = p.Types
+ } else {
+ pkg = p.Types
+ }
+ }
+ bctx := newBuildContext(packageRLP)
+
+ // Find the type and generate.
+ typ, err := lookupStructType(pkg.Scope(), cfg.Type)
+ if err != nil {
+ return nil, fmt.Errorf("can't find %s in %s: %v", cfg.Type, pkg, err)
+ }
+ code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder)
+ if err != nil {
+ return nil, err
+ }
+
+ // Add build comments.
+ // This is done here to avoid processing these lines with gofmt.
+ var header bytes.Buffer
+ fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n")
+ fmt.Fprint(&header, "//go:build !norlpgen\n")
+ fmt.Fprint(&header, "// +build !norlpgen\n\n")
+ return append(header.Bytes(), code...), nil
+}
+
+func lookupStructType(scope *types.Scope, name string) (*types.Named, error) {
+ typ, err := lookupType(scope, name)
+ if err != nil {
+ return nil, err
+ }
+ _, ok := typ.Underlying().(*types.Struct)
+ if !ok {
+ return nil, errors.New("not a struct type")
+ }
+ return typ, nil
+}
+
+func lookupType(scope *types.Scope, name string) (*types.Named, error) {
+ obj := scope.Lookup(name)
+ if obj == nil {
+ return nil, errors.New("no such identifier")
+ }
+ typ, ok := obj.(*types.TypeName)
+ if !ok {
+ return nil, errors.New("not a type")
+ }
+ return typ.Type().(*types.Named), nil
+}
diff --git a/rlp/rlpgen/testdata/bigint.in.txt b/rlp/rlpgen/testdata/bigint.in.txt
new file mode 100644
index 000000000..d23d84a28
--- /dev/null
+++ b/rlp/rlpgen/testdata/bigint.in.txt
@@ -0,0 +1,10 @@
+// -*- mode: go -*-
+
+package test
+
+import "math/big"
+
+type Test struct {
+ Int *big.Int
+ IntNoPtr big.Int
+}
diff --git a/rlp/rlpgen/testdata/bigint.out.txt b/rlp/rlpgen/testdata/bigint.out.txt
new file mode 100644
index 000000000..6dc7bea3b
--- /dev/null
+++ b/rlp/rlpgen/testdata/bigint.out.txt
@@ -0,0 +1,49 @@
+package test
+
+import "github.com/tomochain/tomochain/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ if obj.Int == nil {
+ w.Write(rlp.EmptyString)
+ } else {
+ if obj.Int.Sign() == -1 {
+ return rlp.ErrNegativeBigInt
+ }
+ w.WriteBigInt(obj.Int)
+ }
+ if obj.IntNoPtr.Sign() == -1 {
+ return rlp.ErrNegativeBigInt
+ }
+ w.WriteBigInt(&obj.IntNoPtr)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // Int:
+ _tmp1, err := dec.BigInt()
+ if err != nil {
+ return err
+ }
+ _tmp0.Int = _tmp1
+ // IntNoPtr:
+ _tmp2, err := dec.BigInt()
+ if err != nil {
+ return err
+ }
+ _tmp0.IntNoPtr = (*_tmp2)
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/rlp/rlpgen/testdata/nil.in.txt b/rlp/rlpgen/testdata/nil.in.txt
new file mode 100644
index 000000000..a28ff3448
--- /dev/null
+++ b/rlp/rlpgen/testdata/nil.in.txt
@@ -0,0 +1,30 @@
+// -*- mode: go -*-
+
+package test
+
+type Aux struct{
+ A uint32
+}
+
+type Test struct{
+ Uint8 *byte `rlp:"nil"`
+ Uint8List *byte `rlp:"nilList"`
+
+ Uint32 *uint32 `rlp:"nil"`
+ Uint32List *uint32 `rlp:"nilList"`
+
+ Uint64 *uint64 `rlp:"nil"`
+ Uint64List *uint64 `rlp:"nilList"`
+
+ String *string `rlp:"nil"`
+ StringList *string `rlp:"nilList"`
+
+ ByteArray *[3]byte `rlp:"nil"`
+ ByteArrayList *[3]byte `rlp:"nilList"`
+
+ ByteSlice *[]byte `rlp:"nil"`
+ ByteSliceList *[]byte `rlp:"nilList"`
+
+ Struct *Aux `rlp:"nil"`
+ StructString *Aux `rlp:"nilString"`
+}
diff --git a/rlp/rlpgen/testdata/nil.out.txt b/rlp/rlpgen/testdata/nil.out.txt
new file mode 100644
index 000000000..b3bdd0b86
--- /dev/null
+++ b/rlp/rlpgen/testdata/nil.out.txt
@@ -0,0 +1,289 @@
+package test
+
+import "github.com/tomochain/tomochain/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ if obj.Uint8 == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint8)))
+ }
+ if obj.Uint8List == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint8List)))
+ }
+ if obj.Uint32 == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint32)))
+ }
+ if obj.Uint32List == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint32List)))
+ }
+ if obj.Uint64 == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64((*obj.Uint64))
+ }
+ if obj.Uint64List == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteUint64((*obj.Uint64List))
+ }
+ if obj.String == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteString((*obj.String))
+ }
+ if obj.StringList == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteString((*obj.StringList))
+ }
+ if obj.ByteArray == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteBytes(obj.ByteArray[:])
+ }
+ if obj.ByteArrayList == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteBytes(obj.ByteArrayList[:])
+ }
+ if obj.ByteSlice == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteBytes((*obj.ByteSlice))
+ }
+ if obj.ByteSliceList == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteBytes((*obj.ByteSliceList))
+ }
+ if obj.Struct == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ _tmp1 := w.List()
+ w.WriteUint64(uint64(obj.Struct.A))
+ w.ListEnd(_tmp1)
+ }
+ if obj.StructString == nil {
+ w.Write([]byte{0x80})
+ } else {
+ _tmp2 := w.List()
+ w.WriteUint64(uint64(obj.StructString.A))
+ w.ListEnd(_tmp2)
+ }
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // Uint8:
+ var _tmp2 *byte
+ if _tmp3, _tmp4, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp4 != 0 || _tmp3 != rlp.String {
+ _tmp1, err := dec.Uint8()
+ if err != nil {
+ return err
+ }
+ _tmp2 = &_tmp1
+ }
+ _tmp0.Uint8 = _tmp2
+ // Uint8List:
+ var _tmp6 *byte
+ if _tmp7, _tmp8, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp8 != 0 || _tmp7 != rlp.List {
+ _tmp5, err := dec.Uint8()
+ if err != nil {
+ return err
+ }
+ _tmp6 = &_tmp5
+ }
+ _tmp0.Uint8List = _tmp6
+ // Uint32:
+ var _tmp10 *uint32
+ if _tmp11, _tmp12, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp12 != 0 || _tmp11 != rlp.String {
+ _tmp9, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp10 = &_tmp9
+ }
+ _tmp0.Uint32 = _tmp10
+ // Uint32List:
+ var _tmp14 *uint32
+ if _tmp15, _tmp16, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp16 != 0 || _tmp15 != rlp.List {
+ _tmp13, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp14 = &_tmp13
+ }
+ _tmp0.Uint32List = _tmp14
+ // Uint64:
+ var _tmp18 *uint64
+ if _tmp19, _tmp20, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp20 != 0 || _tmp19 != rlp.String {
+ _tmp17, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp18 = &_tmp17
+ }
+ _tmp0.Uint64 = _tmp18
+ // Uint64List:
+ var _tmp22 *uint64
+ if _tmp23, _tmp24, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp24 != 0 || _tmp23 != rlp.List {
+ _tmp21, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp22 = &_tmp21
+ }
+ _tmp0.Uint64List = _tmp22
+ // String:
+ var _tmp26 *string
+ if _tmp27, _tmp28, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp28 != 0 || _tmp27 != rlp.String {
+ _tmp25, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp26 = &_tmp25
+ }
+ _tmp0.String = _tmp26
+ // StringList:
+ var _tmp30 *string
+ if _tmp31, _tmp32, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp32 != 0 || _tmp31 != rlp.List {
+ _tmp29, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp30 = &_tmp29
+ }
+ _tmp0.StringList = _tmp30
+ // ByteArray:
+ var _tmp34 *[3]byte
+ if _tmp35, _tmp36, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp36 != 0 || _tmp35 != rlp.String {
+ var _tmp33 [3]byte
+ if err := dec.ReadBytes(_tmp33[:]); err != nil {
+ return err
+ }
+ _tmp34 = &_tmp33
+ }
+ _tmp0.ByteArray = _tmp34
+ // ByteArrayList:
+ var _tmp38 *[3]byte
+ if _tmp39, _tmp40, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp40 != 0 || _tmp39 != rlp.List {
+ var _tmp37 [3]byte
+ if err := dec.ReadBytes(_tmp37[:]); err != nil {
+ return err
+ }
+ _tmp38 = &_tmp37
+ }
+ _tmp0.ByteArrayList = _tmp38
+ // ByteSlice:
+ var _tmp42 *[]byte
+ if _tmp43, _tmp44, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp44 != 0 || _tmp43 != rlp.String {
+ _tmp41, err := dec.Bytes()
+ if err != nil {
+ return err
+ }
+ _tmp42 = &_tmp41
+ }
+ _tmp0.ByteSlice = _tmp42
+ // ByteSliceList:
+ var _tmp46 *[]byte
+ if _tmp47, _tmp48, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp48 != 0 || _tmp47 != rlp.List {
+ _tmp45, err := dec.Bytes()
+ if err != nil {
+ return err
+ }
+ _tmp46 = &_tmp45
+ }
+ _tmp0.ByteSliceList = _tmp46
+ // Struct:
+ var _tmp51 *Aux
+ if _tmp52, _tmp53, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp53 != 0 || _tmp52 != rlp.List {
+ var _tmp49 Aux
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp50, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp49.A = _tmp50
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp51 = &_tmp49
+ }
+ _tmp0.Struct = _tmp51
+ // StructString:
+ var _tmp56 *Aux
+ if _tmp57, _tmp58, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp58 != 0 || _tmp57 != rlp.String {
+ var _tmp54 Aux
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp55, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp54.A = _tmp55
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp56 = &_tmp54
+ }
+ _tmp0.StructString = _tmp56
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/rlp/rlpgen/testdata/optional.in.txt b/rlp/rlpgen/testdata/optional.in.txt
new file mode 100644
index 000000000..f1ac9f789
--- /dev/null
+++ b/rlp/rlpgen/testdata/optional.in.txt
@@ -0,0 +1,17 @@
+// -*- mode: go -*-
+
+package test
+
+type Aux struct {
+ A uint64
+}
+
+type Test struct {
+ Uint64 uint64 `rlp:"optional"`
+ Pointer *uint64 `rlp:"optional"`
+ String string `rlp:"optional"`
+ Slice []uint64 `rlp:"optional"`
+ Array [3]byte `rlp:"optional"`
+ NamedStruct Aux `rlp:"optional"`
+ AnonStruct struct{ A string } `rlp:"optional"`
+}
diff --git a/rlp/rlpgen/testdata/optional.out.txt b/rlp/rlpgen/testdata/optional.out.txt
new file mode 100644
index 000000000..fb9b95d44
--- /dev/null
+++ b/rlp/rlpgen/testdata/optional.out.txt
@@ -0,0 +1,153 @@
+package test
+
+import "github.com/tomochain/tomochain/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ _tmp1 := obj.Uint64 != 0
+ _tmp2 := obj.Pointer != nil
+ _tmp3 := obj.String != ""
+ _tmp4 := len(obj.Slice) > 0
+ _tmp5 := obj.Array != ([3]byte{})
+ _tmp6 := obj.NamedStruct != (Aux{})
+ _tmp7 := obj.AnonStruct != (struct{ A string }{})
+ if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ w.WriteUint64(obj.Uint64)
+ }
+ if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ if obj.Pointer == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64((*obj.Pointer))
+ }
+ }
+ if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ w.WriteString(obj.String)
+ }
+ if _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ _tmp8 := w.List()
+ for _, _tmp9 := range obj.Slice {
+ w.WriteUint64(_tmp9)
+ }
+ w.ListEnd(_tmp8)
+ }
+ if _tmp5 || _tmp6 || _tmp7 {
+ w.WriteBytes(obj.Array[:])
+ }
+ if _tmp6 || _tmp7 {
+ _tmp10 := w.List()
+ w.WriteUint64(obj.NamedStruct.A)
+ w.ListEnd(_tmp10)
+ }
+ if _tmp7 {
+ _tmp11 := w.List()
+ w.WriteString(obj.AnonStruct.A)
+ w.ListEnd(_tmp11)
+ }
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // Uint64:
+ if dec.MoreDataInList() {
+ _tmp1, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp0.Uint64 = _tmp1
+ // Pointer:
+ if dec.MoreDataInList() {
+ _tmp2, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp0.Pointer = &_tmp2
+ // String:
+ if dec.MoreDataInList() {
+ _tmp3, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp0.String = _tmp3
+ // Slice:
+ if dec.MoreDataInList() {
+ var _tmp4 []uint64
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ for dec.MoreDataInList() {
+ _tmp5, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp4 = append(_tmp4, _tmp5)
+ }
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ _tmp0.Slice = _tmp4
+ // Array:
+ if dec.MoreDataInList() {
+ var _tmp6 [3]byte
+ if err := dec.ReadBytes(_tmp6[:]); err != nil {
+ return err
+ }
+ _tmp0.Array = _tmp6
+ // NamedStruct:
+ if dec.MoreDataInList() {
+ var _tmp7 Aux
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp8, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp7.A = _tmp8
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp0.NamedStruct = _tmp7
+ // AnonStruct:
+ if dec.MoreDataInList() {
+ var _tmp9 struct{ A string }
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp10, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp9.A = _tmp10
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp0.AnonStruct = _tmp9
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/rlp/rlpgen/testdata/rawvalue.in.txt b/rlp/rlpgen/testdata/rawvalue.in.txt
new file mode 100644
index 000000000..6c1784995
--- /dev/null
+++ b/rlp/rlpgen/testdata/rawvalue.in.txt
@@ -0,0 +1,11 @@
+// -*- mode: go -*-
+
+package test
+
+import "github.com/tomochain/tomochain/rlp"
+
+type Test struct {
+ RawValue rlp.RawValue
+ PointerToRawValue *rlp.RawValue
+ SliceOfRawValue []rlp.RawValue
+}
diff --git a/rlp/rlpgen/testdata/rawvalue.out.txt b/rlp/rlpgen/testdata/rawvalue.out.txt
new file mode 100644
index 000000000..4b6eb385d
--- /dev/null
+++ b/rlp/rlpgen/testdata/rawvalue.out.txt
@@ -0,0 +1,64 @@
+package test
+
+import "github.com/tomochain/tomochain/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ w.Write(obj.RawValue)
+ if obj.PointerToRawValue == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.Write((*obj.PointerToRawValue))
+ }
+ _tmp1 := w.List()
+ for _, _tmp2 := range obj.SliceOfRawValue {
+ w.Write(_tmp2)
+ }
+ w.ListEnd(_tmp1)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // RawValue:
+ _tmp1, err := dec.Raw()
+ if err != nil {
+ return err
+ }
+ _tmp0.RawValue = _tmp1
+ // PointerToRawValue:
+ _tmp2, err := dec.Raw()
+ if err != nil {
+ return err
+ }
+ _tmp0.PointerToRawValue = &_tmp2
+ // SliceOfRawValue:
+ var _tmp3 []rlp.RawValue
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ for dec.MoreDataInList() {
+ _tmp4, err := dec.Raw()
+ if err != nil {
+ return err
+ }
+ _tmp3 = append(_tmp3, _tmp4)
+ }
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ _tmp0.SliceOfRawValue = _tmp3
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/rlp/rlpgen/testdata/uint256.in.txt b/rlp/rlpgen/testdata/uint256.in.txt
new file mode 100644
index 000000000..ed16e0a78
--- /dev/null
+++ b/rlp/rlpgen/testdata/uint256.in.txt
@@ -0,0 +1,10 @@
+// -*- mode: go -*-
+
+package test
+
+import "github.com/holiman/uint256"
+
+type Test struct {
+ Int *uint256.Int
+ IntNoPtr uint256.Int
+}
diff --git a/rlp/rlpgen/testdata/uint256.out.txt b/rlp/rlpgen/testdata/uint256.out.txt
new file mode 100644
index 000000000..5d99ca2e6
--- /dev/null
+++ b/rlp/rlpgen/testdata/uint256.out.txt
@@ -0,0 +1,44 @@
+package test
+
+import "github.com/holiman/uint256"
+import "github.com/tomochain/tomochain/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ if obj.Int == nil {
+ w.Write(rlp.EmptyString)
+ } else {
+ w.WriteUint256(obj.Int)
+ }
+ w.WriteUint256(&obj.IntNoPtr)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // Int:
+ var _tmp1 uint256.Int
+ if err := dec.ReadUint256(&_tmp1); err != nil {
+ return err
+ }
+ _tmp0.Int = &_tmp1
+ // IntNoPtr:
+ var _tmp2 uint256.Int
+ if err := dec.ReadUint256(&_tmp2); err != nil {
+ return err
+ }
+ _tmp0.IntNoPtr = _tmp2
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/rlp/rlpgen/testdata/uints.in.txt b/rlp/rlpgen/testdata/uints.in.txt
new file mode 100644
index 000000000..8095da997
--- /dev/null
+++ b/rlp/rlpgen/testdata/uints.in.txt
@@ -0,0 +1,10 @@
+// -*- mode: go -*-
+
+package test
+
+type Test struct{
+ A uint8
+ B uint16
+ C uint32
+ D uint64
+}
diff --git a/rlp/rlpgen/testdata/uints.out.txt b/rlp/rlpgen/testdata/uints.out.txt
new file mode 100644
index 000000000..17896dd30
--- /dev/null
+++ b/rlp/rlpgen/testdata/uints.out.txt
@@ -0,0 +1,53 @@
+package test
+
+import "github.com/tomochain/tomochain/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ w.WriteUint64(uint64(obj.A))
+ w.WriteUint64(uint64(obj.B))
+ w.WriteUint64(uint64(obj.C))
+ w.WriteUint64(obj.D)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp1, err := dec.Uint8()
+ if err != nil {
+ return err
+ }
+ _tmp0.A = _tmp1
+ // B:
+ _tmp2, err := dec.Uint16()
+ if err != nil {
+ return err
+ }
+ _tmp0.B = _tmp2
+ // C:
+ _tmp3, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp0.C = _tmp3
+ // D:
+ _tmp4, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp0.D = _tmp4
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/rlp/rlpgen/types.go b/rlp/rlpgen/types.go
new file mode 100644
index 000000000..ea7dc96d8
--- /dev/null
+++ b/rlp/rlpgen/types.go
@@ -0,0 +1,124 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "fmt"
+ "go/types"
+ "reflect"
+)
+
+// typeReflectKind gives the reflect.Kind that represents typ.
+func typeReflectKind(typ types.Type) reflect.Kind {
+ switch typ := typ.(type) {
+ case *types.Basic:
+ k := typ.Kind()
+ if k >= types.Bool && k <= types.Complex128 {
+ // value order matches for Bool..Complex128
+ return reflect.Bool + reflect.Kind(k-types.Bool)
+ }
+ if k == types.String {
+ return reflect.String
+ }
+ if k == types.UnsafePointer {
+ return reflect.UnsafePointer
+ }
+ panic(fmt.Errorf("unhandled BasicKind %v", k))
+ case *types.Array:
+ return reflect.Array
+ case *types.Chan:
+ return reflect.Chan
+ case *types.Interface:
+ return reflect.Interface
+ case *types.Map:
+ return reflect.Map
+ case *types.Pointer:
+ return reflect.Ptr
+ case *types.Signature:
+ return reflect.Func
+ case *types.Slice:
+ return reflect.Slice
+ case *types.Struct:
+ return reflect.Struct
+ default:
+ panic(fmt.Errorf("unhandled type %T", typ))
+ }
+}
+
+// nonZeroCheck returns the expression that checks whether 'v' is a non-zero value of type 'vtyp'.
+func nonZeroCheck(v string, vtyp types.Type, qualify types.Qualifier) string {
+ // Resolve type name.
+ typ := resolveUnderlying(vtyp)
+ switch typ := typ.(type) {
+ case *types.Basic:
+ k := typ.Kind()
+ switch {
+ case k == types.Bool:
+ return v
+ case k >= types.Uint && k <= types.Complex128:
+ return fmt.Sprintf("%s != 0", v)
+ case k == types.String:
+ return fmt.Sprintf(`%s != ""`, v)
+ default:
+ panic(fmt.Errorf("unhandled BasicKind %v", k))
+ }
+ case *types.Array, *types.Struct:
+ return fmt.Sprintf("%s != (%s{})", v, types.TypeString(vtyp, qualify))
+ case *types.Interface, *types.Pointer, *types.Signature:
+ return fmt.Sprintf("%s != nil", v)
+ case *types.Slice, *types.Map:
+ return fmt.Sprintf("len(%s) > 0", v)
+ default:
+ panic(fmt.Errorf("unhandled type %T", typ))
+ }
+}
+
+// isBigInt checks whether 'typ' is "math/big".Int.
+func isBigInt(typ types.Type) bool {
+ named, ok := typ.(*types.Named)
+ if !ok {
+ return false
+ }
+ name := named.Obj()
+ return name.Pkg().Path() == "math/big" && name.Name() == "Int"
+}
+
+// isUint256 checks whether 'typ' is "github.com/holiman/uint256".Int.
+func isUint256(typ types.Type) bool {
+ named, ok := typ.(*types.Named)
+ if !ok {
+ return false
+ }
+ name := named.Obj()
+ return name.Pkg().Path() == "github.com/holiman/uint256" && name.Name() == "Int"
+}
+
+// isByte checks whether the underlying type of 'typ' is uint8.
+func isByte(typ types.Type) bool {
+ basic, ok := resolveUnderlying(typ).(*types.Basic)
+ return ok && basic.Kind() == types.Uint8
+}
+
+func resolveUnderlying(typ types.Type) types.Type {
+ for {
+ t := typ.Underlying()
+ if t == typ {
+ return t
+ }
+ typ = t
+ }
+}
diff --git a/rlp/safe.go b/rlp/safe.go
new file mode 100644
index 000000000..3c910337b
--- /dev/null
+++ b/rlp/safe.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+//go:build nacl || js || !cgo
+// +build nacl js !cgo
+
+package rlp
+
+import "reflect"
+
+// byteArrayBytes returns a slice of the byte array v.
+func byteArrayBytes(v reflect.Value, length int) []byte {
+ return v.Slice(0, length).Bytes()
+}
diff --git a/rlp/typecache.go b/rlp/typecache.go
index 3df799e1e..c3244050b 100644
--- a/rlp/typecache.go
+++ b/rlp/typecache.go
@@ -19,138 +19,222 @@ package rlp
import (
"fmt"
"reflect"
- "strings"
"sync"
-)
+ "sync/atomic"
-var (
- typeCacheMutex sync.RWMutex
- typeCache = make(map[typekey]*typeinfo)
+ "github.com/tomochain/tomochain/rlp/internal/rlpstruct"
)
+// typeinfo is an entry in the type cache.
type typeinfo struct {
- decoder
- writer
-}
-
-// represents struct tags
-type tags struct {
- // rlp:"nil" controls whether empty input results in a nil pointer.
- nilOK bool
- // rlp:"tail" controls whether this field swallows additional list
- // elements. It can only be set for the last field, which must be
- // of slice type.
- tail bool
- // rlp:"-" ignores fields.
- ignored bool
+ decoder decoder
+ decoderErr error // error from makeDecoder
+ writer writer
+ writerErr error // error from makeWriter
}
+// typekey is the key of a type in typeCache. It includes the struct tags because
+// they might generate a different decoder.
type typekey struct {
reflect.Type
- // the key must include the struct tags because they
- // might generate a different decoder.
- tags
+ rlpstruct.Tags
}
type decoder func(*Stream, reflect.Value) error
-type writer func(reflect.Value, *encbuf) error
+type writer func(reflect.Value, *encBuffer) error
+
+var theTC = newTypeCache()
+
+type typeCache struct {
+ cur atomic.Value
+
+ // This lock synchronizes writers.
+ mu sync.Mutex
+ next map[typekey]*typeinfo
+}
+
+func newTypeCache() *typeCache {
+ c := new(typeCache)
+ c.cur.Store(make(map[typekey]*typeinfo))
+ return c
+}
+
+func cachedDecoder(typ reflect.Type) (decoder, error) {
+ info := theTC.info(typ)
+ return info.decoder, info.decoderErr
+}
+
+func cachedWriter(typ reflect.Type) (writer, error) {
+ info := theTC.info(typ)
+ return info.writer, info.writerErr
+}
+
+func (c *typeCache) info(typ reflect.Type) *typeinfo {
+ key := typekey{Type: typ}
+ if info := c.cur.Load().(map[typekey]*typeinfo)[key]; info != nil {
+ return info
+ }
+
+ // Not in the cache, need to generate info for this type.
+ return c.generate(typ, rlpstruct.Tags{})
+}
+
+func (c *typeCache) generate(typ reflect.Type, tags rlpstruct.Tags) *typeinfo {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ cur := c.cur.Load().(map[typekey]*typeinfo)
+ if info := cur[typekey{typ, tags}]; info != nil {
+ return info
+ }
-func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) {
- typeCacheMutex.RLock()
- info := typeCache[typekey{typ, tags}]
- typeCacheMutex.RUnlock()
- if info != nil {
- return info, nil
+ // Copy cur to next.
+ c.next = make(map[typekey]*typeinfo, len(cur)+1)
+ for k, v := range cur {
+ c.next[k] = v
}
- // not in the cache, need to generate info for this type.
- typeCacheMutex.Lock()
- defer typeCacheMutex.Unlock()
- return cachedTypeInfo1(typ, tags)
+
+ // Generate.
+ info := c.infoWhileGenerating(typ, tags)
+
+ // next -> cur
+ c.cur.Store(c.next)
+ c.next = nil
+ return info
}
-func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) {
+func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags rlpstruct.Tags) *typeinfo {
key := typekey{typ, tags}
- info := typeCache[key]
- if info != nil {
- // another goroutine got the write lock first
- return info, nil
+ if info := c.next[key]; info != nil {
+ return info
}
- // put a dummmy value into the cache before generating.
- // if the generator tries to lookup itself, it will get
+ // Put a dummy value into the cache before generating.
+ // If the generator tries to lookup itself, it will get
// the dummy value and won't call itself recursively.
- typeCache[key] = new(typeinfo)
- info, err := genTypeInfo(typ, tags)
- if err != nil {
- // remove the dummy value if the generator fails
- delete(typeCache, key)
- return nil, err
- }
- *typeCache[key] = *info
- return typeCache[key], err
+ info := new(typeinfo)
+ c.next[key] = info
+ info.generate(typ, tags)
+ return info
}
type field struct {
- index int
- info *typeinfo
+ index int
+ info *typeinfo
+ optional bool
}
+// structFields resolves the typeinfo of all public fields in a struct type.
func structFields(typ reflect.Type) (fields []field, err error) {
+ // Convert fields to rlpstruct.Field.
+ var allStructFields []rlpstruct.Field
for i := 0; i < typ.NumField(); i++ {
- if f := typ.Field(i); f.PkgPath == "" { // exported
- tags, err := parseStructTag(typ, i)
- if err != nil {
- return nil, err
- }
- if tags.ignored {
- continue
- }
- info, err := cachedTypeInfo1(f.Type, tags)
- if err != nil {
- return nil, err
- }
- fields = append(fields, field{i, info})
+ rf := typ.Field(i)
+ allStructFields = append(allStructFields, rlpstruct.Field{
+ Name: rf.Name,
+ Index: i,
+ Exported: rf.PkgPath == "",
+ Tag: string(rf.Tag),
+ Type: *rtypeToStructType(rf.Type, nil),
+ })
+ }
+
+ // Filter/validate fields.
+ structFields, structTags, err := rlpstruct.ProcessFields(allStructFields)
+ if err != nil {
+ if tagErr, ok := err.(rlpstruct.TagError); ok {
+ tagErr.StructType = typ.String()
+ return nil, tagErr
}
+ return nil, err
+ }
+
+ // Resolve typeinfo.
+ for i, sf := range structFields {
+ typ := typ.Field(sf.Index).Type
+ tags := structTags[i]
+ info := theTC.infoWhileGenerating(typ, tags)
+ fields = append(fields, field{sf.Index, info, tags.Optional})
}
return fields, nil
}
-func parseStructTag(typ reflect.Type, fi int) (tags, error) {
- f := typ.Field(fi)
- var ts tags
- for _, t := range strings.Split(f.Tag.Get("rlp"), ",") {
- switch t = strings.TrimSpace(t); t {
- case "":
- case "-":
- ts.ignored = true
- case "nil":
- ts.nilOK = true
- case "tail":
- ts.tail = true
- if fi != typ.NumField()-1 {
- return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name)
- }
- if f.Type.Kind() != reflect.Slice {
- return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name)
- }
- default:
- return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name)
+// firstOptionalField returns the index of the first field with "optional" tag.
+func firstOptionalField(fields []field) int {
+ for i, f := range fields {
+ if f.optional {
+ return i
}
}
- return ts, nil
+ return len(fields)
}
-func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) {
- info = new(typeinfo)
- if info.decoder, err = makeDecoder(typ, tags); err != nil {
- return nil, err
+type structFieldError struct {
+ typ reflect.Type
+ field int
+ err error
+}
+
+func (e structFieldError) Error() string {
+ return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name)
+}
+
+func (i *typeinfo) generate(typ reflect.Type, tags rlpstruct.Tags) {
+ i.decoder, i.decoderErr = makeDecoder(typ, tags)
+ i.writer, i.writerErr = makeWriter(typ, tags)
+}
+
+// rtypeToStructType converts typ to rlpstruct.Type.
+func rtypeToStructType(typ reflect.Type, rec map[reflect.Type]*rlpstruct.Type) *rlpstruct.Type {
+ k := typ.Kind()
+ if k == reflect.Invalid {
+ panic("invalid kind")
}
- if info.writer, err = makeWriter(typ, tags); err != nil {
- return nil, err
+
+ if prev := rec[typ]; prev != nil {
+ return prev // short-circuit for recursive types
+ }
+ if rec == nil {
+ rec = make(map[reflect.Type]*rlpstruct.Type)
+ }
+
+ t := &rlpstruct.Type{
+ Name: typ.String(),
+ Kind: k,
+ IsEncoder: typ.Implements(encoderInterface),
+ IsDecoder: typ.Implements(decoderInterface),
+ }
+ rec[typ] = t
+ if k == reflect.Array || k == reflect.Slice || k == reflect.Ptr {
+ t.Elem = rtypeToStructType(typ.Elem(), rec)
+ }
+ return t
+}
+
+// typeNilKind gives the RLP value kind for nil pointers to 'typ'.
+func typeNilKind(typ reflect.Type, tags rlpstruct.Tags) Kind {
+ styp := rtypeToStructType(typ, nil)
+
+ var nk rlpstruct.NilKind
+ if tags.NilOK {
+ nk = tags.NilKind
+ } else {
+ nk = styp.DefaultNilValue()
+ }
+ switch nk {
+ case rlpstruct.NilKindString:
+ return String
+ case rlpstruct.NilKindList:
+ return List
+ default:
+ panic("invalid nil kind value")
}
- return info, nil
}
func isUint(k reflect.Kind) bool {
return k >= reflect.Uint && k <= reflect.Uintptr
}
+
+func isByte(typ reflect.Type) bool {
+ return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
+}
diff --git a/rlp/unsafe.go b/rlp/unsafe.go
new file mode 100644
index 000000000..2152ba35f
--- /dev/null
+++ b/rlp/unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+//go:build !nacl && !js && cgo
+// +build !nacl,!js,cgo
+
+package rlp
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// byteArrayBytes returns a slice of the byte array v.
+func byteArrayBytes(v reflect.Value, length int) []byte {
+ var s []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s))
+ hdr.Data = v.UnsafeAddr()
+ hdr.Cap = length
+ hdr.Len = length
+ return s
+}
diff --git a/rpc/json.go b/rpc/json.go
index 715f33ee1..e35a74118 100644
--- a/rpc/json.go
+++ b/rpc/json.go
@@ -96,6 +96,10 @@ func (err *jsonError) ErrorCode() int {
return err.Code
}
+func (err *jsonError) ErrorData() interface{} {
+ return err.Data
+}
+
// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based
// on explicitly given encoding and decoding methods.
func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec {
diff --git a/swarm/network/hive.go b/swarm/network/hive.go
index 413074c47..0b8824c90 100644
--- a/swarm/network/hive.go
+++ b/swarm/network/hive.go
@@ -25,7 +25,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/metrics"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/netutil"
"github.com/tomochain/tomochain/swarm/network/kademlia"
"github.com/tomochain/tomochain/swarm/storage"
@@ -49,7 +49,7 @@ var (
type Hive struct {
listenAddr func() string
callInterval uint64
- id discover.NodeID
+ id enode.ID
addr kademlia.Address
kad *kademlia.Kademlia
path string
@@ -77,7 +77,7 @@ type HiveParams struct {
*kademlia.KadParams
}
-//create default params
+// create default params
func NewDefaultHiveParams() *HiveParams {
kad := kademlia.NewDefaultKadParams()
// kad.BucketSize = bucketSize
@@ -90,8 +90,8 @@ func NewDefaultHiveParams() *HiveParams {
}
}
-//this can only finally be set after all config options (file, cmd line, env vars)
-//have been evaluated
+// this can only finally be set after all config options (file, cmd line, env vars)
+// have been evaluated
func (self *HiveParams) Init(path string) {
self.KadDbPath = filepath.Join(path, "bzz-peers.json")
}
@@ -133,7 +133,7 @@ func (self *Hive) Addr() kademlia.Address {
// listedAddr is a function to retrieve listening address to advertise to peers
// connectPeer is a function to connect to a peer based on its NodeID or enode URL
// there are called on the p2p.Server which runs on the node
-func (self *Hive) Start(id discover.NodeID, listenAddr func() string, connectPeer func(string) error) (err error) {
+func (self *Hive) Start(id enode.ID, listenAddr func() string, connectPeer func(string) error) (err error) {
self.toggle = make(chan bool)
self.more = make(chan bool)
self.quit = make(chan bool)
diff --git a/swarm/network/messages.go b/swarm/network/messages.go
index 18ab63353..b434c2ff1 100644
--- a/swarm/network/messages.go
+++ b/swarm/network/messages.go
@@ -21,8 +21,9 @@ import (
"net"
"time"
+ "github.com/tomochain/tomochain/p2p/enode"
+
"github.com/tomochain/tomochain/contracts/chequebook"
- "github.com/tomochain/tomochain/p2p/discover"
"github.com/tomochain/tomochain/swarm/network/kademlia"
"github.com/tomochain/tomochain/swarm/services/swap"
"github.com/tomochain/tomochain/swarm/storage"
@@ -45,7 +46,7 @@ const (
)
/*
- Handshake
+ Handshake
* Version: 8 byte integer version of the protocol
* ID: arbitrary byte sequence client identifier human readable
@@ -54,7 +55,6 @@ const (
* NetworkID: 8 byte integer network identifier
* Caps: swarm-specific capabilities, format identical to devp2p
* SyncState: syncronisation state (db iterator key and address space etc) persisted about the peer
-
*/
type statusMsgData struct {
Version uint64
@@ -69,12 +69,12 @@ func (self *statusMsgData) String() string {
}
/*
- store requests are forwarded to the peers in their kademlia proximity bin
- if they are distant
- if they are within our storage radius or have any incentive to store it
- then attach your nodeID to the metadata
- if the storage request is sufficiently close (within our proxLimit, i. e., the
- last row of the routing table)
+store requests are forwarded to the peers in their kademlia proximity bin
+if they are distant
+if they are within our storage radius or have any incentive to store it
+then attach your nodeID to the metadata
+if the storage request is sufficiently close (within our proxLimit, i. e., the
+last row of the routing table)
*/
type storeRequestMsgData struct {
Key storage.Key // hash of datasize | data
@@ -181,9 +181,9 @@ type peerAddr struct {
// peerAddr pretty prints as enode
func (self *peerAddr) String() string {
- var nodeid discover.NodeID
+ var nodeid enode.ID
copy(nodeid[:], self.ID)
- return discover.NewNode(nodeid, self.IP, 0, self.Port).String()
+ return nodeid.GoString()
}
/*
diff --git a/swarm/swarm.go b/swarm/swarm.go
index 34a790eca..e970fc55c 100644
--- a/swarm/swarm.go
+++ b/swarm/swarm.go
@@ -37,7 +37,7 @@ import (
"github.com/tomochain/tomochain/metrics"
"github.com/tomochain/tomochain/node"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/params"
"github.com/tomochain/tomochain/rpc"
"github.com/tomochain/tomochain/swarm/api"
@@ -275,7 +275,7 @@ Start is called when the stack is started
func (self *Swarm) Start(srv *p2p.Server) error {
startTime = time.Now()
connectPeer := func(url string) error {
- node, err := discover.ParseNode(url)
+ node, err := enode.ParseV4(url)
if err != nil {
return fmt.Errorf("invalid node URL: %v", err)
}
@@ -296,7 +296,7 @@ func (self *Swarm) Start(srv *p2p.Server) error {
log.Warn(fmt.Sprintf("Starting Swarm service"))
self.hive.Start(
- discover.PubkeyID(&srv.PrivateKey.PublicKey),
+ enode.PubkeyToIDV4(&srv.PrivateKey.PublicKey),
func() string { return srv.ListenAddr },
connectPeer,
)
diff --git a/tests/state_test.go b/tests/state_test.go
index 7c8c5e926..a6d23edac 100644
--- a/tests/state_test.go
+++ b/tests/state_test.go
@@ -26,6 +26,9 @@ import (
)
func TestState(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping testing in short mode")
+ }
t.Parallel()
st := new(testMatcher)
@@ -50,13 +53,17 @@ func TestState(t *testing.T) {
subtest := subtest
key := fmt.Sprintf("%s/%d", subtest.Fork, subtest.Index)
name := name + "/" + key
- t.Run(key, func(t *testing.T) {
- if subtest.Fork == "Constantinople" {
- t.Skip("constantinople not supported yet")
- }
+
+ t.Run(key+"/trie", func(t *testing.T) {
+ withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
+ _, err := test.Run(subtest, vmconfig, false)
+ return st.checkFailure(t, name+"/trie", err)
+ })
+ })
+ t.Run(key+"/snap", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
- _, err := test.Run(subtest, vmconfig)
- return st.checkFailure(t, name, err)
+ _, err := test.Run(subtest, vmconfig, true)
+ return st.checkFailure(t, name+"/snap", err)
})
})
}
diff --git a/tests/state_test_util.go b/tests/state_test_util.go
index e532aa8a4..8e99c9b76 100644
--- a/tests/state_test_util.go
+++ b/tests/state_test_util.go
@@ -19,8 +19,8 @@ package tests
import (
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
"strings"
@@ -28,8 +28,9 @@ import (
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/common/math"
"github.com/tomochain/tomochain/core"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
- "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/core/state/snapshot"
"github.com/tomochain/tomochain/core/vm"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/crypto/sha3"
@@ -121,14 +122,14 @@ func (t *StateTest) Subtests() []StateSubtest {
}
// Run executes a specific subtest.
-func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateDB, error) {
+func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bool) (*state.StateDB, error) {
config, ok := Forks[subtest.Fork]
if !ok {
return nil, UnsupportedForkError{subtest.Fork}
}
block := t.genesis(config).ToBlock(nil)
db := rawdb.NewMemoryDatabase()
- statedb := MakePreState(db, t.json.Pre)
+ statedb := MakePreState(db, t.json.Pre, snapshotter)
post := t.json.Post[subtest.Fork][subtest.Index]
msg, err := t.json.Tx.toMessage(post)
@@ -144,7 +145,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
snapshot := statedb.Snapshot()
coinbase := &t.json.Env.Coinbase
- if _, _, _, err := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil {
+ if _, err := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil {
statedb.RevertToSnapshot(snapshot)
}
if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) {
@@ -161,9 +162,9 @@ func (t *StateTest) gasLimit(subtest StateSubtest) uint64 {
return t.json.Tx.GasLimit[t.json.Post[subtest.Fork][subtest.Index].Indexes.Gas]
}
-func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB {
+func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter bool) *state.StateDB {
sdb := state.NewDatabase(db)
- statedb, _ := state.New(common.Hash{}, sdb)
+ statedb, _ := state.New(common.Hash{}, sdb, nil)
for addr, a := range accounts {
statedb.SetCode(addr, a.Code)
statedb.SetNonce(addr, a.Nonce)
@@ -174,7 +175,12 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB
}
// Commit and re-open to start with a clean state.
root, _ := statedb.Commit(false)
- statedb, _ = state.New(root, sdb)
+
+ var snaps *snapshot.Tree
+ if snapshotter {
+ snaps = snapshot.New(db, sdb.TrieDB(), 1, root, false)
+ }
+ statedb, _ = state.New(root, sdb, snaps)
return statedb
}
@@ -190,7 +196,7 @@ func (t *StateTest) genesis(config *params.ChainConfig) *core.Genesis {
}
}
-func (tx *stTransaction) toMessage(ps stPostState) (core.Message, error) {
+func (tx *stTransaction) toMessage(ps stPostState) (*core.Message, error) {
// Derive sender from private key if present.
var from common.Address
if len(tx.PrivateKey) > 0 {
@@ -235,7 +241,21 @@ func (tx *stTransaction) toMessage(ps stPostState) (core.Message, error) {
if err != nil {
return nil, fmt.Errorf("invalid tx data %q", dataHex)
}
- msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, tx.GasPrice, data, true, nil)
+ // If baseFee provided, set gasPrice to effectiveGasPrice.
+ gasPrice := tx.GasPrice
+ if gasPrice == nil {
+ return nil, errors.New("no gas price provided")
+ }
+
+ msg := &core.Message{
+ From: from,
+ To: to,
+ Nonce: tx.Nonce,
+ Value: value,
+ GasLimit: gasLimit,
+ GasPrice: gasPrice,
+ Data: data,
+ }
return msg, nil
}
diff --git a/tests/vm_test.go b/tests/vm_test.go
index 9e1f73543..234d73620 100644
--- a/tests/vm_test.go
+++ b/tests/vm_test.go
@@ -17,15 +17,15 @@
package tests
import (
- "github.com/tomochain/tomochain/common"
"math/big"
"testing"
+ "github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core/vm"
)
func TestVM(t *testing.T) {
- common.TIPTomoXCancellationFee=big.NewInt(100000000)
+ common.TIPTomoXCancellationFee = big.NewInt(100000000)
t.Parallel()
vmt := new(testMatcher)
vmt.fails("^vmSystemOperationsTest.json/createNameRegistrator$", "fails without parallel execution")
@@ -37,7 +37,10 @@ func TestVM(t *testing.T) {
vmt.walk(t, vmTestDir, func(t *testing.T, name string, test *VMTest) {
withTrace(t, test.json.Exec.GasLimit, func(vmconfig vm.Config) error {
- return vmt.checkFailure(t, name, test.Run(vmconfig))
+ return vmt.checkFailure(t, name+"/trie", test.Run(vmconfig, false))
+ })
+ withTrace(t, test.json.Exec.GasLimit, func(vmconfig vm.Config) error {
+ return vmt.checkFailure(t, name+"/snap", test.Run(vmconfig, true))
})
})
}
diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go
index 01c471af2..c2a56d779 100644
--- a/tests/vm_test_util.go
+++ b/tests/vm_test_util.go
@@ -20,9 +20,10 @@ import (
"bytes"
"encoding/json"
"fmt"
- "github.com/tomochain/tomochain/core/rawdb"
"math/big"
+ "github.com/tomochain/tomochain/core/rawdb"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/common/math"
@@ -78,9 +79,9 @@ type vmExecMarshaling struct {
GasPrice *math.HexOrDecimal256
}
-func (t *VMTest) Run(vmconfig vm.Config) error {
+func (t *VMTest) Run(vmconfig vm.Config, snapshotter bool) error {
db := rawdb.NewMemoryDatabase()
- statedb := MakePreState(db, t.json.Pre)
+ statedb := MakePreState(db, t.json.Pre, snapshotter)
ret, gasRemaining, err := t.exec(statedb, vmconfig)
if t.json.GasRemaining == nil {
diff --git a/tomox/token.go b/tomox/token.go
index 24e6e138f..a90a3dce9 100644
--- a/tomox/token.go
+++ b/tomox/token.go
@@ -37,7 +37,7 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA
return nil, err
}
var unpackResult interface{}
- err = abi.Unpack(&unpackResult, method, result)
+ err = abi.UnpackIntoInterface(&unpackResult, method, result)
if err != nil {
return nil, err
}
@@ -75,4 +75,4 @@ func (tomox *TomoX) GetTokenDecimal(chain consensus.ChainContext, statedb *state
// FIXME: using in unit tests only
func (tomox *TomoX) SetTokenDecimal(token common.Address, decimal *big.Int) {
tomox.tokenDecimalCache.Add(token, decimal)
-}
\ No newline at end of file
+}
diff --git a/tomox/tomox.go b/tomox/tomox.go
index ae5b960e6..3f19e02f3 100644
--- a/tomox/tomox.go
+++ b/tomox/tomox.go
@@ -573,6 +573,11 @@ func (tomox *TomoX) GetTradingState(block *types.Block, author common.Address) (
return tradingstate.New(root, tomox.StateCache)
}
+
+func (tomox *TomoX) GetEmptyTradingState() (*tradingstate.TradingStateDB, error) {
+ return tradingstate.New(tradingstate.EmptyRoot, tomox.StateCache)
+}
+
func (tomox *TomoX) GetStateCache() tradingstate.Database {
return tomox.StateCache
}
diff --git a/tomox/tradingstate/database.go b/tomox/tradingstate/database.go
index 56acf61ec..e77b6be1a 100644
--- a/tomox/tradingstate/database.go
+++ b/tomox/tradingstate/database.go
@@ -81,7 +81,7 @@ type Trie interface {
func NewDatabase(db ethdb.Database) Database {
csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{
- db: trie.NewDatabase(db),
+ db: trie.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}),
codeSizeCache: csc,
}
}
diff --git a/tomox/tradingstate/tomox_trie.go b/tomox/tradingstate/tomox_trie.go
index 908648def..197e50b4c 100644
--- a/tomox/tradingstate/tomox_trie.go
+++ b/tomox/tradingstate/tomox_trie.go
@@ -18,11 +18,11 @@ package tradingstate
import (
"fmt"
- "github.com/tomochain/tomochain/ethdb"
- "github.com/tomochain/tomochain/trie"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/trie"
)
// TomoXTrie wraps a trie with key hashing. In a secure trie, all
@@ -78,10 +78,10 @@ func (t *TomoXTrie) Get(key []byte) []byte {
// The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryGet(key []byte) ([]byte, error) {
- return t.trie.TryGet(key)
+ return t.trie.Get(key)
}
-// TryGetBestLeftKey returns the value of max left leaf
+// TryGetBestLeftKeyAndValue returns the value of max left leaf
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryGetBestLeftKeyAndValue() ([]byte, []byte, error) {
return t.trie.TryGetBestLeftKeyAndValue()
@@ -91,7 +91,7 @@ func (t *TomoXTrie) TryGetAllLeftKeyAndValue(limit []byte) ([][]byte, [][]byte,
return t.trie.TryGetAllLeftKeyAndValue(limit)
}
-// TryGetBestRightKey returns the value of max left leaf
+// TryGetBestRightKeyAndValue returns the value of max left leaf
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryGetBestRightKeyAndValue() ([]byte, []byte, error) {
return t.trie.TryGetBestRightKeyAndValue()
@@ -118,7 +118,7 @@ func (t *TomoXTrie) Update(key, value []byte) {
//
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryUpdate(key, value []byte) error {
- err := t.trie.TryUpdate(key, value)
+ err := t.trie.Update(key, value)
if err != nil {
return err
}
@@ -137,7 +137,7 @@ func (t *TomoXTrie) Delete(key []byte) {
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryDelete(key []byte) error {
delete(t.getSecKeyCache(), string(key))
- return t.trie.TryDelete(key)
+ return t.trie.Delete(key)
}
// GetKey returns the sha3 preimage of a hashed key that was
diff --git a/tomoxlending/lendingstate/database.go b/tomoxlending/lendingstate/database.go
index d82360259..c27c41dcf 100644
--- a/tomoxlending/lendingstate/database.go
+++ b/tomoxlending/lendingstate/database.go
@@ -80,7 +80,7 @@ type Trie interface {
func NewDatabase(db ethdb.Database) Database {
csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{
- db: trie.NewDatabase(db),
+ db: trie.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}),
codeSizeCache: csc,
}
}
diff --git a/tomoxlending/lendingstate/lendingitem_test.go b/tomoxlending/lendingstate/lendingitem_test.go
index b83c59ebe..564dffddf 100644
--- a/tomoxlending/lendingstate/lendingitem_test.go
+++ b/tomoxlending/lendingstate/lendingitem_test.go
@@ -2,17 +2,18 @@ package lendingstate
import (
"fmt"
+ "math/big"
+ "math/rand"
+ "os"
+ "testing"
+ "time"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/core/state"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/crypto/sha3"
"github.com/tomochain/tomochain/rpc"
- "math/big"
- "math/rand"
- "os"
- "testing"
- "time"
)
func TestLendingItem_VerifyLendingSide(t *testing.T) {
@@ -152,7 +153,7 @@ func SetCollateralDetail(statedb *state.StateDB, token common.Address, depositRa
func TestVerifyBalance(t *testing.T) {
db := rawdb.NewMemoryDatabase()
- statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
relayer := common.HexToAddress("0x0D3ab14BBaD3D99F4203bd7a11aCB94882050E7e")
uAddr := common.HexToAddress("0xDeE6238780f98c0ca2c2C28453149bEA49a3Abc9")
lendingToken := common.HexToAddress("0xd9bb01454c85247B2ef35BB5BE57384cC275a8cf") // USD
diff --git a/tomoxlending/lendingstate/tomox_trie.go b/tomoxlending/lendingstate/tomox_trie.go
index 8ff0a5633..2852139ae 100644
--- a/tomoxlending/lendingstate/tomox_trie.go
+++ b/tomoxlending/lendingstate/tomox_trie.go
@@ -18,11 +18,11 @@ package lendingstate
import (
"fmt"
- "github.com/tomochain/tomochain/ethdb"
- "github.com/tomochain/tomochain/trie"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/trie"
)
// TomoXTrie wraps a trie with key hashing. In a secure trie, all
@@ -78,16 +78,16 @@ func (t *TomoXTrie) Get(key []byte) []byte {
// The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryGet(key []byte) ([]byte, error) {
- return t.trie.TryGet(key)
+ return t.trie.Get(key)
}
-// TryGetBestLeftKey returns the value of max left leaf
+// TryGetBestLeftKeyAndValue returns the value of max left leaf
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryGetBestLeftKeyAndValue() ([]byte, []byte, error) {
return t.trie.TryGetBestLeftKeyAndValue()
}
-// TryGetBestRightKey returns the value of max left leaf
+// TryGetBestRightKeyAndValue returns the value of max left leaf
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryGetBestRightKeyAndValue() ([]byte, []byte, error) {
return t.trie.TryGetBestRightKeyAndValue()
@@ -114,7 +114,7 @@ func (t *TomoXTrie) Update(key, value []byte) {
//
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryUpdate(key, value []byte) error {
- err := t.trie.TryUpdate(key, value)
+ err := t.trie.Update(key, value)
if err != nil {
return err
}
@@ -133,7 +133,7 @@ func (t *TomoXTrie) Delete(key []byte) {
// If a node was not found in the database, a MissingNodeError is returned.
func (t *TomoXTrie) TryDelete(key []byte) error {
delete(t.getSecKeyCache(), string(key))
- return t.trie.TryDelete(key)
+ return t.trie.Delete(key)
}
// GetKey returns the sha3 preimage of a hashed key that was
diff --git a/trie/committer.go b/trie/committer.go
index 78ed86bb4..43a31381b 100644
--- a/trie/committer.go
+++ b/trie/committer.go
@@ -22,8 +22,8 @@ import (
"sync"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/rlp"
- "golang.org/x/crypto/sha3"
)
// leafChanSize is the size of the leafCh. It's a pretty arbitrary number, to allow
@@ -46,7 +46,7 @@ type leaf struct {
// processed sequentially - onleaf will never be called in parallel or out of order.
type committer struct {
tmp sliceBuffer
- sha keccakState
+ sha crypto.KeccakState
onleaf LeafCallback
leafCh chan *leaf
@@ -57,7 +57,7 @@ var committerPool = sync.Pool{
New: func() interface{} {
return &committer{
tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode.
- sha: sha3.NewLegacyKeccak256().(keccakState),
+ sha: crypto.NewKeccakState(),
}
},
}
diff --git a/trie/database.go b/trie/database.go
index bb2da07c2..446fda34c 100644
--- a/trie/database.go
+++ b/trie/database.go
@@ -25,6 +25,7 @@ import (
"time"
"github.com/VictoriaMetrics/fastcache"
+
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/ethdb"
"github.com/tomochain/tomochain/log"
@@ -65,6 +66,12 @@ const secureKeyPrefixLength = 11
// secureKeyLength is the length of the above prefix + 32byte hash.
const secureKeyLength = secureKeyPrefixLength + 32
+// Config defines all necessary options for database.
+type Config struct {
+ Cache int // Memory allowance (MB) to use for caching trie nodes in memory
+ Preimages bool // Flag whether the preimage of trie key is recorded
+}
+
// Database is an intermediate write layer between the trie data structures and
// the disk database. The aim is to accumulate trie writes in-memory and only
// periodically flush a couple tries to disk, garbage collecting the remainder.
@@ -74,6 +81,7 @@ const secureKeyLength = secureKeyPrefixLength + 32
// behind this split design is to provide read access to RPC handlers and sync
// servers even while the trie is executing expensive garbage collection.
type Database struct {
+ config *Config // Configuration for trie database
diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes
cleans *fastcache.Cache // GC friendly memory Cache of clean Node RLPs
@@ -81,7 +89,7 @@ type Database struct {
oldest common.Hash // Oldest tracked Node, flush-list head
newest common.Hash // Newest tracked Node, flush-list tail
- preimages map[common.Hash][]byte // Preimages of nodes from the secure trie
+ preimages *preimageStore // The store for caching preimages
gctime time.Duration // Time spent on garbage collection since last commit
gcnodes uint64 // Nodes garbage collected since last commit
@@ -106,7 +114,12 @@ type rawNode []byte
func (n rawNode) Cache() (HashNode, bool) { panic("this should never end up in a live trie") }
func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
-// rawFullNode represents only the useful data content of a full Node, with the
+func (n rawNode) EncodeRLP(w io.Writer) error {
+ _, err := w.Write([]byte(n))
+ return err
+}
+
+// rawFullNode represents only the useful data content of a full node, with the
// caches and flags stripped out to minimize its data storage. This type honors
// the same RLP encoding as the original parent.
type rawFullNode [17]Node
@@ -184,7 +197,7 @@ func (n *cachedNode) obj(hash common.Hash) Node {
// forChilds invokes the callback for all the tracked children of this Node,
// both the implicit ones from inside the Node as well as the explicit ones
-//from outside the Node.
+// from outside the Node.
func (n *cachedNode) forChilds(onChild func(hash common.Hash)) {
for child := range n.children {
onChild(child)
@@ -277,26 +290,32 @@ func expandNode(hash HashNode, n Node) Node {
// NewDatabase creates a new trie database to store ephemeral trie content before
// its written out to disk or garbage collected. No read Cache is created, so all
// data retrievals will hit the underlying disk database.
-func NewDatabase(diskdb ethdb.KeyValueStore) *Database {
- return NewDatabaseWithCache(diskdb, 0)
+func NewDatabase(diskdb ethdb.Database) *Database {
+ return NewDatabaseWithConfig(diskdb, nil)
}
-// NewDatabaseWithCache creates a new trie database to store ephemeral trie content
+// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
// before its written out to disk or garbage collected. It also acts as a read Cache
// for nodes loaded from disk.
-func NewDatabaseWithCache(diskdb ethdb.KeyValueStore, cache int) *Database {
+func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database {
var cleans *fastcache.Cache
- if cache > 0 {
- cleans = fastcache.New(cache * 1024 * 1024)
+ if config != nil && config.Cache > 0 {
+ cleans = fastcache.New(config.Cache * 1024 * 1024)
+ }
+ var preimages *preimageStore
+ if config != nil && config.Preimages {
+ preimages = newPreimageStore(diskdb)
}
- return &Database{
+ db := &Database{
diskdb: diskdb,
cleans: cleans,
dirties: map[common.Hash]*cachedNode{{}: {
children: make(map[common.Hash]uint16),
}},
- preimages: make(map[common.Hash][]byte),
+ preimages: preimages,
}
+
+ return db
}
// DiskDB retrieves the persistent storage backing the trie database.
@@ -352,11 +371,12 @@ func (db *Database) insert(hash common.Hash, size int, node Node) {
// yet unknown. The method will make a copy of the slice.
//
// Note, this method assumes that the database's Lock is held!
+// This function's still be kept because of TomoX tries
func (db *Database) InsertPreimage(hash common.Hash, preimage []byte) {
- if _, ok := db.preimages[hash]; ok {
+ if _, ok := db.preimages.preimages[hash]; ok {
return
}
- db.preimages[hash] = common.CopyBytes(preimage)
+ db.preimages.preimages[hash] = common.CopyBytes(preimage)
db.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
}
@@ -440,7 +460,7 @@ func (db *Database) Node(hash common.Hash) ([]byte, error) {
func (db *Database) Preimage(hash common.Hash) ([]byte, error) {
// Retrieve the Node from Cache if available
db.Lock.RLock()
- preimage := db.preimages[hash]
+ preimage := db.preimages.preimages[hash]
db.Lock.RUnlock()
if preimage != nil {
@@ -607,7 +627,7 @@ func (db *Database) Cap(limit common.StorageSize) error {
// leave for later to deduplicate writes.
flushPreimages := db.preimagesSize > 4*1024*1024
if flushPreimages {
- for hash, preimage := range db.preimages {
+ for hash, preimage := range db.preimages.preimages {
copy(keyBuf[secureKeyPrefixLength:], hash[:])
if err := batch.Put(keyBuf[:], preimage); err != nil {
log.Error("Failed to commit Preimage from trie database", "err", err)
@@ -656,7 +676,7 @@ func (db *Database) Cap(limit common.StorageSize) error {
defer db.Lock.Unlock()
if flushPreimages {
- db.preimages = make(map[common.Hash][]byte)
+ db.preimages.preimages = make(map[common.Hash][]byte)
db.preimagesSize = 0
}
for db.oldest != oldest {
@@ -706,26 +726,28 @@ func (db *Database) Commit(node common.Hash, report bool) error {
copy(keyBuf[:], secureKeyPrefix)
// Move all of the accumulated preimages into a write batch
- for hash, preimage := range db.preimages {
- copy(keyBuf[secureKeyPrefixLength:], hash[:])
- if err := batch.Put(keyBuf[:], preimage); err != nil {
- log.Error("Failed to commit Preimage from trie database", "err", err)
- return err
- }
- // If the batch is too large, flush to disk
- if batch.ValueSize() > ethdb.IdealBatchSize {
- if err := batch.Write(); err != nil {
+ if db.preimages != nil {
+ for hash, preimage := range db.preimages.preimages {
+ copy(keyBuf[secureKeyPrefixLength:], hash[:])
+ if err := batch.Put(keyBuf[:], preimage); err != nil {
+ log.Error("Failed to commit Preimage from trie database", "err", err)
return err
}
- batch.Reset()
+ // If the batch is too large, flush to disk
+ if batch.ValueSize() > ethdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ batch.Reset()
+ }
}
+ // Since we're going to replay trie Node writes into the clean Cache, flush out
+ // any batched pre-images before continuing.
+ if err := batch.Write(); err != nil {
+ return err
+ }
+ batch.Reset()
}
- // Since we're going to replay trie Node writes into the clean Cache, flush out
- // any batched pre-images before continuing.
- if err := batch.Write(); err != nil {
- return err
- }
- batch.Reset()
// Move the trie itself into the batch, flushing if enough data is accumulated
nodes, storage := len(db.dirties), db.dirtiesSize
@@ -747,10 +769,6 @@ func (db *Database) Commit(node common.Hash, report bool) error {
batch.Replay(uncacher)
batch.Reset()
- // Reset the storage counters and bumpd metrics
- db.preimages = make(map[common.Hash][]byte)
- db.preimagesSize = 0
-
memcacheCommitTimeTimer.Update(time.Since(start))
memcacheCommitSizeMeter.Mark(int64(storage - db.dirtiesSize))
memcacheCommitNodesMeter.Mark(int64(nodes - len(db.dirties)))
@@ -785,6 +803,7 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane
if err != nil {
return err
}
+
if err := batch.Put(hash[:], node.rlp()); err != nil {
return err
}
@@ -794,9 +813,12 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane
return err
}
db.Lock.Lock()
- batch.Replay(uncacher)
+ err := batch.Replay(uncacher)
batch.Reset()
db.Lock.Unlock()
+ if err != nil {
+ return err
+ }
}
return nil
}
@@ -810,7 +832,7 @@ type cleaner struct {
// Put reacts to database writes and implements dirty data uncaching. This is the
// post-processing step of a commit operation where the already persisted trie is
// removed from the dirty Cache and moved into the clean Cache. The reason behind
-// the two-phase commit is to ensure ensure data availability while moving from
+// the two-phase commit is to ensure data availability while moving from
// memory to disk.
func (c *cleaner) Put(key []byte, rlp []byte) error {
hash := common.BytesToHash(key)
diff --git a/trie/database_test.go b/trie/database_test.go
index ed6b58fdc..126923b12 100644
--- a/trie/database_test.go
+++ b/trie/database_test.go
@@ -20,13 +20,13 @@ import (
"testing"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/ethdb/memorydb"
+ "github.com/tomochain/tomochain/core/rawdb"
)
// Tests that the trie database returns a missing trie Node error if attempting
// to retrieve the meta root.
func TestDatabaseMetarootFetch(t *testing.T) {
- db := NewDatabase(memorydb.New())
+ db := NewDatabase(rawdb.NewMemoryDatabase())
if _, err := db.Node(common.Hash{}); err == nil {
t.Fatalf("metaroot retrieval succeeded")
}
diff --git a/trie/hasher.go b/trie/hasher.go
index 8a2ea1806..d4a36dd5e 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The go-ethereum Authors
+// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
@@ -17,21 +17,12 @@
package trie
import (
- "hash"
"sync"
+ "github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/rlp"
- "golang.org/x/crypto/sha3"
)
-// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports
-// Read to get a variable amount of data from the hash state. Read is faster than Sum
-// because it doesn't copy the internal state, but also modifies the internal state.
-type keccakState interface {
- hash.Hash
- Read([]byte) (int, error)
-}
-
type sliceBuffer []byte
func (b *sliceBuffer) Write(data []byte) (n int, err error) {
@@ -46,17 +37,19 @@ func (b *sliceBuffer) Reset() {
// hasher is a type used for the trie Hash operation. A hasher has some
// internal preallocated temp space
type hasher struct {
- sha keccakState
- tmp sliceBuffer
- parallel bool // Whether to use paralallel threads when hashing
+ sha crypto.KeccakState
+ tmp []byte
+ encbuf rlp.EncoderBuffer
+ parallel bool // Whether to use parallel threads when hashing
}
// hasherPool holds pureHashers
var hasherPool = sync.Pool{
New: func() interface{} {
return &hasher{
- tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode.
- sha: sha3.NewLegacyKeccak256().(keccakState),
+ tmp: make([]byte, 0, 550), // cap is as large as a full fullNode.
+ sha: crypto.NewKeccakState(),
+ encbuf: rlp.NewEncoderBuffer(nil),
}
},
}
@@ -71,14 +64,14 @@ func returnHasherToPool(h *hasher) {
hasherPool.Put(h)
}
-// hash collapses a Node down into a hash Node, also returning a copy of the
-// original Node initialized with the computed hash to replace the original one.
+// hash collapses a node down into a hash node, also returning a copy of the
+// original node initialized with the computed hash to replace the original one.
func (h *hasher) hash(n Node, force bool) (hashed Node, cached Node) {
- // We're not storing the Node, just hashing, use available cached data
+ // Return the cached hash if it's available
if hash, _ := n.Cache(); hash != nil {
return hash, n
}
- // Trie not processed yet or needs storage, walk the children
+ // Trie not processed yet, walk the children
switch n := n.(type) {
case *ShortNode:
collapsed, cached := h.hashShortNodeChildren(n)
@@ -106,11 +99,11 @@ func (h *hasher) hash(n Node, force bool) (hashed Node, cached Node) {
}
}
-// hashShortNodeChildren collapses the short Node. The returned collapsed Node
+// hashShortNodeChildren collapses the short node. The returned collapsed node
// holds a live reference to the Key, and must not be modified.
// The cached
func (h *hasher) hashShortNodeChildren(n *ShortNode) (collapsed, cached *ShortNode) {
- // Hash the short Node's child, caching the newly hashed subtree
+ // Hash the short node's child, caching the newly hashed subtree
collapsed, cached = n.copy(), n.copy()
// Previously, we did copy this one. We don't seem to need to actually
// do that, since we don't overwrite/reuse keys
@@ -125,7 +118,7 @@ func (h *hasher) hashShortNodeChildren(n *ShortNode) (collapsed, cached *ShortNo
}
func (h *hasher) hashFullNodeChildren(n *FullNode) (collapsed *FullNode, cached *FullNode) {
- // Hash the full Node's children, caching the newly hashed subtrees
+ // Hash the full node's children, caching the newly hashed subtrees
cached = n.copy()
collapsed = n.copy()
if h.parallel {
@@ -156,35 +149,46 @@ func (h *hasher) hashFullNodeChildren(n *FullNode) (collapsed *FullNode, cached
return collapsed, cached
}
-// shortnodeToHash creates a HashNode from a ShortNode. The supplied shortnode
+// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode
// should have hex-type Key, which will be converted (without modification)
// into compact form for RLP encoding.
// If the rlp data is smaller than 32 bytes, `nil` is returned.
func (h *hasher) shortnodeToHash(n *ShortNode, force bool) Node {
- h.tmp.Reset()
- if err := rlp.Encode(&h.tmp, n); err != nil {
- panic("encode error: " + err.Error())
- }
+ n.encode(h.encbuf)
+ enc := h.encodedBytes()
- if len(h.tmp) < 32 && !force {
+ if len(enc) < 32 && !force {
return n // Nodes smaller than 32 bytes are stored inside their parent
}
- return h.hashData(h.tmp)
+ return h.hashData(enc)
}
-// shortnodeToHash is used to creates a HashNode from a set of hashNodes, (which
+// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which
// may contain nil values)
func (h *hasher) fullnodeToHash(n *FullNode, force bool) Node {
- h.tmp.Reset()
- // Generate the RLP encoding of the Node
- if err := n.EncodeRLP(&h.tmp); err != nil {
- panic("encode error: " + err.Error())
- }
+ n.encode(h.encbuf)
+ enc := h.encodedBytes()
- if len(h.tmp) < 32 && !force {
+ if len(enc) < 32 && !force {
return n // Nodes smaller than 32 bytes are stored inside their parent
}
- return h.hashData(h.tmp)
+ return h.hashData(enc)
+}
+
+// encodedBytes returns the result of the last encoding operation on h.encbuf.
+// This also resets the encoder buffer.
+//
+// All node encoding must be done like this:
+//
+// node.encode(h.encbuf)
+// enc := h.encodedBytes()
+//
+// This convention exists because node.encode can only be inlined/escape-analyzed when
+// called on a concrete receiver type.
+func (h *hasher) encodedBytes() []byte {
+ h.tmp = h.encbuf.AppendToBytes(h.tmp[:0])
+ h.encbuf.Reset(nil)
+ return h.tmp
}
// hashData hashes the provided data
@@ -197,8 +201,8 @@ func (h *hasher) hashData(data []byte) HashNode {
}
// proofHash is used to construct trie proofs, and returns the 'collapsed'
-// Node (for later RLP encoding) aswell as the hashed Node -- unless the
-// Node is smaller than 32 bytes, in which case it will be returned as is.
+// node (for later RLP encoding) as well as the hashed node -- unless the
+// node is smaller than 32 bytes, in which case it will be returned as is.
// This method does not do anything on value- or hash-nodes.
func (h *hasher) proofHash(original Node) (collapsed, hashed Node) {
switch n := original.(type) {
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 26d48c95c..b93d66422 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -23,7 +23,7 @@ import (
"testing"
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/ethdb/memorydb"
+ "github.com/tomochain/tomochain/core/rawdb"
)
func TestIterator(t *testing.T) {
@@ -292,7 +292,7 @@ func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueA
func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
func testIteratorContinueAfterError(t *testing.T, memonly bool) {
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
tr, _ := New(common.Hash{}, triedb)
@@ -383,7 +383,7 @@ func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
// Commit test trie to Db, then remove the Node containing "bars".
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
ctr, _ := New(common.Hash{}, triedb)
diff --git a/trie/node.go b/trie/node.go
index ffb2f1811..fbbe29341 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -30,6 +30,7 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b
type Node interface {
fstring(string) string
Cache() (HashNode, bool)
+ encode(w rlp.EncoderBuffer)
}
type (
@@ -52,16 +53,9 @@ var nilValueNode = ValueNode(nil)
// EncodeRLP encodes a full Node into the consensus RLP format.
func (n *FullNode) EncodeRLP(w io.Writer) error {
- var nodes [17]Node
-
- for i, child := range &n.Children {
- if child != nil {
- nodes[i] = child
- } else {
- nodes[i] = nilValueNode
- }
- }
- return rlp.Encode(w, nodes)
+ eb := rlp.NewEncoderBuffer(w)
+ n.encode(eb)
+ return eb.Flush()
}
func (n *FullNode) copy() *FullNode { copy := *n; return © }
diff --git a/trie/node_enc.go b/trie/node_enc.go
new file mode 100644
index 000000000..b987abfbf
--- /dev/null
+++ b/trie/node_enc.go
@@ -0,0 +1,72 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "github.com/tomochain/tomochain/rlp"
+)
+
+func nodeToBytes(n Node) []byte {
+ w := rlp.NewEncoderBuffer(nil)
+ n.encode(w)
+ result := w.ToBytes()
+ w.Flush()
+ return result
+}
+
+func (n *FullNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+ for _, c := range n.Children {
+ if c != nil {
+ c.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+ }
+ w.ListEnd(offset)
+}
+
+func (n *ShortNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+ w.WriteBytes(n.Key)
+ if n.Val != nil {
+ n.Val.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+ w.ListEnd(offset)
+}
+
+func (n HashNode) encode(w rlp.EncoderBuffer) {
+ w.WriteBytes(n)
+}
+
+func (n ValueNode) encode(w rlp.EncoderBuffer) {
+ w.WriteBytes(n)
+}
+
+func (n rawNode) encode(w rlp.EncoderBuffer) {
+ w.Write(n)
+}
+
+func (n rawShortNode) encode(w rlp.EncoderBuffer) {
+ panic("this should never end up in a live trie")
+}
+
+func (n rawFullNode) encode(w rlp.EncoderBuffer) {
+ panic("this should never end up in a live trie")
+}
diff --git a/trie/preimages.go b/trie/preimages.go
new file mode 100644
index 000000000..760f2290f
--- /dev/null
+++ b/trie/preimages.go
@@ -0,0 +1,94 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "sync"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/ethdb"
+)
+
+// preimageStore is the store for caching preimages of node key.
+type preimageStore struct {
+ lock sync.RWMutex
+ disk ethdb.Database
+ preimages map[common.Hash][]byte // Preimages of nodes from the secure trie
+ preimagesSize common.StorageSize // Storage size of the preimages cache
+}
+
+// newPreimageStore initializes the store for caching preimages.
+func newPreimageStore(disk ethdb.Database) *preimageStore {
+ return &preimageStore{
+ disk: disk,
+ preimages: make(map[common.Hash][]byte),
+ }
+}
+
+// insertPreimage writes a new trie node pre-image to the memory database if it's
+// yet unknown. The method will NOT make a copy of the slice, only use if the
+// preimage will NOT be changed later on.
+func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) {
+ store.lock.Lock()
+ defer store.lock.Unlock()
+
+ for hash, preimage := range preimages {
+ if _, ok := store.preimages[hash]; ok {
+ continue
+ }
+ store.preimages[hash] = preimage
+ store.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
+ }
+}
+
+// preimage retrieves a cached trie node pre-image from memory. If it cannot be
+// found cached, the method queries the persistent database for the content.
+func (store *preimageStore) preimage(hash common.Hash) []byte {
+ store.lock.RLock()
+ preimage := store.preimages[hash]
+ store.lock.RUnlock()
+
+ if preimage != nil {
+ return preimage
+ }
+ return rawdb.ReadPreimage(store.disk, hash)
+}
+
+// commit flushes the cached preimages into the disk.
+func (store *preimageStore) commit(force bool) error {
+ store.lock.Lock()
+ defer store.lock.Unlock()
+
+ if store.preimagesSize <= 4*1024*1024 && !force {
+ return nil
+ }
+ if err := rawdb.WritePreimages(store.disk, 0, store.preimages); err != nil {
+ return err
+ }
+
+ store.preimages, store.preimagesSize = make(map[common.Hash][]byte), 0
+ return nil
+}
+
+// size returns the current storage size of accumulated preimages.
+func (store *preimageStore) size() common.StorageSize {
+ store.lock.RLock()
+ defer store.lock.RUnlock()
+
+ return store.preimagesSize
+}
diff --git a/trie/proof.go b/trie/proof.go
index 9e4082a27..28320e8a0 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -22,8 +22,8 @@ import (
"fmt"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/ethdb"
- "github.com/tomochain/tomochain/ethdb/memorydb"
"github.com/tomochain/tomochain/log"
"github.com/tomochain/tomochain/rlp"
)
@@ -395,11 +395,11 @@ func hasRightElement(node Node, key []byte) bool {
// Expect the normal case, this function can also be used to verify the following
// range proofs(note this function doesn't accept zero element proof):
//
-// - All elements proof. In this case the left and right proof can be nil, but the
-// range should be all the leaves in the trie.
+// - All elements proof. In this case the left and right proof can be nil, but the
+// range should be all the leaves in the trie.
//
-// - One element proof. In this case no matter the left edge proof is a non-existent
-// proof or not, we can always verify the correctness of the proof.
+// - One element proof. In this case no matter the left edge proof is a non-existent
+// proof or not, we can always verify the correctness of the proof.
//
// Except returning the error to indicate the proof is valid or not, the function will
// also return a flag to indicate whether there exists more accounts/slots in the trie.
@@ -419,15 +419,12 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu
// Special case, there is no edge proof at all. The given range is expected
// to be the whole leaf-set in the trie.
if firstProof == nil && lastProof == nil {
- emptytrie, err := New(common.Hash{}, NewDatabase(memorydb.New()))
- if err != nil {
- return err, false
- }
+ tr := NewStackTrie(nil)
for index, key := range keys {
- emptytrie.TryUpdate(key, values[index])
+ tr.Update(key, values[index])
}
- if emptytrie.Hash() != rootHash {
- return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, emptytrie.Hash()), false
+ if have, want := tr.Hash(), rootHash; have != want {
+ return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()), false
}
return nil, false // no more element.
}
@@ -464,9 +461,10 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu
}
// Rebuild the trie with the leave stream, the shape of trie
// should be same with the original one.
- newtrie := &Trie{root: root, Db: NewDatabase(memorydb.New())}
+
+ newtrie := &Trie{root: root, Db: NewDatabase(rawdb.NewMemoryDatabase())}
for index, key := range keys {
- newtrie.TryUpdate(key, values[index])
+ newtrie.Update(key, values[index])
}
if newtrie.Hash() != rootHash {
return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, newtrie.Hash()), false
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index f62d3d06d..cbffd559e 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -17,10 +17,9 @@
package trie
import (
- "fmt"
-
"github.com/tomochain/tomochain/common"
- "github.com/tomochain/tomochain/log"
+ "github.com/tomochain/tomochain/core/types"
+ "github.com/tomochain/tomochain/rlp"
)
// SecureTrie wraps a trie with key hashing. In a secure trie, all
@@ -35,6 +34,7 @@ import (
// SecureTrie is not safe for concurrent use.
type SecureTrie struct {
trie Trie
+ preimages *preimageStore
hashKeyBuf [common.HashLength]byte
secKeyCache map[string][]byte
secKeyCacheOwner *SecureTrie // Pointer to self, replace the key Cache on mismatch
@@ -50,7 +50,7 @@ type SecureTrie struct {
// Accessing the trie loads nodes from the database or Node pool on demand.
// Loaded nodes are kept around until their 'Cache generation' expires.
// A new Cache generation is created by each call to Commit.
-// cachelimit sets the number of past Cache generations to keep.
+// cache limit sets the number of past Cache generations to keep.
func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
if db == nil {
panic("trie.NewSecure called without a database")
@@ -59,49 +59,83 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
if err != nil {
return nil, err
}
- return &SecureTrie{trie: *trie}, nil
+ return &SecureTrie{trie: *trie, preimages: db.preimages}, nil
}
-// Get returns the value for key stored in the trie.
+// MustGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
-func (t *SecureTrie) Get(key []byte) []byte {
- res, err := t.TryGet(key)
- if err != nil {
- log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
+//
+// This function will omit any encountered error but just
+// print out an error message.
+func (t *SecureTrie) MustGet(key []byte) []byte {
+ return t.trie.MustGet(t.hashKey(key))
+}
+
+// GetStorage attempts to retrieve a storage slot with provided account address
+// and slot key. The value bytes must not be modified by the caller.
+// If the specified storage slot is not in the trie, nil will be returned.
+// If a trie node is not found in the database, a MissingNodeError is returned.
+func (t *SecureTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) {
+ enc, err := t.trie.Get(t.hashKey(key))
+ if err != nil || len(enc) == 0 {
+ return nil, err
}
- return res
+ _, content, _, err := rlp.Split(enc)
+ return content, err
}
-// TryGet returns the value for key stored in the trie.
-// The value bytes must not be modified by the caller.
-// If a Node was not found in the database, a MissingNodeError is returned.
-func (t *SecureTrie) TryGet(key []byte) ([]byte, error) {
- return t.trie.TryGet(t.hashKey(key))
+// GetAccount attempts to retrieve an account with provided account address.
+// If the specified account is not in the trie, nil will be returned.
+// If a trie node is not found in the database, a MissingNodeError is returned.
+func (t *SecureTrie) GetAccount(address common.Address) (*types.StateAccount, error) {
+ res, err := t.trie.Get(t.hashKey(address.Bytes()))
+ if res == nil || err != nil {
+ return nil, err
+ }
+ ret := new(types.StateAccount)
+ err = rlp.DecodeBytes(res, ret)
+ return ret, err
}
-// Update associates key with value in the trie. Subsequent calls to
+// GetAccountByHash does the same thing as GetAccount, however it expects an
+// account hash that is the hash of address. This constitutes an abstraction
+// leak, since the client code needs to know the key format.
+func (t *SecureTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
+ res, err := t.trie.Get(addrHash.Bytes())
+ if res == nil || err != nil {
+ return nil, err
+ }
+ ret := new(types.StateAccount)
+ err = rlp.DecodeBytes(res, ret)
+ return ret, err
+}
+
+// MustUpdate associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
-func (t *SecureTrie) Update(key, value []byte) {
- if err := t.TryUpdate(key, value); err != nil {
- log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
- }
+//
+// This function will omit any encountered error but just print out an
+// error message.
+func (t *SecureTrie) MustUpdate(key, value []byte) {
+ hk := t.hashKey(key)
+ t.trie.MustUpdate(hk, value)
+ t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
}
-// TryUpdate associates key with value in the trie. Subsequent calls to
+// UpdateStorage associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
-// If a Node was not found in the database, a MissingNodeError is returned.
-func (t *SecureTrie) TryUpdate(key, value []byte) error {
+// If a node is not found in the database, a MissingNodeError is returned.
+func (t *SecureTrie) UpdateStorage(_ common.Address, key, value []byte) error {
hk := t.hashKey(key)
- err := t.trie.TryUpdate(hk, value)
+ err := t.trie.Update(hk, value)
if err != nil {
return err
}
@@ -109,19 +143,47 @@ func (t *SecureTrie) TryUpdate(key, value []byte) error {
return nil
}
-// Delete removes any existing value for key from the trie.
-func (t *SecureTrie) Delete(key []byte) {
- if err := t.TryDelete(key); err != nil {
- log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
+// UpdateAccount will abstract the write of an account to the secure trie.
+
+func (t *SecureTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error {
+ hk := t.hashKey(address.Bytes())
+ data, err := rlp.EncodeToBytes(acc)
+ if err != nil {
+ return err
}
+ if err := t.trie.Update(hk, data); err != nil {
+ return err
+ }
+ t.getSecKeyCache()[string(hk)] = address.Bytes()
+ return nil
+}
+
+func (t *SecureTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error {
+ return nil
+}
+
+// MustDelete removes any existing value for key from the trie. This function
+// will omit any encountered error but just print out an error message.
+func (t *SecureTrie) MustDelete(key []byte) {
+ hk := t.hashKey(key)
+ delete(t.getSecKeyCache(), string(hk))
+ t.trie.MustDelete(hk)
}
-// TryDelete removes any existing value for key from the trie.
-// If a Node was not found in the database, a MissingNodeError is returned.
-func (t *SecureTrie) TryDelete(key []byte) error {
+// DeleteStorage removes any existing storage slot from the trie.
+// If the specified trie node is not in the trie, nothing will be changed.
+// If a node is not found in the database, a MissingNodeError is returned.
+func (t *SecureTrie) DeleteStorage(_ common.Address, key []byte) error {
hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk))
- return t.trie.TryDelete(hk)
+ return t.trie.Delete(hk)
+}
+
+// DeleteAccount abstracts an account deletion from the trie.
+func (t *SecureTrie) DeleteAccount(address common.Address) error {
+ hk := t.hashKey(address.Bytes())
+ delete(t.getSecKeyCache(), string(hk))
+ return t.trie.Delete(hk)
}
// GetKey returns the sha3 Preimage of a hashed key that was
@@ -130,8 +192,10 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key
}
- key, _ := t.trie.Db.Preimage(common.BytesToHash(shaKey))
- return key
+ if t.preimages == nil {
+ return nil
+ }
+ return t.preimages.preimage(common.BytesToHash(shaKey))
}
// Commit writes all nodes and the secure hash pre-images to the trie's database.
@@ -142,12 +206,15 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
// Write all the pre-images to the actual disk database
if len(t.getSecKeyCache()) > 0 {
- t.trie.Db.Lock.Lock()
- for hk, key := range t.secKeyCache {
- t.trie.Db.InsertPreimage(common.BytesToHash([]byte(hk)), key)
+ if t.preimages != nil {
+ t.trie.Db.Lock.Lock()
+ preimages := make(map[common.Hash][]byte)
+ for hk, key := range t.secKeyCache {
+ preimages[common.BytesToHash([]byte(hk))] = key
+ }
+ t.preimages.insertPreimage(preimages)
+ t.trie.Db.Lock.Unlock()
}
- t.trie.Db.Lock.Unlock()
-
t.secKeyCache = make(map[string][]byte)
}
// Commit the trie to its intermediate Node database
@@ -162,8 +229,11 @@ func (t *SecureTrie) Hash() common.Hash {
// Copy returns a copy of SecureTrie.
func (t *SecureTrie) Copy() *SecureTrie {
- cpy := *t
- return &cpy
+ return &SecureTrie{
+ trie: *t.trie.Copy(),
+ preimages: t.preimages,
+ secKeyCache: t.secKeyCache,
+ }
}
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
index a015ffcff..bc17b2ca4 100644
--- a/trie/secure_trie_test.go
+++ b/trie/secure_trie_test.go
@@ -23,19 +23,19 @@ import (
"testing"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/crypto"
- "github.com/tomochain/tomochain/ethdb/memorydb"
)
func newEmptySecure() *SecureTrie {
- trie, _ := NewSecure(common.Hash{}, NewDatabase(memorydb.New()))
+ trie, _ := NewSecure(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
return trie
}
// makeTestSecureTrie creates a large enough secure trie for testing.
func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) {
// Create an empty trie
- triedb := NewDatabase(memorydb.New())
+ triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, _ := NewSecure(common.Hash{}, triedb)
// Fill it with some arbitrary data
@@ -44,17 +44,17 @@ func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
- trie.Update(key, val)
+ trie.MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
- trie.Update(key, val)
+ trie.MustUpdate(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
- trie.Update(key, val)
+ trie.MustUpdate(key, val)
}
}
trie.Commit(nil)
@@ -77,9 +77,9 @@ func TestSecureDelete(t *testing.T) {
}
for _, val := range vals {
if val.v != "" {
- trie.Update([]byte(val.k), []byte(val.v))
+ trie.MustUpdate([]byte(val.k), []byte(val.v))
} else {
- trie.Delete([]byte(val.k))
+ trie.MustDelete([]byte(val.k))
}
}
hash := trie.Hash()
@@ -91,13 +91,13 @@ func TestSecureDelete(t *testing.T) {
func TestSecureGetKey(t *testing.T) {
trie := newEmptySecure()
- trie.Update([]byte("foo"), []byte("bar"))
+ trie.MustUpdate([]byte("foo"), []byte("bar"))
key := []byte("foo")
value := []byte("bar")
seckey := crypto.Keccak256(key)
- if !bytes.Equal(trie.Get(key), value) {
+ if !bytes.Equal(trie.MustGet(key), value) {
t.Errorf("Get did not return bar")
}
if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
@@ -125,15 +125,15 @@ func TestSecureTrieConcurrency(t *testing.T) {
for j := byte(0); j < 255; j++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
- tries[index].Update(key, val)
+ tries[index].MustUpdate(key, val)
key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
- tries[index].Update(key, val)
+ tries[index].MustUpdate(key, val)
// Add some other data to inflate the trie
for k := byte(3); k < 13; k++ {
key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
- tries[index].Update(key, val)
+ tries[index].MustUpdate(key, val)
}
}
tries[index].Commit(nil)
diff --git a/trie/stacktrie.go b/trie/stacktrie.go
new file mode 100644
index 000000000..48417e556
--- /dev/null
+++ b/trie/stacktrie.go
@@ -0,0 +1,533 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/gob"
+ "errors"
+ "io"
+ "sync"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/log"
+)
+
+var ErrCommitDisabled = errors.New("no database for committing")
+
+var stPool = sync.Pool{
+ New: func() interface{} {
+ return NewStackTrie(nil)
+ },
+}
+
+// NodeWriteFunc is used to provide all information of a dirty node for committing
+// so that callers can flush nodes into database with desired scheme.
+type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte)
+
+func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie {
+ st := stPool.Get().(*StackTrie)
+ st.owner = owner
+ st.writeFn = writeFn
+ return st
+}
+
+func returnToPool(st *StackTrie) {
+ st.Reset()
+ stPool.Put(st)
+}
+
+// StackTrie is a trie implementation that expects keys to be inserted
+// in order. Once it determines that a subtree will no longer be inserted
+// into, it will hash it and free up the memory it uses.
+type StackTrie struct {
+ owner common.Hash // the owner of the trie
+ nodeType uint8 // node type (as in branch, ext, leaf)
+ val []byte // value contained by this node if it's a leaf
+ key []byte // key chunk covered by this (leaf|ext) node
+ children [16]*StackTrie // list of children (for branch and exts)
+ writeFn NodeWriteFunc // function for committing nodes, can be nil
+}
+
+// NewStackTrie allocates and initializes an empty trie.
+func NewStackTrie(writeFn NodeWriteFunc) *StackTrie {
+ return &StackTrie{
+ nodeType: emptyNode,
+ writeFn: writeFn,
+ }
+}
+
+// NewStackTrieWithOwner allocates and initializes an empty trie, but with
+// the additional owner field.
+func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie {
+ return &StackTrie{
+ owner: owner,
+ nodeType: emptyNode,
+ writeFn: writeFn,
+ }
+}
+
+// NewFromBinary initialises a serialized stacktrie with the given db.
+func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) {
+ var st StackTrie
+ if err := st.UnmarshalBinary(data); err != nil {
+ return nil, err
+ }
+ // If a database is used, we need to recursively add it to every child
+ if writeFn != nil {
+ st.setWriter(writeFn)
+ }
+ return &st, nil
+}
+
+// MarshalBinary implements encoding.BinaryMarshaler
+func (st *StackTrie) MarshalBinary() (data []byte, err error) {
+ var (
+ b bytes.Buffer
+ w = bufio.NewWriter(&b)
+ )
+ if err := gob.NewEncoder(w).Encode(struct {
+ Owner common.Hash
+ NodeType uint8
+ Val []byte
+ Key []byte
+ }{
+ st.owner,
+ st.nodeType,
+ st.val,
+ st.key,
+ }); err != nil {
+ return nil, err
+ }
+ for _, child := range st.children {
+ if child == nil {
+ w.WriteByte(0)
+ continue
+ }
+ w.WriteByte(1)
+ if childData, err := child.MarshalBinary(); err != nil {
+ return nil, err
+ } else {
+ w.Write(childData)
+ }
+ }
+ w.Flush()
+ return b.Bytes(), nil
+}
+
+// UnmarshalBinary implements encoding.BinaryUnmarshaler
+func (st *StackTrie) UnmarshalBinary(data []byte) error {
+ r := bytes.NewReader(data)
+ return st.unmarshalBinary(r)
+}
+
+func (st *StackTrie) unmarshalBinary(r io.Reader) error {
+ var dec struct {
+ Owner common.Hash
+ NodeType uint8
+ Val []byte
+ Key []byte
+ }
+ if err := gob.NewDecoder(r).Decode(&dec); err != nil {
+ return err
+ }
+ st.owner = dec.Owner
+ st.nodeType = dec.NodeType
+ st.val = dec.Val
+ st.key = dec.Key
+
+ var hasChild = make([]byte, 1)
+ for i := range st.children {
+ if _, err := r.Read(hasChild); err != nil {
+ return err
+ } else if hasChild[0] == 0 {
+ continue
+ }
+ var child StackTrie
+ if err := child.unmarshalBinary(r); err != nil {
+ return err
+ }
+ st.children[i] = &child
+ }
+ return nil
+}
+
+func (st *StackTrie) setWriter(writeFn NodeWriteFunc) {
+ st.writeFn = writeFn
+ for _, child := range st.children {
+ if child != nil {
+ child.setWriter(writeFn)
+ }
+ }
+}
+
+func newLeaf(owner common.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie {
+ st := stackTrieFromPool(writeFn, owner)
+ st.nodeType = leafNode
+ st.key = append(st.key, key...)
+ st.val = val
+ return st
+}
+
+func newExt(owner common.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie {
+ st := stackTrieFromPool(writeFn, owner)
+ st.nodeType = extNode
+ st.key = append(st.key, key...)
+ st.children[0] = child
+ return st
+}
+
+// List all values that StackTrie#nodeType can hold
+const (
+ emptyNode = iota
+ branchNode
+ extNode
+ leafNode
+ hashedNode
+)
+
+// Update inserts a (key, value) pair into the stack trie.
+func (st *StackTrie) Update(key, value []byte) error {
+ k := keybytesToHex(key)
+ if len(value) == 0 {
+ panic("deletion not supported")
+ }
+ st.insert(k[:len(k)-1], value, nil)
+ return nil
+}
+
+// MustUpdate is a wrapper of Update and will omit any encountered error but
+// just print out an error message.
+func (st *StackTrie) MustUpdate(key, value []byte) {
+ if err := st.Update(key, value); err != nil {
+ log.Error("Unhandled trie error in StackTrie.Update", "err", err)
+ }
+}
+
+func (st *StackTrie) Reset() {
+ st.owner = common.Hash{}
+ st.writeFn = nil
+ st.key = st.key[:0]
+ st.val = nil
+ for i := range st.children {
+ st.children[i] = nil
+ }
+ st.nodeType = emptyNode
+}
+
+// Helper function that, given a full key, determines the index
+// at which the chunk pointed by st.keyOffset is different from
+// the same chunk in the full key.
+func (st *StackTrie) getDiffIndex(key []byte) int {
+ for idx, nibble := range st.key {
+ if nibble != key[idx] {
+ return idx
+ }
+ }
+ return len(st.key)
+}
+
+// Helper function to that inserts a (key, value) pair into
+// the trie.
+func (st *StackTrie) insert(key, value []byte, prefix []byte) {
+ switch st.nodeType {
+ case branchNode: /* Branch */
+ idx := int(key[0])
+
+ // Unresolve elder siblings
+ for i := idx - 1; i >= 0; i-- {
+ if st.children[i] != nil {
+ if st.children[i].nodeType != hashedNode {
+ st.children[i].hash(append(prefix, byte(i)))
+ }
+ break
+ }
+ }
+
+ // Add new child
+ if st.children[idx] == nil {
+ st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn)
+ } else {
+ st.children[idx].insert(key[1:], value, append(prefix, key[0]))
+ }
+
+ case extNode: /* Ext */
+ // Compare both key chunks and see where they differ
+ diffidx := st.getDiffIndex(key)
+
+ // Check if chunks are identical. If so, recurse into
+ // the child node. Otherwise, the key has to be split
+ // into 1) an optional common prefix, 2) the fullnode
+ // representing the two differing path, and 3) a leaf
+ // for each of the differentiated subtrees.
+ if diffidx == len(st.key) {
+ // Ext key and key segment are identical, recurse into
+ // the child node.
+ st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...))
+ return
+ }
+ // Save the original part. Depending if the break is
+ // at the extension's last byte or not, create an
+ // intermediate extension or use the extension's child
+ // node directly.
+ var n *StackTrie
+ if diffidx < len(st.key)-1 {
+ // Break on the non-last byte, insert an intermediate
+ // extension. The path prefix of the newly-inserted
+ // extension should also contain the different byte.
+ n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn)
+ n.hash(append(prefix, st.key[:diffidx+1]...))
+ } else {
+ // Break on the last byte, no need to insert
+ // an extension node: reuse the current node.
+ // The path prefix of the original part should
+ // still be same.
+ n = st.children[0]
+ n.hash(append(prefix, st.key...))
+ }
+ var p *StackTrie
+ if diffidx == 0 {
+ // the break is on the first byte, so
+ // the current node is converted into
+ // a branch node.
+ st.children[0] = nil
+ p = st
+ st.nodeType = branchNode
+ } else {
+ // the common prefix is at least one byte
+ // long, insert a new intermediate branch
+ // node.
+ st.children[0] = stackTrieFromPool(st.writeFn, st.owner)
+ st.children[0].nodeType = branchNode
+ p = st.children[0]
+ }
+ // Create a leaf for the inserted part
+ o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn)
+
+ // Insert both child leaves where they belong:
+ origIdx := st.key[diffidx]
+ newIdx := key[diffidx]
+ p.children[origIdx] = n
+ p.children[newIdx] = o
+ st.key = st.key[:diffidx]
+
+ case leafNode: /* Leaf */
+ // Compare both key chunks and see where they differ
+ diffidx := st.getDiffIndex(key)
+
+ // Overwriting a key isn't supported, which means that
+ // the current leaf is expected to be split into 1) an
+ // optional extension for the common prefix of these 2
+ // keys, 2) a fullnode selecting the path on which the
+ // keys differ, and 3) one leaf for the differentiated
+ // component of each key.
+ if diffidx >= len(st.key) {
+ panic("Trying to insert into existing key")
+ }
+
+ // Check if the split occurs at the first nibble of the
+ // chunk. In that case, no prefix extnode is necessary.
+ // Otherwise, create that
+ var p *StackTrie
+ if diffidx == 0 {
+ // Convert current leaf into a branch
+ st.nodeType = branchNode
+ p = st
+ st.children[0] = nil
+ } else {
+ // Convert current node into an ext,
+ // and insert a child branch node.
+ st.nodeType = extNode
+ st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner)
+ st.children[0].nodeType = branchNode
+ p = st.children[0]
+ }
+
+ // Create the two child leaves: one containing the original
+ // value and another containing the new value. The child leaf
+ // is hashed directly in order to free up some memory.
+ origIdx := st.key[diffidx]
+ p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn)
+ p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...))
+
+ newIdx := key[diffidx]
+ p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn)
+
+ // Finally, cut off the key part that has been passed
+ // over to the children.
+ st.key = st.key[:diffidx]
+ st.val = nil
+
+ case emptyNode: /* Empty */
+ st.nodeType = leafNode
+ st.key = key
+ st.val = value
+
+ case hashedNode:
+ panic("trying to insert into hash")
+
+ default:
+ panic("invalid type")
+ }
+}
+
+// hash converts st into a 'hashedNode', if possible. Possible outcomes:
+//
+// 1. The rlp-encoded value was >= 32 bytes:
+// - Then the 32-byte `hash` will be accessible in `st.val`.
+// - And the 'st.type' will be 'hashedNode'
+//
+// 2. The rlp-encoded value was < 32 bytes
+// - Then the <32 byte rlp-encoded value will be accessible in 'st.val'.
+// - And the 'st.type' will be 'hashedNode' AGAIN
+//
+// This method also sets 'st.type' to hashedNode, and clears 'st.key'.
+func (st *StackTrie) hash(path []byte) {
+ h := newHasher(false)
+ defer returnHasherToPool(h)
+
+ st.hashRec(h, path)
+}
+
+func (st *StackTrie) hashRec(hasher *hasher, path []byte) {
+ // The switch below sets this to the RLP-encoding of this node.
+ var encodedNode []byte
+
+ switch st.nodeType {
+ case hashedNode:
+ return
+
+ case emptyNode:
+ st.val = emptyRoot.Bytes()
+ st.key = st.key[:0]
+ st.nodeType = hashedNode
+ return
+
+ case branchNode:
+ var nodes FullNode
+ for i, child := range st.children {
+ if child == nil {
+ nodes.Children[i] = nilValueNode
+ continue
+ }
+ child.hashRec(hasher, append(path, byte(i)))
+ if len(child.val) < 32 {
+ nodes.Children[i] = rawNode(child.val)
+ } else {
+ nodes.Children[i] = HashNode(child.val)
+ }
+
+ // Release child back to pool.
+ st.children[i] = nil
+ returnToPool(child)
+ }
+
+ nodes.encode(hasher.encbuf)
+ encodedNode = hasher.encodedBytes()
+
+ case extNode:
+ st.children[0].hashRec(hasher, append(path, st.key...))
+
+ n := ShortNode{Key: hexToCompact(st.key)}
+ if len(st.children[0].val) < 32 {
+ n.Val = rawNode(st.children[0].val)
+ } else {
+ n.Val = HashNode(st.children[0].val)
+ }
+
+ n.encode(hasher.encbuf)
+ encodedNode = hasher.encodedBytes()
+
+ // Release child back to pool.
+ returnToPool(st.children[0])
+ st.children[0] = nil
+
+ case leafNode:
+ st.key = append(st.key, byte(16))
+ n := ShortNode{Key: hexToCompact(st.key), Val: ValueNode(st.val)}
+
+ n.encode(hasher.encbuf)
+ encodedNode = hasher.encodedBytes()
+
+ default:
+ panic("invalid node type")
+ }
+
+ st.nodeType = hashedNode
+ st.key = st.key[:0]
+ if len(encodedNode) < 32 {
+ st.val = common.CopyBytes(encodedNode)
+ return
+ }
+
+ // Write the hash to the 'val'. We allocate a new val here to not mutate
+ // input values
+ st.val = hasher.hashData(encodedNode)
+ if st.writeFn != nil {
+ st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode)
+ }
+}
+
+// Hash returns the hash of the current node.
+func (st *StackTrie) Hash() (h common.Hash) {
+ hasher := newHasher(false)
+ defer returnHasherToPool(hasher)
+
+ st.hashRec(hasher, nil)
+ if len(st.val) == 32 {
+ copy(h[:], st.val)
+ return h
+ }
+ // If the node's RLP isn't 32 bytes long, the node will not
+ // be hashed, and instead contain the rlp-encoding of the
+ // node. For the top level node, we need to force the hashing.
+ hasher.sha.Reset()
+ hasher.sha.Write(st.val)
+ hasher.sha.Read(h[:])
+ return h
+}
+
+// Commit will firstly hash the entire trie if it's still not hashed
+// and then commit all nodes to the associated database. Actually most
+// of the trie nodes MAY have been committed already. The main purpose
+// here is to commit the root node.
+//
+// The associated database is expected, otherwise the whole commit
+// functionality should be disabled.
+func (st *StackTrie) Commit() (h common.Hash, err error) {
+ if st.writeFn == nil {
+ return common.Hash{}, ErrCommitDisabled
+ }
+ hasher := newHasher(false)
+ defer returnHasherToPool(hasher)
+
+ st.hashRec(hasher, nil)
+ if len(st.val) == 32 {
+ copy(h[:], st.val)
+ return h, nil
+ }
+ // If the node's RLP isn't 32 bytes long, the node will not
+ // be hashed (and committed), and instead contain the rlp-encoding of the
+ // node. For the top level node, we need to force the hashing+commit.
+ hasher.sha.Reset()
+ hasher.sha.Write(st.val)
+ hasher.sha.Read(h[:])
+
+ st.writeFn(st.owner, nil, h, st.val)
+ return h, nil
+}
diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go
new file mode 100644
index 000000000..dd5206c87
--- /dev/null
+++ b/trie/stacktrie_test.go
@@ -0,0 +1,413 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "math/big"
+ "testing"
+
+ "github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
+ "github.com/tomochain/tomochain/crypto"
+)
+
+func TestStackTrieInsertAndHash(t *testing.T) {
+ type KeyValueHash struct {
+ K string // Hex string for key.
+ V string // Value, directly converted to bytes.
+ H string // Expected root hash after insert of (K, V) to an existing trie.
+ }
+ tests := [][]KeyValueHash{
+ { // {0:0, 7:0, f:0}
+ {"00", "v_______________________0___0", "5cb26357b95bb9af08475be00243ceb68ade0b66b5cd816b0c18a18c612d2d21"},
+ {"70", "v_______________________0___1", "8ff64309574f7a437a7ad1628e690eb7663cfde10676f8a904a8c8291dbc1603"},
+ {"f0", "v_______________________0___2", "9e3a01bd8d43efb8e9d4b5506648150b8e3ed1caea596f84ee28e01a72635470"},
+ },
+ { // {1:0cc, e:{1:fc, e:fc}}
+ {"10cc", "v_______________________1___0", "233e9b257843f3dfdb1cce6676cdaf9e595ac96ee1b55031434d852bc7ac9185"},
+ {"e1fc", "v_______________________1___1", "39c5e908ae83d0c78520c7c7bda0b3782daf594700e44546e93def8f049cca95"},
+ {"eefc", "v_______________________1___2", "d789567559fd76fe5b7d9cc42f3750f942502ac1c7f2a466e2f690ec4b6c2a7c"},
+ },
+ { // {b:{a:ac, b:ac}, d:acc}
+ {"baac", "v_______________________2___0", "8be1c86ba7ec4c61e14c1a9b75055e0464c2633ae66a055a24e75450156a5d42"},
+ {"bbac", "v_______________________2___1", "8495159b9895a7d88d973171d737c0aace6fe6ac02a4769fff1bc43bcccce4cc"},
+ {"dacc", "v_______________________2___2", "9bcfc5b220a27328deb9dc6ee2e3d46c9ebc9c69e78acda1fa2c7040602c63ca"},
+ },
+ { // {0:0cccc, 2:456{0:0, 2:2}
+ {"00cccc", "v_______________________3___0", "e57dc2785b99ce9205080cb41b32ebea7ac3e158952b44c87d186e6d190a6530"},
+ {"245600", "v_______________________3___1", "0335354adbd360a45c1871a842452287721b64b4234dfe08760b243523c998db"},
+ {"245622", "v_______________________3___2", "9e6832db0dca2b5cf81c0e0727bfde6afc39d5de33e5720bccacc183c162104e"},
+ },
+ { // {1:4567{1:1c, 3:3c}, 3:0cccccc}
+ {"1456711c", "v_______________________4___0", "f2389e78d98fed99f3e63d6d1623c1d4d9e8c91cb1d585de81fbc7c0e60d3529"},
+ {"1456733c", "v_______________________4___1", "101189b3fab852be97a0120c03d95eefcf984d3ed639f2328527de6def55a9c0"},
+ {"30cccccc", "v_______________________4___2", "3780ce111f98d15751dfde1eb21080efc7d3914b429e5c84c64db637c55405b3"},
+ },
+ { // 8800{1:f, 2:e, 3:d}
+ {"88001f", "v_______________________5___0", "e817db50d84f341d443c6f6593cafda093fc85e773a762421d47daa6ac993bd5"},
+ {"88002e", "v_______________________5___1", "d6e3e6047bdc110edd296a4d63c030aec451bee9d8075bc5a198eee8cda34f68"},
+ {"88003d", "v_______________________5___2", "b6bdf8298c703342188e5f7f84921a402042d0e5fb059969dd53a6b6b1fb989e"},
+ },
+ { // 0{1:fc, 2:ec, 4:dc}
+ {"01fc", "v_______________________6___0", "693268f2ca80d32b015f61cd2c4dba5a47a6b52a14c34f8e6945fad684e7a0d5"},
+ {"02ec", "v_______________________6___1", "e24ddd44469310c2b785a2044618874bf486d2f7822603a9b8dce58d6524d5de"},
+ {"04dc", "v_______________________6___2", "33fc259629187bbe54b92f82f0cd8083b91a12e41a9456b84fc155321e334db7"},
+ },
+ { // f{0:fccc, f:ff{0:f, f:f}}
+ {"f0fccc", "v_______________________7___0", "b0966b5aa469a3e292bc5fcfa6c396ae7a657255eef552ea7e12f996de795b90"},
+ {"ffff0f", "v_______________________7___1", "3b1ca154ec2a3d96d8d77bddef0abfe40a53a64eb03cecf78da9ec43799fa3d0"},
+ {"ffffff", "v_______________________7___2", "e75463041f1be8252781be0ace579a44ea4387bf5b2739f4607af676f7719678"},
+ },
+ { // ff{0:f{0:f, f:f}, f:fcc}
+ {"ff0f0f", "v_______________________8___0", "0928af9b14718ec8262ab89df430f1e5fbf66fac0fed037aff2b6767ae8c8684"},
+ {"ff0fff", "v_______________________8___1", "d870f4d3ce26b0bf86912810a1960693630c20a48ba56be0ad04bc3e9ddb01e6"},
+ {"ffffcc", "v_______________________8___2", "4239f10dd9d9915ecf2e047d6a576bdc1733ed77a30830f1bf29deaf7d8e966f"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"123f", "x___________________________2", "1164d7299964e74ac40d761f9189b2a3987fae959800d0f7e29d3aaf3eae9e15"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"124a", "x___________________________2", "661a96a669869d76b7231380da0649d013301425fbea9d5c5fae6405aa31cfce"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"13aa", "x___________________________2", "6590120e1fd3ffd1a90e8de5bb10750b61079bb0776cca4414dd79a24e4d4356"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"2aaa", "x___________________________2", "f869b40e0c55eace1918332ef91563616fbf0755e2b946119679f7ef8e44b514"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"1234fa", "x___________________________2", "4f4e368ab367090d5bc3dbf25f7729f8bd60df84de309b4633a6b69ab66142c0"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"1235aa", "x___________________________2", "21840121d11a91ac8bbad9a5d06af902a5c8d56a47b85600ba813814b7bfcb9b"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"124aaa", "x___________________________2", "ea4040ddf6ae3fbd1524bdec19c0ab1581015996262006632027fa5cf21e441e"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"13aaaa", "x___________________________2", "e4beb66c67e44f2dd8ba36036e45a44ff68f8d52942472b1911a45f886a34507"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"2aaaaa", "x___________________________2", "5f5989b820ff5d76b7d49e77bb64f26602294f6c42a1a3becc669cd9e0dc8ec9"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"1234fa", "x___________________________3", "65bb3aafea8121111d693ffe34881c14d27b128fd113fa120961f251fe28428d"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"1235aa", "x___________________________3", "f670e4d2547c533c5f21e0045442e2ecb733f347ad6d29ef36e0f5ba31bb11a8"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"124aaa", "x___________________________3", "c17464123050a9a6f29b5574bb2f92f6d305c1794976b475b7fb0316b6335598"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"13aaaa", "x___________________________3", "aa8301be8cb52ea5cd249f5feb79fb4315ee8de2140c604033f4b3fff78f0105"},
+ },
+ {
+ {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"},
+ {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"},
+ {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"},
+ {"123f", "x___________________________3", "80f7bad1893ca57e3443bb3305a517723a74d3ba831bcaca22a170645eb7aafb"},
+ },
+ {
+ {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"},
+ {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"},
+ {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"},
+ {"124a", "x___________________________3", "383bc1bb4f019e6bc4da3751509ea709b58dd1ac46081670834bae072f3e9557"},
+ },
+ {
+ {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"},
+ {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"},
+ {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"},
+ {"13aa", "x___________________________3", "ff0dc70ce2e5db90ee42a4c2ad12139596b890e90eb4e16526ab38fa465b35cf"},
+ },
+ }
+ st := NewStackTrie(nil)
+ for i, test := range tests {
+ // The StackTrie does not allow Insert(), Hash(), Insert(), ...
+ // so we will create new trie for every sequence length of inserts.
+ for l := 1; l <= len(test); l++ {
+ st.Reset()
+ for j := 0; j < l; j++ {
+ kv := &test[j]
+ if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil {
+ t.Fatal(err)
+ }
+ }
+ expected := common.HexToHash(test[l-1].H)
+ if h := st.Hash(); h != expected {
+ t.Errorf("%d(%d): root hash mismatch: %x, expected %x", i, l, h, expected)
+ }
+ }
+ }
+}
+
+func TestSizeBug(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase()))
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
+ value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
+
+ nt.Update(leaf, value)
+ st.Update(leaf, value)
+
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+func TestEmptyBug(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase()))
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
+ //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "9496f4ec2bf9dab484cac6be589e8417d84781be08"},
+ {K: "40edb63a35fcf86c08022722aa3287cdd36440d671b4918131b2514795fefa9c", V: "01"},
+ {K: "b10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6", V: "947a30f7736e48d6599356464ba4c150d8da0302ff"},
+ {K: "c2575a0e9e593c00f959f8c92f12db2869c3395a3b0502d05e2516446f71f85b", V: "02"},
+ }
+
+ for _, kv := range kvs {
+ nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ }
+
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+func TestValLength56(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase()))
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
+ //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"},
+ }
+
+ for _, kv := range kvs {
+ nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ }
+
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+// TestUpdateSmallNodes tests a case where the leaves are small (both key and value),
+// which causes a lot of node-within-node. This case was found via fuzzing.
+func TestUpdateSmallNodes(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase()))
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {"63303030", "3041"}, // stacktrie.Update
+ {"65", "3000"}, // stacktrie.Update
+ }
+ for _, kv := range kvs {
+ nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ }
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+// TestUpdateVariableKeys contains a case which stacktrie fails: when keys of different
+// sizes are used, and the second one has the same prefix as the first, then the
+// stacktrie fails, since it's unable to 'expand' on an already added leaf.
+// For all practical purposes, this is fine, since keys are fixed-size length
+// in account and storage tries.
+//
+// The test is marked as 'skipped', and exists just to have the behaviour documented.
+// This case was found via fuzzing.
+func TestUpdateVariableKeys(t *testing.T) {
+ t.SkipNow()
+ st := NewStackTrie(nil)
+ nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase()))
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {"0x33303534636532393561313031676174", "303030"},
+ {"0x3330353463653239356131303167617430", "313131"},
+ }
+ for _, kv := range kvs {
+ nt.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ st.Update(common.FromHex(kv.K), common.FromHex(kv.V))
+ }
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+// TestStacktrieNotModifyValues checks that inserting blobs of data into the
+// stacktrie does not mutate the blobs
+func TestStacktrieNotModifyValues(t *testing.T) {
+ st := NewStackTrie(nil)
+ { // Test a very small trie
+ // Give it the value as a slice with large backing alloc,
+ // so if the stacktrie tries to append, it won't have to realloc
+ value := make([]byte, 1, 100)
+ value[0] = 0x2
+ want := common.CopyBytes(value)
+ st.Update([]byte{0x01}, value)
+ st.Hash()
+ if have := value; !bytes.Equal(have, want) {
+ t.Fatalf("tiny trie: have %#x want %#x", have, want)
+ }
+ st = NewStackTrie(nil)
+ }
+ // Test with a larger trie
+ keyB := big.NewInt(1)
+ keyDelta := big.NewInt(1)
+ var vals [][]byte
+ getValue := func(i int) []byte {
+ if i%2 == 0 { // large
+ return crypto.Keccak256(big.NewInt(int64(i)).Bytes())
+ } else { //small
+ return big.NewInt(int64(i)).Bytes()
+ }
+ }
+ for i := 0; i < 1000; i++ {
+ key := common.BigToHash(keyB)
+ value := getValue(i)
+ st.Update(key.Bytes(), value)
+ vals = append(vals, value)
+ keyB = keyB.Add(keyB, keyDelta)
+ keyDelta.Add(keyDelta, common.Big1)
+ }
+ st.Hash()
+ for i := 0; i < 1000; i++ {
+ want := getValue(i)
+
+ have := vals[i]
+ if !bytes.Equal(have, want) {
+ t.Fatalf("item %d, have %#x want %#x", i, have, want)
+ }
+ }
+}
+
+// TestStacktrieSerialization tests that the stacktrie works well if we
+// serialize/unserialize it a lot
+func TestStacktrieSerialization(t *testing.T) {
+ var (
+ st = NewStackTrie(nil)
+ keyB = big.NewInt(1)
+ keyDelta = big.NewInt(1)
+ vals [][]byte
+ keys [][]byte
+ )
+ nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase()))
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ getValue := func(i int) []byte {
+ if i%2 == 0 { // large
+ return crypto.Keccak256(big.NewInt(int64(i)).Bytes())
+ } else { //small
+ return big.NewInt(int64(i)).Bytes()
+ }
+ }
+ for i := 0; i < 10; i++ {
+ vals = append(vals, getValue(i))
+ keys = append(keys, common.BigToHash(keyB).Bytes())
+ keyB = keyB.Add(keyB, keyDelta)
+ keyDelta.Add(keyDelta, common.Big1)
+ }
+ for i, k := range keys {
+ nt.Update(k, common.CopyBytes(vals[i]))
+ }
+
+ for i, k := range keys {
+ blob, err := st.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ newSt, err := NewFromBinary(blob, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ st = newSt
+ st.Update(k, common.CopyBytes(vals[i]))
+ }
+ if have, want := st.Hash(), nt.Hash(); have != want {
+ t.Fatalf("have %#x want %#x", have, want)
+ }
+}
diff --git a/trie/sync_test.go b/trie/sync_test.go
index b7627054a..25baa5c67 100644
--- a/trie/sync_test.go
+++ b/trie/sync_test.go
@@ -21,13 +21,14 @@ import (
"testing"
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/ethdb/memorydb"
)
// makeTestTrie create a sample test trie to test Node-wise reconstruction.
func makeTestTrie() (*Database, *Trie, map[string][]byte) {
// Create an empty trie
- triedb := NewDatabase(memorydb.New())
+ triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, _ := New(common.Hash{}, triedb)
// Fill it with some arbitrary data
@@ -67,7 +68,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri
t.Fatalf("inconsistent trie at %x: %v", root, err)
}
for key, val := range content {
- if have := trie.Get([]byte(key)); !bytes.Equal(have, val) {
+ if have, _ := trie.Get([]byte(key)); !bytes.Equal(have, val) {
t.Errorf("entry %x: content mismatch: have %x, want %x", key, have, val)
}
}
@@ -88,8 +89,8 @@ func checkTrieConsistency(db *Database, root common.Hash) error {
// Tests that an empty trie is not scheduled for syncing.
func TestEmptySync(t *testing.T) {
- dbA := NewDatabase(memorydb.New())
- dbB := NewDatabase(memorydb.New())
+ dbA := NewDatabase(rawdb.NewMemoryDatabase())
+ dbB := NewDatabase(rawdb.NewMemoryDatabase())
emptyA, _ := New(common.Hash{}, dbA)
emptyB, _ := New(emptyRoot, dbB)
@@ -110,7 +111,7 @@ func testIterativeSync(t *testing.T, count int) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb))
@@ -145,7 +146,7 @@ func TestIterativeDelayedSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb))
@@ -185,7 +186,7 @@ func testIterativeRandomSync(t *testing.T, count int) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb))
@@ -228,7 +229,7 @@ func TestIterativeRandomDelayedSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb))
@@ -277,7 +278,7 @@ func TestDuplicateAvoidanceSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb))
@@ -319,7 +320,7 @@ func TestIncompleteSync(t *testing.T) {
srcDb, srcTrie, _ := makeTestTrie()
// Create a destination trie and sync with the scheduler
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb))
diff --git a/trie/trie.go b/trie/trie.go
index 589a96186..9df6e5655 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -82,35 +82,45 @@ func New(root common.Hash, db *Database) (*Trie, error) {
return trie, nil
}
+// Copy returns a copy of Trie.
+func (t *Trie) Copy() *Trie {
+ return &Trie{
+ Db: t.Db,
+ root: t.root,
+ unhashed: t.unhashed,
+ }
+}
+
// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at
// the key after the given start key.
func (t *Trie) NodeIterator(start []byte) NodeIterator {
return newNodeIterator(t, start)
}
-// Get returns the value for key stored in the trie.
-// The value bytes must not be modified by the caller.
-func (t *Trie) Get(key []byte) []byte {
- res, err := t.TryGet(key)
+// MustGet is a wrapper of Get and will omit any encountered error but just
+// print out an error message.
+func (t *Trie) MustGet(key []byte) []byte {
+ res, err := t.Get(key)
if err != nil {
- log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
+ log.Error("Unhandled trie error in Trie.Get", "err", err)
}
return res
}
-// TryGet returns the value for key stored in the trie.
+// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
-// If a Node was not found in the database, a MissingNodeError is returned.
-func (t *Trie) TryGet(key []byte) ([]byte, error) {
- key = keybytesToHex(key)
- value, newroot, didResolve, err := t.tryGet(t.root, key, 0)
+//
+// If the requested node is not present in trie, no error will be returned.
+// If the trie is corrupted, a MissingNodeError is returned.
+func (t *Trie) Get(key []byte) ([]byte, error) {
+ value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0)
if err == nil && didResolve {
t.root = newroot
}
return value, err
}
-func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode Node, didResolve bool, err error) {
+func (t *Trie) get(origNode Node, key []byte, pos int) (value []byte, newnode Node, didResolve bool, err error) {
switch n := (origNode).(type) {
case nil:
return nil, nil, false, nil
@@ -121,14 +131,14 @@ func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode
// key not found in trie
return nil, n, false, nil
}
- value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
+ value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key))
if err == nil && didResolve {
n = n.copy()
n.Val = newnode
}
return value, n, didResolve, err
case *FullNode:
- value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
+ value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
n = n.copy()
n.Children[key[pos]] = newnode
@@ -139,10 +149,10 @@ func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode
if err != nil {
return nil, n, true, err
}
- value, newnode, _, err := t.tryGet(child, key, pos)
+ value, newnode, _, err := t.get(child, key, pos)
return value, newnode, true, err
default:
- panic(fmt.Sprintf("%T: invalid Node: %v", origNode, origNode))
+ panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
}
}
@@ -310,27 +320,28 @@ func (t *Trie) tryGetBestRightKeyAndValue(origNode Node, prefix []byte) (key []b
return nil, nil, nil, false, fmt.Errorf("%T: invalid Node: %v", origNode, origNode)
}
-// Update associates key with value in the trie. Subsequent calls to
-// Get will return value. If value has length zero, any existing value
-// is deleted from the trie and calls to Get will return nil.
-//
-// The value bytes must not be modified by the caller while they are
-// stored in the trie.
-func (t *Trie) Update(key, value []byte) {
- if err := t.TryUpdate(key, value); err != nil {
- log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
+// MustUpdate is a wrapper of Update and will omit any encountered error but
+// just print out an error message.
+func (t *Trie) MustUpdate(key, value []byte) {
+ if err := t.Update(key, value); err != nil {
+ log.Error("Unhandled trie error in Trie.Update", "err", err)
}
}
-// TryUpdate associates key with value in the trie. Subsequent calls to
+// Update associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
-// If a Node was not found in the database, a MissingNodeError is returned.
-func (t *Trie) TryUpdate(key, value []byte) error {
+// If the requested node is not present in trie, no error will be returned.
+// If the trie is corrupted, a MissingNodeError is returned.
+func (t *Trie) Update(key, value []byte) error {
+ return t.update(key, value)
+}
+
+func (t *Trie) update(key, value []byte) error {
t.unhashed++
k := keybytesToHex(key)
if len(value) != 0 {
@@ -418,16 +429,19 @@ func (t *Trie) insert(n Node, prefix, key []byte, value Node) (bool, Node, error
}
}
-// Delete removes any existing value for key from the trie.
-func (t *Trie) Delete(key []byte) {
- if err := t.TryDelete(key); err != nil {
- log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
+// MustDelete is a wrapper of Delete and will omit any encountered error but
+// just print out an error message.
+func (t *Trie) MustDelete(key []byte) {
+ if err := t.Delete(key); err != nil {
+ log.Error("Unhandled trie error in Trie.Delete", "err", err)
}
}
-// TryDelete removes any existing value for key from the trie.
-// If a Node was not found in the database, a MissingNodeError is returned.
-func (t *Trie) TryDelete(key []byte) error {
+// Delete removes any existing value for key from the trie.
+//
+// If the requested node is not present in trie, no error will be returned.
+// If the trie is corrupted, a MissingNodeError is returned.
+func (t *Trie) Delete(key []byte) error {
t.unhashed++
k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k)
@@ -462,8 +476,8 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) {
switch child := child.(type) {
case *ShortNode:
// Deleting from the subtrie reduced it to another
- // short Node. Merge the nodes to avoid creating a
- // ShortNode{..., ShortNode{...}}. Use concat (which
+ // short node. Merge the nodes to avoid creating a
+ // shortNode{..., shortNode{...}}. Use concat (which
// always creates a new slice) instead of append to
// avoid modifying n.Key since it might be shared with
// other nodes.
@@ -481,10 +495,18 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) {
n.flags = t.newFlag()
n.Children[key[0]] = nn
+ // Because n is a full node, it must've contained at least two children
+ // before the delete operation. If the new child value is non-nil, n still
+ // has at least two children after the deletion, and cannot be reduced to
+ // a short node.
+ if nn != nil {
+ return true, n, nil
+ }
+ // Reduction:
// Check how many non-nil entries are left after deleting and
- // reduce the full Node to a short Node if only one entry is
+ // reduce the full node to a short node if only one entry is
// left. Since n must've contained at least two children
- // before deletion (otherwise it would not be a full Node) n
+ // before deletion (otherwise it would not be a full node) n
// can never be reduced to nil.
//
// When the loop is done, pos contains the index of the single
@@ -503,10 +525,10 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) {
}
if pos >= 0 {
if pos != 16 {
- // If the remaining entry is a short Node, it replaces
+ // If the remaining entry is a short node, it replaces
// n and its key gets the missing nibble tacked to the
// front. This avoids creating an invalid
- // ShortNode{..., ShortNode{...}}. Since the entry
+ // shortNode{..., shortNode{...}}. Since the entry
// might not be loaded yet, resolve it just for this
// check.
cnode, err := t.resolve(n.Children[pos], prefix)
@@ -518,7 +540,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) {
return true, &ShortNode{k, cnode.Val, t.newFlag()}, nil
}
}
- // Otherwise, n is replaced by a one-nibble short Node
+ // Otherwise, n is replaced by a one-nibble short node
// containing the child.
return true, &ShortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
}
@@ -533,7 +555,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) {
case HashNode:
// We've hit a part of the trie that isn't loaded yet. Load
- // the Node and delete from it. This leaves all child nodes on
+ // the node and delete from it. This leaves all child nodes on
// the path to the value in the trie.
rn, err := t.resolveHash(n, prefix)
if err != nil {
@@ -546,7 +568,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) {
return true, nn, nil
default:
- panic(fmt.Sprintf("%T: invalid Node: %v (%v)", n, n, key))
+ panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
}
}
@@ -637,3 +659,9 @@ func (t *Trie) hashRoot(db *Database) (Node, Node, error) {
t.unhashed = 0
return hashed, cached, nil
}
+
+// Reset drops the referenced root node and cleans all internal state.
+func (t *Trie) Reset() {
+ t.root = nil
+ t.unhashed = 0
+}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 8087a4a8a..fdfcf4858 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -29,10 +29,11 @@ import (
"testing/quick"
"github.com/davecgh/go-spew/spew"
+
"github.com/tomochain/tomochain/common"
+ "github.com/tomochain/tomochain/core/rawdb"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/ethdb/leveldb"
- "github.com/tomochain/tomochain/ethdb/memorydb"
"github.com/tomochain/tomochain/rlp"
)
@@ -43,7 +44,7 @@ func init() {
// Used for testing
func newEmpty() *Trie {
- trie, _ := New(common.Hash{}, NewDatabase(memorydb.New()))
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
return trie
}
@@ -61,13 +62,13 @@ func TestNull(t *testing.T) {
key := make([]byte, 32)
value := []byte("test")
trie.Update(key, value)
- if !bytes.Equal(trie.Get(key), value) {
+ if !bytes.Equal(trie.MustGet(key), value) {
t.Fatal("wrong value")
}
}
func TestMissingRoot(t *testing.T) {
- trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(memorydb.New()))
+ trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(rawdb.NewMemoryDatabase()))
if trie != nil {
t.Error("New returned non-nil trie for invalid root")
}
@@ -80,7 +81,7 @@ func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) }
func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
func testMissingNode(t *testing.T, memonly bool) {
- diskdb := memorydb.New()
+ diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
trie, _ := New(common.Hash{}, triedb)
@@ -92,27 +93,27 @@ func testMissingNode(t *testing.T, memonly bool) {
}
trie, _ = New(root, triedb)
- _, err := trie.TryGet([]byte("120000"))
+ _, err := trie.Get([]byte("120000"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = New(root, triedb)
- _, err = trie.TryGet([]byte("120099"))
+ _, err = trie.Get([]byte("120099"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = New(root, triedb)
- _, err = trie.TryGet([]byte("123456"))
+ _, err = trie.Get([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = New(root, triedb)
- err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
+ err = trie.Update([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = New(root, triedb)
- err = trie.TryDelete([]byte("123456"))
+ err = trie.Delete([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
@@ -125,27 +126,27 @@ func testMissingNode(t *testing.T, memonly bool) {
}
trie, _ = New(root, triedb)
- _, err = trie.TryGet([]byte("120000"))
+ _, err = trie.Get([]byte("120000"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = New(root, triedb)
- _, err = trie.TryGet([]byte("120099"))
+ _, err = trie.Get([]byte("120099"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = New(root, triedb)
- _, err = trie.TryGet([]byte("123456"))
+ _, err = trie.Get([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = New(root, triedb)
- err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
+ err = trie.Update([]byte("120099"), []byte("zxcv"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = New(root, triedb)
- err = trie.TryDelete([]byte("123456"))
+ err = trie.Delete([]byte("123456"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
@@ -403,7 +404,7 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
}
func runRandTest(rt randTest) bool {
- triedb := NewDatabase(memorydb.New())
+ triedb := NewDatabase(rawdb.NewMemoryDatabase())
tr, _ := New(common.Hash{}, triedb)
values := make(map[string]string) // tracks content of the trie
@@ -419,7 +420,7 @@ func runRandTest(rt randTest) bool {
tr.Delete(step.key)
delete(values, string(step.key))
case opGet:
- v := tr.Get(step.key)
+ v := tr.MustGet(step.key)
want := values[string(step.key)]
if string(v) != want {
rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want)
@@ -823,15 +824,11 @@ func tempDB() (string, *Database) {
if err != nil {
panic(fmt.Sprintf("can't create temporary directory: %v", err))
}
- diskdb, err := leveldb.New(dir, 256, 0, "")
- if err != nil {
- panic(fmt.Sprintf("can't create temporary database: %v", err))
- }
- return dir, NewDatabase(diskdb)
+ return dir, NewDatabase(rawdb.NewMemoryDatabase())
}
func getString(trie *Trie, k string) []byte {
- return trie.Get([]byte(k))
+ return trie.MustGet([]byte(k))
}
func updateString(trie *Trie, k, v string) {
diff --git a/whisper/whisperv5/api.go b/whisper/whisperv5/api.go
index 0ab821b4f..6c6cb1644 100644
--- a/whisper/whisperv5/api.go
+++ b/whisper/whisperv5/api.go
@@ -28,7 +28,7 @@ import (
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rpc"
)
@@ -93,19 +93,19 @@ func (api *PublicWhisperAPI) SetMaxMessageSize(ctx context.Context, size uint32)
return true, api.w.SetMaxMessageSize(size)
}
-// SetMinPow sets the minimum PoW for a message before it is accepted.
+// SetMinPoW sets the minimum PoW for a message before it is accepted.
func (api *PublicWhisperAPI) SetMinPoW(ctx context.Context, pow float64) (bool, error) {
return true, api.w.SetMinimumPoW(pow)
}
// MarkTrustedPeer marks a peer trusted. , which will allow it to send historic (expired) messages.
// Note: This function is not adding new nodes, the node needs to exists as a peer.
-func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, enode string) (bool, error) {
- n, err := discover.ParseNode(enode)
+func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, url string) (bool, error) {
+ n, err := enode.ParseV4(url)
if err != nil {
return false, err
}
- return true, api.w.AllowP2PMessagesFromPeer(n.ID[:])
+ return true, api.w.AllowP2PMessagesFromPeer(n.ID().Bytes())
}
// NewKeyPair generates a new public and private key pair for message decryption and encryption.
@@ -275,11 +275,11 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er
// send to specific node (skip PoW check)
if len(req.TargetPeer) > 0 {
- n, err := discover.ParseNode(req.TargetPeer)
+ n, err := enode.ParseV4(req.TargetPeer)
if err != nil {
return false, fmt.Errorf("failed to parse target peer: %s", err)
}
- return true, api.w.SendP2PMessage(n.ID[:], env)
+ return true, api.w.SendP2PMessage(n.ID().Bytes(), env)
}
// ensure that the message PoW meets the node's minimum accepted PoW
diff --git a/whisper/whisperv5/peer_test.go b/whisper/whisperv5/peer_test.go
index 2805aa3d5..0594c22d5 100644
--- a/whisper/whisperv5/peer_test.go
+++ b/whisper/whisperv5/peer_test.go
@@ -28,7 +28,7 @@ import (
"github.com/tomochain/tomochain/common"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
)
@@ -131,12 +131,11 @@ func initialize(t *testing.T) {
port := port0 + i
addr := fmt.Sprintf(":%d", port) // e.g. ":30303"
name := common.MakeName("whisper-go", "2.0")
- var peers []*discover.Node
+ var peers []*enode.Node
if i > 0 {
peerNodeId := nodes[i-1].id
- peerPort := uint16(port - 1)
- peerNode := discover.PubkeyID(&peerNodeId.PublicKey)
- peer := discover.NewNode(peerNode, ip, peerPort, peerPort)
+ peerPort := port - 1
+ peer := enode.NewV4(&peerNodeId.PublicKey, ip, peerPort, peerPort)
peers = append(peers, peer)
}
diff --git a/whisper/whisperv6/api.go b/whisper/whisperv6/api.go
index 32831d1ec..dca7c8b3f 100644
--- a/whisper/whisperv6/api.go
+++ b/whisper/whisperv6/api.go
@@ -28,7 +28,7 @@ import (
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/log"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/rpc"
)
@@ -106,12 +106,12 @@ func (api *PublicWhisperAPI) SetBloomFilter(ctx context.Context, bloom hexutil.B
// MarkTrustedPeer marks a peer trusted, which will allow it to send historic (expired) messages.
// Note: This function is not adding new nodes, the node needs to exists as a peer.
-func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, enode string) (bool, error) {
- n, err := discover.ParseNode(enode)
+func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, url string) (bool, error) {
+ n, err := enode.ParseV4(url)
if err != nil {
return false, err
}
- return true, api.w.AllowP2PMessagesFromPeer(n.ID[:])
+ return true, api.w.AllowP2PMessagesFromPeer(n.ID().Bytes())
}
// NewKeyPair generates a new public and private key pair for message decryption and encryption.
@@ -294,11 +294,11 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er
// send to specific node (skip PoW check)
if len(req.TargetPeer) > 0 {
- n, err := discover.ParseNode(req.TargetPeer)
+ n, err := enode.ParseV4(req.TargetPeer)
if err != nil {
return false, fmt.Errorf("failed to parse target peer: %s", err)
}
- return true, api.w.SendP2PMessage(n.ID[:], env)
+ return true, api.w.SendP2PMessage(n.ID().Bytes(), env)
}
// ensure that the message PoW meets the node's minimum accepted PoW
diff --git a/whisper/whisperv6/peer_test.go b/whisper/whisperv6/peer_test.go
index 1f0365eac..7a3b53265 100644
--- a/whisper/whisperv6/peer_test.go
+++ b/whisper/whisperv6/peer_test.go
@@ -31,7 +31,7 @@ import (
"github.com/tomochain/tomochain/common/hexutil"
"github.com/tomochain/tomochain/crypto"
"github.com/tomochain/tomochain/p2p"
- "github.com/tomochain/tomochain/p2p/discover"
+ "github.com/tomochain/tomochain/p2p/enode"
"github.com/tomochain/tomochain/p2p/nat"
)
@@ -202,12 +202,11 @@ func initialize(t *testing.T) {
port := port0 + i
addr := fmt.Sprintf(":%d", port) // e.g. ":30303"
name := common.MakeName("whisper-go", "2.0")
- var peers []*discover.Node
+ var peers []*enode.Node
if i > 0 {
peerNodeID := nodes[i-1].id
- peerPort := uint16(port - 1)
- peerNode := discover.PubkeyID(&peerNodeID.PublicKey)
- peer := discover.NewNode(peerNode, ip, peerPort, peerPort)
+ peerPort := port - 1
+ peer := enode.NewV4(&peerNodeID.PublicKey, ip, peerPort, peerPort)
peers = append(peers, peer)
}
@@ -437,7 +436,7 @@ func checkPowExchangeForNodeZeroOnce(t *testing.T, mustPass bool) bool {
cnt := 0
for i, node := range nodes {
for peer := range node.shh.peers {
- if peer.peer.ID() == discover.PubkeyID(&nodes[0].id.PublicKey) {
+ if peer.peer.ID() == enode.PubkeyToIDV4(&nodes[0].id.PublicKey) {
cnt++
if peer.powRequirement != masterPow {
if mustPass {
@@ -458,7 +457,7 @@ func checkPowExchangeForNodeZeroOnce(t *testing.T, mustPass bool) bool {
func checkPowExchange(t *testing.T) {
for i, node := range nodes {
for peer := range node.shh.peers {
- if peer.peer.ID() != discover.PubkeyID(&nodes[0].id.PublicKey) {
+ if peer.peer.ID() != enode.PubkeyToIDV4(&nodes[0].id.PublicKey) {
if peer.powRequirement != masterPow {
t.Fatalf("node %d: failed to exchange pow requirement in round %d; expected %f, got %f",
i, round, masterPow, peer.powRequirement)