diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index d4757e193..c6bc2a470 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,3 +1,9 @@ +Hi there, + +please note that this is an issue tracker reserved for bug reports and feature requests. + +For general questions please use the gitter channel or the Ethereum stack exchange at https://ethereum.stackexchange.com. + #### System information GMC version: `gmc version` diff --git a/.travis.yml b/.travis.yml index 12b371f63..da3e91465 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,7 +38,7 @@ matrix: - sudo chmod 666 /dev/fuse - sudo chown root:$USER /etc/fuse.conf - go run build/ci.go install - - go run build/ci.go test -coverage -misspell + - go run build/ci.go test -coverage - os: osx go: 1.9.x @@ -48,7 +48,21 @@ matrix: - brew install caskroom/cask/brew-cask - brew cask install osxfuse - go run build/ci.go install - - go run build/ci.go test -coverage -misspell + - go run build/ci.go test -coverage + + # This builder only tests code linters on latest version of Go + - os: linux + dist: trusty + sudo: required + go: 1.9.x + env: + - lint + script: + - sudo -E apt-get -yq --no-install-suggests --no-install-recommends --force-yes install fuse + - sudo modprobe fuse + - sudo chmod 666 /dev/fuse + - sudo chown root:$USER /etc/fuse.conf + - go run build/ci.go lint # This builder does the Ubuntu PPA and Linux Azure uploads - os: linux diff --git a/README.md b/README.md index 9e96015c0..d86a0d926 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,8 @@ docker run -d --name musicoin-node -v /Users/alice/musicoin:/root \ This will start gmc in fast sync mode with a DB memory allowance of 512MB just as the above command does. It will also create a persistent volume in your home directory for saving your blockchain as well as map the default ports. There is also an `alpine` tag available for a slim version of the image. +Do not forget `--rpcaddr 0.0.0.0`, if you want to access RPC from other containers and/or hosts. By default, `gmc` binds to the local interface and RPC endpoints is not accessible from the outside. + ### Programatically interfacing GMC nodes As a developer, sooner rather than later you'll want to start interacting with GMC and the Musicoin diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 0621e81c2..e1d5745ca 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -59,7 +59,7 @@ type SimulatedBackend struct { // for testing purposes. func NewSimulatedBackend(alloc core.GenesisAlloc) *SimulatedBackend { database, _ := ethdb.NewMemDatabase() - genesis := core.Genesis{Config: params.AllProtocolChanges, Alloc: alloc} + genesis := core.Genesis{Config: params.AllEthashProtocolChanges, Alloc: alloc} genesis.MustCommit(database) blockchain, _ := core.NewBlockChain(database, genesis.Config, ethash.NewFaker(), vm.Config{}) backend := &SimulatedBackend{database: database, blockchain: blockchain, config: genesis.Config} diff --git a/accounts/abi/bind/bind_test.go b/accounts/abi/bind/bind_test.go index 17b221642..43ed53b92 100644 --- a/accounts/abi/bind/bind_test.go +++ b/accounts/abi/bind/bind_test.go @@ -472,7 +472,7 @@ func TestBindings(t *testing.T) { t.Fatalf("failed to create temporary workspace: %v", err) } defer os.RemoveAll(ws) - + pkg := filepath.Join(ws, "bindtest") if err = os.MkdirAll(pkg, 0700); err != nil { t.Fatalf("failed to create package: %v", err) diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go index 9c7c339f3..294908378 100644 --- a/accounts/abi/unpack_test.go +++ b/accounts/abi/unpack_test.go @@ -365,7 +365,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("0102000000000000000000000000000000000000000000000000000000000000")) err = abi.Unpack(&mixedBytes, "mixedBytes", buff.Bytes()) - if err !=nil { + if err != nil { t.Error(err) } else { if bytes.Compare(p0, p0Exp) != 0 { diff --git a/build/ci.go b/build/ci.go index e2e2afc0d..a2ac1a710 100644 --- a/build/ci.go +++ b/build/ci.go @@ -24,7 +24,8 @@ Usage: go run ci.go Available commands are: install [ -arch architecture ] [ packages... ] -- builds packages and executables - test [ -coverage ] [ -misspell ] [ packages... ] -- runs the tests + test [ -coverage ] [ packages... ] -- runs the tests + lint -- runs certain pre-selected linters archive [ -arch architecture ] [ -type zip|tar ] [ -signer key-envvar ] [ -upload dest ] -- archives build artefacts importkeys -- imports signing keys from env debsrc [ -signer key-id ] [ -upload dest ] -- creates a debian source package @@ -146,6 +147,8 @@ func main() { doInstall(os.Args[2:]) case "test": doTest(os.Args[2:]) + case "lint": + doLint(os.Args[2:]) case "archive": doArchive(os.Args[2:]) case "debsrc": @@ -280,7 +283,6 @@ func goToolArch(arch string, subcmd string, args ...string) *exec.Cmd { func doTest(cmdline []string) { var ( - misspell = flag.Bool("misspell", false, "Whether to run the spell checker") coverage = flag.Bool("coverage", false, "Whether to record code coverage") ) flag.CommandLine.Parse(cmdline) @@ -294,10 +296,7 @@ func doTest(cmdline []string) { // Run analysis tools before the tests. build.MustRun(goTool("vet", packages...)) - if *misspell { - // TODO(karalabe): Reenable after false detection is fixed: https://github.com/client9/misspell/issues/105 - // spellcheck(packages) - } + // Run the actual tests. gotest := goTool("test", buildFlags(env)...) // Test a single package at a time. CI builders are slow @@ -306,36 +305,26 @@ func doTest(cmdline []string) { if *coverage { gotest.Args = append(gotest.Args, "-covermode=atomic", "-cover") } + gotest.Args = append(gotest.Args, packages...) build.MustRun(gotest) } -// spellcheck runs the client9/misspell spellchecker package on all Go, Cgo and -// test files in the requested packages. -func spellcheck(packages []string) { - // Ensure the spellchecker is available - build.MustRun(goTool("get", "github.com/client9/misspell/cmd/misspell")) +// runs gometalinter on requested packages +func doLint(cmdline []string) { + flag.CommandLine.Parse(cmdline) - // Windows chokes on long argument lists, check packages individually - for _, pkg := range packages { - // The spell checker doesn't work on packages, gather all .go files for it - out, err := goTool("list", "-f", "{{.Dir}}{{range .GoFiles}}\n{{.}}{{end}}{{range .CgoFiles}}\n{{.}}{{end}}{{range .TestGoFiles}}\n{{.}}{{end}}", pkg).CombinedOutput() - if err != nil { - log.Fatalf("source file listing failed: %v\n%s", err, string(out)) - } - // Retrieve the folder and assemble the source list - lines := strings.Split(string(out), "\n") - root := lines[0] - - sources := make([]string, 0, len(lines)-1) - for _, line := range lines[1:] { - if line = strings.TrimSpace(line); line != "" { - sources = append(sources, filepath.Join(root, line)) - } - } - // Run the spell checker for this particular package - build.MustRunCommand(filepath.Join(GOBIN, "misspell"), append([]string{"-error"}, sources...)...) + packages := []string{"./..."} + if len(flag.CommandLine.Args()) > 0 { + packages = flag.CommandLine.Args() } + // Get metalinter and install all supported linters + build.MustRun(goTool("get", "gopkg.in/alecthomas/gometalinter.v1")) + build.MustRunCommand(filepath.Join(GOBIN, "gometalinter.v1"), "--install") + + configs := []string{"--vendor", "--disable-all", "--enable=vet"} // Add additional linters to the slice with "--enable=linter-name" + + build.MustRunCommand(filepath.Join(GOBIN, "gometalinter.v1"), append(configs, packages...)...) } // Release Packaging diff --git a/cmd/faucet/website.go b/cmd/faucet/website.go index eeb8e410e..6a99f8c6f 100644 --- a/cmd/faucet/website.go +++ b/cmd/faucet/website.go @@ -182,8 +182,9 @@ type bintree struct { Func func() (*asset, error) Children map[string]*bintree } + var _bintree = &bintree{nil, map[string]*bintree{ - "faucet.html": &bintree{faucetHtml, map[string]*bintree{}}, + "faucet.html": {faucetHtml, map[string]*bintree{}}, }} // RestoreAsset restores an asset under the given directory @@ -232,4 +233,3 @@ func _filePath(dir, name string) string { cannonicalName := strings.Replace(name, "\\", "/", -1) return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) } - diff --git a/cmd/gmc/accountcmd_test.go b/cmd/gmc/accountcmd_test.go index 8fc0fdc18..4d7040027 100644 --- a/cmd/gmc/accountcmd_test.go +++ b/cmd/gmc/accountcmd_test.go @@ -204,7 +204,7 @@ Passphrase: {{.InputLine "foobar"}} func TestUnlockFlagPasswordFile(t *testing.T) { datadir := tmpDatadirWithKeystore(t) gmc := runGMC(t, - "--datadir", datadir, "--nat", "none", "--nodiscover", "--dev", + "--datadir", datadir, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", "--password", "testdata/passwords.txt", "--unlock", "0,2", "js", "testdata/empty.js") gmc.ExpectExit() @@ -224,7 +224,7 @@ func TestUnlockFlagPasswordFile(t *testing.T) { func TestUnlockFlagPasswordFileWrongPassword(t *testing.T) { datadir := tmpDatadirWithKeystore(t) gmc := runGMC(t, - "--datadir", datadir, "--nat", "none", "--nodiscover", "--dev", + "--datadir", datadir, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", "--password", "testdata/wrong-passwords.txt", "--unlock", "0,2") defer gmc.ExpectExit() gmc.Expect(` @@ -235,7 +235,7 @@ Fatal: Failed to unlock account 0 (could not decrypt key with given passphrase) func TestUnlockFlagAmbiguous(t *testing.T) { store := filepath.Join("..", "..", "accounts", "keystore", "testdata", "dupes") gmc := runGMC(t, - "--keystore", store, "--nat", "none", "--nodiscover", "--dev", + "--keystore", store, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "js", "testdata/empty.js") defer gmc.ExpectExit() @@ -273,7 +273,7 @@ In order to avoid this warning, you need to remove the following duplicate key f func TestUnlockFlagAmbiguousWrongPassword(t *testing.T) { store := filepath.Join("..", "..", "accounts", "keystore", "testdata", "dupes") gmc := runGMC(t, - "--keystore", store, "--nat", "none", "--nodiscover", "--dev", + "--keystore", store, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", "--unlock", "f466859ead1932d743d622cb74fc058882e8648a") defer gmc.ExpectExit() diff --git a/cmd/gmc/config.go b/cmd/gmc/config.go index 5bb5790a9..dc9821c92 100644 --- a/cmd/gmc/config.go +++ b/cmd/gmc/config.go @@ -155,7 +155,7 @@ func makeFullNode(ctx *cli.Context) *node.Node { // Whisper must be explicitly enabled by specifying at least 1 whisper flag or in dev mode shhEnabled := enableWhisper(ctx) - shhAutoEnabled := !ctx.GlobalIsSet(utils.WhisperEnabledFlag.Name) && ctx.GlobalIsSet(utils.DevModeFlag.Name) + shhAutoEnabled := !ctx.GlobalIsSet(utils.WhisperEnabledFlag.Name) && ctx.GlobalIsSet(utils.DeveloperFlag.Name) if shhEnabled || shhAutoEnabled { if ctx.GlobalIsSet(utils.WhisperMaxMessageSizeFlag.Name) { cfg.Shh.MaxMessageSize = uint32(ctx.Int(utils.WhisperMaxMessageSizeFlag.Name)) diff --git a/cmd/gmc/main.go b/cmd/gmc/main.go index fab254a36..46b7033d5 100644 --- a/cmd/gmc/main.go +++ b/cmd/gmc/main.go @@ -99,7 +99,8 @@ var ( utils.NetrestrictFlag, utils.NodeKeyFileFlag, utils.NodeKeyHexFlag, - utils.DevModeFlag, + utils.DeveloperFlag, + utils.DeveloperPeriodFlag, utils.TestnetFlag, utils.RinkebyFlag, utils.VMEnableDebugFlag, @@ -284,7 +285,7 @@ func startNode(ctx *cli.Context, stack *node.Node) { } }() // Start auxiliary services if enabled - if ctx.GlobalBool(utils.MiningEnabledFlag.Name) { + if ctx.GlobalBool(utils.MiningEnabledFlag.Name) || ctx.GlobalBool(utils.DeveloperFlag.Name) { // Mining only makes sense if a full Ethereum node is running var ethereum *eth.Ethereum if err := stack.Service(ðereum); err != nil { diff --git a/cmd/gmc/usage.go b/cmd/gmc/usage.go index 210078655..77c0f2c5d 100644 --- a/cmd/gmc/usage.go +++ b/cmd/gmc/usage.go @@ -73,7 +73,6 @@ var AppHelpFlagGroups = []flagGroup{ utils.NetworkIdFlag, utils.TestnetFlag, utils.RinkebyFlag, - utils.DevModeFlag, utils.SyncModeFlag, utils.EthStatsURLFlag, utils.IdentityFlag, @@ -82,6 +81,12 @@ var AppHelpFlagGroups = []flagGroup{ utils.LightKDFFlag, }, }, + {Name: "DEVELOPER CHAIN", + Flags: []cli.Flag{ + utils.DeveloperFlag, + utils.DeveloperPeriodFlag, + }, + }, { Name: "ETHASH", Flags: []cli.Flag{ diff --git a/cmd/rlpdump/main.go b/cmd/rlpdump/main.go index 7d328e59b..d0f993c5b 100644 --- a/cmd/rlpdump/main.go +++ b/cmd/rlpdump/main.go @@ -51,7 +51,7 @@ func main() { var r io.Reader switch { case *hexMode != "": - data, err := hex.DecodeString(*hexMode) + data, err := hex.DecodeString(strings.TrimPrefix(*hexMode, "0x")) if err != nil { die(err) } diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 36a8101e0..239b2f38d 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -137,9 +137,13 @@ var ( Name: "rinkeby", Usage: "Rinkeby network: pre-configured proof-of-authority test network", } - DevModeFlag = cli.BoolFlag{ + DeveloperFlag = cli.BoolFlag{ Name: "dev", - Usage: "Developer mode: pre-configured private network with several debugging flags", + Usage: "Ephemeral proof-of-authority network with a pre-funded developer account, mining enabled", + } + DeveloperPeriodFlag = cli.IntFlag{ + Name: "dev.period", + Usage: "Block period to use in developer mode (0 = mine only if transaction pending)", } IdentityFlag = cli.StringFlag{ Name: "identity", @@ -796,7 +800,7 @@ func SetP2PConfig(ctx *cli.Context, cfg *p2p.Config) { cfg.NetRestrict = list } - if ctx.GlobalBool(DevModeFlag.Name) { + if ctx.GlobalBool(DeveloperFlag.Name) { // --dev mode can't use p2p networking. cfg.MaxPeers = 0 cfg.ListenAddr = ":0" @@ -817,8 +821,8 @@ func SetNodeConfig(ctx *cli.Context, cfg *node.Config) { switch { case ctx.GlobalIsSet(DataDirFlag.Name): cfg.DataDir = ctx.GlobalString(DataDirFlag.Name) - case ctx.GlobalBool(DevModeFlag.Name): - cfg.DataDir = filepath.Join(os.TempDir(), "ethereum_dev_mode") + case ctx.GlobalBool(DeveloperFlag.Name): + cfg.DataDir = "" // unless explicitly requested, use memory databases case ctx.GlobalBool(TestnetFlag.Name): cfg.DataDir = filepath.Join(node.DefaultDataDir(), "testnet") case ctx.GlobalBool(RinkebyFlag.Name): @@ -924,7 +928,7 @@ func SetShhConfig(ctx *cli.Context, stack *node.Node, cfg *whisper.Config) { // SetEthConfig applies eth-related command line flags to the config. func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { // Avoid conflicting network flags - checkExclusive(ctx, DevModeFlag, TestnetFlag, RinkebyFlag) + checkExclusive(ctx, DeveloperFlag, TestnetFlag, RinkebyFlag) checkExclusive(ctx, FastSyncFlag, LightModeFlag, SyncModeFlag) ks := stack.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore) @@ -985,14 +989,30 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { cfg.NetworkId = 4 } cfg.Genesis = core.DefaultRinkebyGenesisBlock() - case ctx.GlobalBool(DevModeFlag.Name): - cfg.Genesis = core.DevGenesisBlock() + case ctx.GlobalBool(DeveloperFlag.Name): + // Create new developer account or reuse existing one + var ( + developer accounts.Account + err error + ) + if accs := ks.Accounts(); len(accs) > 0 { + developer = ks.Accounts()[0] + } else { + developer, err = ks.NewAccount("") + if err != nil { + Fatalf("Failed to create developer account: %v", err) + } + } + if err := ks.Unlock(developer, ""); err != nil { + Fatalf("Failed to unlock developer account: %v", err) + } + log.Info("Using developer account", "address", developer.Address) + + cfg.Genesis = core.DeveloperGenesisBlock(uint64(ctx.GlobalInt(DeveloperPeriodFlag.Name)), developer.Address) if !ctx.GlobalIsSet(GasPriceFlag.Name) { - cfg.GasPrice = new(big.Int) + cfg.GasPrice = big.NewInt(1) } - cfg.PowTest = true } - // TODO(fjl): move trie cache generations into config if gen := ctx.GlobalInt(TrieCacheGenFlag.Name); gen > 0 { state.MaxTrieCacheGen = uint16(gen) @@ -1077,8 +1097,8 @@ func MakeGenesis(ctx *cli.Context) *core.Genesis { genesis = core.DefaultTestnetGenesisBlock() case ctx.GlobalBool(RinkebyFlag.Name): genesis = core.DefaultRinkebyGenesisBlock() - case ctx.GlobalBool(DevModeFlag.Name): - genesis = core.DevGenesisBlock() + case ctx.GlobalBool(DeveloperFlag.Name): + Fatalf("Developer chains are ephemeral") } return genesis } diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index 8d6cf653d..6892d8390 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -125,6 +125,11 @@ var ( // errUnauthorized is returned if a header is signed by a non-authorized entity. errUnauthorized = errors.New("unauthorized") + + // errWaitTransactions is returned if an empty block is attempted to be sealed + // on an instant chain (0 second period). It's important to refuse these as the + // block reward is zero, so an empty block just bloats the chain... fast. + errWaitTransactions = errors.New("waiting for transactions") ) // SignerFn is a signer callback function to request a hash to be signed by a @@ -211,9 +216,6 @@ func New(config *params.CliqueConfig, db ethdb.Database) *Clique { if conf.Epoch == 0 { conf.Epoch = epochLength } - if conf.Period == 0 { - conf.Period = blockPeriod - } // Allocate the snapshot caches and create the engine recents, _ := lru.NewARC(inmemorySnapshots) signatures, _ := lru.NewARC(inmemorySignatures) @@ -599,6 +601,10 @@ func (c *Clique) Seal(chain consensus.ChainReader, block *types.Block, stop <-ch if number == 0 { return nil, errUnknownBlock } + // For 0-period chains, refuse to seal empty blocks (no reward but would spin sealing) + if c.config.Period == 0 && len(block.Transactions()) == 0 { + return nil, errWaitTransactions + } // Don't hold the signer fields for the entire sealing procedure c.lock.RLock() signer, signFn := c.signer, c.signFn diff --git a/consensus/clique/snapshot_test.go b/consensus/clique/snapshot_test.go index a1717d799..8b51e6e09 100644 --- a/consensus/clique/snapshot_test.go +++ b/consensus/clique/snapshot_test.go @@ -74,7 +74,7 @@ type testerChainReader struct { db ethdb.Database } -func (r *testerChainReader) Config() *params.ChainConfig { return params.AllProtocolChanges } +func (r *testerChainReader) Config() *params.ChainConfig { return params.AllCliqueProtocolChanges } func (r *testerChainReader) CurrentHeader() *types.Header { panic("not supported") } func (r *testerChainReader) GetHeader(common.Hash, uint64) *types.Header { panic("not supported") } func (r *testerChainReader) GetBlock(common.Hash, uint64) *types.Block { panic("not supported") } diff --git a/console/console_test.go b/console/console_test.go index 8ac499bd1..a159b62bb 100644 --- a/console/console_test.go +++ b/console/console_test.go @@ -94,7 +94,7 @@ func newTester(t *testing.T, confOverride func(*eth.Config)) *tester { t.Fatalf("failed to create node: %v", err) } ethConf := ð.Config{ - Genesis: core.DevGenesisBlock(), + Genesis: core.DeveloperGenesisBlock(15, common.Address{}), Etherbase: common.HexToAddress(testAddress), PowTest: true, } diff --git a/core/bloombits/matcher.go b/core/bloombits/matcher.go index e33de018a..d38d4ba83 100644 --- a/core/bloombits/matcher.go +++ b/core/bloombits/matcher.go @@ -18,6 +18,7 @@ package bloombits import ( "bytes" + "context" "errors" "math" "sort" @@ -56,10 +57,16 @@ type partialMatches struct { // Retrieval represents a request for retrieval task assignments for a given // bit with the given number of fetch elements, or a response for such a request. // It can also have the actual results set to be used as a delivery data struct. +// +// The contest and error fields are used by the light client to terminate matching +// early if an error is enountered on some path of the pipeline. type Retrieval struct { Bit uint Sections []uint64 Bitsets [][]byte + + Context context.Context + Error error } // Matcher is a pipelined system of schedulers and logic matchers which perform @@ -137,7 +144,7 @@ func (m *Matcher) addScheduler(idx uint) { // Start starts the matching process and returns a stream of bloom matches in // a given range of blocks. If there are no more matches in the range, the result // channel is closed. -func (m *Matcher) Start(begin, end uint64, results chan uint64) (*MatcherSession, error) { +func (m *Matcher) Start(ctx context.Context, begin, end uint64, results chan uint64) (*MatcherSession, error) { // Make sure we're not creating concurrent sessions if atomic.SwapUint32(&m.running, 1) == 1 { return nil, errors.New("matcher already running") @@ -149,6 +156,7 @@ func (m *Matcher) Start(begin, end uint64, results chan uint64) (*MatcherSession matcher: m, quit: make(chan struct{}), kill: make(chan struct{}), + ctx: ctx, } for _, scheduler := range m.schedulers { scheduler.reset() @@ -502,25 +510,34 @@ func (m *Matcher) distributor(dist chan *request, session *MatcherSession) { type MatcherSession struct { matcher *Matcher - quit chan struct{} // Quit channel to request pipeline termination - kill chan struct{} // Term channel to signal non-graceful forced shutdown + closer sync.Once // Sync object to ensure we only ever close once + quit chan struct{} // Quit channel to request pipeline termination + kill chan struct{} // Term channel to signal non-graceful forced shutdown + + ctx context.Context // Context used by the light client to abort filtering + err atomic.Value // Global error to track retrieval failures deep in the chain + pend sync.WaitGroup } // Close stops the matching process and waits for all subprocesses to terminate // before returning. The timeout may be used for graceful shutdown, allowing the // currently running retrievals to complete before this time. -func (s *MatcherSession) Close(timeout time.Duration) { - // Bail out if the matcher is not running - select { - case <-s.quit: - return - default: +func (s *MatcherSession) Close() { + s.closer.Do(func() { + // Signal termination and wait for all goroutines to tear down + close(s.quit) + time.AfterFunc(time.Second, func() { close(s.kill) }) + s.pend.Wait() + }) +} + +// Error returns any failure encountered during the matching session. +func (s *MatcherSession) Error() error { + if err := s.err.Load(); err != nil { + return err.(error) } - // Signal termination and wait for all goroutines to tear down - close(s.quit) - time.AfterFunc(timeout, func() { close(s.kill) }) - s.pend.Wait() + return nil } // AllocateRetrieval assigns a bloom bit index to a client process that can either @@ -618,9 +635,13 @@ func (s *MatcherSession) Multiplex(batch int, wait time.Duration, mux chan chan case mux <- request: // Retrieval accepted, something must arrive before we're aborting - request <- &Retrieval{Bit: bit, Sections: sections} + request <- &Retrieval{Bit: bit, Sections: sections, Context: s.ctx} result := <-request + if result.Error != nil { + s.err.Store(result.Error) + s.Close() + } s.DeliverSections(result.Bit, result.Sections, result.Bitsets) } } diff --git a/core/bloombits/matcher_test.go b/core/bloombits/matcher_test.go index 2e15e7aac..4a31854c5 100644 --- a/core/bloombits/matcher_test.go +++ b/core/bloombits/matcher_test.go @@ -17,6 +17,7 @@ package bloombits import ( + "context" "math/rand" "sync/atomic" "testing" @@ -30,14 +31,14 @@ const testSectionSize = 4096 // Tests that wildcard filter rules (nil) can be specified and are handled well. func TestMatcherWildcards(t *testing.T) { matcher := NewMatcher(testSectionSize, [][][]byte{ - [][]byte{common.Address{}.Bytes(), common.Address{0x01}.Bytes()}, // Default address is not a wildcard - [][]byte{common.Hash{}.Bytes(), common.Hash{0x01}.Bytes()}, // Default hash is not a wildcard - [][]byte{common.Hash{0x01}.Bytes()}, // Plain rule, sanity check - [][]byte{common.Hash{0x01}.Bytes(), nil}, // Wildcard suffix, drop rule - [][]byte{nil, common.Hash{0x01}.Bytes()}, // Wildcard prefix, drop rule - [][]byte{nil, nil}, // Wildcard combo, drop rule - [][]byte{}, // Inited wildcard rule, drop rule - nil, // Proper wildcard rule, drop rule + {common.Address{}.Bytes(), common.Address{0x01}.Bytes()}, // Default address is not a wildcard + {common.Hash{}.Bytes(), common.Hash{0x01}.Bytes()}, // Default hash is not a wildcard + {common.Hash{0x01}.Bytes()}, // Plain rule, sanity check + {common.Hash{0x01}.Bytes(), nil}, // Wildcard suffix, drop rule + {nil, common.Hash{0x01}.Bytes()}, // Wildcard prefix, drop rule + {nil, nil}, // Wildcard combo, drop rule + {}, // Inited wildcard rule, drop rule + nil, // Proper wildcard rule, drop rule }) if len(matcher.filters) != 3 { t.Fatalf("filter system size mismatch: have %d, want %d", len(matcher.filters), 3) @@ -144,7 +145,7 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, blocks uint64, intermitt quit := make(chan struct{}) matches := make(chan uint64, 16) - session, err := matcher.Start(0, blocks-1, matches) + session, err := matcher.Start(context.Background(), 0, blocks-1, matches) if err != nil { t.Fatalf("failed to stat matcher session: %v", err) } @@ -163,13 +164,13 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, blocks uint64, intermitt } // If we're testing intermittent mode, abort and restart the pipeline if intermittent { - session.Close(time.Second) + session.Close() close(quit) quit = make(chan struct{}) matches = make(chan uint64, 16) - session, err = matcher.Start(i+1, blocks-1, matches) + session, err = matcher.Start(context.Background(), i+1, blocks-1, matches) if err != nil { t.Fatalf("failed to stat matcher session: %v", err) } @@ -183,7 +184,7 @@ func testMatcher(t *testing.T, filter [][]bloomIndexes, blocks uint64, intermitt t.Errorf("filter = %v blocks = %v intermittent = %v: expected closed channel, got #%v", filter, blocks, intermittent, match) } // Clean up the session and ensure we match the expected retrieval count - session.Close(time.Second) + session.Close() close(quit) if retrievals != 0 && requested != retrievals { diff --git a/core/bloombits/scheduler_test.go b/core/bloombits/scheduler_test.go index 8a159c237..70772e4ab 100644 --- a/core/bloombits/scheduler_test.go +++ b/core/bloombits/scheduler_test.go @@ -60,7 +60,7 @@ func testScheduler(t *testing.T, clients int, fetchers int, requests int) { req.section, // Requested data req.section, // Duplicated data (ensure it doesn't double close anything) }, [][]byte{ - []byte{}, + {}, new(big.Int).SetUint64(req.section).Bytes(), new(big.Int).SetUint64(req.section).Bytes(), }) diff --git a/core/chain_indexer.go b/core/chain_indexer.go index f4c207dcc..7e7500dc8 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -36,7 +36,7 @@ import ( type ChainIndexerBackend interface { // Reset initiates the processing of a new chain segment, potentially terminating // any partially completed operations (in case of a reorg). - Reset(section uint64) + Reset(section uint64, prevHead common.Hash) error // Process crunches through the next header in the chain segment. The caller // will ensure a sequential order of headers. @@ -46,6 +46,15 @@ type ChainIndexerBackend interface { Commit() error } +// ChainIndexerChain interface is used for connecting the indexer to a blockchain +type ChainIndexerChain interface { + // CurrentHeader retrieves the latest locally known header. + CurrentHeader() *types.Header + + // SubscribeChainEvent subscribes to new head header notifications. + SubscribeChainEvent(ch chan<- ChainEvent) event.Subscription +} + // ChainIndexer does a post-processing job for equally sized sections of the // canonical chain (like BlooomBits and CHT structures). A ChainIndexer is // connected to the blockchain through the event system by starting a @@ -100,11 +109,27 @@ func NewChainIndexer(chainDb, indexDb ethdb.Database, backend ChainIndexerBacken return c } +// AddKnownSectionHead marks a new section head as known/processed if it is newer +// than the already known best section head +func (c *ChainIndexer) AddKnownSectionHead(section uint64, shead common.Hash) { + c.lock.Lock() + defer c.lock.Unlock() + + if section < c.storedSections { + return + } + c.setSectionHead(section, shead) + c.setValidSections(section + 1) +} + // Start creates a goroutine to feed chain head events into the indexer for // cascading background processing. Children do not need to be started, they // are notified about new events by their parents. -func (c *ChainIndexer) Start(currentHeader *types.Header, chainEventer func(ch chan<- ChainEvent) event.Subscription) { - go c.eventLoop(currentHeader, chainEventer) +func (c *ChainIndexer) Start(chain ChainIndexerChain) { + events := make(chan ChainEvent, 10) + sub := chain.SubscribeChainEvent(events) + + go c.eventLoop(chain.CurrentHeader(), events, sub) } // Close tears down all goroutines belonging to the indexer and returns any error @@ -147,12 +172,10 @@ func (c *ChainIndexer) Close() error { // eventLoop is a secondary - optional - event loop of the indexer which is only // started for the outermost indexer to push chain head events into a processing // queue. -func (c *ChainIndexer) eventLoop(currentHeader *types.Header, chainEventer func(ch chan<- ChainEvent) event.Subscription) { +func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainEvent, sub event.Subscription) { // Mark the chain indexer as active, requiring an additional teardown atomic.StoreUint32(&c.active, 1) - events := make(chan ChainEvent, 10) - sub := chainEventer(events) defer sub.Unsubscribe() // Fire the initial new head event to start any outstanding processing @@ -178,7 +201,11 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, chainEventer func( } header := ev.Block.Header() if header.ParentHash != prevHash { - c.newHead(FindCommonAncestor(c.chainDb, prevHeader, header).Number.Uint64(), true) + // Reorg to the common ancestor (might not exist in light sync mode, skip reorg then) + // TODO(karalabe, zsfelfoldi): This seems a bit brittle, can we detect this case explicitly? + if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { + c.newHead(h.Number.Uint64(), true) + } } c.newHead(header.Number.Uint64(), false) @@ -236,6 +263,7 @@ func (c *ChainIndexer) updateLoop() { updating bool updated time.Time ) + for { select { case errc := <-c.quit: @@ -259,7 +287,7 @@ func (c *ChainIndexer) updateLoop() { section := c.storedSections var oldHead common.Hash if section > 0 { - oldHead = c.sectionHead(section - 1) + oldHead = c.SectionHead(section - 1) } // Process the newly defined section in the background c.lock.Unlock() @@ -270,7 +298,7 @@ func (c *ChainIndexer) updateLoop() { c.lock.Lock() // If processing succeeded and no reorgs occcurred, mark the section completed - if err == nil && oldHead == c.sectionHead(section-1) { + if err == nil && oldHead == c.SectionHead(section-1) { c.setSectionHead(section, newHead) c.setValidSections(section + 1) if c.storedSections == c.knownSections && updating { @@ -311,7 +339,11 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com c.log.Trace("Processing new chain section", "section", section) // Reset and partial processing - c.backend.Reset(section) + + if err := c.backend.Reset(section, lastHead); err != nil { + c.setValidSections(0) + return common.Hash{}, err + } for number := section * c.sectionSize; number < (section+1)*c.sectionSize; number++ { hash := GetCanonicalHash(c.chainDb, number) @@ -341,7 +373,7 @@ func (c *ChainIndexer) Sections() (uint64, uint64, common.Hash) { c.lock.Lock() defer c.lock.Unlock() - return c.storedSections, c.storedSections*c.sectionSize - 1, c.sectionHead(c.storedSections - 1) + return c.storedSections, c.storedSections*c.sectionSize - 1, c.SectionHead(c.storedSections - 1) } // AddChildIndexer adds a child ChainIndexer that can use the output of this one @@ -381,9 +413,9 @@ func (c *ChainIndexer) setValidSections(sections uint64) { c.storedSections = sections // needed if new > old } -// sectionHead retrieves the last block hash of a processed section from the +// SectionHead retrieves the last block hash of a processed section from the // index database. -func (c *ChainIndexer) sectionHead(section uint64) common.Hash { +func (c *ChainIndexer) SectionHead(section uint64) common.Hash { var data [8]byte binary.BigEndian.PutUint64(data[:], section) diff --git a/core/chain_indexer_test.go b/core/chain_indexer_test.go index b761e8a5b..9fc09eda5 100644 --- a/core/chain_indexer_test.go +++ b/core/chain_indexer_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" ) @@ -208,9 +209,10 @@ func (b *testChainIndexBackend) reorg(headNum uint64) uint64 { return b.stored * b.indexer.sectionSize } -func (b *testChainIndexBackend) Reset(section uint64) { +func (b *testChainIndexBackend) Reset(section uint64, prevHead common.Hash) error { b.section = section b.headerCnt = 0 + return nil } func (b *testChainIndexBackend) Process(header *types.Header) { diff --git a/core/chain_makers.go b/core/chain_makers.go index dd3e2fb19..59af633df 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -235,7 +235,7 @@ func newCanonical(n int, full bool) (ethdb.Database, *BlockChain, error) { db, _ := ethdb.NewMemDatabase() genesis := gspec.MustCommit(db) - blockchain, _ := NewBlockChain(db, params.AllProtocolChanges, ethash.NewFaker(), vm.Config{}) + blockchain, _ := NewBlockChain(db, params.AllEthashProtocolChanges, ethash.NewFaker(), vm.Config{}) // Create and inject the requested chain if n == 0 { return db, blockchain, nil diff --git a/core/database_util.go b/core/database_util.go index 1730a048e..c6b125dae 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -74,9 +74,9 @@ var ( preimageHitCounter = metrics.NewCounter("db/preimage/hits") ) -// txLookupEntry is a positional metadata to help looking up the data content of +// TxLookupEntry is a positional metadata to help looking up the data content of // a transaction or receipt given only its hash. -type txLookupEntry struct { +type TxLookupEntry struct { BlockHash common.Hash BlockIndex uint64 Index uint64 @@ -260,7 +260,7 @@ func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, return common.Hash{}, 0, 0 } // Parse and return the contents of the lookup entry - var entry txLookupEntry + var entry TxLookupEntry if err := rlp.DecodeBytes(data, &entry); err != nil { log.Error("Invalid lookup entry RLP", "hash", hash, "err", err) return common.Hash{}, 0, 0 @@ -296,7 +296,7 @@ func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, co if len(data) == 0 { return nil, common.Hash{}, 0, 0 } - var entry txLookupEntry + var entry TxLookupEntry if err := rlp.DecodeBytes(data, &entry); err != nil { return nil, common.Hash{}, 0, 0 } @@ -332,14 +332,13 @@ func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Has // GetBloomBits retrieves the compressed bloom bit vector belonging to the given // section and bit index from the. -func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) []byte { +func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...) binary.BigEndian.PutUint16(key[1:], uint16(bit)) binary.BigEndian.PutUint64(key[3:], section) - bits, _ := db.Get(key) - return bits + return db.Get(key) } // WriteCanonicalHash stores the canonical hash for the given block number. @@ -465,7 +464,7 @@ func WriteBlockReceipts(db ethdb.Putter, hash common.Hash, number uint64, receip func WriteTxLookupEntries(db ethdb.Putter, block *types.Block) error { // Iterate over each transaction and encode its metadata for i, tx := range block.Transactions() { - entry := txLookupEntry{ + entry := TxLookupEntry{ BlockHash: block.Hash(), BlockIndex: block.NumberU64(), Index: uint64(i), diff --git a/core/genesis.go b/core/genesis.go index dc1a53206..9f0e9f6e4 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -151,7 +151,7 @@ func (e *GenesisMismatchError) Error() string { // The returned chain configuration is never nil. func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig, common.Hash, error) { if genesis != nil && genesis.Config == nil { - return params.AllProtocolChanges, common.Hash{}, errGenesisNoConfig + return params.AllEthashProtocolChanges, common.Hash{}, errGenesisNoConfig } // Just commit the new block if there is no stored genesis block. @@ -216,7 +216,7 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig { case ghash == params.TestnetGenesisHash: return params.TestnetChainConfig default: - return params.AllProtocolChanges + return params.AllEthashProtocolChanges } } @@ -285,7 +285,7 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { } config := g.Config if config == nil { - config = params.AllProtocolChanges + config = params.AllEthashProtocolChanges } return block, WriteChainConfig(db, block.Hash(), config) } @@ -356,14 +356,30 @@ func DefaultRinkebyGenesisBlock() *Genesis { } } -// DevGenesisBlock returns the 'gmc --dev' genesis block. -func DevGenesisBlock() *Genesis { +// DeveloperGenesisBlock returns the 'gmc --dev' genesis block. Note, this must +// be seeded with the +func DeveloperGenesisBlock(period uint64, faucet common.Address) *Genesis { + // Override the default period to the user requested one + config := *params.AllCliqueProtocolChanges + config.Clique.Period = period + + // Assemble and return the genesis with the precompiles and faucet pre-funded return &Genesis{ - Config: params.AllProtocolChanges, - Nonce: 42, - GasLimit: 4712388, - Difficulty: big.NewInt(131072), - Alloc: decodePrealloc(devAllocData), + Config: &config, + ExtraData: append(append(make([]byte, 32), faucet[:]...), make([]byte, 65)...), + GasLimit: 6283185, + Difficulty: big.NewInt(1), + Alloc: map[common.Address]GenesisAccount{ + common.BytesToAddress([]byte{1}): {Balance: big.NewInt(1)}, // ECRecover + common.BytesToAddress([]byte{2}): {Balance: big.NewInt(1)}, // SHA256 + common.BytesToAddress([]byte{3}): {Balance: big.NewInt(1)}, // RIPEMD + common.BytesToAddress([]byte{4}): {Balance: big.NewInt(1)}, // Identity + common.BytesToAddress([]byte{5}): {Balance: big.NewInt(1)}, // ModExp + common.BytesToAddress([]byte{6}): {Balance: big.NewInt(1)}, // ECAdd + common.BytesToAddress([]byte{7}): {Balance: big.NewInt(1)}, // ECScalarMul + common.BytesToAddress([]byte{8}): {Balance: big.NewInt(1)}, // ECPairing + faucet: {Balance: new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 256), big.NewInt(9))}, + }, } } diff --git a/core/genesis_alloc.go b/core/genesis_alloc.go index bb2a04329..fbe63c856 100644 --- a/core/genesis_alloc.go +++ b/core/genesis_alloc.go @@ -23,4 +23,3 @@ package core const mainnetAllocData = "\xcc\xc2\x01\x01\xc2\x02\x01\xc2\x03\x01\xc2\x04\x01" const testnetAllocData = "\xf9\x03\xa4\u0080\x01\xc2\x01\x01\xc2\x02\x01\xc2\x03\x01\xc2\x04\x01\xc2\x05\x01\xc2\x06\x01\xc2\a\x01\xc2\b\x01\xc2\t\x01\xc2\n\x80\xc2\v\x80\xc2\f\x80\xc2\r\x80\xc2\x0e\x80\xc2\x0f\x80\xc2\x10\x80\xc2\x11\x80\xc2\x12\x80\xc2\x13\x80\xc2\x14\x80\xc2\x15\x80\xc2\x16\x80\xc2\x17\x80\xc2\x18\x80\xc2\x19\x80\xc2\x1a\x80\xc2\x1b\x80\xc2\x1c\x80\xc2\x1d\x80\xc2\x1e\x80\xc2\x1f\x80\xc2 \x80\xc2!\x80\xc2\"\x80\xc2#\x80\xc2$\x80\xc2%\x80\xc2&\x80\xc2'\x80\xc2(\x80\xc2)\x80\xc2*\x80\xc2+\x80\xc2,\x80\xc2-\x80\xc2.\x80\xc2/\x80\xc20\x80\xc21\x80\xc22\x80\xc23\x80\xc24\x80\xc25\x80\xc26\x80\xc27\x80\xc28\x80\xc29\x80\xc2:\x80\xc2;\x80\xc2<\x80\xc2=\x80\xc2>\x80\xc2?\x80\xc2@\x80\xc2A\x80\xc2B\x80\xc2C\x80\xc2D\x80\xc2E\x80\xc2F\x80\xc2G\x80\xc2H\x80\xc2I\x80\xc2J\x80\xc2K\x80\xc2L\x80\xc2M\x80\xc2N\x80\xc2O\x80\xc2P\x80\xc2Q\x80\xc2R\x80\xc2S\x80\xc2T\x80\xc2U\x80\xc2V\x80\xc2W\x80\xc2X\x80\xc2Y\x80\xc2Z\x80\xc2[\x80\xc2\\\x80\xc2]\x80\xc2^\x80\xc2_\x80\xc2`\x80\xc2a\x80\xc2b\x80\xc2c\x80\xc2d\x80\xc2e\x80\xc2f\x80\xc2g\x80\xc2h\x80\xc2i\x80\xc2j\x80\xc2k\x80\xc2l\x80\xc2m\x80\xc2n\x80\xc2o\x80\xc2p\x80\xc2q\x80\xc2r\x80\xc2s\x80\xc2t\x80\xc2u\x80\xc2v\x80\xc2w\x80\xc2x\x80\xc2y\x80\xc2z\x80\xc2{\x80\xc2|\x80\xc2}\x80\xc2~\x80\xc2\u007f\x80\u00c1\x80\x80\u00c1\x81\x80\u00c1\x82\x80\u00c1\x83\x80\u00c1\x84\x80\u00c1\x85\x80\u00c1\x86\x80\u00c1\x87\x80\u00c1\x88\x80\u00c1\x89\x80\u00c1\x8a\x80\u00c1\x8b\x80\u00c1\x8c\x80\u00c1\x8d\x80\u00c1\x8e\x80\u00c1\x8f\x80\u00c1\x90\x80\u00c1\x91\x80\u00c1\x92\x80\u00c1\x93\x80\u00c1\x94\x80\u00c1\x95\x80\u00c1\x96\x80\u00c1\x97\x80\u00c1\x98\x80\u00c1\x99\x80\u00c1\x9a\x80\u00c1\x9b\x80\u00c1\x9c\x80\u00c1\x9d\x80\u00c1\x9e\x80\u00c1\x9f\x80\u00c1\xa0\x80\u00c1\xa1\x80\u00c1\xa2\x80\u00c1\xa3\x80\u00c1\xa4\x80\u00c1\xa5\x80\u00c1\xa6\x80\u00c1\xa7\x80\u00c1\xa8\x80\u00c1\xa9\x80\u00c1\xaa\x80\u00c1\xab\x80\u00c1\xac\x80\u00c1\xad\x80\u00c1\xae\x80\u00c1\xaf\x80\u00c1\xb0\x80\u00c1\xb1\x80\u00c1\xb2\x80\u00c1\xb3\x80\u00c1\xb4\x80\u00c1\xb5\x80\u00c1\xb6\x80\u00c1\xb7\x80\u00c1\xb8\x80\u00c1\xb9\x80\u00c1\xba\x80\u00c1\xbb\x80\u00c1\xbc\x80\u00c1\xbd\x80\u00c1\xbe\x80\u00c1\xbf\x80\u00c1\xc0\x80\u00c1\xc1\x80\u00c1\u0080\u00c1\u00c0\u00c1\u0100\u00c1\u0140\u00c1\u0180\u00c1\u01c0\u00c1\u0200\u00c1\u0240\u00c1\u0280\u00c1\u02c0\u00c1\u0300\u00c1\u0340\u00c1\u0380\u00c1\u03c0\u00c1\u0400\u00c1\u0440\u00c1\u0480\u00c1\u04c0\u00c1\u0500\u00c1\u0540\u00c1\u0580\u00c1\u05c0\u00c1\u0600\u00c1\u0640\u00c1\u0680\u00c1\u06c0\u00c1\u0700\u00c1\u0740\u00c1\u0780\u00c1\u07c0\u00c1\xe0\x80\u00c1\xe1\x80\u00c1\xe2\x80\u00c1\xe3\x80\u00c1\xe4\x80\u00c1\xe5\x80\u00c1\xe6\x80\u00c1\xe7\x80\u00c1\xe8\x80\u00c1\xe9\x80\u00c1\xea\x80\u00c1\xeb\x80\u00c1\xec\x80\u00c1\xed\x80\u00c1\xee\x80\u00c1\xef\x80\u00c1\xf0\x80\u00c1\xf1\x80\u00c1\xf2\x80\u00c1\xf3\x80\u00c1\xf4\x80\u00c1\xf5\x80\u00c1\xf6\x80\u00c1\xf7\x80\u00c1\xf8\x80\u00c1\xf9\x80\u00c1\xfa\x80\u00c1\xfb\x80\u00c1\xfc\x80\u00c1\xfd\x80\u00c1\xfe\x80\u00c1\xff\x80\u3507KT\xa8\xbd\x15)f\xd6?pk\xae\x1f\xfe\xb0A\x19!\xe5\x8d\f\x9f,\x9c\xd0Ft\xed\xea@\x00\x00\x00" const rinkebyAllocData = "\xf9\x03\xb7\u0080\x01\xc2\x01\x01\xc2\x02\x01\xc2\x03\x01\xc2\x04\x01\xc2\x05\x01\xc2\x06\x01\xc2\a\x01\xc2\b\x01\xc2\t\x01\xc2\n\x01\xc2\v\x01\xc2\f\x01\xc2\r\x01\xc2\x0e\x01\xc2\x0f\x01\xc2\x10\x01\xc2\x11\x01\xc2\x12\x01\xc2\x13\x01\xc2\x14\x01\xc2\x15\x01\xc2\x16\x01\xc2\x17\x01\xc2\x18\x01\xc2\x19\x01\xc2\x1a\x01\xc2\x1b\x01\xc2\x1c\x01\xc2\x1d\x01\xc2\x1e\x01\xc2\x1f\x01\xc2 \x01\xc2!\x01\xc2\"\x01\xc2#\x01\xc2$\x01\xc2%\x01\xc2&\x01\xc2'\x01\xc2(\x01\xc2)\x01\xc2*\x01\xc2+\x01\xc2,\x01\xc2-\x01\xc2.\x01\xc2/\x01\xc20\x01\xc21\x01\xc22\x01\xc23\x01\xc24\x01\xc25\x01\xc26\x01\xc27\x01\xc28\x01\xc29\x01\xc2:\x01\xc2;\x01\xc2<\x01\xc2=\x01\xc2>\x01\xc2?\x01\xc2@\x01\xc2A\x01\xc2B\x01\xc2C\x01\xc2D\x01\xc2E\x01\xc2F\x01\xc2G\x01\xc2H\x01\xc2I\x01\xc2J\x01\xc2K\x01\xc2L\x01\xc2M\x01\xc2N\x01\xc2O\x01\xc2P\x01\xc2Q\x01\xc2R\x01\xc2S\x01\xc2T\x01\xc2U\x01\xc2V\x01\xc2W\x01\xc2X\x01\xc2Y\x01\xc2Z\x01\xc2[\x01\xc2\\\x01\xc2]\x01\xc2^\x01\xc2_\x01\xc2`\x01\xc2a\x01\xc2b\x01\xc2c\x01\xc2d\x01\xc2e\x01\xc2f\x01\xc2g\x01\xc2h\x01\xc2i\x01\xc2j\x01\xc2k\x01\xc2l\x01\xc2m\x01\xc2n\x01\xc2o\x01\xc2p\x01\xc2q\x01\xc2r\x01\xc2s\x01\xc2t\x01\xc2u\x01\xc2v\x01\xc2w\x01\xc2x\x01\xc2y\x01\xc2z\x01\xc2{\x01\xc2|\x01\xc2}\x01\xc2~\x01\xc2\u007f\x01\u00c1\x80\x01\u00c1\x81\x01\u00c1\x82\x01\u00c1\x83\x01\u00c1\x84\x01\u00c1\x85\x01\u00c1\x86\x01\u00c1\x87\x01\u00c1\x88\x01\u00c1\x89\x01\u00c1\x8a\x01\u00c1\x8b\x01\u00c1\x8c\x01\u00c1\x8d\x01\u00c1\x8e\x01\u00c1\x8f\x01\u00c1\x90\x01\u00c1\x91\x01\u00c1\x92\x01\u00c1\x93\x01\u00c1\x94\x01\u00c1\x95\x01\u00c1\x96\x01\u00c1\x97\x01\u00c1\x98\x01\u00c1\x99\x01\u00c1\x9a\x01\u00c1\x9b\x01\u00c1\x9c\x01\u00c1\x9d\x01\u00c1\x9e\x01\u00c1\x9f\x01\u00c1\xa0\x01\u00c1\xa1\x01\u00c1\xa2\x01\u00c1\xa3\x01\u00c1\xa4\x01\u00c1\xa5\x01\u00c1\xa6\x01\u00c1\xa7\x01\u00c1\xa8\x01\u00c1\xa9\x01\u00c1\xaa\x01\u00c1\xab\x01\u00c1\xac\x01\u00c1\xad\x01\u00c1\xae\x01\u00c1\xaf\x01\u00c1\xb0\x01\u00c1\xb1\x01\u00c1\xb2\x01\u00c1\xb3\x01\u00c1\xb4\x01\u00c1\xb5\x01\u00c1\xb6\x01\u00c1\xb7\x01\u00c1\xb8\x01\u00c1\xb9\x01\u00c1\xba\x01\u00c1\xbb\x01\u00c1\xbc\x01\u00c1\xbd\x01\u00c1\xbe\x01\u00c1\xbf\x01\u00c1\xc0\x01\u00c1\xc1\x01\u00c1\xc2\x01\u00c1\xc3\x01\u00c1\xc4\x01\u00c1\xc5\x01\u00c1\xc6\x01\u00c1\xc7\x01\u00c1\xc8\x01\u00c1\xc9\x01\u00c1\xca\x01\u00c1\xcb\x01\u00c1\xcc\x01\u00c1\xcd\x01\u00c1\xce\x01\u00c1\xcf\x01\u00c1\xd0\x01\u00c1\xd1\x01\u00c1\xd2\x01\u00c1\xd3\x01\u00c1\xd4\x01\u00c1\xd5\x01\u00c1\xd6\x01\u00c1\xd7\x01\u00c1\xd8\x01\u00c1\xd9\x01\u00c1\xda\x01\u00c1\xdb\x01\u00c1\xdc\x01\u00c1\xdd\x01\u00c1\xde\x01\u00c1\xdf\x01\u00c1\xe0\x01\u00c1\xe1\x01\u00c1\xe2\x01\u00c1\xe3\x01\u00c1\xe4\x01\u00c1\xe5\x01\u00c1\xe6\x01\u00c1\xe7\x01\u00c1\xe8\x01\u00c1\xe9\x01\u00c1\xea\x01\u00c1\xeb\x01\u00c1\xec\x01\u00c1\xed\x01\u00c1\xee\x01\u00c1\xef\x01\u00c1\xf0\x01\u00c1\xf1\x01\u00c1\xf2\x01\u00c1\xf3\x01\u00c1\xf4\x01\u00c1\xf5\x01\u00c1\xf6\x01\u00c1\xf7\x01\u00c1\xf8\x01\u00c1\xf9\x01\u00c1\xfa\x01\u00c1\xfb\x01\u00c1\xfc\x01\u00c1\xfd\x01\u00c1\xfe\x01\u00c1\xff\x01\xf6\x941\xb9\x8d\x14\x00{\xde\xe67)\x80\x86\x98\x8a\v\xbd1\x18E#\xa0\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" -const devAllocData = "\xf9\x01\x94\xc2\x01\x01\xc2\x02\x01\xc2\x03\x01\xc2\x04\x01\xf0\x94\x1a&3\x8f\r\x90^)_\u0337\x1f\xa9\ua11f\xfa\x12\xaa\xf4\x9a\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x94.\xf4q\x00\xe0x{\x91Q\x05\xfd^?O\xf6u y\xd5\u02da\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x94l8jK&\xf7<\x80/4g?rH\xbb\x11\x8f\x97BJ\x9a\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x94\xb9\xc0\x15\x91\x8b\u06ba$\xb4\xff\x05z\x92\xa3\x87=n\xb2\x01\xbe\x9a\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x94\xcd*=\x9f\x93\x8e\x13\u0354~\xc0Z\xbc\u007f\xe74\u07cd\xd8&\x9a\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x94\xdb\xdb\xdb,\xbd#\xb7\x83t\x1e\x8d\u007f\xcfQ\xe4Y\xb4\x97\u499a\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x94\xe4\x15{4\xea\x96\x15\u03fd\xe6\xb4\xfd\xa4\x19\x82\x81$\xb7\fx\x9a\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x94\xe6qo\x95D\xa5lS\r\x86\x8eK\xfb\xac\xb1r1[\u07ad\x9a\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" diff --git a/core/genesis_test.go b/core/genesis_test.go index 482f86190..2fe931b24 100644 --- a/core/genesis_test.go +++ b/core/genesis_test.go @@ -65,7 +65,7 @@ func TestSetupGenesis(t *testing.T) { return SetupGenesisBlock(db, new(Genesis)) }, wantErr: errGenesisNoConfig, - wantConfig: params.AllProtocolChanges, + wantConfig: params.AllEthashProtocolChanges, }, { name: "no block in DB, genesis == nil", diff --git a/core/tx_journal.go b/core/tx_journal.go index 3fd8ece49..e872d7b53 100644 --- a/core/tx_journal.go +++ b/core/tx_journal.go @@ -68,7 +68,7 @@ func (journal *txJournal) load(add func(*types.Transaction) error) error { } defer input.Close() - // Temporarilly discard any journal additions (don't double add on load) + // Temporarily discard any journal additions (don't double add on load) journal.writer = new(devNull) defer func() { journal.writer = nil }() diff --git a/core/tx_list.go b/core/tx_list.go index 2935929d7..838433b89 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -254,7 +254,10 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran old := l.txs.Get(tx.Nonce()) if old != nil { threshold := new(big.Int).Div(new(big.Int).Mul(old.GasPrice(), big.NewInt(100+int64(priceBump))), big.NewInt(100)) - if threshold.Cmp(tx.GasPrice()) >= 0 { + // Have to ensure that the new gas price is higher than the old gas + // price as well as checking the percentage threshold to ensure that + // this is accurate for low (Wei-level) gas price replacements + if old.GasPrice().Cmp(tx.GasPrice()) >= 0 || threshold.Cmp(tx.GasPrice()) > 0 { return false, nil } } diff --git a/core/tx_pool.go b/core/tx_pool.go index a705e36d6..c3915575b 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -103,6 +103,16 @@ var ( underpricedTxCounter = metrics.NewCounter("txpool/underpriced") ) +// TxStatus is the current status of a transaction as seen py the pool. +type TxStatus uint + +const ( + TxStatusUnknown TxStatus = iota + TxStatusQueued + TxStatusPending + TxStatusIncluded +) + // blockChain provides the state of blockchain and current gas limit to do // some pre checks in tx pool and event subscribers. type blockChain interface { @@ -754,14 +764,14 @@ func (pool *TxPool) AddRemote(tx *types.Transaction) error { // AddLocals enqueues a batch of transactions into the pool if they are valid, // marking the senders as a local ones in the mean time, ensuring they go around // the local pricing constraints. -func (pool *TxPool) AddLocals(txs []*types.Transaction) error { +func (pool *TxPool) AddLocals(txs []*types.Transaction) []error { return pool.addTxs(txs, !pool.config.NoLocals) } // AddRemotes enqueues a batch of transactions into the pool if they are valid. // If the senders are not among the locally tracked ones, full pricing constraints // will apply. -func (pool *TxPool) AddRemotes(txs []*types.Transaction) error { +func (pool *TxPool) AddRemotes(txs []*types.Transaction) []error { return pool.addTxs(txs, false) } @@ -784,7 +794,7 @@ func (pool *TxPool) addTx(tx *types.Transaction, local bool) error { } // addTxs attempts to queue a batch of transactions if they are valid. -func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) error { +func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) []error { pool.mu.Lock() defer pool.mu.Unlock() @@ -793,11 +803,14 @@ func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) error { // addTxsLocked attempts to queue a batch of transactions if they are valid, // whilst assuming the transaction pool lock is already held. -func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) error { +func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) []error { // Add the batch of transaction, tracking the accepted ones dirty := make(map[common.Address]struct{}) - for _, tx := range txs { - if replace, err := pool.add(tx, local); err == nil { + errs := make([]error, len(txs)) + + for i, tx := range txs { + var replace bool + if replace, errs[i] = pool.add(tx, local); errs[i] == nil { if !replace { from, _ := types.Sender(pool.signer, tx) // already validated dirty[from] = struct{}{} @@ -807,12 +820,32 @@ func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) error { // Only reprocess the internal state if something was actually added if len(dirty) > 0 { addrs := make([]common.Address, 0, len(dirty)) - for addr, _ := range dirty { + for addr := range dirty { addrs = append(addrs, addr) } pool.promoteExecutables(addrs) } - return nil + return errs +} + +// Status returns the status (unknown/pending/queued) of a batch of transactions +// identified by their hashes. +func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { + pool.mu.RLock() + defer pool.mu.RUnlock() + + status := make([]TxStatus, len(hashes)) + for i, hash := range hashes { + if tx := pool.all[hash]; tx != nil { + from, _ := types.Sender(pool.signer, tx) // already validated + if pool.pending[from].txs.items[tx.Nonce()] != nil { + status[i] = TxStatusPending + } else { + status[i] = TxStatusQueued + } + } + } + return status } // Get returns a transaction if it is contained in the pool @@ -874,7 +907,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { // Gather all the accounts potentially needing updates if accounts == nil { accounts = make([]common.Address, 0, len(pool.queue)) - for addr, _ := range pool.queue { + for addr := range pool.queue { accounts = append(accounts, addr) } } diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index eec128cba..e9ecbb933 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -105,7 +105,7 @@ func validateTxPoolInternals(pool *TxPool) error { for addr, txs := range pool.pending { // Find the last transaction var last uint64 - for nonce, _ := range txs.txs.items { + for nonce := range txs.txs.items { if last < nonce { last = nonce } @@ -1411,10 +1411,10 @@ func TestTransactionReplacement(t *testing.T) { if err := pool.AddRemote(pricedTransaction(0, big.NewInt(100000), big.NewInt(price), key)); err != nil { t.Fatalf("failed to add original proper pending transaction: %v", err) } - if err := pool.AddRemote(pricedTransaction(0, big.NewInt(100000), big.NewInt(threshold), key)); err != ErrReplaceUnderpriced { + if err := pool.AddRemote(pricedTransaction(0, big.NewInt(100001), big.NewInt(threshold-1), key)); err != ErrReplaceUnderpriced { t.Fatalf("original proper pending transaction replacement error mismatch: have %v, want %v", err, ErrReplaceUnderpriced) } - if err := pool.AddRemote(pricedTransaction(0, big.NewInt(100000), big.NewInt(threshold+1), key)); err != nil { + if err := pool.AddRemote(pricedTransaction(0, big.NewInt(100000), big.NewInt(threshold), key)); err != nil { t.Fatalf("failed to replace original proper pending transaction: %v", err) } if err := validateEvents(events, 2); err != nil { @@ -1422,23 +1422,23 @@ func TestTransactionReplacement(t *testing.T) { } // Add queued transactions, ensuring the minimum price bump is enforced for replacement (for ultra low prices too) if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100000), big.NewInt(1), key)); err != nil { - t.Fatalf("failed to add original queued transaction: %v", err) + t.Fatalf("failed to add original cheap queued transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100001), big.NewInt(1), key)); err != ErrReplaceUnderpriced { - t.Fatalf("original queued transaction replacement error mismatch: have %v, want %v", err, ErrReplaceUnderpriced) + t.Fatalf("original cheap queued transaction replacement error mismatch: have %v, want %v", err, ErrReplaceUnderpriced) } if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100000), big.NewInt(2), key)); err != nil { - t.Fatalf("failed to replace original queued transaction: %v", err) + t.Fatalf("failed to replace original cheap queued transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100000), big.NewInt(price), key)); err != nil { - t.Fatalf("failed to add original queued transaction: %v", err) + t.Fatalf("failed to add original proper queued transaction: %v", err) } - if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100001), big.NewInt(threshold), key)); err != ErrReplaceUnderpriced { - t.Fatalf("original queued transaction replacement error mismatch: have %v, want %v", err, ErrReplaceUnderpriced) + if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100001), big.NewInt(threshold-1), key)); err != ErrReplaceUnderpriced { + t.Fatalf("original proper queued transaction replacement error mismatch: have %v, want %v", err, ErrReplaceUnderpriced) } - if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100000), big.NewInt(threshold+1), key)); err != nil { - t.Fatalf("failed to replace original queued transaction: %v", err) + if err := pool.AddRemote(pricedTransaction(2, big.NewInt(100000), big.NewInt(threshold), key)); err != nil { + t.Fatalf("failed to replace original proper queued transaction: %v", err) } if err := validateEvents(events, 0); err != nil { diff --git a/eth/backend.go b/eth/backend.go index 7dcc3bfd6..c7fd8ebd6 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -54,6 +54,7 @@ type LesServer interface { Start(srvr *p2p.Server) Stop() Protocols() []p2p.Protocol + SetBloomBitsIndexer(bbIndexer *core.ChainIndexer) } // Ethereum implements the Ethereum full node service. @@ -95,6 +96,7 @@ type Ethereum struct { func (s *Ethereum) AddLesServer(ls LesServer) { s.lesServer = ls + ls.SetBloomBitsIndexer(s.bloomIndexer) } // New creates a new Ethereum object (including the @@ -154,7 +156,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { eth.blockchain.SetHead(compat.RewindTo) core.WriteChainConfig(chainDb, genesisHash, chainConfig) } - eth.bloomIndexer.Start(eth.blockchain.CurrentHeader(), eth.blockchain.SubscribeChainEvent) + eth.bloomIndexer.Start(eth.blockchain) if config.TxPool.Journal != "" { config.TxPool.Journal = ctx.ResolvePath(config.TxPool.Journal) diff --git a/eth/bloombits.go b/eth/bloombits.go index 32f6c7b31..c5597391c 100644 --- a/eth/bloombits.go +++ b/eth/bloombits.go @@ -58,15 +58,18 @@ func (eth *Ethereum) startBloomHandlers() { case request := <-eth.bloomRequests: task := <-request - task.Bitsets = make([][]byte, len(task.Sections)) for i, section := range task.Sections { head := core.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1) - blob, err := bitutil.DecompressBytes(core.GetBloomBits(eth.chainDb, task.Bit, section, head), int(params.BloomBitsBlocks)/8) - if err != nil { - panic(err) + if compVector, err := core.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil { + if blob, err := bitutil.DecompressBytes(compVector, int(params.BloomBitsBlocks)/8); err == nil { + task.Bitsets[i] = blob + } else { + task.Error = err + } + } else { + task.Error = err } - task.Bitsets[i] = blob } request <- task } @@ -111,12 +114,10 @@ func NewBloomIndexer(db ethdb.Database, size uint64) *core.ChainIndexer { // Reset implements core.ChainIndexerBackend, starting a new bloombits index // section. -func (b *BloomIndexer) Reset(section uint64) { +func (b *BloomIndexer) Reset(section uint64, lastSectionHead common.Hash) error { gen, err := bloombits.NewGenerator(uint(b.size)) - if err != nil { - panic(err) - } b.gen, b.section, b.head = gen, section, common.Hash{} + return err } // Process implements core.ChainIndexerBackend, adding a new header's bloom into diff --git a/eth/filters/bench_test.go b/eth/filters/bench_test.go index 6e22a2da3..40205dc04 100644 --- a/eth/filters/bench_test.go +++ b/eth/filters/bench_test.go @@ -192,7 +192,7 @@ func BenchmarkNoBloomBits(b *testing.B) { start := time.Now() mux := new(event.TypeMux) backend := &testBackend{mux, db, 0, new(event.Feed), new(event.Feed), new(event.Feed), new(event.Feed)} - filter := New(backend, 0, int64(headNum), []common.Address{common.Address{}}, nil) + filter := New(backend, 0, int64(headNum), []common.Address{{}}, nil) filter.Logs(context.Background()) d := time.Since(start) fmt.Println("Finished running filter benchmarks") diff --git a/eth/filters/filter.go b/eth/filters/filter.go index 026cbf95c..d16af84ee 100644 --- a/eth/filters/filter.go +++ b/eth/filters/filter.go @@ -19,7 +19,6 @@ package filters import ( "context" "math/big" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -136,11 +135,11 @@ func (f *Filter) indexedLogs(ctx context.Context, end uint64) ([]*types.Log, err // Create a matcher session and request servicing from the backend matches := make(chan uint64, 64) - session, err := f.matcher.Start(uint64(f.begin), end, matches) + session, err := f.matcher.Start(ctx, uint64(f.begin), end, matches) if err != nil { return nil, err } - defer session.Close(time.Second) + defer session.Close() f.backend.ServiceFilter(ctx, session) @@ -152,9 +151,13 @@ func (f *Filter) indexedLogs(ctx context.Context, end uint64) ([]*types.Log, err case number, ok := <-matches: // Abort if all matches have been fulfilled if !ok { - f.begin = int64(end) + 1 - return logs, nil + err := session.Error() + if err == nil { + f.begin = int64(end) + 1 + } + return logs, err } + f.begin = int64(number) + 1 // Retrieve the suggested block and pull any truly matching logs header, err := f.backend.HeaderByNumber(ctx, rpc.BlockNumber(number)) if header == nil || err != nil { diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index bc3511f23..7da114fda 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -109,7 +109,7 @@ func (b *testBackend) ServiceFilter(ctx context.Context, session *bloombits.Matc for i, section := range task.Sections { if rand.Int()%4 != 0 { // Handle occasional missing deliveries head := core.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1) - task.Bitsets[i] = core.GetBloomBits(b.db, task.Bit, section, head) + task.Bitsets[i], _ = core.GetBloomBits(b.db, task.Bit, section, head) } } request <- task diff --git a/eth/helper_test.go b/eth/helper_test.go index b66553135..f02242b15 100644 --- a/eth/helper_test.go +++ b/eth/helper_test.go @@ -97,7 +97,7 @@ type testTxPool struct { // AddRemotes appends a batch of transactions to the pool, and notifies any // listeners if the addition channel is non nil -func (p *testTxPool) AddRemotes(txs []*types.Transaction) error { +func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error { p.lock.Lock() defer p.lock.Unlock() @@ -105,8 +105,7 @@ func (p *testTxPool) AddRemotes(txs []*types.Transaction) error { if p.added != nil { p.added <- txs } - - return nil + return make([]error, len(txs)) } // Pending returns all the transactions known to the pool diff --git a/eth/protocol.go b/eth/protocol.go index 2c41376fa..cd7db57f2 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -97,7 +97,7 @@ var errorToString = map[int]string{ type txPool interface { // AddRemotes should add the given transactions to the pool. - AddRemotes([]*types.Transaction) error + AddRemotes([]*types.Transaction) []error // Pending should return pending transactions. // The slice should be modifiable by the caller. diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index bb03dc72b..77784ff4a 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -379,7 +379,7 @@ func (s *Service) login(conn *websocket.Conn) error { protocol = fmt.Sprintf("eth/%d", eth.ProtocolVersions[0]) } else { network = fmt.Sprintf("%d", infos.Protocols["les"].(*eth.EthNodeInfo).Network) - protocol = fmt.Sprintf("les/%d", les.ProtocolVersions[0]) + protocol = fmt.Sprintf("les/%d", les.ClientProtocolVersions[0]) } auth := &authMsg{ Id: s.node, diff --git a/les/api_backend.go b/les/api_backend.go index 0d2d31b67..56f617a7d 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -174,8 +174,15 @@ func (b *LesApiBackend) AccountManager() *accounts.Manager { } func (b *LesApiBackend) BloomStatus() (uint64, uint64) { - return params.BloomBitsBlocks, 0 + if b.eth.bloomIndexer == nil { + return 0, 0 + } + sections, _, _ := b.eth.bloomIndexer.Sections() + return light.BloomTrieFrequency, sections } func (b *LesApiBackend) ServiceFilter(ctx context.Context, session *bloombits.MatcherSession) { + for i := 0; i < bloomFilterThreads; i++ { + go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests) + } } diff --git a/les/backend.go b/les/backend.go index 4c33417c0..333df920e 100644 --- a/les/backend.go +++ b/les/backend.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/bloombits" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth/downloader" @@ -61,6 +62,9 @@ type LightEthereum struct { // DB interfaces chainDb ethdb.Database // Block chain database + bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests + bloomIndexer, chtIndexer, bloomTrieIndexer *core.ChainIndexer + ApiBackend *LesApiBackend eventMux *event.TypeMux @@ -87,47 +91,61 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { peers := newPeerSet() quitSync := make(chan struct{}) - eth := &LightEthereum{ - chainConfig: chainConfig, - chainDb: chainDb, - eventMux: ctx.EventMux, - peers: peers, - reqDist: newRequestDistributor(peers, quitSync), - accountManager: ctx.AccountManager, - engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), - shutdownChan: make(chan bool), - networkId: config.NetworkId, + leth := &LightEthereum{ + chainConfig: chainConfig, + chainDb: chainDb, + eventMux: ctx.EventMux, + peers: peers, + reqDist: newRequestDistributor(peers, quitSync), + accountManager: ctx.AccountManager, + engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), + shutdownChan: make(chan bool), + networkId: config.NetworkId, + bloomRequests: make(chan chan *bloombits.Retrieval), + bloomIndexer: eth.NewBloomIndexer(chainDb, light.BloomTrieFrequency), + chtIndexer: light.NewChtIndexer(chainDb, true), + bloomTrieIndexer: light.NewBloomTrieIndexer(chainDb, true), } - eth.relay = NewLesTxRelay(peers, eth.reqDist) - eth.serverPool = newServerPool(chainDb, quitSync, ð.wg) - eth.retriever = newRetrieveManager(peers, eth.reqDist, eth.serverPool) - eth.odr = NewLesOdr(chainDb, eth.retriever) - if eth.blockchain, err = light.NewLightChain(eth.odr, eth.chainConfig, eth.engine); err != nil { + leth.relay = NewLesTxRelay(peers, leth.reqDist) + leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg) + leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) + leth.odr = NewLesOdr(chainDb, leth.chtIndexer, leth.bloomTrieIndexer, leth.bloomIndexer, leth.retriever) + if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine); err != nil { return nil, err } + leth.bloomIndexer.Start(leth.blockchain) // Rewind the chain in case of an incompatible config upgrade. if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) - eth.blockchain.SetHead(compat.RewindTo) + leth.blockchain.SetHead(compat.RewindTo) core.WriteChainConfig(chainDb, genesisHash, chainConfig) } - eth.txPool = light.NewTxPool(eth.chainConfig, eth.blockchain, eth.relay) - if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, true, config.NetworkId, eth.eventMux, eth.engine, eth.peers, eth.blockchain, nil, chainDb, eth.odr, eth.relay, quitSync, ð.wg); err != nil { + leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) + if leth.protocolManager, err = NewProtocolManager(leth.chainConfig, true, ClientProtocolVersions, config.NetworkId, leth.eventMux, leth.engine, leth.peers, leth.blockchain, nil, chainDb, leth.odr, leth.relay, quitSync, &leth.wg); err != nil { return nil, err } - eth.ApiBackend = &LesApiBackend{eth, nil} + leth.ApiBackend = &LesApiBackend{leth, nil} gpoParams := config.GPO if gpoParams.Default == nil { gpoParams.Default = config.GasPrice } - eth.ApiBackend.gpo = gasprice.NewOracle(eth.ApiBackend, gpoParams) - return eth, nil + leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) + return leth, nil } -func lesTopic(genesisHash common.Hash) discv5.Topic { - return discv5.Topic("LES@" + common.Bytes2Hex(genesisHash.Bytes()[0:8])) +func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic { + var name string + switch protocolVersion { + case lpv1: + name = "LES" + case lpv2: + name = "LES2" + default: + panic(nil) + } + return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8])) } type LightDummyAPI struct{} @@ -200,9 +218,13 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { // Start implements node.Service, starting all internal goroutines needed by the // Ethereum protocol implementation. func (s *LightEthereum) Start(srvr *p2p.Server) error { + s.startBloomHandlers() log.Warn("Light client mode is an experimental feature") s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId) - s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash())) + // search the topic belonging to the oldest supported protocol because + // servers always advertise all supported protocols + protocolVersion := ClientProtocolVersions[len(ClientProtocolVersions)-1] + s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) s.protocolManager.Start() return nil } @@ -211,6 +233,15 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error { // Ethereum protocol. func (s *LightEthereum) Stop() error { s.odr.Stop() + if s.bloomIndexer != nil { + s.bloomIndexer.Close() + } + if s.chtIndexer != nil { + s.chtIndexer.Close() + } + if s.bloomTrieIndexer != nil { + s.bloomTrieIndexer.Close() + } s.blockchain.Stop() s.protocolManager.Stop() s.txPool.Stop() diff --git a/les/bloombits.go b/les/bloombits.go new file mode 100644 index 000000000..de233d751 --- /dev/null +++ b/les/bloombits.go @@ -0,0 +1,84 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "time" + + "github.com/ethereum/go-ethereum/common/bitutil" + "github.com/ethereum/go-ethereum/light" +) + +const ( + // bloomServiceThreads is the number of goroutines used globally by an Ethereum + // instance to service bloombits lookups for all running filters. + bloomServiceThreads = 16 + + // bloomFilterThreads is the number of goroutines used locally per filter to + // multiplex requests onto the global servicing goroutines. + bloomFilterThreads = 3 + + // bloomRetrievalBatch is the maximum number of bloom bit retrievals to service + // in a single batch. + bloomRetrievalBatch = 16 + + // bloomRetrievalWait is the maximum time to wait for enough bloom bit requests + // to accumulate request an entire batch (avoiding hysteresis). + bloomRetrievalWait = time.Microsecond * 100 +) + +// startBloomHandlers starts a batch of goroutines to accept bloom bit database +// retrievals from possibly a range of filters and serving the data to satisfy. +func (eth *LightEthereum) startBloomHandlers() { + for i := 0; i < bloomServiceThreads; i++ { + go func() { + for { + select { + case <-eth.shutdownChan: + return + + case request := <-eth.bloomRequests: + task := <-request + task.Bitsets = make([][]byte, len(task.Sections)) + compVectors, err := light.GetBloomBits(task.Context, eth.odr, task.Bit, task.Sections) + if err == nil { + for i := range task.Sections { + if blob, err := bitutil.DecompressBytes(compVectors[i], int(light.BloomTrieFrequency/8)); err == nil { + task.Bitsets[i] = blob + } else { + task.Error = err + } + } + } else { + task.Error = err + } + request <- task + } + } + }() + } +} + +const ( + // bloomConfirms is the number of confirmation blocks before a bloom section is + // considered probably final and its rotated bits are calculated. + bloomConfirms = 256 + + // bloomThrottling is the time to wait between processing two consecutive index + // sections. It's useful during chain upgrades to prevent disk overload. + bloomThrottling = 100 * time.Millisecond +) diff --git a/les/distributor.go b/les/distributor.go index e8ef5b02e..159fa4c73 100644 --- a/les/distributor.go +++ b/les/distributor.go @@ -191,7 +191,7 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) { for (len(d.peers) > 0 || elem == d.reqQueue.Front()) && elem != nil { req := elem.Value.(*distReq) canSend := false - for peer, _ := range d.peers { + for peer := range d.peers { if _, ok := checkedPeers[peer]; !ok && peer.canQueue() && req.canSend(peer) { canSend = true cost := req.getCost(peer) diff --git a/les/distributor_test.go b/les/distributor_test.go index 4e7f8bd29..55defb69b 100644 --- a/les/distributor_test.go +++ b/les/distributor_test.go @@ -124,7 +124,7 @@ func testRequestDistributor(t *testing.T, resend bool) { dist := newRequestDistributor(nil, stop) var peers [testDistPeerCount]*testDistPeer - for i, _ := range peers { + for i := range peers { peers[i] = &testDistPeer{} go peers[i].worker(t, !resend, stop) dist.registerTestPeer(peers[i]) diff --git a/les/fetcher.go b/les/fetcher.go index 4fc142f0f..3fc4df30b 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -117,16 +117,16 @@ func newLightFetcher(pm *ProtocolManager) *lightFetcher { maxConfirmedTd: big.NewInt(0), } pm.peers.notify(f) + + f.pm.wg.Add(1) go f.syncLoop() return f } // syncLoop is the main event loop of the light fetcher func (f *lightFetcher) syncLoop() { - f.pm.wg.Add(1) - defer f.pm.wg.Done() - requesting := false + defer f.pm.wg.Done() for { select { case <-f.pm.quitSync: diff --git a/les/handler.go b/les/handler.go index df7eb6af5..613fbb79f 100644 --- a/les/handler.go +++ b/les/handler.go @@ -18,6 +18,7 @@ package les import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -35,6 +36,7 @@ import ( "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" @@ -50,13 +52,14 @@ const ( ethVersion = 63 // equivalent eth version for the downloader - MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request - MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request - MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request - MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request - MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxHeaderProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxTxSend = 64 // Amount of transactions to be send per request + MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request + MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request + MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request + MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request + MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxHelperTrieProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxTxSend = 64 // Amount of transactions to be send per request + MaxTxStatus = 256 // Amount of transactions to queried per request disableClientRemovePeer = false ) @@ -86,8 +89,8 @@ type BlockChain interface { } type txPool interface { - // AddRemotes should add the given transactions to the pool. - AddRemotes([]*types.Transaction) error + AddRemotes(txs []*types.Transaction) []error + Status(hashes []common.Hash) []core.TxStatus } type ProtocolManager struct { @@ -125,7 +128,7 @@ type ProtocolManager struct { // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // with the ethereum network. -func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) { +func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protocolVersions []uint, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) { // Create the protocol manager with the base fields manager := &ProtocolManager{ lightSync: lightSync, @@ -147,15 +150,16 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network manager.retriever = odr.retriever manager.reqDist = odr.retriever.dist } + // Initiate a sub-protocol for every implemented version we can handle - manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions)) - for i, version := range ProtocolVersions { + manager.SubProtocols = make([]p2p.Protocol, 0, len(protocolVersions)) + for _, version := range protocolVersions { // Compatible, initialize the sub-protocol version := version // Closure for the run manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{ Name: "les", Version: version, - Length: ProtocolLengths[i], + Length: ProtocolLengths[version], Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { var entry *poolEntry peer := manager.newPeer(int(version), networkId, p, rw) @@ -315,7 +319,7 @@ func (pm *ProtocolManager) handle(p *peer) error { } } -var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsMsg, SendTxMsg, GetHeaderProofsMsg} +var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsV1Msg, SendTxMsg, SendTxV2Msg, GetTxStatusMsg, GetHeaderProofsMsg, GetProofsV2Msg, GetHelperTrieProofsMsg} // handleMsg is invoked whenever an inbound message is received from a remote // peer. The remote connection is torn down upon returning any error. @@ -362,11 +366,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Block header query, collect the requested headers and reply case AnnounceMsg: p.Log().Trace("Received announce message") + if p.requestAnnounceType == announceTypeNone { + return errResp(ErrUnexpectedResponse, "") + } var req announceData if err := msg.Decode(&req); err != nil { return errResp(ErrDecode, "%v: %v", msg, err) } + + if p.requestAnnounceType == announceTypeSigned { + if err := req.checkSignature(p.pubKey); err != nil { + p.Log().Trace("Invalid announcement signature", "err", err) + return err + } + p.Log().Trace("Valid announcement signature") + } + p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth) if pm.fetcher != nil { pm.fetcher.announce(p, &req) @@ -655,7 +671,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { Obj: resp.Receipts, } - case GetProofsMsg: + case GetProofsV1Msg: p.Log().Trace("Received proofs request") // Decode the retrieval message var req struct { @@ -690,9 +706,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } } if tr != nil { - proof := tr.Prove(req.Key) + var proof light.NodeList + tr.Prove(req.Key, 0, &proof) proofs = append(proofs, proof) - bytes += len(proof) + bytes += proof.DataSize() } } } @@ -701,7 +718,67 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) return p.SendProofs(req.ReqID, bv, proofs) - case ProofsMsg: + case GetProofsV2Msg: + p.Log().Trace("Received les/2 proofs request") + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []ProofReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + lastBHash common.Hash + lastAccKey []byte + tr, str *trie.Trie + ) + reqCnt := len(req.Reqs) + if reject(uint64(reqCnt), MaxProofsFetch) { + return errResp(ErrRequestRejected, "") + } + + nodes := light.NewNodeSet() + + for _, req := range req.Reqs { + if nodes.DataSize() >= softResponseLimit { + break + } + if tr == nil || req.BHash != lastBHash { + if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + tr, _ = trie.New(header.Root, pm.chainDb) + } else { + tr = nil + } + lastBHash = req.BHash + str = nil + } + if tr != nil { + if len(req.AccKey) > 0 { + if str == nil || !bytes.Equal(req.AccKey, lastAccKey) { + sdata := tr.Get(req.AccKey) + str = nil + var acc state.Account + if err := rlp.DecodeBytes(sdata, &acc); err == nil { + str, _ = trie.New(acc.Root, pm.chainDb) + } + lastAccKey = common.CopyBytes(req.AccKey) + } + if str != nil { + str.Prove(req.Key, req.FromLevel, nodes) + } + } else { + tr.Prove(req.Key, req.FromLevel, nodes) + } + } + } + proofs := nodes.NodeList() + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendProofsV2(req.ReqID, bv, proofs) + + case ProofsV1Msg: if pm.odr == nil { return errResp(ErrUnexpectedResponse, "") } @@ -710,14 +787,35 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // A batch of merkle proofs arrived to one of our previous requests var resp struct { ReqID, BV uint64 - Data [][]rlp.RawValue + Data []light.NodeList } if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } p.fcServer.GotReply(resp.ReqID, resp.BV) deliverMsg = &Msg{ - MsgType: MsgProofs, + MsgType: MsgProofsV1, + ReqID: resp.ReqID, + Obj: resp.Data, + } + + case ProofsV2Msg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received les/2 proofs response") + // A batch of merkle proofs arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Data light.NodeList + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgProofsV2, ReqID: resp.ReqID, Obj: resp.Data, } @@ -738,22 +836,25 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { proofs []ChtResp ) reqCnt := len(req.Reqs) - if reject(uint64(reqCnt), MaxHeaderProofsFetch) { + if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { return errResp(ErrRequestRejected, "") } + trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix) for _, req := range req.Reqs { if bytes >= softResponseLimit { break } if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { - if root := getChtRoot(pm.chainDb, req.ChtNum); root != (common.Hash{}) { - if tr, _ := trie.New(root, pm.chainDb); tr != nil { + sectionHead := core.GetCanonicalHash(pm.chainDb, (req.ChtNum+1)*light.ChtV1Frequency-1) + if root := light.GetChtRoot(pm.chainDb, req.ChtNum, sectionHead); root != (common.Hash{}) { + if tr, _ := trie.New(root, trieDb); tr != nil { var encNumber [8]byte binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) - proof := tr.Prove(encNumber[:]) + var proof light.NodeList + tr.Prove(encNumber[:], 0, &proof) proofs = append(proofs, ChtResp{Header: header, Proof: proof}) - bytes += len(proof) + estHeaderRlpSize + bytes += proof.DataSize() + estHeaderRlpSize } } } @@ -762,6 +863,73 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) return p.SendHeaderProofs(req.ReqID, bv, proofs) + case GetHelperTrieProofsMsg: + p.Log().Trace("Received helper trie proof request") + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []HelperTrieReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + auxBytes int + auxData [][]byte + ) + reqCnt := len(req.Reqs) + if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { + return errResp(ErrRequestRejected, "") + } + + var ( + lastIdx uint64 + lastType uint + root common.Hash + tr *trie.Trie + ) + + nodes := light.NewNodeSet() + + for _, req := range req.Reqs { + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } + if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { + var prefix string + root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx) + if root != (common.Hash{}) { + if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil { + tr = t + } + } + lastType = req.HelperTrieType + lastIdx = req.TrieIdx + } + if req.AuxReq == auxRoot { + var data []byte + if root != (common.Hash{}) { + data = root[:] + } + auxData = append(auxData, data) + auxBytes += len(data) + } else { + if tr != nil { + tr.Prove(req.Key, req.FromLevel, nodes) + } + if req.AuxReq != 0 { + data := pm.getHelperTrieAuxData(req) + auxData = append(auxData, data) + auxBytes += len(data) + } + } + } + proofs := nodes.NodeList() + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendHelperTrieProofs(req.ReqID, bv, HelperTrieResps{Proofs: proofs, AuxData: auxData}) + case HeaderProofsMsg: if pm.odr == nil { return errResp(ErrUnexpectedResponse, "") @@ -782,9 +950,30 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { Obj: resp.Data, } + case HelperTrieProofsMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received helper trie proof response") + var resp struct { + ReqID, BV uint64 + Data HelperTrieResps + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgHelperTrieProofs, + ReqID: resp.ReqID, + Obj: resp.Data, + } + case SendTxMsg: if pm.txpool == nil { - return errResp(ErrUnexpectedResponse, "") + return errResp(ErrRequestRejected, "") } // Transactions arrived, parse all of them and deliver to the pool var txs []*types.Transaction @@ -795,14 +984,85 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if reject(uint64(reqCnt), MaxTxSend) { return errResp(ErrRequestRejected, "") } + pm.txpool.AddRemotes(txs) - if err := pm.txpool.AddRemotes(txs); err != nil { - return errResp(ErrUnexpectedResponse, "msg: %v", err) + _, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + + case SendTxV2Msg: + if pm.txpool == nil { + return errResp(ErrRequestRejected, "") + } + // Transactions arrived, parse all of them and deliver to the pool + var req struct { + ReqID uint64 + Txs []*types.Transaction + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.Txs) + if reject(uint64(reqCnt), MaxTxSend) { + return errResp(ErrRequestRejected, "") } - _, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + hashes := make([]common.Hash, len(req.Txs)) + for i, tx := range req.Txs { + hashes[i] = tx.Hash() + } + stats := pm.txStatus(hashes) + for i, stat := range stats { + if stat.Status == core.TxStatusUnknown { + if errs := pm.txpool.AddRemotes([]*types.Transaction{req.Txs[i]}); errs[0] != nil { + stats[i].Error = errs[0] + continue + } + stats[i] = pm.txStatus([]common.Hash{hashes[i]})[0] + } + } + + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendTxStatus(req.ReqID, bv, stats) + + case GetTxStatusMsg: + if pm.txpool == nil { + return errResp(ErrUnexpectedResponse, "") + } + // Transactions arrived, parse all of them and deliver to the pool + var req struct { + ReqID uint64 + Hashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.Hashes) + if reject(uint64(reqCnt), MaxTxStatus) { + return errResp(ErrRequestRejected, "") + } + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + + return p.SendTxStatus(req.ReqID, bv, pm.txStatus(req.Hashes)) + + case TxStatusMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received tx status response") + var resp struct { + ReqID, BV uint64 + Status []core.TxStatus + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + + p.fcServer.GotReply(resp.ReqID, resp.BV) + default: p.Log().Trace("Received unknown message", "code", msg.Code) return errResp(ErrInvalidMsgCode, "%v", msg.Code) @@ -820,6 +1080,49 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return nil } +// getHelperTrie returns the post-processed trie root for the given trie ID and section index +func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { + switch id { + case htCanonical: + sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.ChtFrequency-1) + return light.GetChtV2Root(pm.chainDb, idx, sectionHead), light.ChtTablePrefix + case htBloomBits: + sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1) + return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix + } + return common.Hash{}, "" +} + +// getHelperTrieAuxData returns requested auxiliary data for the given HelperTrie request +func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte { + if req.HelperTrieType == htCanonical && req.AuxReq == auxHeader { + if len(req.Key) != 8 { + return nil + } + blockNum := binary.BigEndian.Uint64(req.Key) + hash := core.GetCanonicalHash(pm.chainDb, blockNum) + return core.GetHeaderRLP(pm.chainDb, hash, blockNum) + } + return nil +} + +func (pm *ProtocolManager) txStatus(hashes []common.Hash) []txStatus { + stats := make([]txStatus, len(hashes)) + for i, stat := range pm.txpool.Status(hashes) { + // Save the status we've got from the transaction pool + stats[i].Status = stat + + // If the transaction is unknown to the pool, try looking it up locally + if stat == core.TxStatusUnknown { + if block, number, index := core.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) { + stats[i].Status = core.TxStatusIncluded + stats[i].Lookup = &core.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index} + } + } + } + return stats +} + // NodeInfo retrieves some protocol metadata about the running host node. func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo { return ð.EthNodeInfo{ diff --git a/les/handler_test.go b/les/handler_test.go index b1f1aa095..9e63e15a6 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -17,8 +17,11 @@ package les import ( + "bytes" + "math/big" "math/rand" "testing" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -26,7 +29,9 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -39,9 +44,29 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{} return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data}) } +func testCheckProof(t *testing.T, exp *light.NodeSet, got light.NodeList) { + if exp.KeyCount() > len(got) { + t.Errorf("proof has fewer nodes than expected") + return + } + if exp.KeyCount() < len(got) { + t.Errorf("proof has more nodes than expected") + return + } + for _, node := range got { + n, _ := exp.Get(crypto.Keccak256(node)) + if !bytes.Equal(n, node) { + t.Errorf("proof contents mismatch") + return + } + } +} + // Tests that block headers can be retrieved from a remote chain based on user queries. func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) } +func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } + func testGetBlockHeaders(t *testing.T, protocol int) { db, _ := ethdb.NewMemDatabase() pm := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil, nil, nil, db) @@ -171,6 +196,8 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Tests that block contents can be retrieved from a remote chain based on their hashes. func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) } +func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) } + func testGetBlockBodies(t *testing.T, protocol int) { db, _ := ethdb.NewMemDatabase() pm := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil, nil, nil, db) @@ -247,6 +274,8 @@ func testGetBlockBodies(t *testing.T, protocol int) { // Tests that the contract codes can be retrieved based on account addresses. func TestGetCodeLes1(t *testing.T) { testGetCode(t, 1) } +func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) } + func testGetCode(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -280,6 +309,8 @@ func testGetCode(t *testing.T, protocol int) { // Tests that the transaction receipts can be retrieved based on hashes. func TestGetReceiptLes1(t *testing.T) { testGetReceipt(t, 1) } +func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) } + func testGetReceipt(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -307,6 +338,8 @@ func testGetReceipt(t *testing.T, protocol int) { // Tests that trie merkle proofs can be retrieved func TestGetProofsLes1(t *testing.T) { testGetProofs(t, 1) } +func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) } + func testGetProofs(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -315,8 +348,11 @@ func testGetProofs(t *testing.T, protocol int) { peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() - var proofreqs []ProofReq - var proofs [][]rlp.RawValue + var ( + proofreqs []ProofReq + proofsV1 [][]rlp.RawValue + ) + proofsV2 := light.NewNodeSet() accounts := []common.Address{testBankAddress, acc1Addr, acc2Addr, {}} for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { @@ -331,14 +367,135 @@ func testGetProofs(t *testing.T, protocol int) { } proofreqs = append(proofreqs, req) - proof := trie.Prove(crypto.Keccak256(acc[:])) - proofs = append(proofs, proof) + switch protocol { + case 1: + var proof light.NodeList + trie.Prove(crypto.Keccak256(acc[:]), 0, &proof) + proofsV1 = append(proofsV1, proof) + case 2: + trie.Prove(crypto.Keccak256(acc[:]), 0, proofsV2) + } } } // Send the proof request and verify the response - cost := peer.GetRequestCost(GetProofsMsg, len(proofreqs)) - sendRequest(peer.app, GetProofsMsg, 42, cost, proofreqs) - if err := expectResponse(peer.app, ProofsMsg, 42, testBufLimit, proofs); err != nil { - t.Errorf("proofs mismatch: %v", err) + switch protocol { + case 1: + cost := peer.GetRequestCost(GetProofsV1Msg, len(proofreqs)) + sendRequest(peer.app, GetProofsV1Msg, 42, cost, proofreqs) + if err := expectResponse(peer.app, ProofsV1Msg, 42, testBufLimit, proofsV1); err != nil { + t.Errorf("proofs mismatch: %v", err) + } + case 2: + cost := peer.GetRequestCost(GetProofsV2Msg, len(proofreqs)) + sendRequest(peer.app, GetProofsV2Msg, 42, cost, proofreqs) + msg, err := peer.app.ReadMsg() + if err != nil { + t.Errorf("Message read error: %v", err) + } + var resp struct { + ReqID, BV uint64 + Data light.NodeList + } + if err := msg.Decode(&resp); err != nil { + t.Errorf("reply decode error: %v", err) + } + if msg.Code != ProofsV2Msg { + t.Errorf("Message code mismatch") + } + if resp.ReqID != 42 { + t.Errorf("ReqID mismatch") + } + if resp.BV != testBufLimit { + t.Errorf("BV mismatch") + } + testCheckProof(t, proofsV2, resp.Data) + } +} + +func TestTransactionStatusLes2(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, 0, nil, nil, nil, db) + chain := pm.blockchain.(*core.BlockChain) + txpool := core.NewTxPool(core.DefaultTxPoolConfig, params.TestChainConfig, chain) + pm.txpool = txpool + peer, _ := newTestPeer(t, "peer", 2, pm, true) + defer peer.close() + + var reqID uint64 + + test := func(tx *types.Transaction, send bool, expStatus txStatus) { + reqID++ + if send { + cost := peer.GetRequestCost(SendTxV2Msg, 1) + sendRequest(peer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx}) + } else { + cost := peer.GetRequestCost(GetTxStatusMsg, 1) + sendRequest(peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()}) + } + if err := expectResponse(peer.app, TxStatusMsg, reqID, testBufLimit, []txStatus{expStatus}); err != nil { + t.Errorf("transaction status mismatch") + } + } + + signer := types.HomesteadSigner{} + + // test error status by sending an underpriced transaction + tx0, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, nil, nil), signer, testBankKey) + test(tx0, true, txStatus{Status: core.TxStatusUnknown, Error: core.ErrUnderpriced}) + + tx1, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + test(tx1, false, txStatus{Status: core.TxStatusUnknown}) // query before sending, should be unknown + test(tx1, true, txStatus{Status: core.TxStatusPending}) // send valid processable tx, should return pending + test(tx1, true, txStatus{Status: core.TxStatusPending}) // adding it again should not return an error + + tx2, _ := types.SignTx(types.NewTransaction(1, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + tx3, _ := types.SignTx(types.NewTransaction(2, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + // send transactions in the wrong order, tx3 should be queued + test(tx3, true, txStatus{Status: core.TxStatusQueued}) + test(tx2, true, txStatus{Status: core.TxStatusPending}) + // query again, now tx3 should be pending too + test(tx3, false, txStatus{Status: core.TxStatusPending}) + + // generate and add a block with tx1 and tx2 included + gchain, _ := core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 1, func(i int, block *core.BlockGen) { + block.AddTx(tx1) + block.AddTx(tx2) + }) + if _, err := chain.InsertChain(gchain); err != nil { + panic(err) + } + // wait until TxPool processes the inserted block + for i := 0; i < 10; i++ { + if pending, _ := txpool.Stats(); pending == 1 { + break + } + time.Sleep(100 * time.Millisecond) + } + if pending, _ := txpool.Stats(); pending != 1 { + t.Fatalf("pending count mismatch: have %d, want 1", pending) + } + + // check if their status is included now + block1hash := core.GetCanonicalHash(db, 1) + test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}}) + test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}}) + + // create a reorg that rolls them back + gchain, _ = core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 2, func(i int, block *core.BlockGen) {}) + if _, err := chain.InsertChain(gchain); err != nil { + panic(err) + } + // wait until TxPool processes the reorg + for i := 0; i < 10; i++ { + if pending, _ := txpool.Stats(); pending == 3 { + break + } + time.Sleep(100 * time.Millisecond) + } + if pending, _ := txpool.Stats(); pending != 3 { + t.Fatalf("pending count mismatch: have %d, want 3", pending) } + // check if their status is pending again + test(tx1, false, txStatus{Status: core.TxStatusPending}) + test(tx2, false, txStatus{Status: core.TxStatusPending}) } diff --git a/les/helper_test.go b/les/helper_test.go index b33454e1d..a06f84cca 100644 --- a/les/helper_test.go +++ b/les/helper_test.go @@ -43,7 +43,7 @@ import ( var ( testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey) - testBankFunds = big.NewInt(1000000) + testBankFunds = big.NewInt(1000000000000000000) acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") @@ -156,7 +156,13 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor chain = blockchain } - pm, err := NewProtocolManager(gspec.Config, lightSync, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup)) + var protocolVersions []uint + if lightSync { + protocolVersions = ClientProtocolVersions + } else { + protocolVersions = ServerProtocolVersions + } + pm, err := NewProtocolManager(gspec.Config, lightSync, protocolVersions, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup)) if err != nil { return nil, err } diff --git a/les/odr.go b/les/odr.go index 3f7584b48..f8412aaad 100644 --- a/les/odr.go +++ b/les/odr.go @@ -19,6 +19,7 @@ package les import ( "context" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" @@ -26,33 +27,56 @@ import ( // LesOdr implements light.OdrBackend type LesOdr struct { - db ethdb.Database - stop chan struct{} - retriever *retrieveManager + db ethdb.Database + chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer + retriever *retrieveManager + stop chan struct{} } -func NewLesOdr(db ethdb.Database, retriever *retrieveManager) *LesOdr { +func NewLesOdr(db ethdb.Database, chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer, retriever *retrieveManager) *LesOdr { return &LesOdr{ - db: db, - retriever: retriever, - stop: make(chan struct{}), + db: db, + chtIndexer: chtIndexer, + bloomTrieIndexer: bloomTrieIndexer, + bloomIndexer: bloomIndexer, + retriever: retriever, + stop: make(chan struct{}), } } +// Stop cancels all pending retrievals func (odr *LesOdr) Stop() { close(odr.stop) } +// Database returns the backing database func (odr *LesOdr) Database() ethdb.Database { return odr.db } +// ChtIndexer returns the CHT chain indexer +func (odr *LesOdr) ChtIndexer() *core.ChainIndexer { + return odr.chtIndexer +} + +// BloomTrieIndexer returns the bloom trie chain indexer +func (odr *LesOdr) BloomTrieIndexer() *core.ChainIndexer { + return odr.bloomTrieIndexer +} + +// BloomIndexer returns the bloombits chain indexer +func (odr *LesOdr) BloomIndexer() *core.ChainIndexer { + return odr.bloomIndexer +} + const ( MsgBlockBodies = iota MsgCode MsgReceipts - MsgProofs + MsgProofsV1 + MsgProofsV2 MsgHeaderProofs + MsgHelperTrieProofs ) // Msg encodes a LES message that delivers reply data for a request @@ -64,7 +88,7 @@ type Msg struct { // Retrieve tries to fetch an object from the LES network. // If the network retrieval was successful, it stores the object in local db. -func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { +func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { lreq := LesRequest(req) reqID := genReqID() @@ -84,9 +108,9 @@ func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err err }, } - if err = self.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(self.db, msg) }); err == nil { + if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { // retrieved from network, store in db - req.StoreResult(self.db) + req.StoreResult(odr.db) } else { log.Debug("Failed to retrieve data from network", "err", err) } diff --git a/les/odr_requests.go b/les/odr_requests.go index 1f853b341..937a4f1d9 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -36,13 +36,15 @@ import ( var ( errInvalidMessageType = errors.New("invalid message type") - errMultipleEntries = errors.New("multiple response entries") + errInvalidEntryCount = errors.New("invalid number of response entries") errHeaderUnavailable = errors.New("header unavailable") errTxHashMismatch = errors.New("transaction hash mismatch") errUncleHashMismatch = errors.New("uncle hash mismatch") errReceiptHashMismatch = errors.New("receipt hash mismatch") errDataHashMismatch = errors.New("data hash mismatch") errCHTHashMismatch = errors.New("cht hash mismatch") + errCHTNumberMismatch = errors.New("cht number mismatch") + errUselessNodes = errors.New("useless nodes in merkle proof nodeset") ) type LesOdrRequest interface { @@ -64,6 +66,8 @@ func LesRequest(req light.OdrRequest) LesOdrRequest { return (*CodeRequest)(r) case *light.ChtRequest: return (*ChtRequest)(r) + case *light.BloomRequest: + return (*BloomRequest)(r) default: return nil } @@ -101,7 +105,7 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { } bodies := msg.Obj.([]*types.Body) if len(bodies) != 1 { - return errMultipleEntries + return errInvalidEntryCount } body := bodies[0] @@ -157,7 +161,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { } receipts := msg.Obj.([]types.Receipts) if len(receipts) != 1 { - return errMultipleEntries + return errInvalidEntryCount } receipt := receipts[0] @@ -186,7 +190,14 @@ type TrieRequest light.TrieRequest // GetCost returns the cost of the given ODR request according to the serving // peer's cost table (implementation of LesOdrRequest) func (r *TrieRequest) GetCost(peer *peer) uint64 { - return peer.GetRequestCost(GetProofsMsg, 1) + switch peer.version { + case lpv1: + return peer.GetRequestCost(GetProofsV1Msg, 1) + case lpv2: + return peer.GetRequestCost(GetProofsV2Msg, 1) + default: + panic(nil) + } } // CanSend tells if a certain peer is suitable for serving the given request @@ -197,12 +208,12 @@ func (r *TrieRequest) CanSend(peer *peer) bool { // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *TrieRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting trie proof", "root", r.Id.Root, "key", r.Key) - req := &ProofReq{ + req := ProofReq{ BHash: r.Id.BlockHash, AccKey: r.Id.AccKey, Key: r.Key, } - return peer.RequestProofs(reqID, r.GetCost(peer), []*ProofReq{req}) + return peer.RequestProofs(reqID, r.GetCost(peer), []ProofReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -211,20 +222,38 @@ func (r *TrieRequest) Request(reqID uint64, peer *peer) error { func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error { log.Debug("Validating trie proof", "root", r.Id.Root, "key", r.Key) - // Ensure we have a correct message with a single proof - if msg.MsgType != MsgProofs { + switch msg.MsgType { + case MsgProofsV1: + proofs := msg.Obj.([]light.NodeList) + if len(proofs) != 1 { + return errInvalidEntryCount + } + nodeSet := proofs[0].NodeSet() + // Verify the proof and store if checks out + if _, err, _ := trie.VerifyProof(r.Id.Root, r.Key, nodeSet); err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + r.Proof = nodeSet + return nil + + case MsgProofsV2: + proofs := msg.Obj.(light.NodeList) + // Verify the proof and store if checks out + nodeSet := proofs.NodeSet() + reads := &readTraceDB{db: nodeSet} + if _, err, _ := trie.VerifyProof(r.Id.Root, r.Key, reads); err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + // check if all nodes have been read by VerifyProof + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + r.Proof = nodeSet + return nil + + default: return errInvalidMessageType } - proofs := msg.Obj.([][]rlp.RawValue) - if len(proofs) != 1 { - return errMultipleEntries - } - // Verify the proof and store if checks out - if _, err := trie.VerifyProof(r.Id.Root, r.Key, proofs[0]); err != nil { - return fmt.Errorf("merkle proof verification failed: %v", err) - } - r.Proof = proofs[0] - return nil } type CodeReq struct { @@ -249,11 +278,11 @@ func (r *CodeRequest) CanSend(peer *peer) bool { // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *CodeRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting code data", "hash", r.Hash) - req := &CodeReq{ + req := CodeReq{ BHash: r.Id.BlockHash, AccKey: r.Id.AccKey, } - return peer.RequestCode(reqID, r.GetCost(peer), []*CodeReq{req}) + return peer.RequestCode(reqID, r.GetCost(peer), []CodeReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -268,7 +297,7 @@ func (r *CodeRequest) Validate(db ethdb.Database, msg *Msg) error { } reply := msg.Obj.([][]byte) if len(reply) != 1 { - return errMultipleEntries + return errInvalidEntryCount } data := reply[0] @@ -280,10 +309,36 @@ func (r *CodeRequest) Validate(db ethdb.Database, msg *Msg) error { return nil } +const ( + // helper trie type constants + htCanonical = iota // Canonical hash trie + htBloomBits // BloomBits trie + + // applicable for all helper trie requests + auxRoot = 1 + // applicable for htCanonical + auxHeader = 2 +) + +type HelperTrieReq struct { + HelperTrieType uint + TrieIdx uint64 + Key []byte + FromLevel, AuxReq uint +} + +type HelperTrieResps struct { // describes all responses, not just a single one + Proofs light.NodeList + AuxData [][]byte +} + +// legacy LES/1 type ChtReq struct { - ChtNum, BlockNum, FromLevel uint64 + ChtNum, BlockNum uint64 + FromLevel uint } +// legacy LES/1 type ChtResp struct { Header *types.Header Proof []rlp.RawValue @@ -295,7 +350,14 @@ type ChtRequest light.ChtRequest // GetCost returns the cost of the given ODR request according to the serving // peer's cost table (implementation of LesOdrRequest) func (r *ChtRequest) GetCost(peer *peer) uint64 { - return peer.GetRequestCost(GetHeaderProofsMsg, 1) + switch peer.version { + case lpv1: + return peer.GetRequestCost(GetHeaderProofsMsg, 1) + case lpv2: + return peer.GetRequestCost(GetHelperTrieProofsMsg, 1) + default: + panic(nil) + } } // CanSend tells if a certain peer is suitable for serving the given request @@ -303,17 +365,21 @@ func (r *ChtRequest) CanSend(peer *peer) bool { peer.lock.RLock() defer peer.lock.RUnlock() - return r.ChtNum <= (peer.headInfo.Number-light.ChtConfirmations)/light.ChtFrequency + return peer.headInfo.Number >= light.HelperTrieConfirmations && r.ChtNum <= (peer.headInfo.Number-light.HelperTrieConfirmations)/light.ChtFrequency } // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *ChtRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting CHT", "cht", r.ChtNum, "block", r.BlockNum) - req := &ChtReq{ - ChtNum: r.ChtNum, - BlockNum: r.BlockNum, + var encNum [8]byte + binary.BigEndian.PutUint64(encNum[:], r.BlockNum) + req := HelperTrieReq{ + HelperTrieType: htCanonical, + TrieIdx: r.ChtNum, + Key: encNum[:], + AuxReq: auxHeader, } - return peer.RequestHeaderProofs(reqID, r.GetCost(peer), []*ChtReq{req}) + return peer.RequestHelperTrieProofs(reqID, r.GetCost(peer), []HelperTrieReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -322,35 +388,179 @@ func (r *ChtRequest) Request(reqID uint64, peer *peer) error { func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error { log.Debug("Validating CHT", "cht", r.ChtNum, "block", r.BlockNum) - // Ensure we have a correct message with a single proof element - if msg.MsgType != MsgHeaderProofs { + switch msg.MsgType { + case MsgHeaderProofs: // LES/1 backwards compatibility + proofs := msg.Obj.([]ChtResp) + if len(proofs) != 1 { + return errInvalidEntryCount + } + proof := proofs[0] + + // Verify the CHT + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) + + value, err, _ := trie.VerifyProof(r.ChtRoot, encNumber[:], light.NodeList(proof.Proof).NodeSet()) + if err != nil { + return err + } + var node light.ChtNode + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != proof.Header.Hash() { + return errCHTHashMismatch + } + // Verifications passed, store and return + r.Header = proof.Header + r.Proof = light.NodeList(proof.Proof).NodeSet() + r.Td = node.Td + case MsgHelperTrieProofs: + resp := msg.Obj.(HelperTrieResps) + if len(resp.AuxData) != 1 { + return errInvalidEntryCount + } + nodeSet := resp.Proofs.NodeSet() + headerEnc := resp.AuxData[0] + if len(headerEnc) == 0 { + return errHeaderUnavailable + } + header := new(types.Header) + if err := rlp.DecodeBytes(headerEnc, header); err != nil { + return errHeaderUnavailable + } + + // Verify the CHT + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) + + reads := &readTraceDB{db: nodeSet} + value, err, _ := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) + if err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + + var node light.ChtNode + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != header.Hash() { + return errCHTHashMismatch + } + if r.BlockNum != header.Number.Uint64() { + return errCHTNumberMismatch + } + // Verifications passed, store and return + r.Header = header + r.Proof = nodeSet + r.Td = node.Td + default: return errInvalidMessageType } - proofs := msg.Obj.([]ChtResp) - if len(proofs) != 1 { - return errMultipleEntries - } - proof := proofs[0] + return nil +} - // Verify the CHT - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) +type BloomReq struct { + BloomTrieNum, BitIdx, SectionIdx, FromLevel uint64 +} - value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], proof.Proof) - if err != nil { - return err +// ODR request type for requesting headers by Canonical Hash Trie, see LesOdrRequest interface +type BloomRequest light.BloomRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (r *BloomRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetHelperTrieProofsMsg, len(r.SectionIdxList)) +} + +// CanSend tells if a certain peer is suitable for serving the given request +func (r *BloomRequest) CanSend(peer *peer) bool { + peer.lock.RLock() + defer peer.lock.RUnlock() + + if peer.version < lpv2 { + return false } - var node light.ChtNode - if err := rlp.DecodeBytes(value, &node); err != nil { - return err + return peer.headInfo.Number >= light.HelperTrieConfirmations && r.BloomTrieNum <= (peer.headInfo.Number-light.HelperTrieConfirmations)/light.BloomTrieFrequency +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (r *BloomRequest) Request(reqID uint64, peer *peer) error { + peer.Log().Debug("Requesting BloomBits", "bloomTrie", r.BloomTrieNum, "bitIdx", r.BitIdx, "sections", r.SectionIdxList) + reqs := make([]HelperTrieReq, len(r.SectionIdxList)) + + var encNumber [10]byte + binary.BigEndian.PutUint16(encNumber[0:2], uint16(r.BitIdx)) + + for i, sectionIdx := range r.SectionIdxList { + binary.BigEndian.PutUint64(encNumber[2:10], sectionIdx) + reqs[i] = HelperTrieReq{ + HelperTrieType: htBloomBits, + TrieIdx: r.BloomTrieNum, + Key: common.CopyBytes(encNumber[:]), + } } - if node.Hash != proof.Header.Hash() { - return errCHTHashMismatch + return peer.RequestHelperTrieProofs(reqID, r.GetCost(peer), reqs) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (r *BloomRequest) Validate(db ethdb.Database, msg *Msg) error { + log.Debug("Validating BloomBits", "bloomTrie", r.BloomTrieNum, "bitIdx", r.BitIdx, "sections", r.SectionIdxList) + + // Ensure we have a correct message with a single proof element + if msg.MsgType != MsgHelperTrieProofs { + return errInvalidMessageType + } + resps := msg.Obj.(HelperTrieResps) + proofs := resps.Proofs + nodeSet := proofs.NodeSet() + reads := &readTraceDB{db: nodeSet} + + r.BloomBits = make([][]byte, len(r.SectionIdxList)) + + // Verify the proofs + var encNumber [10]byte + binary.BigEndian.PutUint16(encNumber[0:2], uint16(r.BitIdx)) + + for i, idx := range r.SectionIdxList { + binary.BigEndian.PutUint64(encNumber[2:10], idx) + value, err, _ := trie.VerifyProof(r.BloomTrieRoot, encNumber[:], reads) + if err != nil { + return err + } + r.BloomBits[i] = value } - // Verifications passed, store and return - r.Header = proof.Header - r.Proof = proof.Proof - r.Td = node.Td + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + r.Proofs = nodeSet return nil } + +// readTraceDB stores the keys of database reads. We use this to check that received node +// sets contain only the trie nodes necessary to make proofs pass. +type readTraceDB struct { + db trie.DatabaseReader + reads map[string]struct{} +} + +// Get returns a stored node +func (db *readTraceDB) Get(k []byte) ([]byte, error) { + if db.reads == nil { + db.reads = make(map[string]struct{}) + } + db.reads[string(k)] = struct{}{} + return db.db.Get(k) +} + +// Has returns true if the node set contains the given key +func (db *readTraceDB) Has(key []byte) (bool, error) { + _, err := db.Get(key) + return err == nil, nil +} diff --git a/les/odr_test.go b/les/odr_test.go index f56c4036d..865f5d83e 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/params" @@ -39,6 +40,8 @@ type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.Chain func TestOdrGetBlockLes1(t *testing.T) { testOdr(t, 1, 1, odrGetBlock) } +func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, odrGetBlock) } + func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var block *types.Block if bc != nil { @@ -55,6 +58,8 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon func TestOdrGetReceiptsLes1(t *testing.T) { testOdr(t, 1, 1, odrGetReceipts) } +func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, odrGetReceipts) } + func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts if bc != nil { @@ -71,6 +76,8 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain func TestOdrAccountsLes1(t *testing.T) { testOdr(t, 1, 1, odrAccounts) } +func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, odrAccounts) } + func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} @@ -100,6 +107,8 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon func TestOdrContractCallLes1(t *testing.T) { testOdr(t, 1, 2, odrContractCall) } +func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, odrContractCall) } + type callmsg struct { types.Message } @@ -154,7 +163,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { rm := newRetrieveManager(peers, dist, nil) db, _ := ethdb.NewMemDatabase() ldb, _ := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, rm) + odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) diff --git a/les/peer.go b/les/peer.go index 3ba2df3fe..524690e2f 100644 --- a/les/peer.go +++ b/les/peer.go @@ -18,6 +18,8 @@ package les import ( + "crypto/ecdsa" + "encoding/binary" "errors" "fmt" "math/big" @@ -28,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rlp" ) @@ -40,14 +43,23 @@ var ( const maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam) +const ( + announceTypeNone = iota + announceTypeSimple + announceTypeSigned +) + type peer struct { *p2p.Peer + pubKey *ecdsa.PublicKey rw p2p.MsgReadWriter version int // Protocol version negotiated network uint64 // Network ID being on + announceType, requestAnnounceType uint64 + id string headInfo *announceData @@ -68,9 +80,11 @@ type peer struct { func newPeer(version int, network uint64, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { id := p.ID() + pubKey, _ := id.Pubkey() return &peer{ Peer: p, + pubKey: pubKey, rw: rw, version: version, network: network, @@ -197,16 +211,31 @@ func (p *peer) SendReceiptsRLP(reqID, bv uint64, receipts []rlp.RawValue) error return sendResponse(p.rw, ReceiptsMsg, reqID, bv, receipts) } -// SendProofs sends a batch of merkle proofs, corresponding to the ones requested. +// SendProofs sends a batch of legacy LES/1 merkle proofs, corresponding to the ones requested. func (p *peer) SendProofs(reqID, bv uint64, proofs proofsData) error { - return sendResponse(p.rw, ProofsMsg, reqID, bv, proofs) + return sendResponse(p.rw, ProofsV1Msg, reqID, bv, proofs) } -// SendHeaderProofs sends a batch of header proofs, corresponding to the ones requested. +// SendProofsV2 sends a batch of merkle proofs, corresponding to the ones requested. +func (p *peer) SendProofsV2(reqID, bv uint64, proofs light.NodeList) error { + return sendResponse(p.rw, ProofsV2Msg, reqID, bv, proofs) +} + +// SendHeaderProofs sends a batch of legacy LES/1 header proofs, corresponding to the ones requested. func (p *peer) SendHeaderProofs(reqID, bv uint64, proofs []ChtResp) error { return sendResponse(p.rw, HeaderProofsMsg, reqID, bv, proofs) } +// SendHelperTrieProofs sends a batch of HelperTrie proofs, corresponding to the ones requested. +func (p *peer) SendHelperTrieProofs(reqID, bv uint64, resp HelperTrieResps) error { + return sendResponse(p.rw, HelperTrieProofsMsg, reqID, bv, resp) +} + +// SendTxStatus sends a batch of transaction status records, corresponding to the ones requested. +func (p *peer) SendTxStatus(reqID, bv uint64, stats []txStatus) error { + return sendResponse(p.rw, TxStatusMsg, reqID, bv, stats) +} + // RequestHeadersByHash fetches a batch of blocks' headers corresponding to the // specified header query, based on the hash of an origin block. func (p *peer) RequestHeadersByHash(reqID, cost uint64, origin common.Hash, amount int, skip int, reverse bool) error { @@ -230,7 +259,7 @@ func (p *peer) RequestBodies(reqID, cost uint64, hashes []common.Hash) error { // RequestCode fetches a batch of arbitrary data from a node's known state // data, corresponding to the specified hashes. -func (p *peer) RequestCode(reqID, cost uint64, reqs []*CodeReq) error { +func (p *peer) RequestCode(reqID, cost uint64, reqs []CodeReq) error { p.Log().Debug("Fetching batch of codes", "count", len(reqs)) return sendRequest(p.rw, GetCodeMsg, reqID, cost, reqs) } @@ -242,20 +271,58 @@ func (p *peer) RequestReceipts(reqID, cost uint64, hashes []common.Hash) error { } // RequestProofs fetches a batch of merkle proofs from a remote node. -func (p *peer) RequestProofs(reqID, cost uint64, reqs []*ProofReq) error { +func (p *peer) RequestProofs(reqID, cost uint64, reqs []ProofReq) error { p.Log().Debug("Fetching batch of proofs", "count", len(reqs)) - return sendRequest(p.rw, GetProofsMsg, reqID, cost, reqs) + switch p.version { + case lpv1: + return sendRequest(p.rw, GetProofsV1Msg, reqID, cost, reqs) + case lpv2: + return sendRequest(p.rw, GetProofsV2Msg, reqID, cost, reqs) + default: + panic(nil) + } + +} + +// RequestHelperTrieProofs fetches a batch of HelperTrie merkle proofs from a remote node. +func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, reqs []HelperTrieReq) error { + p.Log().Debug("Fetching batch of HelperTrie proofs", "count", len(reqs)) + switch p.version { + case lpv1: + reqsV1 := make([]ChtReq, len(reqs)) + for i, req := range reqs { + if req.HelperTrieType != htCanonical || req.AuxReq != auxHeader || len(req.Key) != 8 { + return fmt.Errorf("Request invalid in LES/1 mode") + } + blockNum := binary.BigEndian.Uint64(req.Key) + // convert HelperTrie request to old CHT request + reqsV1[i] = ChtReq{ChtNum: (req.TrieIdx+1)*(light.ChtFrequency/light.ChtV1Frequency) - 1, BlockNum: blockNum, FromLevel: req.FromLevel} + } + return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqsV1) + case lpv2: + return sendRequest(p.rw, GetHelperTrieProofsMsg, reqID, cost, reqs) + default: + panic(nil) + } } -// RequestHeaderProofs fetches a batch of header merkle proofs from a remote node. -func (p *peer) RequestHeaderProofs(reqID, cost uint64, reqs []*ChtReq) error { - p.Log().Debug("Fetching batch of header proofs", "count", len(reqs)) - return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqs) +// RequestTxStatus fetches a batch of transaction status records from a remote node. +func (p *peer) RequestTxStatus(reqID, cost uint64, txHashes []common.Hash) error { + p.Log().Debug("Requesting transaction status", "count", len(txHashes)) + return sendRequest(p.rw, GetTxStatusMsg, reqID, cost, txHashes) } +// SendTxStatus sends a batch of transactions to be added to the remote transaction pool. func (p *peer) SendTxs(reqID, cost uint64, txs types.Transactions) error { p.Log().Debug("Fetching batch of transactions", "count", len(txs)) - return p2p.Send(p.rw, SendTxMsg, txs) + switch p.version { + case lpv1: + return p2p.Send(p.rw, SendTxMsg, txs) // old message format does not include reqID + case lpv2: + return sendRequest(p.rw, SendTxV2Msg, reqID, cost, txs) + default: + panic(nil) + } } type keyValueEntry struct { @@ -289,7 +356,7 @@ func (l keyValueList) decode() keyValueMap { func (m keyValueMap) get(key string, val interface{}) error { enc, ok := m[key] if !ok { - return errResp(ErrHandshakeMissingKey, "%s", key) + return errResp(ErrMissingKey, "%s", key) } if val == nil { return nil @@ -348,6 +415,9 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis list := server.fcCostStats.getCurrentList() send = send.add("flowControl/MRC", list) p.fcCosts = list.decode() + } else { + p.requestAnnounceType = announceTypeSimple // set to default until "very light" client mode is implemented + send = send.add("announceType", p.requestAnnounceType) } recvList, err := p.sendReceiveHandshake(send) if err != nil { @@ -392,6 +462,9 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis /*if recv.get("serveStateSince", nil) == nil { return errResp(ErrUselessPeer, "wanted client, got server") }*/ + if recv.get("announceType", &p.announceType) != nil { + p.announceType = announceTypeSimple + } p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) } else { if recv.get("serveChainSince", nil) != nil { @@ -456,11 +529,15 @@ func newPeerSet() *peerSet { // notify adds a service to be notified about added or removed peers func (ps *peerSet) notify(n peerSetNotify) { ps.lock.Lock() - defer ps.lock.Unlock() - ps.notifyList = append(ps.notifyList, n) + peers := make([]*peer, 0, len(ps.peers)) for _, p := range ps.peers { - go n.registerPeer(p) + peers = append(peers, p) + } + ps.lock.Unlock() + + for _, p := range peers { + n.registerPeer(p) } } @@ -468,8 +545,6 @@ func (ps *peerSet) notify(n peerSetNotify) { // peer is already known. func (ps *peerSet) Register(p *peer) error { ps.lock.Lock() - defer ps.lock.Unlock() - if ps.closed { return errClosed } @@ -478,8 +553,12 @@ func (ps *peerSet) Register(p *peer) error { } ps.peers[p.id] = p p.sendQueue = newExecQueue(100) - for _, n := range ps.notifyList { - go n.registerPeer(p) + peers := make([]peerSetNotify, len(ps.notifyList)) + copy(peers, ps.notifyList) + ps.lock.Unlock() + + for _, n := range peers { + n.registerPeer(p) } return nil } @@ -488,19 +567,22 @@ func (ps *peerSet) Register(p *peer) error { // actions to/from that particular entity. It also initiates disconnection at the networking layer. func (ps *peerSet) Unregister(id string) error { ps.lock.Lock() - defer ps.lock.Unlock() - if p, ok := ps.peers[id]; !ok { + ps.lock.Unlock() return errNotRegistered } else { - for _, n := range ps.notifyList { - go n.unregisterPeer(p) + delete(ps.peers, id) + peers := make([]peerSetNotify, len(ps.notifyList)) + copy(peers, ps.notifyList) + ps.lock.Unlock() + + for _, n := range peers { + n.unregisterPeer(p) } p.sendQueue.quit() p.Peer.Disconnect(p2p.DiscUselessPeer) + return nil } - delete(ps.peers, id) - return nil } // AllPeerIDs returns a list of all registered peer IDs diff --git a/les/protocol.go b/les/protocol.go index 9f06c8b42..9b5068f33 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -18,24 +18,35 @@ package les import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "errors" "fmt" "io" "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/rlp" ) // Constants to match up protocol versions and messages const ( lpv1 = 1 + lpv2 = 2 ) -// Supported versions of the les protocol (first is primary). -var ProtocolVersions = []uint{lpv1} +// Supported versions of the les protocol (first is primary) +var ( + ClientProtocolVersions = []uint{lpv2, lpv1} + ServerProtocolVersions = []uint{lpv2, lpv1} +) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = []uint64{15} +var ProtocolLengths = map[uint]uint64{lpv1: 15, lpv2: 22} const ( NetworkId = 7762959 @@ -53,13 +64,21 @@ const ( BlockBodiesMsg = 0x05 GetReceiptsMsg = 0x06 ReceiptsMsg = 0x07 - GetProofsMsg = 0x08 - ProofsMsg = 0x09 + GetProofsV1Msg = 0x08 + ProofsV1Msg = 0x09 GetCodeMsg = 0x0a CodeMsg = 0x0b SendTxMsg = 0x0c GetHeaderProofsMsg = 0x0d HeaderProofsMsg = 0x0e + // Protocol messages belonging to LPV2 + GetProofsV2Msg = 0x0f + ProofsV2Msg = 0x10 + GetHelperTrieProofsMsg = 0x11 + HelperTrieProofsMsg = 0x12 + SendTxV2Msg = 0x13 + GetTxStatusMsg = 0x14 + TxStatusMsg = 0x15 ) type errCode int @@ -79,7 +98,7 @@ const ( ErrUnexpectedResponse ErrInvalidResponse ErrTooManyTimeouts - ErrHandshakeMissingKey + ErrMissingKey ) func (e errCode) String() string { @@ -101,7 +120,13 @@ var errorToString = map[int]string{ ErrUnexpectedResponse: "Unexpected response", ErrInvalidResponse: "Invalid response", ErrTooManyTimeouts: "Too many request timeouts", - ErrHandshakeMissingKey: "Key missing from handshake message", + ErrMissingKey: "Key missing from list", +} + +type announceBlock struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced + Td *big.Int // Total difficulty of one particular block being announced } // announceData is the network packet for the block announcements. @@ -113,6 +138,32 @@ type announceData struct { Update keyValueList } +// sign adds a signature to the block announcement by the given privKey +func (a *announceData) sign(privKey *ecdsa.PrivateKey) { + rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + sig, _ := crypto.Sign(crypto.Keccak256(rlp), privKey) + a.Update = a.Update.add("sign", sig) +} + +// checkSignature verifies if the block announcement has a valid signature by the given pubKey +func (a *announceData) checkSignature(pubKey *ecdsa.PublicKey) error { + var sig []byte + if err := a.Update.decode().get("sign", &sig); err != nil { + return err + } + rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + recPubkey, err := secp256k1.RecoverPubkey(crypto.Keccak256(rlp), sig) + if err != nil { + return err + } + pbytes := elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y) + if bytes.Equal(pbytes, recPubkey) { + return nil + } else { + return errors.New("Wrong signature") + } +} + type blockInfo struct { Hash common.Hash // Hash of one particular block being announced Number uint64 // Number of one particular block being announced @@ -169,3 +220,9 @@ type CodeData []struct { } type proofsData [][]rlp.RawValue + +type txStatus struct { + Status core.TxStatus + Lookup *core.TxLookupEntry + Error error +} diff --git a/les/request_test.go b/les/request_test.go index 6b594462d..c13625de8 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" ) @@ -38,24 +39,32 @@ type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) ligh func TestBlockAccessLes1(t *testing.T) { testAccess(t, 1, tfBlockAccess) } +func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } + func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.BlockRequest{Hash: bhash, Number: number} } func TestReceiptsAccessLes1(t *testing.T) { testAccess(t, 1, tfReceiptsAccess) } +func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } + func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.ReceiptsRequest{Hash: bhash, Number: number} } func TestTrieEntryAccessLes1(t *testing.T) { testAccess(t, 1, tfTrieEntryAccess) } +func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } + func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} } func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) } +func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } + func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash)) if header.Number.Uint64() < testContractDeployed { @@ -73,7 +82,7 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { rm := newRetrieveManager(peers, dist, nil) db, _ := ethdb.NewMemDatabase() ldb, _ := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, rm) + odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) diff --git a/les/retrieve.go b/les/retrieve.go index b060e0b0d..dd15b56ac 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -22,6 +22,7 @@ import ( "context" "crypto/rand" "encoding/binary" + "fmt" "sync" "time" @@ -111,12 +112,14 @@ func newRetrieveManager(peers *peerSet, dist *requestDistributor, serverPool pee // that is delivered through the deliver function and successfully validated by the // validator callback. It returns when a valid answer is delivered or the context is // cancelled. -func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc) error { +func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc, shutdown chan struct{}) error { sentReq := rm.sendReq(reqID, req, val) select { case <-sentReq.stopCh: case <-ctx.Done(): sentReq.stop(ctx.Err()) + case <-shutdown: + sentReq.stop(fmt.Errorf("Client is shutting down")) } return sentReq.getError() } diff --git a/les/server.go b/les/server.go index 8b2730714..d8f93cd87 100644 --- a/les/server.go +++ b/les/server.go @@ -18,10 +18,11 @@ package les import ( + "crypto/ecdsa" "encoding/binary" + "fmt" "math" "sync" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -34,7 +35,6 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" ) type LesServer struct { @@ -42,23 +42,55 @@ type LesServer struct { fcManager *flowcontrol.ClientManager // nil if our node is client only fcCostStats *requestCostStats defParams *flowcontrol.ServerParams - lesTopic discv5.Topic + lesTopics []discv5.Topic + privateKey *ecdsa.PrivateKey quitSync chan struct{} + + chtIndexer, bloomTrieIndexer *core.ChainIndexer } func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { quitSync := make(chan struct{}) - pm, err := NewProtocolManager(eth.BlockChain().Config(), false, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup)) + pm, err := NewProtocolManager(eth.BlockChain().Config(), false, ServerProtocolVersions, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup)) if err != nil { return nil, err } - pm.blockLoop() + + lesTopics := make([]discv5.Topic, len(ServerProtocolVersions)) + for i, pv := range ServerProtocolVersions { + lesTopics[i] = lesTopic(eth.BlockChain().Genesis().Hash(), pv) + } srv := &LesServer{ - protocolManager: pm, - quitSync: quitSync, - lesTopic: lesTopic(eth.BlockChain().Genesis().Hash()), + protocolManager: pm, + quitSync: quitSync, + lesTopics: lesTopics, + chtIndexer: light.NewChtIndexer(eth.ChainDb(), false), + bloomTrieIndexer: light.NewBloomTrieIndexer(eth.ChainDb(), false), + } + logger := log.New() + + chtV1SectionCount, _, _ := srv.chtIndexer.Sections() // indexer still uses LES/1 4k section size for backwards server compatibility + chtV2SectionCount := chtV1SectionCount / (light.ChtFrequency / light.ChtV1Frequency) + if chtV2SectionCount != 0 { + // convert to LES/2 section + chtLastSection := chtV2SectionCount - 1 + // convert last LES/2 section index back to LES/1 index for chtIndexer.SectionHead + chtLastSectionV1 := (chtLastSection+1)*(light.ChtFrequency/light.ChtV1Frequency) - 1 + chtSectionHead := srv.chtIndexer.SectionHead(chtLastSectionV1) + chtRoot := light.GetChtV2Root(pm.chainDb, chtLastSection, chtSectionHead) + logger.Info("CHT", "section", chtLastSection, "sectionHead", fmt.Sprintf("%064x", chtSectionHead), "root", fmt.Sprintf("%064x", chtRoot)) + } + + bloomTrieSectionCount, _, _ := srv.bloomTrieIndexer.Sections() + if bloomTrieSectionCount != 0 { + bloomTrieLastSection := bloomTrieSectionCount - 1 + bloomTrieSectionHead := srv.bloomTrieIndexer.SectionHead(bloomTrieLastSection) + bloomTrieRoot := light.GetBloomTrieRoot(pm.chainDb, bloomTrieLastSection, bloomTrieSectionHead) + logger.Info("BloomTrie", "section", bloomTrieLastSection, "sectionHead", fmt.Sprintf("%064x", bloomTrieSectionHead), "root", fmt.Sprintf("%064x", bloomTrieRoot)) } + + srv.chtIndexer.Start(eth.BlockChain()) pm.server = srv srv.defParams = &flowcontrol.ServerParams{ @@ -77,17 +109,28 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start(srvr *p2p.Server) { s.protocolManager.Start() - go func() { - logger := log.New("topic", s.lesTopic) - logger.Info("Starting topic registration") - defer logger.Info("Terminated topic registration") + for _, topic := range s.lesTopics { + topic := topic + go func() { + logger := log.New("topic", topic) + logger.Info("Starting topic registration") + defer logger.Info("Terminated topic registration") + + srvr.DiscV5.RegisterTopic(topic, s.quitSync) + }() + } + s.privateKey = srvr.PrivateKey + s.protocolManager.blockLoop() +} - srvr.DiscV5.RegisterTopic(s.lesTopic, s.quitSync) - }() +func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) { + bloomIndexer.AddChildIndexer(s.bloomTrieIndexer) } // Stop stops the LES service func (s *LesServer) Stop() { + s.chtIndexer.Close() + // bloom trie indexer is closed by parent bloombits indexer s.fcCostStats.store() s.fcManager.Stop() go func() { @@ -273,10 +316,7 @@ func (pm *ProtocolManager) blockLoop() { pm.wg.Add(1) headCh := make(chan core.ChainHeadEvent, 10) headSub := pm.blockchain.SubscribeChainHeadEvent(headCh) - newCht := make(chan struct{}, 10) - newCht <- struct{}{} go func() { - var mu sync.Mutex var lastHead *types.Header lastBroadcastTd := common.Big0 for { @@ -299,26 +339,37 @@ func (pm *ProtocolManager) blockLoop() { log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} + var ( + signed bool + signedAnnounce announceData + ) + for _, p := range peers { - select { - case p.announceChn <- announce: - default: - pm.removePeer(p.id) + switch p.announceType { + + case announceTypeSimple: + select { + case p.announceChn <- announce: + default: + pm.removePeer(p.id) + } + + case announceTypeSigned: + if !signed { + signedAnnounce = announce + signedAnnounce.sign(pm.server.privateKey) + signed = true + } + + select { + case p.announceChn <- signedAnnounce: + default: + pm.removePeer(p.id) + } } } } } - newCht <- struct{}{} - case <-newCht: - go func() { - mu.Lock() - more := makeCht(pm.chainDb) - mu.Unlock() - if more { - time.Sleep(time.Millisecond * 10) - newCht <- struct{}{} - } - }() case <-pm.quitSync: headSub.Unsubscribe() pm.wg.Done() @@ -327,86 +378,3 @@ func (pm *ProtocolManager) blockLoop() { } }() } - -var ( - lastChtKey = []byte("LastChtNumber") // chtNum (uint64 big endian) - chtPrefix = []byte("cht") // chtPrefix + chtNum (uint64 big endian) -> trie root hash -) - -func getChtRoot(db ethdb.Database, num uint64) common.Hash { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - data, _ := db.Get(append(chtPrefix, encNumber[:]...)) - return common.BytesToHash(data) -} - -func storeChtRoot(db ethdb.Database, num uint64, root common.Hash) { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - db.Put(append(chtPrefix, encNumber[:]...), root[:]) -} - -func makeCht(db ethdb.Database) bool { - headHash := core.GetHeadBlockHash(db) - headNum := core.GetBlockNumber(db, headHash) - - var newChtNum uint64 - if headNum > light.ChtConfirmations { - newChtNum = (headNum - light.ChtConfirmations) / light.ChtFrequency - } - - var lastChtNum uint64 - data, _ := db.Get(lastChtKey) - if len(data) == 8 { - lastChtNum = binary.BigEndian.Uint64(data[:]) - } - if newChtNum <= lastChtNum { - return false - } - - var t *trie.Trie - if lastChtNum > 0 { - var err error - t, err = trie.New(getChtRoot(db, lastChtNum), db) - if err != nil { - lastChtNum = 0 - } - } - if lastChtNum == 0 { - t, _ = trie.New(common.Hash{}, db) - } - - for num := lastChtNum * light.ChtFrequency; num < (lastChtNum+1)*light.ChtFrequency; num++ { - hash := core.GetCanonicalHash(db, num) - if hash == (common.Hash{}) { - panic("Canonical hash not found") - } - td := core.GetTd(db, hash, num) - if td == nil { - panic("TD not found") - } - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - var node light.ChtNode - node.Hash = hash - node.Td = td - data, _ := rlp.EncodeToBytes(node) - t.Update(encNumber[:], data) - } - - root, err := t.Commit() - if err != nil { - lastChtNum = 0 - } else { - lastChtNum++ - - log.Trace("Generated CHT", "number", lastChtNum, "root", root.Hex()) - - storeChtRoot(db, lastChtNum, root) - var data [8]byte - binary.BigEndian.PutUint64(data[:], lastChtNum) - db.Put(lastChtKey, data[:]) - } - - return newChtNum > lastChtNum -} diff --git a/les/serverpool.go b/les/serverpool.go index f4e4df2fb..dc1ea6bf0 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -145,15 +145,15 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { pool.wg.Add(1) pool.loadNodes() - go pool.eventLoop() - - pool.checkDial() if pool.server.DiscV5 != nil { pool.discSetPeriod = make(chan time.Duration, 1) pool.discNodes = make(chan *discv5.Node, 100) pool.discLookups = make(chan bool, 100) go pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, pool.discNodes, pool.discLookups) } + + go pool.eventLoop() + pool.checkDial() } // connect should be called upon any incoming connection. If the connection has been diff --git a/light/lightchain.go b/light/lightchain.go index 4c877a771..30baeaccb 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -95,15 +95,8 @@ func NewLightChain(odr OdrBackend, config *params.ChainConfig, engine consensus. if bc.genesisBlock == nil { return nil, core.ErrNoGenesis } - if bc.genesisBlock.Hash() == params.MainnetGenesisHash { - // add trusted CHT - WriteTrustedCht(bc.chainDb, TrustedCht{Number: 1040, Root: common.HexToHash("bb4fb4076cbe6923c8a8ce8f158452bbe19564959313466989fda095a60884ca")}) - log.Info("Added trusted CHT for mainnet") - } - if bc.genesisBlock.Hash() == params.TestnetGenesisHash { - // add trusted CHT - WriteTrustedCht(bc.chainDb, TrustedCht{Number: 400, Root: common.HexToHash("2a4befa19e4675d939c3dc22dca8c6ae9fcd642be1f04b06bd6e4203cc304660")}) - log.Info("Added trusted CHT for ropsten testnet") + if cp, ok := trustedCheckpoints[bc.genesisBlock.Hash()]; ok { + bc.addTrustedCheckpoint(cp) } if err := bc.loadLastState(); err != nil { @@ -120,6 +113,22 @@ func NewLightChain(odr OdrBackend, config *params.ChainConfig, engine consensus. return bc, nil } +// addTrustedCheckpoint adds a trusted checkpoint to the blockchain +func (self *LightChain) addTrustedCheckpoint(cp trustedCheckpoint) { + if self.odr.ChtIndexer() != nil { + StoreChtRoot(self.chainDb, cp.sectionIdx, cp.sectionHead, cp.chtRoot) + self.odr.ChtIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + } + if self.odr.BloomTrieIndexer() != nil { + StoreBloomTrieRoot(self.chainDb, cp.sectionIdx, cp.sectionHead, cp.bloomTrieRoot) + self.odr.BloomTrieIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + } + if self.odr.BloomIndexer() != nil { + self.odr.BloomIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + } + log.Info("Added trusted checkpoint", "chain name", cp.name) +} + func (self *LightChain) getProcInterrupt() bool { return atomic.LoadInt32(&self.procInterrupt) == 1 } @@ -449,10 +458,13 @@ func (self *LightChain) GetHeaderByNumberOdr(ctx context.Context, number uint64) } func (self *LightChain) SyncCht(ctx context.Context) bool { + if self.odr.ChtIndexer() == nil { + return false + } headNum := self.CurrentHeader().Number.Uint64() - cht := GetTrustedCht(self.chainDb) - if headNum+1 < cht.Number*ChtFrequency { - num := cht.Number*ChtFrequency - 1 + chtCount, _, _ := self.odr.ChtIndexer().Sections() + if headNum+1 < chtCount*ChtFrequency { + num := chtCount*ChtFrequency - 1 header, err := GetHeaderByNumber(ctx, self.odr, num) if header != nil && err == nil { self.mu.Lock() diff --git a/light/nodeset.go b/light/nodeset.go new file mode 100644 index 000000000..c530a4fbe --- /dev/null +++ b/light/nodeset.go @@ -0,0 +1,141 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package light + +import ( + "errors" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +// NodeSet stores a set of trie nodes. It implements trie.Database and can also +// act as a cache for another trie.Database. +type NodeSet struct { + db map[string][]byte + dataSize int + lock sync.RWMutex +} + +// NewNodeSet creates an empty node set +func NewNodeSet() *NodeSet { + return &NodeSet{ + db: make(map[string][]byte), + } +} + +// Put stores a new node in the set +func (db *NodeSet) Put(key []byte, value []byte) error { + db.lock.Lock() + defer db.lock.Unlock() + + if _, ok := db.db[string(key)]; !ok { + db.db[string(key)] = common.CopyBytes(value) + db.dataSize += len(value) + } + return nil +} + +// Get returns a stored node +func (db *NodeSet) Get(key []byte) ([]byte, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + if entry, ok := db.db[string(key)]; ok { + return entry, nil + } + return nil, errors.New("not found") +} + +// Has returns true if the node set contains the given key +func (db *NodeSet) Has(key []byte) (bool, error) { + _, err := db.Get(key) + return err == nil, nil +} + +// KeyCount returns the number of nodes in the set +func (db *NodeSet) KeyCount() int { + db.lock.RLock() + defer db.lock.RUnlock() + + return len(db.db) +} + +// DataSize returns the aggregated data size of nodes in the set +func (db *NodeSet) DataSize() int { + db.lock.RLock() + defer db.lock.RUnlock() + + return db.dataSize +} + +// NodeList converts the node set to a NodeList +func (db *NodeSet) NodeList() NodeList { + db.lock.RLock() + defer db.lock.RUnlock() + + var values NodeList + for _, value := range db.db { + values = append(values, value) + } + return values +} + +// Store writes the contents of the set to the given database +func (db *NodeSet) Store(target trie.Database) { + db.lock.RLock() + defer db.lock.RUnlock() + + for key, value := range db.db { + target.Put([]byte(key), value) + } +} + +// NodeList stores an ordered list of trie nodes. It implements trie.DatabaseWriter. +type NodeList []rlp.RawValue + +// Store writes the contents of the list to the given database +func (n NodeList) Store(db trie.Database) { + for _, node := range n { + db.Put(crypto.Keccak256(node), node) + } +} + +// NodeSet converts the node list to a NodeSet +func (n NodeList) NodeSet() *NodeSet { + db := NewNodeSet() + n.Store(db) + return db +} + +// Put stores a new node at the end of the list +func (n *NodeList) Put(key []byte, value []byte) error { + *n = append(*n, value) + return nil +} + +// DataSize returns the aggregated data size of nodes in the list +func (n NodeList) DataSize() int { + var size int + for _, node := range n { + size += len(node) + } + return size +} diff --git a/light/odr.go b/light/odr.go index d19a488f6..e2c3d9c5a 100644 --- a/light/odr.go +++ b/light/odr.go @@ -25,9 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/rlp" ) // NoOdr is the default context passed to an ODR capable function when the ODR @@ -37,6 +35,9 @@ var NoOdr = context.Background() // OdrBackend is an interface to a backend service that handles ODR retrievals type type OdrBackend interface { Database() ethdb.Database + ChtIndexer() *core.ChainIndexer + BloomTrieIndexer() *core.ChainIndexer + BloomIndexer() *core.ChainIndexer Retrieve(ctx context.Context, req OdrRequest) error } @@ -80,23 +81,12 @@ type TrieRequest struct { OdrRequest Id *TrieID Key []byte - Proof []rlp.RawValue + Proof *NodeSet } // StoreResult stores the retrieved data in local database func (req *TrieRequest) StoreResult(db ethdb.Database) { - storeProof(db, req.Proof) -} - -// storeProof stores the new trie nodes obtained from a merkle proof in the database -func storeProof(db ethdb.Database, proof []rlp.RawValue) { - for _, buf := range proof { - hash := crypto.Keccak256(buf) - val, _ := db.Get(hash) - if val == nil { - db.Put(hash, buf) - } - } + req.Proof.Store(db) } // CodeRequest is the ODR request type for retrieving contract code @@ -138,14 +128,14 @@ func (req *ReceiptsRequest) StoreResult(db ethdb.Database) { core.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts) } -// TrieRequest is the ODR request type for state/storage trie entries +// ChtRequest is the ODR request type for state/storage trie entries type ChtRequest struct { OdrRequest ChtNum, BlockNum uint64 ChtRoot common.Hash Header *types.Header Td *big.Int - Proof []rlp.RawValue + Proof *NodeSet } // StoreResult stores the retrieved data in local database @@ -155,5 +145,27 @@ func (req *ChtRequest) StoreResult(db ethdb.Database) { hash, num := req.Header.Hash(), req.Header.Number.Uint64() core.WriteTd(db, hash, num, req.Td) core.WriteCanonicalHash(db, hash, num) - //storeProof(db, req.Proof) +} + +// BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure +type BloomRequest struct { + OdrRequest + BloomTrieNum uint64 + BitIdx uint + SectionIdxList []uint64 + BloomTrieRoot common.Hash + BloomBits [][]byte + Proofs *NodeSet +} + +// StoreResult stores the retrieved data in local database +func (req *BloomRequest) StoreResult(db ethdb.Database) { + for i, sectionIdx := range req.SectionIdxList { + sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + // if we don't have the canonical hash stored for this section head number, we'll still store it under + // a key with a zero sectionHead. GetBloomBits will look there too if we still don't have the canonical + // hash. In the unlikely case we've retrieved the section head hash since then, we'll just retrieve the + // bit vector again from the network. + core.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i]) + } } diff --git a/light/odr_test.go b/light/odr_test.go index c0c5438fd..e6afb1a48 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -77,7 +77,9 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) case *TrieRequest: t, _ := trie.New(req.Id.Root, odr.sdb) - req.Proof = t.Prove(req.Key) + nodes := NewNodeSet() + t.Prove(req.Key, 0, nodes) + req.Proof = nodes case *CodeRequest: req.Data, _ = odr.sdb.Get(req.Hash[:]) } diff --git a/light/odr_util.go b/light/odr_util.go index fcdfdb82c..a0eb6303d 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -19,56 +19,16 @@ package light import ( "bytes" "context" - "errors" - "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" ) var sha3_nil = crypto.Keccak256Hash(nil) -var ( - ErrNoTrustedCht = errors.New("No trusted canonical hash trie") - ErrNoHeader = errors.New("Header not found") - - ChtFrequency = uint64(4096) - ChtConfirmations = uint64(2048) - trustedChtKey = []byte("TrustedCHT") -) - -type ChtNode struct { - Hash common.Hash - Td *big.Int -} - -type TrustedCht struct { - Number uint64 - Root common.Hash -} - -func GetTrustedCht(db ethdb.Database) TrustedCht { - data, _ := db.Get(trustedChtKey) - var res TrustedCht - if err := rlp.DecodeBytes(data, &res); err != nil { - return TrustedCht{0, common.Hash{}} - } - return res -} - -func WriteTrustedCht(db ethdb.Database, cht TrustedCht) { - data, _ := rlp.EncodeToBytes(cht) - db.Put(trustedChtKey, data) -} - -func DeleteTrustedCht(db ethdb.Database) { - db.Delete(trustedChtKey) -} - func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) { db := odr.Database() hash := core.GetCanonicalHash(db, number) @@ -81,12 +41,29 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ return header, nil } - cht := GetTrustedCht(db) - if number >= cht.Number*ChtFrequency { + var ( + chtCount, sectionHeadNum uint64 + sectionHead common.Hash + ) + if odr.ChtIndexer() != nil { + chtCount, sectionHeadNum, sectionHead = odr.ChtIndexer().Sections() + canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + // if the CHT was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too + for chtCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { + chtCount-- + if chtCount > 0 { + sectionHeadNum = chtCount*ChtFrequency - 1 + sectionHead = odr.ChtIndexer().SectionHead(chtCount - 1) + canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + } + } + } + + if number >= chtCount*ChtFrequency { return nil, ErrNoTrustedCht } - r := &ChtRequest{ChtRoot: cht.Root, ChtNum: cht.Number, BlockNum: number} + r := &ChtRequest{ChtRoot: GetChtRoot(db, chtCount-1, sectionHead), ChtNum: chtCount - 1, BlockNum: number} if err := odr.Retrieve(ctx, r); err != nil { return nil, err } else { @@ -162,3 +139,61 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num } return r.Receipts, nil } + +// GetBloomBits retrieves a batch of compressed bloomBits vectors belonging to the given bit index and section indexes +func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxList []uint64) ([][]byte, error) { + db := odr.Database() + result := make([][]byte, len(sectionIdxList)) + var ( + reqList []uint64 + reqIdx []int + ) + + var ( + bloomTrieCount, sectionHeadNum uint64 + sectionHead common.Hash + ) + if odr.BloomTrieIndexer() != nil { + bloomTrieCount, sectionHeadNum, sectionHead = odr.BloomTrieIndexer().Sections() + canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + // if the BloomTrie was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too + for bloomTrieCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { + bloomTrieCount-- + if bloomTrieCount > 0 { + sectionHeadNum = bloomTrieCount*BloomTrieFrequency - 1 + sectionHead = odr.BloomTrieIndexer().SectionHead(bloomTrieCount - 1) + canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + } + } + } + + for i, sectionIdx := range sectionIdxList { + sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + // if we don't have the canonical hash stored for this section head number, we'll still look for + // an entry with a zero sectionHead (we store it with zero section head too if we don't know it + // at the time of the retrieval) + bloomBits, err := core.GetBloomBits(db, bitIdx, sectionIdx, sectionHead) + if err == nil { + result[i] = bloomBits + } else { + if sectionIdx >= bloomTrieCount { + return nil, ErrNoTrustedBloomTrie + } + reqList = append(reqList, sectionIdx) + reqIdx = append(reqIdx, i) + } + } + if reqList == nil { + return result, nil + } + + r := &BloomRequest{BloomTrieRoot: GetBloomTrieRoot(db, bloomTrieCount-1, sectionHead), BloomTrieNum: bloomTrieCount - 1, BitIdx: bitIdx, SectionIdxList: reqList} + if err := odr.Retrieve(ctx, r); err != nil { + return nil, err + } else { + for i, idx := range reqIdx { + result[idx] = r.BloomBits[i] + } + return result, nil + } +} diff --git a/light/postprocess.go b/light/postprocess.go new file mode 100644 index 000000000..e7e513880 --- /dev/null +++ b/light/postprocess.go @@ -0,0 +1,295 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package light + +import ( + "encoding/binary" + "errors" + "fmt" + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/bitutil" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + ChtFrequency = 32768 + ChtV1Frequency = 4096 // as long as we want to retain LES/1 compatibility, servers generate CHTs with the old, higher frequency + HelperTrieConfirmations = 2048 // number of confirmations before a server is expected to have the given HelperTrie available + HelperTrieProcessConfirmations = 256 // number of confirmations before a HelperTrie is generated +) + +// trustedCheckpoint represents a set of post-processed trie roots (CHT and BloomTrie) associated with +// the appropriate section index and head hash. It is used to start light syncing from this checkpoint +// and avoid downloading the entire header chain while still being able to securely access old headers/logs. +type trustedCheckpoint struct { + name string + sectionIdx uint64 + sectionHead, chtRoot, bloomTrieRoot common.Hash +} + +var ( + mainnetCheckpoint = trustedCheckpoint{ + name: "ETH mainnet", + sectionIdx: 129, + sectionHead: common.HexToHash("64100587c8ec9a76870056d07cb0f58622552d16de6253a59cac4b580c899501"), + chtRoot: common.HexToHash("bb4fb4076cbe6923c8a8ce8f158452bbe19564959313466989fda095a60884ca"), + bloomTrieRoot: common.HexToHash("0db524b2c4a2a9520a42fd842b02d2e8fb58ff37c75cf57bd0eb82daeace6716"), + } + + ropstenCheckpoint = trustedCheckpoint{ + name: "Ropsten testnet", + sectionIdx: 50, + sectionHead: common.HexToHash("00bd65923a1aa67f85e6b4ae67835784dd54be165c37f056691723c55bf016bd"), + chtRoot: common.HexToHash("6f56dc61936752cc1f8c84b4addabdbe6a1c19693de3f21cb818362df2117f03"), + bloomTrieRoot: common.HexToHash("aca7d7c504d22737242effc3fdc604a762a0af9ced898036b5986c3a15220208"), + } +) + +// trustedCheckpoints associates each known checkpoint with the genesis hash of the chain it belongs to +var trustedCheckpoints = map[common.Hash]trustedCheckpoint{ + params.MainnetGenesisHash: mainnetCheckpoint, + params.TestnetGenesisHash: ropstenCheckpoint, +} + +var ( + ErrNoTrustedCht = errors.New("No trusted canonical hash trie") + ErrNoTrustedBloomTrie = errors.New("No trusted bloom trie") + ErrNoHeader = errors.New("Header not found") + chtPrefix = []byte("chtRoot-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash + ChtTablePrefix = "cht-" +) + +// ChtNode structures are stored in the Canonical Hash Trie in an RLP encoded format +type ChtNode struct { + Hash common.Hash + Td *big.Int +} + +// GetChtRoot reads the CHT root assoctiated to the given section from the database +// Note that sectionIdx is specified according to LES/1 CHT section size +func GetChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead common.Hash) common.Hash { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + data, _ := db.Get(append(append(chtPrefix, encNumber[:]...), sectionHead.Bytes()...)) + return common.BytesToHash(data) +} + +// GetChtV2Root reads the CHT root assoctiated to the given section from the database +// Note that sectionIdx is specified according to LES/2 CHT section size +func GetChtV2Root(db ethdb.Database, sectionIdx uint64, sectionHead common.Hash) common.Hash { + return GetChtRoot(db, (sectionIdx+1)*(ChtFrequency/ChtV1Frequency)-1, sectionHead) +} + +// StoreChtRoot writes the CHT root assoctiated to the given section into the database +// Note that sectionIdx is specified according to LES/1 CHT section size +func StoreChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common.Hash) { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + db.Put(append(append(chtPrefix, encNumber[:]...), sectionHead.Bytes()...), root.Bytes()) +} + +// ChtIndexerBackend implements core.ChainIndexerBackend +type ChtIndexerBackend struct { + db, cdb ethdb.Database + section, sectionSize uint64 + lastHash common.Hash + trie *trie.Trie +} + +// NewBloomTrieIndexer creates a BloomTrie chain indexer +func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { + cdb := ethdb.NewTable(db, ChtTablePrefix) + idb := ethdb.NewTable(db, "chtIndex-") + var sectionSize, confirmReq uint64 + if clientMode { + sectionSize = ChtFrequency + confirmReq = HelperTrieConfirmations + } else { + sectionSize = ChtV1Frequency + confirmReq = HelperTrieProcessConfirmations + } + return core.NewChainIndexer(db, idb, &ChtIndexerBackend{db: db, cdb: cdb, sectionSize: sectionSize}, sectionSize, confirmReq, time.Millisecond*100, "cht") +} + +// Reset implements core.ChainIndexerBackend +func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { + var root common.Hash + if section > 0 { + root = GetChtRoot(c.db, section-1, lastSectionHead) + } + var err error + c.trie, err = trie.New(root, c.cdb) + c.section = section + return err +} + +// Process implements core.ChainIndexerBackend +func (c *ChtIndexerBackend) Process(header *types.Header) { + hash, num := header.Hash(), header.Number.Uint64() + c.lastHash = hash + + td := core.GetTd(c.db, hash, num) + if td == nil { + panic(nil) + } + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], num) + data, _ := rlp.EncodeToBytes(ChtNode{hash, td}) + c.trie.Update(encNumber[:], data) +} + +// Commit implements core.ChainIndexerBackend +func (c *ChtIndexerBackend) Commit() error { + batch := c.cdb.NewBatch() + root, err := c.trie.CommitTo(batch) + if err != nil { + return err + } else { + batch.Write() + if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 { + log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root)) + } + StoreChtRoot(c.db, c.section, c.lastHash, root) + } + return nil +} + +const ( + BloomTrieFrequency = 32768 + ethBloomBitsSection = 4096 + ethBloomBitsConfirmations = 256 +) + +var ( + bloomTriePrefix = []byte("bltRoot-") // bloomTriePrefix + bloomTrieNum (uint64 big endian) -> trie root hash + BloomTrieTablePrefix = "blt-" +) + +// GetBloomTrieRoot reads the BloomTrie root assoctiated to the given section from the database +func GetBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead common.Hash) common.Hash { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + data, _ := db.Get(append(append(bloomTriePrefix, encNumber[:]...), sectionHead.Bytes()...)) + return common.BytesToHash(data) +} + +// StoreBloomTrieRoot writes the BloomTrie root assoctiated to the given section into the database +func StoreBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common.Hash) { + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], sectionIdx) + db.Put(append(append(bloomTriePrefix, encNumber[:]...), sectionHead.Bytes()...), root.Bytes()) +} + +// BloomTrieIndexerBackend implements core.ChainIndexerBackend +type BloomTrieIndexerBackend struct { + db, cdb ethdb.Database + section, parentSectionSize, bloomTrieRatio uint64 + trie *trie.Trie + sectionHeads []common.Hash +} + +// NewBloomTrieIndexer creates a BloomTrie chain indexer +func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { + cdb := ethdb.NewTable(db, BloomTrieTablePrefix) + idb := ethdb.NewTable(db, "bltIndex-") + backend := &BloomTrieIndexerBackend{db: db, cdb: cdb} + var confirmReq uint64 + if clientMode { + backend.parentSectionSize = BloomTrieFrequency + confirmReq = HelperTrieConfirmations + } else { + backend.parentSectionSize = ethBloomBitsSection + confirmReq = HelperTrieProcessConfirmations + } + backend.bloomTrieRatio = BloomTrieFrequency / backend.parentSectionSize + backend.sectionHeads = make([]common.Hash, backend.bloomTrieRatio) + return core.NewChainIndexer(db, idb, backend, BloomTrieFrequency, confirmReq-ethBloomBitsConfirmations, time.Millisecond*100, "bloomtrie") +} + +// Reset implements core.ChainIndexerBackend +func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { + var root common.Hash + if section > 0 { + root = GetBloomTrieRoot(b.db, section-1, lastSectionHead) + } + var err error + b.trie, err = trie.New(root, b.cdb) + b.section = section + return err +} + +// Process implements core.ChainIndexerBackend +func (b *BloomTrieIndexerBackend) Process(header *types.Header) { + num := header.Number.Uint64() - b.section*BloomTrieFrequency + if (num+1)%b.parentSectionSize == 0 { + b.sectionHeads[num/b.parentSectionSize] = header.Hash() + } +} + +// Commit implements core.ChainIndexerBackend +func (b *BloomTrieIndexerBackend) Commit() error { + var compSize, decompSize uint64 + + for i := uint(0); i < types.BloomBitLength; i++ { + var encKey [10]byte + binary.BigEndian.PutUint16(encKey[0:2], uint16(i)) + binary.BigEndian.PutUint64(encKey[2:10], b.section) + var decomp []byte + for j := uint64(0); j < b.bloomTrieRatio; j++ { + data, err := core.GetBloomBits(b.db, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) + if err != nil { + return err + } + decompData, err2 := bitutil.DecompressBytes(data, int(b.parentSectionSize/8)) + if err2 != nil { + return err2 + } + decomp = append(decomp, decompData...) + } + comp := bitutil.CompressBytes(decomp) + + decompSize += uint64(len(decomp)) + compSize += uint64(len(comp)) + if len(comp) > 0 { + b.trie.Update(encKey[:], comp) + } else { + b.trie.Delete(encKey[:]) + } + } + + batch := b.cdb.NewBatch() + root, err := b.trie.CommitTo(batch) + if err != nil { + return err + } else { + batch.Write() + sectionHead := b.sectionHeads[b.bloomTrieRatio-1] + log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize)) + StoreBloomTrieRoot(b.db, b.section, sectionHead, root) + } + + return nil +} diff --git a/miner/worker.go b/miner/worker.go index bf24970f5..c1f848e32 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -269,6 +269,11 @@ func (self *worker) update() { self.current.commitTransactions(self.mux, txset, self.chain, self.coinbase) self.currentMu.Unlock() + } else { + // If we're mining, but nothing is being processed, wake on new transactions + if self.config.Clique != nil && self.config.Clique.Period == 0 { + self.commitNewWork() + } } // System stopped diff --git a/mobile/ethclient.go b/mobile/ethclient.go index f56f72f1b..6bc5ff5d1 100644 --- a/mobile/ethclient.go +++ b/mobile/ethclient.go @@ -198,8 +198,8 @@ func (ec *EthereumClient) FilterLogs(ctx *Context, query *FilterQuery) (logs *Lo } // Temp hack due to vm.Logs being []*vm.Log res := make([]*types.Log, len(rawLogs)) - for i, log := range rawLogs { - res[i] = &log + for i := range rawLogs { + res[i] = &rawLogs[i] } return &Logs{res}, nil } diff --git a/params/config.go b/params/config.go index e2732700d..730acb30f 100644 --- a/params/config.go +++ b/params/config.go @@ -83,16 +83,21 @@ var ( }, } - // AllProtocolChanges contains every protocol change (EIPs) - // introduced and accepted by the Ethereum core developers. + // AllEthashProtocolChanges contains every protocol change (EIPs) introduced + // and accepted by the Ethereum core developers into the Ethash consensus. // - // This configuration is intentionally not using keyed fields. - // This configuration must *always* have all forks enabled, which - // means that all fields must be set at all times. This forces - // anyone adding flags to the config to also have to set these - // fields. - AllProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil} - TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil} + // This configuration is intentionally not using keyed fields to force anyone + // adding flags to the config to also have to set these fields. + AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil} + + // AllCliqueProtocolChanges contains every protocol change (EIPs) introduced + // and accepted by the Ethereum core developers into the Clique consensus. + // + // This configuration is intentionally not using keyed fields to force anyone + // adding flags to the config to also have to set these fields. + AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, &CliqueConfig{Period: 0, Epoch: 30000}} + + TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil} TestRules = TestChainConfig.Rules(new(big.Int)) ) diff --git a/params/config_test.go b/params/config_test.go index 487dc380c..02c5fe291 100644 --- a/params/config_test.go +++ b/params/config_test.go @@ -29,8 +29,8 @@ func TestCheckCompatible(t *testing.T) { wantErr *ConfigCompatError } tests := []test{ - {stored: AllProtocolChanges, new: AllProtocolChanges, head: 0, wantErr: nil}, - {stored: AllProtocolChanges, new: AllProtocolChanges, head: 100, wantErr: nil}, + {stored: AllEthashProtocolChanges, new: AllEthashProtocolChanges, head: 0, wantErr: nil}, + {stored: AllEthashProtocolChanges, new: AllEthashProtocolChanges, head: 100, wantErr: nil}, { stored: &ChainConfig{EIP150Block: big.NewInt(10)}, new: &ChainConfig{EIP150Block: big.NewInt(20)}, @@ -38,7 +38,7 @@ func TestCheckCompatible(t *testing.T) { wantErr: nil, }, { - stored: AllProtocolChanges, + stored: AllEthashProtocolChanges, new: &ChainConfig{HomesteadBlock: nil}, head: 3, wantErr: &ConfigCompatError{ @@ -49,7 +49,7 @@ func TestCheckCompatible(t *testing.T) { }, }, { - stored: AllProtocolChanges, + stored: AllEthashProtocolChanges, new: &ChainConfig{HomesteadBlock: big.NewInt(1)}, head: 3, wantErr: &ConfigCompatError{ diff --git a/rpc/http.go b/rpc/http.go index 4143e2a8d..3f572b34c 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "io/ioutil" + "mime" "net" "net/http" "sync" @@ -151,6 +152,16 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.StatusRequestEntityTooLarge) return } + + ct := r.Header.Get("content-type") + mt, _, err := mime.ParseMediaType(ct) + if err != nil || mt != "application/json" { + http.Error(w, + "invalid content type, only application/json is supported", + http.StatusUnsupportedMediaType) + return + } + w.Header().Set("content-type", "application/json") // create a codec that reads direct from the request body until diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index 39f759692..0ba177e63 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -290,7 +290,7 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) { for { done := true - for id, _ := range count { + for id := range count { if count, found := count[id]; !found || count < (2*n) { done = false } diff --git a/swarm/api/client/client_test.go b/swarm/api/client/client_test.go index edf385dd0..c1d144e37 100644 --- a/swarm/api/client/client_test.go +++ b/swarm/api/client/client_test.go @@ -244,25 +244,25 @@ func TestClientFileList(t *testing.T) { } tests := map[string][]string{ - "": []string{"dir1/", "dir2/", "file1.txt", "file2.txt"}, - "file": []string{"file1.txt", "file2.txt"}, - "file1": []string{"file1.txt"}, - "file2.txt": []string{"file2.txt"}, - "file12": []string{}, - "dir": []string{"dir1/", "dir2/"}, - "dir1": []string{"dir1/"}, - "dir1/": []string{"dir1/file3.txt", "dir1/file4.txt"}, - "dir1/file": []string{"dir1/file3.txt", "dir1/file4.txt"}, - "dir1/file3.txt": []string{"dir1/file3.txt"}, - "dir1/file34": []string{}, - "dir2/": []string{"dir2/dir3/", "dir2/dir4/", "dir2/file5.txt"}, - "dir2/file": []string{"dir2/file5.txt"}, - "dir2/dir": []string{"dir2/dir3/", "dir2/dir4/"}, - "dir2/dir3/": []string{"dir2/dir3/file6.txt"}, - "dir2/dir4/": []string{"dir2/dir4/file7.txt", "dir2/dir4/file8.txt"}, - "dir2/dir4/file": []string{"dir2/dir4/file7.txt", "dir2/dir4/file8.txt"}, - "dir2/dir4/file7.txt": []string{"dir2/dir4/file7.txt"}, - "dir2/dir4/file78": []string{}, + "": {"dir1/", "dir2/", "file1.txt", "file2.txt"}, + "file": {"file1.txt", "file2.txt"}, + "file1": {"file1.txt"}, + "file2.txt": {"file2.txt"}, + "file12": {}, + "dir": {"dir1/", "dir2/"}, + "dir1": {"dir1/"}, + "dir1/": {"dir1/file3.txt", "dir1/file4.txt"}, + "dir1/file": {"dir1/file3.txt", "dir1/file4.txt"}, + "dir1/file3.txt": {"dir1/file3.txt"}, + "dir1/file34": {}, + "dir2/": {"dir2/dir3/", "dir2/dir4/", "dir2/file5.txt"}, + "dir2/file": {"dir2/file5.txt"}, + "dir2/dir": {"dir2/dir3/", "dir2/dir4/"}, + "dir2/dir3/": {"dir2/dir3/file6.txt"}, + "dir2/dir4/": {"dir2/dir4/file7.txt", "dir2/dir4/file8.txt"}, + "dir2/dir4/file": {"dir2/dir4/file7.txt", "dir2/dir4/file8.txt"}, + "dir2/dir4/file7.txt": {"dir2/dir4/file7.txt"}, + "dir2/dir4/file78": {}, } for prefix, expected := range tests { actual := ls(prefix) diff --git a/swarm/api/http/error.go b/swarm/api/http/error.go index b4d46b3c4..dbd97182f 100644 --- a/swarm/api/http/error.go +++ b/swarm/api/http/error.go @@ -72,7 +72,7 @@ func initErrHandling() { //ShowMultipeChoices is used when a user requests a resource in a manifest which results //in ambiguous results. It returns a HTML page with clickable links of each of the entry //in the manifest which fits the request URI ambiguity. -//For example, if the user requests bzz://read and that manifest containes entries +//For example, if the user requests bzz://read and that manifest contains entries //"readme.md" and "readinglist.txt", a HTML page is returned with this two links. //This only applies if the manifest has no default entry func ShowMultipleChoices(w http.ResponseWriter, r *http.Request, list api.ManifestList) { diff --git a/swarm/api/http/error_test.go b/swarm/api/http/error_test.go index 465fbf4ca..ed52bafbd 100644 --- a/swarm/api/http/error_test.go +++ b/swarm/api/http/error_test.go @@ -132,7 +132,7 @@ func TestJsonResponse(t *testing.T) { } if !isJSON(string(respbody)) { - t.Fatalf("Expected repsonse to be JSON, received invalid JSON: %s", string(respbody)) + t.Fatalf("Expected response to be JSON, received invalid JSON: %s", string(respbody)) } } diff --git a/swarm/storage/chunker.go b/swarm/storage/chunker.go index 0454828b9..8c0d62cbe 100644 --- a/swarm/storage/chunker.go +++ b/swarm/storage/chunker.go @@ -50,7 +50,6 @@ data_{i} := size(subtree_{i}) || key_{j} || key_{j+1} .... || key_{j+n-1} The underlying hash function is configurable */ - /* Tree chunker is a concrete implementation of data chunking. This chunker works in a simple way, it builds a tree out of the document so that each node either represents a chunk of real data or a chunk of data representing an branching non-leaf node of the tree. In particular each such non-leaf chunk will represent is a concatenation of the hash of its respective children. This scheme simultaneously guarantees data integrity as well as self addressing. Abstract nodes are transparent since their represented size component is strictly greater than their maximum data size, since they encode a subtree. @@ -61,17 +60,17 @@ The hashing itself does use extra copies and allocation though, since it does ne var ( errAppendOppNotSuported = errors.New("Append operation not supported") - errOperationTimedOut = errors.New("operation timed out") + errOperationTimedOut = errors.New("operation timed out") ) type TreeChunker struct { branches int64 hashFunc SwarmHasher // calculated - hashSize int64 // self.hashFunc.New().Size() - chunkSize int64 // hashSize* branches - workerCount int64 // the number of worker routines used - workerLock sync.RWMutex // lock for the worker count + hashSize int64 // self.hashFunc.New().Size() + chunkSize int64 // hashSize* branches + workerCount int64 // the number of worker routines used + workerLock sync.RWMutex // lock for the worker count } func NewTreeChunker(params *ChunkerParams) (self *TreeChunker) { @@ -124,7 +123,6 @@ func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, s panic("chunker must be initialised") } - jobC := make(chan *hashJob, 2*ChunkProcessors) wg := &sync.WaitGroup{} errC := make(chan error) @@ -164,7 +162,6 @@ func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, s close(errC) }() - defer close(quitC) select { case err := <-errC: @@ -172,7 +169,7 @@ func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, s return nil, err } case <-time.NewTimer(splitTimeout).C: - return nil,errOperationTimedOut + return nil, errOperationTimedOut } return key, nil diff --git a/swarm/storage/pyramid.go b/swarm/storage/pyramid.go index e3be2a987..42b83583d 100644 --- a/swarm/storage/pyramid.go +++ b/swarm/storage/pyramid.go @@ -27,7 +27,7 @@ import ( /* The main idea of a pyramid chunker is to process the input data without knowing the entire size apriori. For this to be achieved, the chunker tree is built from the ground up until the data is exhausted. - This opens up new aveneus such as easy append and other sort of modifications to the tree therby avoiding + This opens up new aveneus such as easy append and other sort of modifications to the tree thereby avoiding duplication of data chunks. @@ -123,7 +123,7 @@ type PyramidChunker struct { hashSize int64 branches int64 workerCount int64 - workerLock sync.RWMutex + workerLock sync.RWMutex } func NewPyramidChunker(params *ChunkerParams) (self *PyramidChunker) { @@ -451,7 +451,7 @@ func (self *PyramidChunker) prepareChunks(isAppend bool, chunkLevel [][]*TreeEnt } } - // Data ended in chunk boundry.. just signal to start bulding tree + // Data ended in chunk boundary.. just signal to start bulding tree if n == 0 { self.buildTree(isAppend, chunkLevel, parent, chunkWG, jobC, quitC, true, rootKey) break @@ -634,4 +634,4 @@ func (self *PyramidChunker) enqueueDataChunk(chunkData []byte, size uint64, pare return pkey -} \ No newline at end of file +} diff --git a/tests/init.go b/tests/init.go index a2c633ad6..9e884efe3 100644 --- a/tests/init.go +++ b/tests/init.go @@ -25,26 +25,26 @@ import ( // This table defines supported forks and their chain config. var Forks = map[string]*params.ChainConfig{ - "Frontier": ¶ms.ChainConfig{ + "Frontier": { ChainId: big.NewInt(1), }, - "Homestead": ¶ms.ChainConfig{ + "Homestead": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(0), }, - "EIP150": ¶ms.ChainConfig{ + "EIP150": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(0), EIP150Block: big.NewInt(0), }, - "EIP158": ¶ms.ChainConfig{ + "EIP158": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(0), EIP150Block: big.NewInt(0), EIP155Block: big.NewInt(0), EIP158Block: big.NewInt(0), }, - "Byzantium": ¶ms.ChainConfig{ + "Byzantium": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(0), EIP150Block: big.NewInt(0), @@ -53,22 +53,22 @@ var Forks = map[string]*params.ChainConfig{ DAOForkBlock: big.NewInt(0), ByzantiumBlock: big.NewInt(0), }, - "FrontierToHomesteadAt5": ¶ms.ChainConfig{ + "FrontierToHomesteadAt5": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(5), }, - "HomesteadToEIP150At5": ¶ms.ChainConfig{ + "HomesteadToEIP150At5": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(0), EIP150Block: big.NewInt(5), }, - "HomesteadToDaoAt5": ¶ms.ChainConfig{ + "HomesteadToDaoAt5": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(0), DAOForkBlock: big.NewInt(5), DAOForkSupport: true, }, - "EIP158ToByzantiumAt5": ¶ms.ChainConfig{ + "EIP158ToByzantiumAt5": { ChainId: big.NewInt(1), HomesteadBlock: big.NewInt(0), EIP150Block: big.NewInt(0), diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 64bf09cb4..352f840d9 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -112,7 +112,7 @@ type stTransactionMarshaling struct { func (t *StateTest) Subtests() []StateSubtest { var sub []StateSubtest for fork, pss := range t.json.Post { - for i, _ := range pss { + for i := range pss { sub = append(sub, StateSubtest{fork, i}) } } diff --git a/trie/proof.go b/trie/proof.go index 298f648c4..5e886a259 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -18,11 +18,10 @@ package trie import ( "bytes" - "errors" "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) @@ -36,7 +35,7 @@ import ( // contains all nodes of the longest existing prefix of the key // (at least the root node), ending with the node that proves the // absence of the key. -func (t *Trie) Prove(key []byte) []rlp.RawValue { +func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { // Collect all nodes on the path to key. key = keybytesToHex(key) nodes := []node{} @@ -61,67 +60,63 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { tn, err = t.resolveHash(n, nil) if err != nil { log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - return nil + return err } default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } hasher := newHasher(0, 0) - proof := make([]rlp.RawValue, 0, len(nodes)) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. n, _, _ = hasher.hashChildren(n, nil) hn, _ := hasher.store(n, nil, false) - if _, ok := hn.(hashNode); ok || i == 0 { + if hash, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. - enc, _ := rlp.EncodeToBytes(n) - proof = append(proof, enc) + if fromLevel > 0 { + fromLevel-- + } else { + enc, _ := rlp.EncodeToBytes(n) + if !ok { + hash = crypto.Keccak256(enc) + } + proofDb.Put(hash, enc) + } } } - return proof + return nil } // VerifyProof checks merkle proofs. The given proof must contain the // value for key in a trie with the given root hash. VerifyProof // returns an error if the proof contains invalid trie nodes or the // wrong value. -func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { +func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) { key = keybytesToHex(key) - sha := sha3.NewKeccak256() - wantHash := rootHash.Bytes() - for i, buf := range proof { - sha.Reset() - sha.Write(buf) - if !bytes.Equal(sha.Sum(nil), wantHash) { - return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) + wantHash := rootHash[:] + for i := 0; ; i++ { + buf, _ := proofDb.Get(wantHash) + if buf == nil { + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i } n, err := decodeNode(wantHash, buf, 0) if err != nil { - return nil, fmt.Errorf("bad proof node %d: %v", i, err) + return nil, fmt.Errorf("bad proof node %d: %v", i, err), i } keyrest, cld := get(n, key) switch cld := cld.(type) { case nil: - if i != len(proof)-1 { - return nil, fmt.Errorf("key mismatch at proof node %d", i) - } else { - // The trie doesn't contain the key. - return nil, nil - } + // The trie doesn't contain the key. + return nil, nil, i case hashNode: key = keyrest wantHash = cld case valueNode: - if i != len(proof)-1 { - return nil, errors.New("additional nodes at end of proof") - } - return cld, nil + return cld, nil, i + 1 } } - return nil, errors.New("unexpected end of proof") } func get(tn node, key []byte) ([]byte, node) { diff --git a/trie/proof_test.go b/trie/proof_test.go index 91ebcd4a5..fff313d7f 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -24,7 +24,8 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" ) func init() { @@ -35,13 +36,13 @@ func TestProof(t *testing.T) { trie, vals := randomTrie(500) root := trie.Hash() for _, kv := range vals { - proof := trie.Prove(kv.k) - if proof == nil { + proofs, _ := ethdb.NewMemDatabase() + if trie.Prove(kv.k, 0, proofs) != nil { t.Fatalf("missing key %x while constructing proof", kv.k) } - val, err := VerifyProof(root, kv.k, proof) + val, err, _ := VerifyProof(root, kv.k, proofs) if err != nil { - t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof) + t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs) } if !bytes.Equal(val, kv.v) { t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v) @@ -52,16 +53,14 @@ func TestProof(t *testing.T) { func TestOneElementProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") - proof := trie.Prove([]byte("k")) - if proof == nil { - t.Fatal("nil proof") - } - if len(proof) != 1 { + proofs, _ := ethdb.NewMemDatabase() + trie.Prove([]byte("k"), 0, proofs) + if len(proofs.Keys()) != 1 { t.Error("proof should have one element") } - val, err := VerifyProof(trie.Hash(), []byte("k"), proof) + val, err, _ := VerifyProof(trie.Hash(), []byte("k"), proofs) if err != nil { - t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof) + t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys()) } if !bytes.Equal(val, []byte("v")) { t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val) @@ -72,12 +71,18 @@ func TestVerifyBadProof(t *testing.T) { trie, vals := randomTrie(800) root := trie.Hash() for _, kv := range vals { - proof := trie.Prove(kv.k) - if proof == nil { - t.Fatal("nil proof") + proofs, _ := ethdb.NewMemDatabase() + trie.Prove(kv.k, 0, proofs) + if len(proofs.Keys()) == 0 { + t.Fatal("zero length proof") } - mutateByte(proof[mrand.Intn(len(proof))]) - if _, err := VerifyProof(root, kv.k, proof); err == nil { + keys := proofs.Keys() + key := keys[mrand.Intn(len(keys))] + node, _ := proofs.Get(key) + proofs.Delete(key) + mutateByte(node) + proofs.Put(crypto.Keccak256(node), node) + if _, err, _ := VerifyProof(root, kv.k, proofs); err == nil { t.Fatalf("expected proof to fail for key %x", kv.k) } } @@ -104,8 +109,9 @@ func BenchmarkProve(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { kv := vals[keys[i%len(keys)]] - if trie.Prove(kv.k) == nil { - b.Fatalf("nil proof for %x", kv.k) + proofs, _ := ethdb.NewMemDatabase() + if trie.Prove(kv.k, 0, proofs); len(proofs.Keys()) == 0 { + b.Fatalf("zero length proof for %x", kv.k) } } } @@ -114,16 +120,18 @@ func BenchmarkVerifyProof(b *testing.B) { trie, vals := randomTrie(100) root := trie.Hash() var keys []string - var proofs [][]rlp.RawValue + var proofs []*ethdb.MemDatabase for k := range vals { keys = append(keys, k) - proofs = append(proofs, trie.Prove([]byte(k))) + proof, _ := ethdb.NewMemDatabase() + trie.Prove([]byte(k), 0, proof) + proofs = append(proofs, proof) } b.ResetTimer() for i := 0; i < b.N; i++ { im := i % len(keys) - if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { + if _, err, _ := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { b.Fatalf("key %x: %v", keys[im], err) } } diff --git a/whisper/whisperv6/api.go b/whisper/whisperv6/api.go new file mode 100644 index 000000000..3dddb6953 --- /dev/null +++ b/whisper/whisperv6/api.go @@ -0,0 +1,591 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "context" + "crypto/ecdsa" + "errors" + "fmt" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rpc" +) + +const ( + filterTimeout = 300 // filters are considered timeout out after filterTimeout seconds +) + +var ( + ErrSymAsym = errors.New("specify either a symmetric or an asymmetric key") + ErrInvalidSymmetricKey = errors.New("invalid symmetric key") + ErrInvalidPublicKey = errors.New("invalid public key") + ErrInvalidSigningPubKey = errors.New("invalid signing public key") + ErrTooLowPoW = errors.New("message rejected, PoW too low") + ErrNoTopics = errors.New("missing topic(s)") +) + +// PublicWhisperAPI provides the whisper RPC service that can be +// use publicly without security implications. +type PublicWhisperAPI struct { + w *Whisper + + mu sync.Mutex + lastUsed map[string]time.Time // keeps track when a filter was polled for the last time. +} + +// NewPublicWhisperAPI create a new RPC whisper service. +func NewPublicWhisperAPI(w *Whisper) *PublicWhisperAPI { + api := &PublicWhisperAPI{ + w: w, + lastUsed: make(map[string]time.Time), + } + + go api.run() + return api +} + +// run the api event loop. +// this loop deletes filter that have not been used within filterTimeout +func (api *PublicWhisperAPI) run() { + timeout := time.NewTicker(2 * time.Minute) + for { + <-timeout.C + + api.mu.Lock() + for id, lastUsed := range api.lastUsed { + if time.Since(lastUsed).Seconds() >= filterTimeout { + delete(api.lastUsed, id) + if err := api.w.Unsubscribe(id); err != nil { + log.Error("could not unsubscribe whisper filter", "error", err) + } + log.Debug("delete whisper filter (timeout)", "id", id) + } + } + api.mu.Unlock() + } +} + +// Version returns the Whisper sub-protocol version. +func (api *PublicWhisperAPI) Version(ctx context.Context) string { + return ProtocolVersionStr +} + +// Info contains diagnostic information. +type Info struct { + Memory int `json:"memory"` // Memory size of the floating messages in bytes. + Messages int `json:"messages"` // Number of floating messages. + MinPow float64 `json:"minPow"` // Minimal accepted PoW + MaxMessageSize uint32 `json:"maxMessageSize"` // Maximum accepted message size +} + +// Info returns diagnostic information about the whisper node. +func (api *PublicWhisperAPI) Info(ctx context.Context) Info { + stats := api.w.Stats() + return Info{ + Memory: stats.memoryUsed, + Messages: len(api.w.messageQueue) + len(api.w.p2pMsgQueue), + MinPow: api.w.MinPow(), + MaxMessageSize: api.w.MaxMessageSize(), + } +} + +// SetMaxMessageSize sets the maximum message size that is accepted. +// Upper limit is defined by MaxMessageSize. +func (api *PublicWhisperAPI) SetMaxMessageSize(ctx context.Context, size uint32) (bool, error) { + return true, api.w.SetMaxMessageSize(size) +} + +// SetMinPow sets the minimum PoW for a message before it is accepted. +func (api *PublicWhisperAPI) SetMinPoW(ctx context.Context, pow float64) (bool, error) { + return true, api.w.SetMinimumPoW(pow) +} + +// MarkTrustedPeer marks a peer trusted. , which will allow it to send historic (expired) messages. +// Note: This function is not adding new nodes, the node needs to exists as a peer. +func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, enode string) (bool, error) { + n, err := discover.ParseNode(enode) + if err != nil { + return false, err + } + return true, api.w.AllowP2PMessagesFromPeer(n.ID[:]) +} + +// NewKeyPair generates a new public and private key pair for message decryption and encryption. +// It returns an ID that can be used to refer to the keypair. +func (api *PublicWhisperAPI) NewKeyPair(ctx context.Context) (string, error) { + return api.w.NewKeyPair() +} + +// AddPrivateKey imports the given private key. +func (api *PublicWhisperAPI) AddPrivateKey(ctx context.Context, privateKey hexutil.Bytes) (string, error) { + key, err := crypto.ToECDSA(privateKey) + if err != nil { + return "", err + } + return api.w.AddKeyPair(key) +} + +// DeleteKeyPair removes the key with the given key if it exists. +func (api *PublicWhisperAPI) DeleteKeyPair(ctx context.Context, key string) (bool, error) { + if ok := api.w.DeleteKeyPair(key); ok { + return true, nil + } + return false, fmt.Errorf("key pair %s not found", key) +} + +// HasKeyPair returns an indication if the node has a key pair that is associated with the given id. +func (api *PublicWhisperAPI) HasKeyPair(ctx context.Context, id string) bool { + return api.w.HasKeyPair(id) +} + +// GetPublicKey returns the public key associated with the given key. The key is the hex +// encoded representation of a key in the form specified in section 4.3.6 of ANSI X9.62. +func (api *PublicWhisperAPI) GetPublicKey(ctx context.Context, id string) (hexutil.Bytes, error) { + key, err := api.w.GetPrivateKey(id) + if err != nil { + return hexutil.Bytes{}, err + } + return crypto.FromECDSAPub(&key.PublicKey), nil +} + +// GetPublicKey returns the private key associated with the given key. The key is the hex +// encoded representation of a key in the form specified in section 4.3.6 of ANSI X9.62. +func (api *PublicWhisperAPI) GetPrivateKey(ctx context.Context, id string) (hexutil.Bytes, error) { + key, err := api.w.GetPrivateKey(id) + if err != nil { + return hexutil.Bytes{}, err + } + return crypto.FromECDSA(key), nil +} + +// NewSymKey generate a random symmetric key. +// It returns an ID that can be used to refer to the key. +// Can be used encrypting and decrypting messages where the key is known to both parties. +func (api *PublicWhisperAPI) NewSymKey(ctx context.Context) (string, error) { + return api.w.GenerateSymKey() +} + +// AddSymKey import a symmetric key. +// It returns an ID that can be used to refer to the key. +// Can be used encrypting and decrypting messages where the key is known to both parties. +func (api *PublicWhisperAPI) AddSymKey(ctx context.Context, key hexutil.Bytes) (string, error) { + return api.w.AddSymKeyDirect([]byte(key)) +} + +// GenerateSymKeyFromPassword derive a key from the given password, stores it, and returns its ID. +func (api *PublicWhisperAPI) GenerateSymKeyFromPassword(ctx context.Context, passwd string) (string, error) { + return api.w.AddSymKeyFromPassword(passwd) +} + +// HasSymKey returns an indication if the node has a symmetric key associated with the given key. +func (api *PublicWhisperAPI) HasSymKey(ctx context.Context, id string) bool { + return api.w.HasSymKey(id) +} + +// GetSymKey returns the symmetric key associated with the given id. +func (api *PublicWhisperAPI) GetSymKey(ctx context.Context, id string) (hexutil.Bytes, error) { + return api.w.GetSymKey(id) +} + +// DeleteSymKey deletes the symmetric key that is associated with the given id. +func (api *PublicWhisperAPI) DeleteSymKey(ctx context.Context, id string) bool { + return api.w.DeleteSymKey(id) +} + +//go:generate gencodec -type NewMessage -field-override newMessageOverride -out gen_newmessage_json.go + +// NewMessage represents a new whisper message that is posted through the RPC. +type NewMessage struct { + SymKeyID string `json:"symKeyID"` + PublicKey []byte `json:"pubKey"` + Sig string `json:"sig"` + TTL uint32 `json:"ttl"` + Topic TopicType `json:"topic"` + Payload []byte `json:"payload"` + Padding []byte `json:"padding"` + PowTime uint32 `json:"powTime"` + PowTarget float64 `json:"powTarget"` + TargetPeer string `json:"targetPeer"` +} + +type newMessageOverride struct { + PublicKey hexutil.Bytes + Payload hexutil.Bytes + Padding hexutil.Bytes +} + +// Post a message on the Whisper network. +func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, error) { + var ( + symKeyGiven = len(req.SymKeyID) > 0 + pubKeyGiven = len(req.PublicKey) > 0 + err error + ) + + // user must specify either a symmetric or an asymmetric key + if (symKeyGiven && pubKeyGiven) || (!symKeyGiven && !pubKeyGiven) { + return false, ErrSymAsym + } + + params := &MessageParams{ + TTL: req.TTL, + Payload: req.Payload, + Padding: req.Padding, + WorkTime: req.PowTime, + PoW: req.PowTarget, + Topic: req.Topic, + } + + // Set key that is used to sign the message + if len(req.Sig) > 0 { + if params.Src, err = api.w.GetPrivateKey(req.Sig); err != nil { + return false, err + } + } + + // Set symmetric key that is used to encrypt the message + if symKeyGiven { + if params.Topic == (TopicType{}) { // topics are mandatory with symmetric encryption + return false, ErrNoTopics + } + if params.KeySym, err = api.w.GetSymKey(req.SymKeyID); err != nil { + return false, err + } + if !validateSymmetricKey(params.KeySym) { + return false, ErrInvalidSymmetricKey + } + } + + // Set asymmetric key that is used to encrypt the message + if pubKeyGiven { + params.Dst = crypto.ToECDSAPub(req.PublicKey) + if !ValidatePublicKey(params.Dst) { + return false, ErrInvalidPublicKey + } + } + + // encrypt and sent message + whisperMsg, err := NewSentMessage(params) + if err != nil { + return false, err + } + + env, err := whisperMsg.Wrap(params) + if err != nil { + return false, err + } + + // send to specific node (skip PoW check) + if len(req.TargetPeer) > 0 { + n, err := discover.ParseNode(req.TargetPeer) + if err != nil { + return false, fmt.Errorf("failed to parse target peer: %s", err) + } + return true, api.w.SendP2PMessage(n.ID[:], env) + } + + // ensure that the message PoW meets the node's minimum accepted PoW + if req.PowTarget < api.w.MinPow() { + return false, ErrTooLowPoW + } + + return true, api.w.Send(env) +} + +//go:generate gencodec -type Criteria -field-override criteriaOverride -out gen_criteria_json.go + +// Criteria holds various filter options for inbound messages. +type Criteria struct { + SymKeyID string `json:"symKeyID"` + PrivateKeyID string `json:"privateKeyID"` + Sig []byte `json:"sig"` + MinPow float64 `json:"minPow"` + Topics []TopicType `json:"topics"` + AllowP2P bool `json:"allowP2P"` +} + +type criteriaOverride struct { + Sig hexutil.Bytes +} + +// Messages set up a subscription that fires events when messages arrive that match +// the given set of criteria. +func (api *PublicWhisperAPI) Messages(ctx context.Context, crit Criteria) (*rpc.Subscription, error) { + var ( + symKeyGiven = len(crit.SymKeyID) > 0 + pubKeyGiven = len(crit.PrivateKeyID) > 0 + err error + ) + + // ensure that the RPC connection supports subscriptions + notifier, supported := rpc.NotifierFromContext(ctx) + if !supported { + return nil, rpc.ErrNotificationsUnsupported + } + + // user must specify either a symmetric or an asymmetric key + if (symKeyGiven && pubKeyGiven) || (!symKeyGiven && !pubKeyGiven) { + return nil, ErrSymAsym + } + + filter := Filter{ + PoW: crit.MinPow, + Messages: make(map[common.Hash]*ReceivedMessage), + AllowP2P: crit.AllowP2P, + } + + if len(crit.Sig) > 0 { + filter.Src = crypto.ToECDSAPub(crit.Sig) + if !ValidatePublicKey(filter.Src) { + return nil, ErrInvalidSigningPubKey + } + } + + for i, bt := range crit.Topics { + if len(bt) == 0 || len(bt) > 4 { + return nil, fmt.Errorf("subscribe: topic %d has wrong size: %d", i, len(bt)) + } + filter.Topics = append(filter.Topics, bt[:]) + } + + // listen for message that are encrypted with the given symmetric key + if symKeyGiven { + if len(filter.Topics) == 0 { + return nil, ErrNoTopics + } + key, err := api.w.GetSymKey(crit.SymKeyID) + if err != nil { + return nil, err + } + if !validateSymmetricKey(key) { + return nil, ErrInvalidSymmetricKey + } + filter.KeySym = key + filter.SymKeyHash = crypto.Keccak256Hash(filter.KeySym) + } + + // listen for messages that are encrypted with the given public key + if pubKeyGiven { + filter.KeyAsym, err = api.w.GetPrivateKey(crit.PrivateKeyID) + if err != nil || filter.KeyAsym == nil { + return nil, ErrInvalidPublicKey + } + } + + id, err := api.w.Subscribe(&filter) + if err != nil { + return nil, err + } + + // create subscription and start waiting for message events + rpcSub := notifier.CreateSubscription() + go func() { + // for now poll internally, refactor whisper internal for channel support + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if filter := api.w.GetFilter(id); filter != nil { + for _, rpcMessage := range toMessage(filter.Retrieve()) { + if err := notifier.Notify(rpcSub.ID, rpcMessage); err != nil { + log.Error("Failed to send notification", "err", err) + } + } + } + case <-rpcSub.Err(): + api.w.Unsubscribe(id) + return + case <-notifier.Closed(): + api.w.Unsubscribe(id) + return + } + } + }() + + return rpcSub, nil +} + +//go:generate gencodec -type Message -field-override messageOverride -out gen_message_json.go + +// Message is the RPC representation of a whisper message. +type Message struct { + Sig []byte `json:"sig,omitempty"` + TTL uint32 `json:"ttl"` + Timestamp uint32 `json:"timestamp"` + Topic TopicType `json:"topic"` + Payload []byte `json:"payload"` + Padding []byte `json:"padding"` + PoW float64 `json:"pow"` + Hash []byte `json:"hash"` + Dst []byte `json:"recipientPublicKey,omitempty"` +} + +type messageOverride struct { + Sig hexutil.Bytes + Payload hexutil.Bytes + Padding hexutil.Bytes + Hash hexutil.Bytes + Dst hexutil.Bytes +} + +// ToWhisperMessage converts an internal message into an API version. +func ToWhisperMessage(message *ReceivedMessage) *Message { + msg := Message{ + Payload: message.Payload, + Padding: message.Padding, + Timestamp: message.Sent, + TTL: message.TTL, + PoW: message.PoW, + Hash: message.EnvelopeHash.Bytes(), + Topic: message.Topic, + } + + if message.Dst != nil { + b := crypto.FromECDSAPub(message.Dst) + if b != nil { + msg.Dst = b + } + } + + if isMessageSigned(message.Raw[0]) { + b := crypto.FromECDSAPub(message.SigToPubKey()) + if b != nil { + msg.Sig = b + } + } + + return &msg +} + +// toMessage converts a set of messages to its RPC representation. +func toMessage(messages []*ReceivedMessage) []*Message { + msgs := make([]*Message, len(messages)) + for i, msg := range messages { + msgs[i] = ToWhisperMessage(msg) + } + return msgs +} + +// GetFilterMessages returns the messages that match the filter criteria and +// are received between the last poll and now. +func (api *PublicWhisperAPI) GetFilterMessages(id string) ([]*Message, error) { + api.mu.Lock() + f := api.w.GetFilter(id) + if f == nil { + api.mu.Unlock() + return nil, fmt.Errorf("filter not found") + } + api.lastUsed[id] = time.Now() + api.mu.Unlock() + + receivedMessages := f.Retrieve() + messages := make([]*Message, 0, len(receivedMessages)) + for _, msg := range receivedMessages { + messages = append(messages, ToWhisperMessage(msg)) + } + + return messages, nil +} + +// DeleteMessageFilter deletes a filter. +func (api *PublicWhisperAPI) DeleteMessageFilter(id string) (bool, error) { + api.mu.Lock() + defer api.mu.Unlock() + + delete(api.lastUsed, id) + return true, api.w.Unsubscribe(id) +} + +// NewMessageFilter creates a new filter that can be used to poll for +// (new) messages that satisfy the given criteria. +func (api *PublicWhisperAPI) NewMessageFilter(req Criteria) (string, error) { + var ( + src *ecdsa.PublicKey + keySym []byte + keyAsym *ecdsa.PrivateKey + topics [][]byte + + symKeyGiven = len(req.SymKeyID) > 0 + asymKeyGiven = len(req.PrivateKeyID) > 0 + + err error + ) + + // user must specify either a symmetric or an asymmetric key + if (symKeyGiven && asymKeyGiven) || (!symKeyGiven && !asymKeyGiven) { + return "", ErrSymAsym + } + + if len(req.Sig) > 0 { + src = crypto.ToECDSAPub(req.Sig) + if !ValidatePublicKey(src) { + return "", ErrInvalidSigningPubKey + } + } + + if symKeyGiven { + if keySym, err = api.w.GetSymKey(req.SymKeyID); err != nil { + return "", err + } + if !validateSymmetricKey(keySym) { + return "", ErrInvalidSymmetricKey + } + } + + if asymKeyGiven { + if keyAsym, err = api.w.GetPrivateKey(req.PrivateKeyID); err != nil { + return "", err + } + } + + if len(req.Topics) > 0 { + topics = make([][]byte, 1) + for _, topic := range req.Topics { + topics = append(topics, topic[:]) + } + } + + f := &Filter{ + Src: src, + KeySym: keySym, + KeyAsym: keyAsym, + PoW: req.MinPow, + AllowP2P: req.AllowP2P, + Topics: topics, + Messages: make(map[common.Hash]*ReceivedMessage), + } + + id, err := api.w.Subscribe(f) + if err != nil { + return "", err + } + + api.mu.Lock() + api.lastUsed[id] = time.Now() + api.mu.Unlock() + + return id, nil +} diff --git a/whisper/whisperv6/benchmarks_test.go b/whisper/whisperv6/benchmarks_test.go new file mode 100644 index 000000000..9f413e7b0 --- /dev/null +++ b/whisper/whisperv6/benchmarks_test.go @@ -0,0 +1,206 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "testing" + + "github.com/ethereum/go-ethereum/crypto" +) + +func BenchmarkDeriveKeyMaterial(b *testing.B) { + for i := 0; i < b.N; i++ { + deriveKeyMaterial([]byte("test"), 0) + } +} + +func BenchmarkEncryptionSym(b *testing.B) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + b.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + for i := 0; i < b.N; i++ { + msg, _ := NewSentMessage(params) + _, err := msg.Wrap(params) + if err != nil { + b.Errorf("failed Wrap with seed %d: %s.", seed, err) + b.Errorf("i = %d, len(msg.Raw) = %d, params.Payload = %d.", i, len(msg.Raw), len(params.Payload)) + return + } + } +} + +func BenchmarkEncryptionAsym(b *testing.B) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + b.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + key, err := crypto.GenerateKey() + if err != nil { + b.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + params.KeySym = nil + params.Dst = &key.PublicKey + + for i := 0; i < b.N; i++ { + msg, _ := NewSentMessage(params) + _, err := msg.Wrap(params) + if err != nil { + b.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + } +} + +func BenchmarkDecryptionSymValid(b *testing.B) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + b.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + msg, _ := NewSentMessage(params) + env, err := msg.Wrap(params) + if err != nil { + b.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + f := Filter{KeySym: params.KeySym} + + for i := 0; i < b.N; i++ { + msg := env.Open(&f) + if msg == nil { + b.Fatalf("failed to open with seed %d.", seed) + } + } +} + +func BenchmarkDecryptionSymInvalid(b *testing.B) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + b.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + msg, _ := NewSentMessage(params) + env, err := msg.Wrap(params) + if err != nil { + b.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + f := Filter{KeySym: []byte("arbitrary stuff here")} + + for i := 0; i < b.N; i++ { + msg := env.Open(&f) + if msg != nil { + b.Fatalf("opened envelope with invalid key, seed: %d.", seed) + } + } +} + +func BenchmarkDecryptionAsymValid(b *testing.B) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + b.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + key, err := crypto.GenerateKey() + if err != nil { + b.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + f := Filter{KeyAsym: key} + params.KeySym = nil + params.Dst = &key.PublicKey + msg, _ := NewSentMessage(params) + env, err := msg.Wrap(params) + if err != nil { + b.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + for i := 0; i < b.N; i++ { + msg := env.Open(&f) + if msg == nil { + b.Fatalf("fail to open, seed: %d.", seed) + } + } +} + +func BenchmarkDecryptionAsymInvalid(b *testing.B) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + b.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + key, err := crypto.GenerateKey() + if err != nil { + b.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + params.KeySym = nil + params.Dst = &key.PublicKey + msg, _ := NewSentMessage(params) + env, err := msg.Wrap(params) + if err != nil { + b.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + key, err = crypto.GenerateKey() + if err != nil { + b.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + f := Filter{KeyAsym: key} + + for i := 0; i < b.N; i++ { + msg := env.Open(&f) + if msg != nil { + b.Fatalf("opened envelope with invalid key, seed: %d.", seed) + } + } +} + +func increment(x []byte) { + for i := 0; i < len(x); i++ { + x[i]++ + if x[i] != 0 { + break + } + } +} + +func BenchmarkPoW(b *testing.B) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + b.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + params.Payload = make([]byte, 32) + params.PoW = 10.0 + params.TTL = 1 + + for i := 0; i < b.N; i++ { + increment(params.Payload) + msg, _ := NewSentMessage(params) + _, err := msg.Wrap(params) + if err != nil { + b.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + } +} diff --git a/whisper/whisperv6/config.go b/whisper/whisperv6/config.go new file mode 100644 index 000000000..d7f817aa2 --- /dev/null +++ b/whisper/whisperv6/config.go @@ -0,0 +1,27 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +type Config struct { + MaxMessageSize uint32 `toml:",omitempty"` + MinimumAcceptedPOW float64 `toml:",omitempty"` +} + +var DefaultConfig = Config{ + MaxMessageSize: DefaultMaxMessageSize, + MinimumAcceptedPOW: DefaultMinimumPoW, +} diff --git a/whisper/whisperv6/doc.go b/whisper/whisperv6/doc.go new file mode 100644 index 000000000..e64dd2f42 --- /dev/null +++ b/whisper/whisperv6/doc.go @@ -0,0 +1,87 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +/* +Package whisper implements the Whisper protocol (version 6). + +Whisper combines aspects of both DHTs and datagram messaging systems (e.g. UDP). +As such it may be likened and compared to both, not dissimilar to the +matter/energy duality (apologies to physicists for the blatant abuse of a +fundamental and beautiful natural principle). + +Whisper is a pure identity-based messaging system. Whisper provides a low-level +(non-application-specific) but easily-accessible API without being based upon +or prejudiced by the low-level hardware attributes and characteristics, +particularly the notion of singular endpoints. +*/ +package whisperv6 + +import ( + "fmt" + "time" +) + +const ( + EnvelopeVersion = uint64(0) + ProtocolVersion = uint64(5) + ProtocolVersionStr = "5.0" + ProtocolName = "shh" + + statusCode = 0 // used by whisper protocol + messagesCode = 1 // normal whisper message + p2pCode = 2 // peer-to-peer message (to be consumed by the peer, but not forwarded any further) + p2pRequestCode = 3 // peer-to-peer message, used by Dapp protocol + NumberOfMessageCodes = 64 + + paddingMask = byte(3) + signatureFlag = byte(4) + + TopicLength = 4 + signatureLength = 65 + aesKeyLength = 32 + AESNonceLength = 12 + keyIdSize = 32 + + MaxMessageSize = uint32(10 * 1024 * 1024) // maximum accepted size of a message. + DefaultMaxMessageSize = uint32(1024 * 1024) + DefaultMinimumPoW = 0.2 + + padSizeLimit = 256 // just an arbitrary number, could be changed without breaking the protocol (must not exceed 2^24) + messageQueueLimit = 1024 + + expirationCycle = time.Second + transmissionCycle = 300 * time.Millisecond + + DefaultTTL = 50 // seconds + SynchAllowance = 10 // seconds +) + +type unknownVersionError uint64 + +func (e unknownVersionError) Error() string { + return fmt.Sprintf("invalid envelope version %d", uint64(e)) +} + +// MailServer represents a mail server, capable of +// archiving the old messages for subsequent delivery +// to the peers. Any implementation must ensure that both +// functions are thread-safe. Also, they must return ASAP. +// DeliverMail should use directMessagesCode for delivery, +// in order to bypass the expiry checks. +type MailServer interface { + Archive(env *Envelope) + DeliverMail(whisperPeer *Peer, request *Envelope) +} diff --git a/whisper/whisperv6/envelope.go b/whisper/whisperv6/envelope.go new file mode 100644 index 000000000..a5f4770b0 --- /dev/null +++ b/whisper/whisperv6/envelope.go @@ -0,0 +1,246 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Contains the Whisper protocol Envelope element. + +package whisperv6 + +import ( + "crypto/ecdsa" + "encoding/binary" + "fmt" + gmath "math" + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/rlp" +) + +// Envelope represents a clear-text data packet to transmit through the Whisper +// network. Its contents may or may not be encrypted and signed. +type Envelope struct { + Version []byte + Expiry uint32 + TTL uint32 + Topic TopicType + AESNonce []byte + Data []byte + EnvNonce uint64 + + pow float64 // Message-specific PoW as described in the Whisper specification. + hash common.Hash // Cached hash of the envelope to avoid rehashing every time. + // Don't access hash directly, use Hash() function instead. +} + +// size returns the size of envelope as it is sent (i.e. public fields only) +func (e *Envelope) size() int { + return 20 + len(e.Version) + len(e.AESNonce) + len(e.Data) +} + +// rlpWithoutNonce returns the RLP encoded envelope contents, except the nonce. +func (e *Envelope) rlpWithoutNonce() []byte { + res, _ := rlp.EncodeToBytes([]interface{}{e.Version, e.Expiry, e.TTL, e.Topic, e.AESNonce, e.Data}) + return res +} + +// NewEnvelope wraps a Whisper message with expiration and destination data +// included into an envelope for network forwarding. +func NewEnvelope(ttl uint32, topic TopicType, aesNonce []byte, msg *sentMessage) *Envelope { + env := Envelope{ + Version: make([]byte, 1), + Expiry: uint32(time.Now().Add(time.Second * time.Duration(ttl)).Unix()), + TTL: ttl, + Topic: topic, + AESNonce: aesNonce, + Data: msg.Raw, + EnvNonce: 0, + } + + if EnvelopeVersion < 256 { + env.Version[0] = byte(EnvelopeVersion) + } else { + panic("please increase the size of Envelope.Version before releasing this version") + } + + return &env +} + +func (e *Envelope) IsSymmetric() bool { + return len(e.AESNonce) > 0 +} + +func (e *Envelope) isAsymmetric() bool { + return !e.IsSymmetric() +} + +func (e *Envelope) Ver() uint64 { + return bytesToUintLittleEndian(e.Version) +} + +// Seal closes the envelope by spending the requested amount of time as a proof +// of work on hashing the data. +func (e *Envelope) Seal(options *MessageParams) error { + var target, bestBit int + if options.PoW == 0 { + // adjust for the duration of Seal() execution only if execution time is predefined unconditionally + e.Expiry += options.WorkTime + } else { + target = e.powToFirstBit(options.PoW) + if target < 1 { + target = 1 + } + } + + buf := make([]byte, 64) + h := crypto.Keccak256(e.rlpWithoutNonce()) + copy(buf[:32], h) + + finish := time.Now().Add(time.Duration(options.WorkTime) * time.Second).UnixNano() + for nonce := uint64(0); time.Now().UnixNano() < finish; { + for i := 0; i < 1024; i++ { + binary.BigEndian.PutUint64(buf[56:], nonce) + d := new(big.Int).SetBytes(crypto.Keccak256(buf)) + firstBit := math.FirstBitSet(d) + if firstBit > bestBit { + e.EnvNonce, bestBit = nonce, firstBit + if target > 0 && bestBit >= target { + return nil + } + } + nonce++ + } + } + + if target > 0 && bestBit < target { + return fmt.Errorf("failed to reach the PoW target, specified pow time (%d seconds) was insufficient", options.WorkTime) + } + + return nil +} + +func (e *Envelope) PoW() float64 { + if e.pow == 0 { + e.calculatePoW(0) + } + return e.pow +} + +func (e *Envelope) calculatePoW(diff uint32) { + buf := make([]byte, 64) + h := crypto.Keccak256(e.rlpWithoutNonce()) + copy(buf[:32], h) + binary.BigEndian.PutUint64(buf[56:], e.EnvNonce) + d := new(big.Int).SetBytes(crypto.Keccak256(buf)) + firstBit := math.FirstBitSet(d) + x := gmath.Pow(2, float64(firstBit)) + x /= float64(e.size()) + x /= float64(e.TTL + diff) + e.pow = x +} + +func (e *Envelope) powToFirstBit(pow float64) int { + x := pow + x *= float64(e.size()) + x *= float64(e.TTL) + bits := gmath.Log2(x) + bits = gmath.Ceil(bits) + return int(bits) +} + +// Hash returns the SHA3 hash of the envelope, calculating it if not yet done. +func (e *Envelope) Hash() common.Hash { + if (e.hash == common.Hash{}) { + encoded, _ := rlp.EncodeToBytes(e) + e.hash = crypto.Keccak256Hash(encoded) + } + return e.hash +} + +// DecodeRLP decodes an Envelope from an RLP data stream. +func (e *Envelope) DecodeRLP(s *rlp.Stream) error { + raw, err := s.Raw() + if err != nil { + return err + } + // The decoding of Envelope uses the struct fields but also needs + // to compute the hash of the whole RLP-encoded envelope. This + // type has the same structure as Envelope but is not an + // rlp.Decoder (does not implement DecodeRLP function). + // Only public members will be encoded. + type rlpenv Envelope + if err := rlp.DecodeBytes(raw, (*rlpenv)(e)); err != nil { + return err + } + e.hash = crypto.Keccak256Hash(raw) + return nil +} + +// OpenAsymmetric tries to decrypt an envelope, potentially encrypted with a particular key. +func (e *Envelope) OpenAsymmetric(key *ecdsa.PrivateKey) (*ReceivedMessage, error) { + message := &ReceivedMessage{Raw: e.Data} + err := message.decryptAsymmetric(key) + switch err { + case nil: + return message, nil + case ecies.ErrInvalidPublicKey: // addressed to somebody else + return nil, err + default: + return nil, fmt.Errorf("unable to open envelope, decrypt failed: %v", err) + } +} + +// OpenSymmetric tries to decrypt an envelope, potentially encrypted with a particular key. +func (e *Envelope) OpenSymmetric(key []byte) (msg *ReceivedMessage, err error) { + msg = &ReceivedMessage{Raw: e.Data} + err = msg.decryptSymmetric(key, e.AESNonce) + if err != nil { + msg = nil + } + return msg, err +} + +// Open tries to decrypt an envelope, and populates the message fields in case of success. +func (e *Envelope) Open(watcher *Filter) (msg *ReceivedMessage) { + if e.isAsymmetric() { + msg, _ = e.OpenAsymmetric(watcher.KeyAsym) + if msg != nil { + msg.Dst = &watcher.KeyAsym.PublicKey + } + } else if e.IsSymmetric() { + msg, _ = e.OpenSymmetric(watcher.KeySym) + if msg != nil { + msg.SymKeyHash = crypto.Keccak256Hash(watcher.KeySym) + } + } + + if msg != nil { + ok := msg.Validate() + if !ok { + return nil + } + msg.Topic = e.Topic + msg.PoW = e.PoW() + msg.TTL = e.TTL + msg.Sent = e.Expiry - e.TTL + msg.EnvelopeHash = e.Hash() + msg.EnvelopeVersion = e.Ver() + } + return msg +} diff --git a/whisper/whisperv6/filter.go b/whisper/whisperv6/filter.go new file mode 100644 index 000000000..5cb371b7d --- /dev/null +++ b/whisper/whisperv6/filter.go @@ -0,0 +1,239 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "crypto/ecdsa" + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" +) + +type Filter struct { + Src *ecdsa.PublicKey // Sender of the message + KeyAsym *ecdsa.PrivateKey // Private Key of recipient + KeySym []byte // Key associated with the Topic + Topics [][]byte // Topics to filter messages with + PoW float64 // Proof of work as described in the Whisper spec + AllowP2P bool // Indicates whether this filter is interested in direct peer-to-peer messages + SymKeyHash common.Hash // The Keccak256Hash of the symmetric key, needed for optimization + + Messages map[common.Hash]*ReceivedMessage + mutex sync.RWMutex +} + +type Filters struct { + watchers map[string]*Filter + whisper *Whisper + mutex sync.RWMutex +} + +func NewFilters(w *Whisper) *Filters { + return &Filters{ + watchers: make(map[string]*Filter), + whisper: w, + } +} + +func (fs *Filters) Install(watcher *Filter) (string, error) { + if watcher.Messages == nil { + watcher.Messages = make(map[common.Hash]*ReceivedMessage) + } + + id, err := GenerateRandomID() + if err != nil { + return "", err + } + + fs.mutex.Lock() + defer fs.mutex.Unlock() + + if fs.watchers[id] != nil { + return "", fmt.Errorf("failed to generate unique ID") + } + + if watcher.expectsSymmetricEncryption() { + watcher.SymKeyHash = crypto.Keccak256Hash(watcher.KeySym) + } + + fs.watchers[id] = watcher + return id, err +} + +func (fs *Filters) Uninstall(id string) bool { + fs.mutex.Lock() + defer fs.mutex.Unlock() + if fs.watchers[id] != nil { + delete(fs.watchers, id) + return true + } + return false +} + +func (fs *Filters) Get(id string) *Filter { + fs.mutex.RLock() + defer fs.mutex.RUnlock() + return fs.watchers[id] +} + +func (fs *Filters) NotifyWatchers(env *Envelope, p2pMessage bool) { + var msg *ReceivedMessage + + fs.mutex.RLock() + defer fs.mutex.RUnlock() + + i := -1 // only used for logging info + for _, watcher := range fs.watchers { + i++ + if p2pMessage && !watcher.AllowP2P { + log.Trace(fmt.Sprintf("msg [%x], filter [%d]: p2p messages are not allowed", env.Hash(), i)) + continue + } + + var match bool + if msg != nil { + match = watcher.MatchMessage(msg) + } else { + match = watcher.MatchEnvelope(env) + if match { + msg = env.Open(watcher) + if msg == nil { + log.Trace("processing message: failed to open", "message", env.Hash().Hex(), "filter", i) + } + } else { + log.Trace("processing message: does not match", "message", env.Hash().Hex(), "filter", i) + } + } + + if match && msg != nil { + log.Trace("processing message: decrypted", "hash", env.Hash().Hex()) + if watcher.Src == nil || IsPubKeyEqual(msg.Src, watcher.Src) { + watcher.Trigger(msg) + } + } + } +} + +func (f *Filter) processEnvelope(env *Envelope) *ReceivedMessage { + if f.MatchEnvelope(env) { + msg := env.Open(f) + if msg != nil { + return msg + } else { + log.Trace("processing envelope: failed to open", "hash", env.Hash().Hex()) + } + } else { + log.Trace("processing envelope: does not match", "hash", env.Hash().Hex()) + } + return nil +} + +func (f *Filter) expectsAsymmetricEncryption() bool { + return f.KeyAsym != nil +} + +func (f *Filter) expectsSymmetricEncryption() bool { + return f.KeySym != nil +} + +func (f *Filter) Trigger(msg *ReceivedMessage) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if _, exist := f.Messages[msg.EnvelopeHash]; !exist { + f.Messages[msg.EnvelopeHash] = msg + } +} + +func (f *Filter) Retrieve() (all []*ReceivedMessage) { + f.mutex.Lock() + defer f.mutex.Unlock() + + all = make([]*ReceivedMessage, 0, len(f.Messages)) + for _, msg := range f.Messages { + all = append(all, msg) + } + + f.Messages = make(map[common.Hash]*ReceivedMessage) // delete old messages + return all +} + +func (f *Filter) MatchMessage(msg *ReceivedMessage) bool { + if f.PoW > 0 && msg.PoW < f.PoW { + return false + } + + if f.expectsAsymmetricEncryption() && msg.isAsymmetricEncryption() { + return IsPubKeyEqual(&f.KeyAsym.PublicKey, msg.Dst) && f.MatchTopic(msg.Topic) + } else if f.expectsSymmetricEncryption() && msg.isSymmetricEncryption() { + return f.SymKeyHash == msg.SymKeyHash && f.MatchTopic(msg.Topic) + } + return false +} + +func (f *Filter) MatchEnvelope(envelope *Envelope) bool { + if f.PoW > 0 && envelope.pow < f.PoW { + return false + } + + if f.expectsAsymmetricEncryption() && envelope.isAsymmetric() { + return f.MatchTopic(envelope.Topic) + } else if f.expectsSymmetricEncryption() && envelope.IsSymmetric() { + return f.MatchTopic(envelope.Topic) + } + return false +} + +func (f *Filter) MatchTopic(topic TopicType) bool { + if len(f.Topics) == 0 { + // any topic matches + return true + } + + for _, bt := range f.Topics { + if matchSingleTopic(topic, bt) { + return true + } + } + return false +} + +func matchSingleTopic(topic TopicType, bt []byte) bool { + if len(bt) > 4 { + bt = bt[:4] + } + + for j, b := range bt { + if topic[j] != b { + return false + } + } + return true +} + +func IsPubKeyEqual(a, b *ecdsa.PublicKey) bool { + if !ValidatePublicKey(a) { + return false + } else if !ValidatePublicKey(b) { + return false + } + // the curve is always the same, just compare the points + return a.X.Cmp(b.X) == 0 && a.Y.Cmp(b.Y) == 0 +} diff --git a/whisper/whisperv6/filter_test.go b/whisper/whisperv6/filter_test.go new file mode 100644 index 000000000..58d90d60c --- /dev/null +++ b/whisper/whisperv6/filter_test.go @@ -0,0 +1,814 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "math/big" + mrand "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" +) + +var seed int64 + +// InitSingleTest should be called in the beginning of every +// test, which uses RNG, in order to make the tests +// reproduciblity independent of their sequence. +func InitSingleTest() { + seed = time.Now().Unix() + mrand.Seed(seed) +} + +func InitDebugTest(i int64) { + seed = i + mrand.Seed(seed) +} + +type FilterTestCase struct { + f *Filter + id string + alive bool + msgCnt int +} + +func generateFilter(t *testing.T, symmetric bool) (*Filter, error) { + var f Filter + f.Messages = make(map[common.Hash]*ReceivedMessage) + + const topicNum = 8 + f.Topics = make([][]byte, topicNum) + for i := 0; i < topicNum; i++ { + f.Topics[i] = make([]byte, 4) + mrand.Read(f.Topics[i][:]) + f.Topics[i][0] = 0x01 + } + + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("generateFilter 1 failed with seed %d.", seed) + return nil, err + } + f.Src = &key.PublicKey + + if symmetric { + f.KeySym = make([]byte, aesKeyLength) + mrand.Read(f.KeySym) + f.SymKeyHash = crypto.Keccak256Hash(f.KeySym) + } else { + f.KeyAsym, err = crypto.GenerateKey() + if err != nil { + t.Fatalf("generateFilter 2 failed with seed %d.", seed) + return nil, err + } + } + + // AcceptP2P & PoW are not set + return &f, nil +} + +func generateTestCases(t *testing.T, SizeTestFilters int) []FilterTestCase { + cases := make([]FilterTestCase, SizeTestFilters) + for i := 0; i < SizeTestFilters; i++ { + f, _ := generateFilter(t, true) + cases[i].f = f + cases[i].alive = (mrand.Int()&int(1) == 0) + } + return cases +} + +func TestInstallFilters(t *testing.T) { + InitSingleTest() + + const SizeTestFilters = 256 + w := New(&Config{}) + filters := NewFilters(w) + tst := generateTestCases(t, SizeTestFilters) + + var err error + var j string + for i := 0; i < SizeTestFilters; i++ { + j, err = filters.Install(tst[i].f) + if err != nil { + t.Fatalf("seed %d: failed to install filter: %s", seed, err) + } + tst[i].id = j + if len(j) != keyIdSize*2 { + t.Fatalf("seed %d: wrong filter id size [%d]", seed, len(j)) + } + } + + for _, testCase := range tst { + if !testCase.alive { + filters.Uninstall(testCase.id) + } + } + + for i, testCase := range tst { + fil := filters.Get(testCase.id) + exist := (fil != nil) + if exist != testCase.alive { + t.Fatalf("seed %d: failed alive: %d, %v, %v", seed, i, exist, testCase.alive) + } + if exist && fil.PoW != testCase.f.PoW { + t.Fatalf("seed %d: failed Get: %d, %v, %v", seed, i, exist, testCase.alive) + } + } +} + +func TestInstallSymKeyGeneratesHash(t *testing.T) { + InitSingleTest() + + w := New(&Config{}) + filters := NewFilters(w) + filter, _ := generateFilter(t, true) + + // save the current SymKeyHash for comparison + initialSymKeyHash := filter.SymKeyHash + + // ensure the SymKeyHash is invalid, for Install to recreate it + var invalid common.Hash + filter.SymKeyHash = invalid + + _, err := filters.Install(filter) + + if err != nil { + t.Fatalf("Error installing the filter: %s", err) + } + + for i, b := range filter.SymKeyHash { + if b != initialSymKeyHash[i] { + t.Fatalf("The filter's symmetric key hash was not properly generated by Install") + } + } +} + +func TestInstallIdenticalFilters(t *testing.T) { + InitSingleTest() + + w := New(&Config{}) + filters := NewFilters(w) + filter1, _ := generateFilter(t, true) + + // Copy the first filter since some of its fields + // are randomly gnerated. + filter2 := &Filter{ + KeySym: filter1.KeySym, + Topics: filter1.Topics, + PoW: filter1.PoW, + AllowP2P: filter1.AllowP2P, + Messages: make(map[common.Hash]*ReceivedMessage), + } + + _, err := filters.Install(filter1) + + if err != nil { + t.Fatalf("Error installing the first filter with seed %d: %s", seed, err) + } + + _, err = filters.Install(filter2) + + if err != nil { + t.Fatalf("Error installing the second filter with seed %d: %s", seed, err) + } + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("Error generating message parameters with seed %d: %s", seed, err) + } + + params.KeySym = filter1.KeySym + params.Topic = BytesToTopic(filter1.Topics[0]) + + filter1.Src = ¶ms.Src.PublicKey + filter2.Src = ¶ms.Src.PublicKey + + sentMessage, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := sentMessage.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + msg := env.Open(filter1) + if msg == nil { + t.Fatalf("failed to Open with filter1") + } + + if !filter1.MatchEnvelope(env) { + t.Fatalf("failed matching with the first filter") + } + + if !filter2.MatchEnvelope(env) { + t.Fatalf("failed matching with the first filter") + } + + if !filter1.MatchMessage(msg) { + t.Fatalf("failed matching with the second filter") + } + + if !filter2.MatchMessage(msg) { + t.Fatalf("failed matching with the second filter") + } +} + +func TestComparePubKey(t *testing.T) { + InitSingleTest() + + key1, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed to generate first key with seed %d: %s.", seed, err) + } + key2, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed to generate second key with seed %d: %s.", seed, err) + } + if IsPubKeyEqual(&key1.PublicKey, &key2.PublicKey) { + t.Fatalf("public keys are equal, seed %d.", seed) + } + + // generate key3 == key1 + mrand.Seed(seed) + key3, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed to generate third key with seed %d: %s.", seed, err) + } + if IsPubKeyEqual(&key1.PublicKey, &key3.PublicKey) { + t.Fatalf("key1 == key3, seed %d.", seed) + } +} + +func TestMatchEnvelope(t *testing.T) { + InitSingleTest() + + fsym, err := generateFilter(t, true) + if err != nil { + t.Fatalf("failed generateFilter with seed %d: %s.", seed, err) + } + + fasym, err := generateFilter(t, false) + if err != nil { + t.Fatalf("failed generateFilter() with seed %d: %s.", seed, err) + } + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + params.Topic[0] = 0xFF // ensure mismatch + + // mismatch with pseudo-random data + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + match := fsym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope symmetric with seed %d.", seed) + } + match = fasym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope asymmetric with seed %d.", seed) + } + + // encrypt symmetrically + i := mrand.Int() % 4 + fsym.Topics[i] = params.Topic[:] + fasym.Topics[i] = params.Topic[:] + msg, err = NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err = msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap() with seed %d: %s.", seed, err) + } + + // symmetric + matching topic: match + match = fsym.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope() symmetric with seed %d.", seed) + } + + // asymmetric + matching topic: mismatch + match = fasym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope() asymmetric with seed %d.", seed) + } + + // symmetric + matching topic + insufficient PoW: mismatch + fsym.PoW = env.PoW() + 1.0 + match = fsym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope(symmetric + matching topic + insufficient PoW) asymmetric with seed %d.", seed) + } + + // symmetric + matching topic + sufficient PoW: match + fsym.PoW = env.PoW() / 2 + match = fsym.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope(symmetric + matching topic + sufficient PoW) with seed %d.", seed) + } + + // symmetric + topics are nil (wildcard): match + prevTopics := fsym.Topics + fsym.Topics = nil + match = fsym.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope(symmetric + topics are nil) with seed %d.", seed) + } + fsym.Topics = prevTopics + + // encrypt asymmetrically + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + params.KeySym = nil + params.Dst = &key.PublicKey + msg, err = NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err = msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap() with seed %d: %s.", seed, err) + } + + // encryption method mismatch + match = fsym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope(encryption method mismatch) with seed %d.", seed) + } + + // asymmetric + mismatching topic: mismatch + match = fasym.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope(asymmetric + mismatching topic) with seed %d.", seed) + } + + // asymmetric + matching topic: match + fasym.Topics[i] = fasym.Topics[i+1] + match = fasym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope(asymmetric + matching topic) with seed %d.", seed) + } + + // asymmetric + filter without topic (wildcard): match + fasym.Topics = nil + match = fasym.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope(asymmetric + filter without topic) with seed %d.", seed) + } + + // asymmetric + insufficient PoW: mismatch + fasym.PoW = env.PoW() + 1.0 + match = fasym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope(asymmetric + insufficient PoW) with seed %d.", seed) + } + + // asymmetric + sufficient PoW: match + fasym.PoW = env.PoW() / 2 + match = fasym.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope(asymmetric + sufficient PoW) with seed %d.", seed) + } + + // filter without topic + envelope without topic: match + env.Topic = TopicType{} + match = fasym.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope(filter without topic + envelope without topic) with seed %d.", seed) + } + + // filter with topic + envelope without topic: mismatch + fasym.Topics = fsym.Topics + match = fasym.MatchEnvelope(env) + if match { + t.Fatalf("failed MatchEnvelope(filter without topic + envelope without topic) with seed %d.", seed) + } +} + +func TestMatchMessageSym(t *testing.T) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + f, err := generateFilter(t, true) + if err != nil { + t.Fatalf("failed generateFilter with seed %d: %s.", seed, err) + } + + const index = 1 + params.KeySym = f.KeySym + params.Topic = BytesToTopic(f.Topics[index]) + + sentMessage, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := sentMessage.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + msg := env.Open(f) + if msg == nil { + t.Fatalf("failed Open with seed %d.", seed) + } + + // Src: match + *f.Src.X = *params.Src.PublicKey.X + *f.Src.Y = *params.Src.PublicKey.Y + if !f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(src match) with seed %d.", seed) + } + + // insufficient PoW: mismatch + f.PoW = msg.PoW + 1.0 + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(insufficient PoW) with seed %d.", seed) + } + + // sufficient PoW: match + f.PoW = msg.PoW / 2 + if !f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(sufficient PoW) with seed %d.", seed) + } + + // topic mismatch + f.Topics[index][0]++ + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(topic mismatch) with seed %d.", seed) + } + f.Topics[index][0]-- + + // key mismatch + f.SymKeyHash[0]++ + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(key mismatch) with seed %d.", seed) + } + f.SymKeyHash[0]-- + + // Src absent: match + f.Src = nil + if !f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(src absent) with seed %d.", seed) + } + + // key hash mismatch + h := f.SymKeyHash + f.SymKeyHash = common.Hash{} + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(key hash mismatch) with seed %d.", seed) + } + f.SymKeyHash = h + if !f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(key hash match) with seed %d.", seed) + } + + // encryption method mismatch + f.KeySym = nil + f.KeyAsym, err = crypto.GenerateKey() + if err != nil { + t.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(encryption method mismatch) with seed %d.", seed) + } +} + +func TestMatchMessageAsym(t *testing.T) { + InitSingleTest() + + f, err := generateFilter(t, false) + if err != nil { + t.Fatalf("failed generateFilter with seed %d: %s.", seed, err) + } + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + const index = 1 + params.Topic = BytesToTopic(f.Topics[index]) + params.Dst = &f.KeyAsym.PublicKey + keySymOrig := params.KeySym + params.KeySym = nil + + sentMessage, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := sentMessage.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + msg := env.Open(f) + if msg == nil { + t.Fatalf("failed to open with seed %d.", seed) + } + + // Src: match + *f.Src.X = *params.Src.PublicKey.X + *f.Src.Y = *params.Src.PublicKey.Y + if !f.MatchMessage(msg) { + t.Fatalf("failed MatchMessage(src match) with seed %d.", seed) + } + + // insufficient PoW: mismatch + f.PoW = msg.PoW + 1.0 + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(insufficient PoW) with seed %d.", seed) + } + + // sufficient PoW: match + f.PoW = msg.PoW / 2 + if !f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(sufficient PoW) with seed %d.", seed) + } + + // topic mismatch + f.Topics[index][0]++ + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(topic mismatch) with seed %d.", seed) + } + f.Topics[index][0]-- + + // key mismatch + prev := *f.KeyAsym.PublicKey.X + zero := *big.NewInt(0) + *f.KeyAsym.PublicKey.X = zero + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(key mismatch) with seed %d.", seed) + } + *f.KeyAsym.PublicKey.X = prev + + // Src absent: match + f.Src = nil + if !f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(src absent) with seed %d.", seed) + } + + // encryption method mismatch + f.KeySym = keySymOrig + f.KeyAsym = nil + if f.MatchMessage(msg) { + t.Fatalf("failed MatchEnvelope(encryption method mismatch) with seed %d.", seed) + } +} + +func cloneFilter(orig *Filter) *Filter { + var clone Filter + clone.Messages = make(map[common.Hash]*ReceivedMessage) + clone.Src = orig.Src + clone.KeyAsym = orig.KeyAsym + clone.KeySym = orig.KeySym + clone.Topics = orig.Topics + clone.PoW = orig.PoW + clone.AllowP2P = orig.AllowP2P + clone.SymKeyHash = orig.SymKeyHash + return &clone +} + +func generateCompatibeEnvelope(t *testing.T, f *Filter) *Envelope { + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + return nil + } + + params.KeySym = f.KeySym + params.Topic = BytesToTopic(f.Topics[2]) + sentMessage, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := sentMessage.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + return nil + } + return env +} + +func TestWatchers(t *testing.T) { + InitSingleTest() + + const NumFilters = 16 + const NumMessages = 256 + var i int + var j uint32 + var e *Envelope + var x, firstID string + var err error + + w := New(&Config{}) + filters := NewFilters(w) + tst := generateTestCases(t, NumFilters) + for i = 0; i < NumFilters; i++ { + tst[i].f.Src = nil + x, err = filters.Install(tst[i].f) + if err != nil { + t.Fatalf("failed to install filter with seed %d: %s.", seed, err) + } + tst[i].id = x + if len(firstID) == 0 { + firstID = x + } + } + + lastID := x + + var envelopes [NumMessages]*Envelope + for i = 0; i < NumMessages; i++ { + j = mrand.Uint32() % NumFilters + e = generateCompatibeEnvelope(t, tst[j].f) + envelopes[i] = e + tst[j].msgCnt++ + } + + for i = 0; i < NumMessages; i++ { + filters.NotifyWatchers(envelopes[i], false) + } + + var total int + var mail []*ReceivedMessage + var count [NumFilters]int + + for i = 0; i < NumFilters; i++ { + mail = tst[i].f.Retrieve() + count[i] = len(mail) + total += len(mail) + } + + if total != NumMessages { + t.Fatalf("failed with seed %d: total = %d, want: %d.", seed, total, NumMessages) + } + + for i = 0; i < NumFilters; i++ { + mail = tst[i].f.Retrieve() + if len(mail) != 0 { + t.Fatalf("failed with seed %d: i = %d.", seed, i) + } + + if tst[i].msgCnt != count[i] { + t.Fatalf("failed with seed %d: count[%d]: get %d, want %d.", seed, i, tst[i].msgCnt, count[i]) + } + } + + // another round with a cloned filter + + clone := cloneFilter(tst[0].f) + filters.Uninstall(lastID) + total = 0 + last := NumFilters - 1 + tst[last].f = clone + filters.Install(clone) + for i = 0; i < NumFilters; i++ { + tst[i].msgCnt = 0 + count[i] = 0 + } + + // make sure that the first watcher receives at least one message + e = generateCompatibeEnvelope(t, tst[0].f) + envelopes[0] = e + tst[0].msgCnt++ + for i = 1; i < NumMessages; i++ { + j = mrand.Uint32() % NumFilters + e = generateCompatibeEnvelope(t, tst[j].f) + envelopes[i] = e + tst[j].msgCnt++ + } + + for i = 0; i < NumMessages; i++ { + filters.NotifyWatchers(envelopes[i], false) + } + + for i = 0; i < NumFilters; i++ { + mail = tst[i].f.Retrieve() + count[i] = len(mail) + total += len(mail) + } + + combined := tst[0].msgCnt + tst[last].msgCnt + if total != NumMessages+count[0] { + t.Fatalf("failed with seed %d: total = %d, count[0] = %d.", seed, total, count[0]) + } + + if combined != count[0] { + t.Fatalf("failed with seed %d: combined = %d, count[0] = %d.", seed, combined, count[0]) + } + + if combined != count[last] { + t.Fatalf("failed with seed %d: combined = %d, count[last] = %d.", seed, combined, count[last]) + } + + for i = 1; i < NumFilters-1; i++ { + mail = tst[i].f.Retrieve() + if len(mail) != 0 { + t.Fatalf("failed with seed %d: i = %d.", seed, i) + } + + if tst[i].msgCnt != count[i] { + t.Fatalf("failed with seed %d: i = %d, get %d, want %d.", seed, i, tst[i].msgCnt, count[i]) + } + } + + // test AcceptP2P + + total = 0 + filters.NotifyWatchers(envelopes[0], true) + + for i = 0; i < NumFilters; i++ { + mail = tst[i].f.Retrieve() + total += len(mail) + } + + if total != 0 { + t.Fatalf("failed with seed %d: total: got %d, want 0.", seed, total) + } + + f := filters.Get(firstID) + if f == nil { + t.Fatalf("failed to get the filter with seed %d.", seed) + } + f.AllowP2P = true + total = 0 + filters.NotifyWatchers(envelopes[0], true) + + for i = 0; i < NumFilters; i++ { + mail = tst[i].f.Retrieve() + total += len(mail) + } + + if total != 1 { + t.Fatalf("failed with seed %d: total: got %d, want 1.", seed, total) + } +} + +func TestVariableTopics(t *testing.T) { + InitSingleTest() + + var match bool + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + f, err := generateFilter(t, true) + if err != nil { + t.Fatalf("failed generateFilter with seed %d: %s.", seed, err) + } + + for i := 0; i < 4; i++ { + arr := make([]byte, i+1, 4) + copy(arr, env.Topic[:i+1]) + + f.Topics[4] = arr + match = f.MatchEnvelope(env) + if !match { + t.Fatalf("failed MatchEnvelope symmetric with seed %d, step %d.", seed, i) + } + + f.Topics[4][i]++ + match = f.MatchEnvelope(env) + if match { + t.Fatalf("MatchEnvelope symmetric with seed %d, step %d: false positive.", seed, i) + } + } +} diff --git a/whisper/whisperv6/gen_criteria_json.go b/whisper/whisperv6/gen_criteria_json.go new file mode 100644 index 000000000..52a4d3cb6 --- /dev/null +++ b/whisper/whisperv6/gen_criteria_json.go @@ -0,0 +1,64 @@ +// Code generated by github.com/fjl/gencodec. DO NOT EDIT. + +package whisperv6 + +import ( + "encoding/json" + + "github.com/ethereum/go-ethereum/common/hexutil" +) + +var _ = (*criteriaOverride)(nil) + +func (c Criteria) MarshalJSON() ([]byte, error) { + type Criteria struct { + SymKeyID string `json:"symKeyID"` + PrivateKeyID string `json:"privateKeyID"` + Sig hexutil.Bytes `json:"sig"` + MinPow float64 `json:"minPow"` + Topics []TopicType `json:"topics"` + AllowP2P bool `json:"allowP2P"` + } + var enc Criteria + enc.SymKeyID = c.SymKeyID + enc.PrivateKeyID = c.PrivateKeyID + enc.Sig = c.Sig + enc.MinPow = c.MinPow + enc.Topics = c.Topics + enc.AllowP2P = c.AllowP2P + return json.Marshal(&enc) +} + +func (c *Criteria) UnmarshalJSON(input []byte) error { + type Criteria struct { + SymKeyID *string `json:"symKeyID"` + PrivateKeyID *string `json:"privateKeyID"` + Sig hexutil.Bytes `json:"sig"` + MinPow *float64 `json:"minPow"` + Topics []TopicType `json:"topics"` + AllowP2P *bool `json:"allowP2P"` + } + var dec Criteria + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + if dec.SymKeyID != nil { + c.SymKeyID = *dec.SymKeyID + } + if dec.PrivateKeyID != nil { + c.PrivateKeyID = *dec.PrivateKeyID + } + if dec.Sig != nil { + c.Sig = dec.Sig + } + if dec.MinPow != nil { + c.MinPow = *dec.MinPow + } + if dec.Topics != nil { + c.Topics = dec.Topics + } + if dec.AllowP2P != nil { + c.AllowP2P = *dec.AllowP2P + } + return nil +} diff --git a/whisper/whisperv6/gen_message_json.go b/whisper/whisperv6/gen_message_json.go new file mode 100644 index 000000000..27b46752b --- /dev/null +++ b/whisper/whisperv6/gen_message_json.go @@ -0,0 +1,82 @@ +// Code generated by github.com/fjl/gencodec. DO NOT EDIT. + +package whisperv6 + +import ( + "encoding/json" + + "github.com/ethereum/go-ethereum/common/hexutil" +) + +var _ = (*messageOverride)(nil) + +func (m Message) MarshalJSON() ([]byte, error) { + type Message struct { + Sig hexutil.Bytes `json:"sig,omitempty"` + TTL uint32 `json:"ttl"` + Timestamp uint32 `json:"timestamp"` + Topic TopicType `json:"topic"` + Payload hexutil.Bytes `json:"payload"` + Padding hexutil.Bytes `json:"padding"` + PoW float64 `json:"pow"` + Hash hexutil.Bytes `json:"hash"` + Dst hexutil.Bytes `json:"recipientPublicKey,omitempty"` + } + var enc Message + enc.Sig = m.Sig + enc.TTL = m.TTL + enc.Timestamp = m.Timestamp + enc.Topic = m.Topic + enc.Payload = m.Payload + enc.Padding = m.Padding + enc.PoW = m.PoW + enc.Hash = m.Hash + enc.Dst = m.Dst + return json.Marshal(&enc) +} + +func (m *Message) UnmarshalJSON(input []byte) error { + type Message struct { + Sig hexutil.Bytes `json:"sig,omitempty"` + TTL *uint32 `json:"ttl"` + Timestamp *uint32 `json:"timestamp"` + Topic *TopicType `json:"topic"` + Payload hexutil.Bytes `json:"payload"` + Padding hexutil.Bytes `json:"padding"` + PoW *float64 `json:"pow"` + Hash hexutil.Bytes `json:"hash"` + Dst hexutil.Bytes `json:"recipientPublicKey,omitempty"` + } + var dec Message + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + if dec.Sig != nil { + m.Sig = dec.Sig + } + if dec.TTL != nil { + m.TTL = *dec.TTL + } + if dec.Timestamp != nil { + m.Timestamp = *dec.Timestamp + } + if dec.Topic != nil { + m.Topic = *dec.Topic + } + if dec.Payload != nil { + m.Payload = dec.Payload + } + if dec.Padding != nil { + m.Padding = dec.Padding + } + if dec.PoW != nil { + m.PoW = *dec.PoW + } + if dec.Hash != nil { + m.Hash = dec.Hash + } + if dec.Dst != nil { + m.Dst = dec.Dst + } + return nil +} diff --git a/whisper/whisperv6/gen_newmessage_json.go b/whisper/whisperv6/gen_newmessage_json.go new file mode 100644 index 000000000..d16011a57 --- /dev/null +++ b/whisper/whisperv6/gen_newmessage_json.go @@ -0,0 +1,88 @@ +// Code generated by github.com/fjl/gencodec. DO NOT EDIT. + +package whisperv6 + +import ( + "encoding/json" + + "github.com/ethereum/go-ethereum/common/hexutil" +) + +var _ = (*newMessageOverride)(nil) + +func (n NewMessage) MarshalJSON() ([]byte, error) { + type NewMessage struct { + SymKeyID string `json:"symKeyID"` + PublicKey hexutil.Bytes `json:"pubKey"` + Sig string `json:"sig"` + TTL uint32 `json:"ttl"` + Topic TopicType `json:"topic"` + Payload hexutil.Bytes `json:"payload"` + Padding hexutil.Bytes `json:"padding"` + PowTime uint32 `json:"powTime"` + PowTarget float64 `json:"powTarget"` + TargetPeer string `json:"targetPeer"` + } + var enc NewMessage + enc.SymKeyID = n.SymKeyID + enc.PublicKey = n.PublicKey + enc.Sig = n.Sig + enc.TTL = n.TTL + enc.Topic = n.Topic + enc.Payload = n.Payload + enc.Padding = n.Padding + enc.PowTime = n.PowTime + enc.PowTarget = n.PowTarget + enc.TargetPeer = n.TargetPeer + return json.Marshal(&enc) +} + +func (n *NewMessage) UnmarshalJSON(input []byte) error { + type NewMessage struct { + SymKeyID *string `json:"symKeyID"` + PublicKey hexutil.Bytes `json:"pubKey"` + Sig *string `json:"sig"` + TTL *uint32 `json:"ttl"` + Topic *TopicType `json:"topic"` + Payload hexutil.Bytes `json:"payload"` + Padding hexutil.Bytes `json:"padding"` + PowTime *uint32 `json:"powTime"` + PowTarget *float64 `json:"powTarget"` + TargetPeer *string `json:"targetPeer"` + } + var dec NewMessage + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + if dec.SymKeyID != nil { + n.SymKeyID = *dec.SymKeyID + } + if dec.PublicKey != nil { + n.PublicKey = dec.PublicKey + } + if dec.Sig != nil { + n.Sig = *dec.Sig + } + if dec.TTL != nil { + n.TTL = *dec.TTL + } + if dec.Topic != nil { + n.Topic = *dec.Topic + } + if dec.Payload != nil { + n.Payload = dec.Payload + } + if dec.Padding != nil { + n.Padding = dec.Padding + } + if dec.PowTime != nil { + n.PowTime = *dec.PowTime + } + if dec.PowTarget != nil { + n.PowTarget = *dec.PowTarget + } + if dec.TargetPeer != nil { + n.TargetPeer = *dec.TargetPeer + } + return nil +} diff --git a/whisper/whisperv6/message.go b/whisper/whisperv6/message.go new file mode 100644 index 000000000..0815f07a2 --- /dev/null +++ b/whisper/whisperv6/message.go @@ -0,0 +1,352 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Contains the Whisper protocol Message element. + +package whisperv6 + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + crand "crypto/rand" + "encoding/binary" + "errors" + "strconv" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/log" +) + +// Options specifies the exact way a message should be wrapped into an Envelope. +type MessageParams struct { + TTL uint32 + Src *ecdsa.PrivateKey + Dst *ecdsa.PublicKey + KeySym []byte + Topic TopicType + WorkTime uint32 + PoW float64 + Payload []byte + Padding []byte +} + +// SentMessage represents an end-user data packet to transmit through the +// Whisper protocol. These are wrapped into Envelopes that need not be +// understood by intermediate nodes, just forwarded. +type sentMessage struct { + Raw []byte +} + +// ReceivedMessage represents a data packet to be received through the +// Whisper protocol. +type ReceivedMessage struct { + Raw []byte + + Payload []byte + Padding []byte + Signature []byte + + PoW float64 // Proof of work as described in the Whisper spec + Sent uint32 // Time when the message was posted into the network + TTL uint32 // Maximum time to live allowed for the message + Src *ecdsa.PublicKey // Message recipient (identity used to decode the message) + Dst *ecdsa.PublicKey // Message recipient (identity used to decode the message) + Topic TopicType + + SymKeyHash common.Hash // The Keccak256Hash of the key, associated with the Topic + EnvelopeHash common.Hash // Message envelope hash to act as a unique id + EnvelopeVersion uint64 +} + +func isMessageSigned(flags byte) bool { + return (flags & signatureFlag) != 0 +} + +func (msg *ReceivedMessage) isSymmetricEncryption() bool { + return msg.SymKeyHash != common.Hash{} +} + +func (msg *ReceivedMessage) isAsymmetricEncryption() bool { + return msg.Dst != nil +} + +// NewMessage creates and initializes a non-signed, non-encrypted Whisper message. +func NewSentMessage(params *MessageParams) (*sentMessage, error) { + msg := sentMessage{} + msg.Raw = make([]byte, 1, len(params.Payload)+len(params.Padding)+signatureLength+padSizeLimit) + msg.Raw[0] = 0 // set all the flags to zero + err := msg.appendPadding(params) + if err != nil { + return nil, err + } + msg.Raw = append(msg.Raw, params.Payload...) + return &msg, nil +} + +// getSizeOfLength returns the number of bytes necessary to encode the entire size padding (including these bytes) +func getSizeOfLength(b []byte) (sz int, err error) { + sz = intSize(len(b)) // first iteration + sz = intSize(len(b) + sz) // second iteration + if sz > 3 { + err = errors.New("oversized padding parameter") + } + return sz, err +} + +// sizeOfIntSize returns minimal number of bytes necessary to encode an integer value +func intSize(i int) (s int) { + for s = 1; i >= 256; s++ { + i /= 256 + } + return s +} + +// appendPadding appends the pseudorandom padding bytes and sets the padding flag. +// The last byte contains the size of padding (thus, its size must not exceed 256). +func (msg *sentMessage) appendPadding(params *MessageParams) error { + rawSize := len(params.Payload) + 1 + if params.Src != nil { + rawSize += signatureLength + } + odd := rawSize % padSizeLimit + + if len(params.Padding) != 0 { + padSize := len(params.Padding) + padLengthSize, err := getSizeOfLength(params.Padding) + if err != nil { + return err + } + totalPadSize := padSize + padLengthSize + buf := make([]byte, 8) + binary.LittleEndian.PutUint32(buf, uint32(totalPadSize)) + buf = buf[:padLengthSize] + msg.Raw = append(msg.Raw, buf...) + msg.Raw = append(msg.Raw, params.Padding...) + msg.Raw[0] |= byte(padLengthSize) // number of bytes indicating the padding size + } else if odd != 0 { + totalPadSize := padSizeLimit - odd + if totalPadSize > 255 { + // this algorithm is only valid if padSizeLimit < 256. + // if padSizeLimit will ever change, please fix the algorithm + // (please see also ReceivedMessage.extractPadding() function). + panic("please fix the padding algorithm before releasing new version") + } + buf := make([]byte, totalPadSize) + _, err := crand.Read(buf[1:]) + if err != nil { + return err + } + if totalPadSize > 6 && !validateSymmetricKey(buf) { + return errors.New("failed to generate random padding of size " + strconv.Itoa(totalPadSize)) + } + buf[0] = byte(totalPadSize) + msg.Raw = append(msg.Raw, buf...) + msg.Raw[0] |= byte(0x1) // number of bytes indicating the padding size + } + return nil +} + +// sign calculates and sets the cryptographic signature for the message, +// also setting the sign flag. +func (msg *sentMessage) sign(key *ecdsa.PrivateKey) error { + if isMessageSigned(msg.Raw[0]) { + // this should not happen, but no reason to panic + log.Error("failed to sign the message: already signed") + return nil + } + + msg.Raw[0] |= signatureFlag + hash := crypto.Keccak256(msg.Raw) + signature, err := crypto.Sign(hash, key) + if err != nil { + msg.Raw[0] &= ^signatureFlag // clear the flag + return err + } + msg.Raw = append(msg.Raw, signature...) + return nil +} + +// encryptAsymmetric encrypts a message with a public key. +func (msg *sentMessage) encryptAsymmetric(key *ecdsa.PublicKey) error { + if !ValidatePublicKey(key) { + return errors.New("invalid public key provided for asymmetric encryption") + } + encrypted, err := ecies.Encrypt(crand.Reader, ecies.ImportECDSAPublic(key), msg.Raw, nil, nil) + if err == nil { + msg.Raw = encrypted + } + return err +} + +// encryptSymmetric encrypts a message with a topic key, using AES-GCM-256. +// nonce size should be 12 bytes (see cipher.gcmStandardNonceSize). +func (msg *sentMessage) encryptSymmetric(key []byte) (nonce []byte, err error) { + if !validateSymmetricKey(key) { + return nil, errors.New("invalid key provided for symmetric encryption") + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + // never use more than 2^32 random nonces with a given key + nonce = make([]byte, aesgcm.NonceSize()) + _, err = crand.Read(nonce) + if err != nil { + return nil, err + } else if !validateSymmetricKey(nonce) { + return nil, errors.New("crypto/rand failed to generate nonce") + } + + msg.Raw = aesgcm.Seal(nil, nonce, msg.Raw, nil) + return nonce, nil +} + +// Wrap bundles the message into an Envelope to transmit over the network. +func (msg *sentMessage) Wrap(options *MessageParams) (envelope *Envelope, err error) { + if options.TTL == 0 { + options.TTL = DefaultTTL + } + if options.Src != nil { + if err = msg.sign(options.Src); err != nil { + return nil, err + } + } + var nonce []byte + if options.Dst != nil { + err = msg.encryptAsymmetric(options.Dst) + } else if options.KeySym != nil { + nonce, err = msg.encryptSymmetric(options.KeySym) + } else { + err = errors.New("unable to encrypt the message: neither symmetric nor assymmetric key provided") + } + if err != nil { + return nil, err + } + + envelope = NewEnvelope(options.TTL, options.Topic, nonce, msg) + if err = envelope.Seal(options); err != nil { + return nil, err + } + return envelope, nil +} + +// decryptSymmetric decrypts a message with a topic key, using AES-GCM-256. +// nonce size should be 12 bytes (see cipher.gcmStandardNonceSize). +func (msg *ReceivedMessage) decryptSymmetric(key []byte, nonce []byte) error { + block, err := aes.NewCipher(key) + if err != nil { + return err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return err + } + if len(nonce) != aesgcm.NonceSize() { + log.Error("decrypting the message", "AES nonce size", len(nonce)) + return errors.New("wrong AES nonce size") + } + decrypted, err := aesgcm.Open(nil, nonce, msg.Raw, nil) + if err != nil { + return err + } + msg.Raw = decrypted + return nil +} + +// decryptAsymmetric decrypts an encrypted payload with a private key. +func (msg *ReceivedMessage) decryptAsymmetric(key *ecdsa.PrivateKey) error { + decrypted, err := ecies.ImportECDSA(key).Decrypt(crand.Reader, msg.Raw, nil, nil) + if err == nil { + msg.Raw = decrypted + } + return err +} + +// Validate checks the validity and extracts the fields in case of success +func (msg *ReceivedMessage) Validate() bool { + end := len(msg.Raw) + if end < 1 { + return false + } + + if isMessageSigned(msg.Raw[0]) { + end -= signatureLength + if end <= 1 { + return false + } + msg.Signature = msg.Raw[end:] + msg.Src = msg.SigToPubKey() + if msg.Src == nil { + return false + } + } + + padSize, ok := msg.extractPadding(end) + if !ok { + return false + } + + msg.Payload = msg.Raw[1+padSize : end] + return true +} + +// extractPadding extracts the padding from raw message. +// although we don't support sending messages with padding size +// exceeding 255 bytes, such messages are perfectly valid, and +// can be successfully decrypted. +func (msg *ReceivedMessage) extractPadding(end int) (int, bool) { + paddingSize := 0 + sz := int(msg.Raw[0] & paddingMask) // number of bytes indicating the entire size of padding (including these bytes) + // could be zero -- it means no padding + if sz != 0 { + paddingSize = int(bytesToUintLittleEndian(msg.Raw[1 : 1+sz])) + if paddingSize < sz || paddingSize+1 > end { + return 0, false + } + msg.Padding = msg.Raw[1+sz : 1+paddingSize] + } + return paddingSize, true +} + +// Recover retrieves the public key of the message signer. +func (msg *ReceivedMessage) SigToPubKey() *ecdsa.PublicKey { + defer func() { recover() }() // in case of invalid signature + + pub, err := crypto.SigToPub(msg.hash(), msg.Signature) + if err != nil { + log.Error("failed to recover public key from signature", "err", err) + return nil + } + return pub +} + +// hash calculates the SHA3 checksum of the message flags, payload and padding. +func (msg *ReceivedMessage) hash() []byte { + if isMessageSigned(msg.Raw[0]) { + sz := len(msg.Raw) - signatureLength + return crypto.Keccak256(msg.Raw[:sz]) + } + return crypto.Keccak256(msg.Raw) +} diff --git a/whisper/whisperv6/message_test.go b/whisper/whisperv6/message_test.go new file mode 100644 index 000000000..912b90f14 --- /dev/null +++ b/whisper/whisperv6/message_test.go @@ -0,0 +1,415 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "bytes" + mrand "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" +) + +func generateMessageParams() (*MessageParams, error) { + // set all the parameters except p.Dst and p.Padding + + buf := make([]byte, 4) + mrand.Read(buf) + sz := mrand.Intn(400) + + var p MessageParams + p.PoW = 0.01 + p.WorkTime = 1 + p.TTL = uint32(mrand.Intn(1024)) + p.Payload = make([]byte, sz) + p.KeySym = make([]byte, aesKeyLength) + mrand.Read(p.Payload) + mrand.Read(p.KeySym) + p.Topic = BytesToTopic(buf) + + var err error + p.Src, err = crypto.GenerateKey() + if err != nil { + return nil, err + } + + return &p, nil +} + +func singleMessageTest(t *testing.T, symmetric bool) { + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + + if !symmetric { + params.KeySym = nil + params.Dst = &key.PublicKey + } + + text := make([]byte, 0, 512) + text = append(text, params.Payload...) + + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + var decrypted *ReceivedMessage + if symmetric { + decrypted, err = env.OpenSymmetric(params.KeySym) + } else { + decrypted, err = env.OpenAsymmetric(key) + } + + if err != nil { + t.Fatalf("failed to encrypt with seed %d: %s.", seed, err) + } + + if !decrypted.Validate() { + t.Fatalf("failed to validate with seed %d.", seed) + } + + if !bytes.Equal(text, decrypted.Payload) { + t.Fatalf("failed with seed %d: compare payload.", seed) + } + if !isMessageSigned(decrypted.Raw[0]) { + t.Fatalf("failed with seed %d: unsigned.", seed) + } + if len(decrypted.Signature) != signatureLength { + t.Fatalf("failed with seed %d: signature len %d.", seed, len(decrypted.Signature)) + } + if !IsPubKeyEqual(decrypted.Src, ¶ms.Src.PublicKey) { + t.Fatalf("failed with seed %d: signature mismatch.", seed) + } +} + +func TestMessageEncryption(t *testing.T) { + InitSingleTest() + + var symmetric bool + for i := 0; i < 256; i++ { + singleMessageTest(t, symmetric) + symmetric = !symmetric + } +} + +func TestMessageWrap(t *testing.T) { + seed = int64(1777444222) + mrand.Seed(seed) + target := 128.0 + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + params.TTL = 1 + params.WorkTime = 12 + params.PoW = target + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + pow := env.PoW() + if pow < target { + t.Fatalf("failed Wrap with seed %d: pow < target (%f vs. %f).", seed, pow, target) + } + + // set PoW target too high, expect error + msg2, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + params.TTL = 1000000 + params.WorkTime = 1 + params.PoW = 10000000.0 + _, err = msg2.Wrap(params) + if err == nil { + t.Fatalf("unexpectedly reached the PoW target with seed %d.", seed) + } +} + +func TestMessageSeal(t *testing.T) { + // this test depends on deterministic choice of seed (1976726903) + seed = int64(1976726903) + mrand.Seed(seed) + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + params.TTL = 1 + aesnonce := make([]byte, 12) + mrand.Read(aesnonce) + + env := NewEnvelope(params.TTL, params.Topic, aesnonce, msg) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + env.Expiry = uint32(seed) // make it deterministic + target := 32.0 + params.WorkTime = 4 + params.PoW = target + env.Seal(params) + + env.calculatePoW(0) + pow := env.PoW() + if pow < target { + t.Fatalf("failed Wrap with seed %d: pow < target (%f vs. %f).", seed, pow, target) + } + + params.WorkTime = 1 + params.PoW = 1000000000.0 + env.Seal(params) + env.calculatePoW(0) + pow = env.PoW() + if pow < 2*target { + t.Fatalf("failed Wrap with seed %d: pow too small %f.", seed, pow) + } +} + +func TestEnvelopeOpen(t *testing.T) { + InitSingleTest() + + var symmetric bool + for i := 0; i < 256; i++ { + singleEnvelopeOpenTest(t, symmetric) + symmetric = !symmetric + } +} + +func singleEnvelopeOpenTest(t *testing.T, symmetric bool) { + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed GenerateKey with seed %d: %s.", seed, err) + } + + if !symmetric { + params.KeySym = nil + params.Dst = &key.PublicKey + } + + text := make([]byte, 0, 512) + text = append(text, params.Payload...) + + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + f := Filter{KeyAsym: key, KeySym: params.KeySym} + decrypted := env.Open(&f) + if decrypted == nil { + t.Fatalf("failed to open with seed %d.", seed) + } + + if !bytes.Equal(text, decrypted.Payload) { + t.Fatalf("failed with seed %d: compare payload.", seed) + } + if !isMessageSigned(decrypted.Raw[0]) { + t.Fatalf("failed with seed %d: unsigned.", seed) + } + if len(decrypted.Signature) != signatureLength { + t.Fatalf("failed with seed %d: signature len %d.", seed, len(decrypted.Signature)) + } + if !IsPubKeyEqual(decrypted.Src, ¶ms.Src.PublicKey) { + t.Fatalf("failed with seed %d: signature mismatch.", seed) + } + if decrypted.isAsymmetricEncryption() == symmetric { + t.Fatalf("failed with seed %d: asymmetric %v vs. %v.", seed, decrypted.isAsymmetricEncryption(), symmetric) + } + if decrypted.isSymmetricEncryption() != symmetric { + t.Fatalf("failed with seed %d: symmetric %v vs. %v.", seed, decrypted.isSymmetricEncryption(), symmetric) + } + if !symmetric { + if decrypted.Dst == nil { + t.Fatalf("failed with seed %d: dst is nil.", seed) + } + if !IsPubKeyEqual(decrypted.Dst, &key.PublicKey) { + t.Fatalf("failed with seed %d: Dst.", seed) + } + } +} + +func TestEncryptWithZeroKey(t *testing.T) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + params.KeySym = make([]byte, aesKeyLength) + _, err = msg.Wrap(params) + if err == nil { + t.Fatalf("wrapped with zero key, seed: %d.", seed) + } + + params, err = generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + msg, err = NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + params.KeySym = make([]byte, 0) + _, err = msg.Wrap(params) + if err == nil { + t.Fatalf("wrapped with empty key, seed: %d.", seed) + } + + params, err = generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + msg, err = NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + params.KeySym = nil + _, err = msg.Wrap(params) + if err == nil { + t.Fatalf("wrapped with nil key, seed: %d.", seed) + } +} + +func TestRlpEncode(t *testing.T) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("wrapped with zero key, seed: %d.", seed) + } + + raw, err := rlp.EncodeToBytes(env) + if err != nil { + t.Fatalf("RLP encode failed: %s.", err) + } + + var decoded Envelope + rlp.DecodeBytes(raw, &decoded) + if err != nil { + t.Fatalf("RLP decode failed: %s.", err) + } + + he := env.Hash() + hd := decoded.Hash() + + if he != hd { + t.Fatalf("Hashes are not equal: %x vs. %x", he, hd) + } +} + +func singlePaddingTest(t *testing.T, padSize int) { + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d and sz=%d: %s.", seed, padSize, err) + } + params.Padding = make([]byte, padSize) + params.PoW = 0.0000000001 + pad := make([]byte, padSize) + _, err = mrand.Read(pad) + if err != nil { + t.Fatalf("padding is not generated (seed %d): %s", seed, err) + } + n := copy(params.Padding, pad) + if n != padSize { + t.Fatalf("padding is not copied (seed %d): %s", seed, err) + } + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed to wrap, seed: %d and sz=%d.", seed, padSize) + } + f := Filter{KeySym: params.KeySym} + decrypted := env.Open(&f) + if decrypted == nil { + t.Fatalf("failed to open, seed and sz=%d: %d.", seed, padSize) + } + if !bytes.Equal(pad, decrypted.Padding) { + t.Fatalf("padding is not retireved as expected with seed %d and sz=%d:\n[%x]\n[%x].", seed, padSize, pad, decrypted.Padding) + } +} + +func TestPadding(t *testing.T) { + InitSingleTest() + + for i := 1; i < 260; i++ { + singlePaddingTest(t, i) + } + + lim := 256 * 256 + for i := lim - 5; i < lim+2; i++ { + singlePaddingTest(t, i) + } + + for i := 0; i < 256; i++ { + n := mrand.Intn(256*254) + 256 + singlePaddingTest(t, n) + } + + for i := 0; i < 256; i++ { + n := mrand.Intn(256*1024) + 256*256 + singlePaddingTest(t, n) + } +} diff --git a/whisper/whisperv6/peer.go b/whisper/whisperv6/peer.go new file mode 100644 index 000000000..ac7b3b12b --- /dev/null +++ b/whisper/whisperv6/peer.go @@ -0,0 +1,174 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "fmt" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" + set "gopkg.in/fatih/set.v0" +) + +// peer represents a whisper protocol peer connection. +type Peer struct { + host *Whisper + peer *p2p.Peer + ws p2p.MsgReadWriter + trusted bool + + known *set.Set // Messages already known by the peer to avoid wasting bandwidth + + quit chan struct{} +} + +// newPeer creates a new whisper peer object, but does not run the handshake itself. +func newPeer(host *Whisper, remote *p2p.Peer, rw p2p.MsgReadWriter) *Peer { + return &Peer{ + host: host, + peer: remote, + ws: rw, + trusted: false, + known: set.New(), + quit: make(chan struct{}), + } +} + +// start initiates the peer updater, periodically broadcasting the whisper packets +// into the network. +func (p *Peer) start() { + go p.update() + log.Trace("start", "peer", p.ID()) +} + +// stop terminates the peer updater, stopping message forwarding to it. +func (p *Peer) stop() { + close(p.quit) + log.Trace("stop", "peer", p.ID()) +} + +// handshake sends the protocol initiation status message to the remote peer and +// verifies the remote status too. +func (p *Peer) handshake() error { + // Send the handshake status message asynchronously + errc := make(chan error, 1) + go func() { + errc <- p2p.Send(p.ws, statusCode, ProtocolVersion) + }() + // Fetch the remote status packet and verify protocol match + packet, err := p.ws.ReadMsg() + if err != nil { + return err + } + if packet.Code != statusCode { + return fmt.Errorf("peer [%x] sent packet %x before status packet", p.ID(), packet.Code) + } + s := rlp.NewStream(packet.Payload, uint64(packet.Size)) + peerVersion, err := s.Uint() + if err != nil { + return fmt.Errorf("peer [%x] sent bad status message: %v", p.ID(), err) + } + if peerVersion != ProtocolVersion { + return fmt.Errorf("peer [%x]: protocol version mismatch %d != %d", p.ID(), peerVersion, ProtocolVersion) + } + // Wait until out own status is consumed too + if err := <-errc; err != nil { + return fmt.Errorf("peer [%x] failed to send status packet: %v", p.ID(), err) + } + return nil +} + +// update executes periodic operations on the peer, including message transmission +// and expiration. +func (p *Peer) update() { + // Start the tickers for the updates + expire := time.NewTicker(expirationCycle) + transmit := time.NewTicker(transmissionCycle) + + // Loop and transmit until termination is requested + for { + select { + case <-expire.C: + p.expire() + + case <-transmit.C: + if err := p.broadcast(); err != nil { + log.Trace("broadcast failed", "reason", err, "peer", p.ID()) + return + } + + case <-p.quit: + return + } + } +} + +// mark marks an envelope known to the peer so that it won't be sent back. +func (peer *Peer) mark(envelope *Envelope) { + peer.known.Add(envelope.Hash()) +} + +// marked checks if an envelope is already known to the remote peer. +func (peer *Peer) marked(envelope *Envelope) bool { + return peer.known.Has(envelope.Hash()) +} + +// expire iterates over all the known envelopes in the host and removes all +// expired (unknown) ones from the known list. +func (peer *Peer) expire() { + unmark := make(map[common.Hash]struct{}) + peer.known.Each(func(v interface{}) bool { + if !peer.host.isEnvelopeCached(v.(common.Hash)) { + unmark[v.(common.Hash)] = struct{}{} + } + return true + }) + // Dump all known but no longer cached + for hash := range unmark { + peer.known.Remove(hash) + } +} + +// broadcast iterates over the collection of envelopes and transmits yet unknown +// ones over the network. +func (p *Peer) broadcast() error { + var cnt int + envelopes := p.host.Envelopes() + for _, envelope := range envelopes { + if !p.marked(envelope) { + err := p2p.Send(p.ws, messagesCode, envelope) + if err != nil { + return err + } else { + p.mark(envelope) + cnt++ + } + } + } + if cnt > 0 { + log.Trace("broadcast", "num. messages", cnt) + } + return nil +} + +func (p *Peer) ID() []byte { + id := p.peer.ID() + return id[:] +} diff --git a/whisper/whisperv6/peer_test.go b/whisper/whisperv6/peer_test.go new file mode 100644 index 000000000..39a4ab198 --- /dev/null +++ b/whisper/whisperv6/peer_test.go @@ -0,0 +1,306 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "bytes" + "crypto/ecdsa" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/nat" +) + +var keys []string = []string{ + "d49dcf37238dc8a7aac57dc61b9fee68f0a97f062968978b9fafa7d1033d03a9", + "73fd6143c48e80ed3c56ea159fe7494a0b6b393a392227b422f4c3e8f1b54f98", + "119dd32adb1daa7a4c7bf77f847fb28730785aa92947edf42fdd997b54de40dc", + "deeda8709dea935bb772248a3144dea449ffcc13e8e5a1fd4ef20ce4e9c87837", + "5bd208a079633befa349441bdfdc4d85ba9bd56081525008380a63ac38a407cf", + "1d27fb4912002d58a2a42a50c97edb05c1b3dffc665dbaa42df1fe8d3d95c9b5", + "15def52800c9d6b8ca6f3066b7767a76afc7b611786c1276165fbc61636afb68", + "51be6ab4b2dc89f251ff2ace10f3c1cc65d6855f3e083f91f6ff8efdfd28b48c", + "ef1ef7441bf3c6419b162f05da6037474664f198b58db7315a6f4de52414b4a0", + "09bdf6985aabc696dc1fbeb5381aebd7a6421727343872eb2fadfc6d82486fd9", + "15d811bf2e01f99a224cdc91d0cf76cea08e8c67905c16fee9725c9be71185c4", + "2f83e45cf1baaea779789f755b7da72d8857aeebff19362dd9af31d3c9d14620", + "73f04e34ac6532b19c2aae8f8e52f38df1ac8f5cd10369f92325b9b0494b0590", + "1e2e07b69e5025537fb73770f483dc8d64f84ae3403775ef61cd36e3faf162c1", + "8963d9bbb3911aac6d30388c786756b1c423c4fbbc95d1f96ddbddf39809e43a", + "0422da85abc48249270b45d8de38a4cc3c02032ede1fcf0864a51092d58a2f1f", + "8ae5c15b0e8c7cade201fdc149831aa9b11ff626a7ffd27188886cc108ad0fa8", + "acd8f5a71d4aecfcb9ad00d32aa4bcf2a602939b6a9dd071bab443154184f805", + "a285a922125a7481600782ad69debfbcdb0316c1e97c267aff29ef50001ec045", + "28fd4eee78c6cd4bf78f39f8ab30c32c67c24a6223baa40e6f9c9a0e1de7cef5", + "c5cca0c9e6f043b288c6f1aef448ab59132dab3e453671af5d0752961f013fc7", + "46df99b051838cb6f8d1b73f232af516886bd8c4d0ee07af9a0a033c391380fd", + "c6a06a53cbaadbb432884f36155c8f3244e244881b5ee3e92e974cfa166d793f", + "783b90c75c63dc72e2f8d11b6f1b4de54d63825330ec76ee8db34f06b38ea211", + "9450038f10ca2c097a8013e5121b36b422b95b04892232f930a29292d9935611", + "e215e6246ed1cfdcf7310d4d8cdbe370f0d6a8371e4eb1089e2ae05c0e1bc10f", + "487110939ed9d64ebbc1f300adeab358bc58875faf4ca64990fbd7fe03b78f2b", + "824a70ea76ac81366da1d4f4ac39de851c8ac49dca456bb3f0a186ceefa269a5", + "ba8f34fa40945560d1006a328fe70c42e35cc3d1017e72d26864cd0d1b150f15", + "30a5dfcfd144997f428901ea88a43c8d176b19c79dde54cc58eea001aa3d246c", + "de59f7183aca39aa245ce66a05245fecfc7e2c75884184b52b27734a4a58efa2", + "92629e2ff5f0cb4f5f08fffe0f64492024d36f045b901efb271674b801095c5a", + "7184c1701569e3a4c4d2ddce691edd983b81e42e09196d332e1ae2f1e062cff4", +} + +const NumNodes = 16 // must not exceed the number of keys (32) + +type TestData struct { + counter [NumNodes]int + mutex sync.RWMutex +} + +type TestNode struct { + shh *Whisper + id *ecdsa.PrivateKey + server *p2p.Server + filerId string +} + +var result TestData +var nodes [NumNodes]*TestNode +var sharedKey []byte = []byte("some arbitrary data here") +var sharedTopic TopicType = TopicType{0xF, 0x1, 0x2, 0} +var expectedMessage []byte = []byte("per rectum ad astra") + +// This test does the following: +// 1. creates a chain of whisper nodes, +// 2. installs the filters with shared (predefined) parameters, +// 3. each node sends a number of random (undecryptable) messages, +// 4. first node sends one expected (decryptable) message, +// 5. checks if each node have received and decrypted exactly one message. +func TestSimulation(t *testing.T) { + initialize(t) + + for i := 0; i < NumNodes; i++ { + sendMsg(t, false, i) + } + + sendMsg(t, true, 0) + checkPropagation(t) + stopServers() +} + +func initialize(t *testing.T) { + var err error + ip := net.IPv4(127, 0, 0, 1) + port0 := 30303 + + for i := 0; i < NumNodes; i++ { + var node TestNode + node.shh = New(&DefaultConfig) + node.shh.SetMinimumPoW(0.00000001) + node.shh.Start(nil) + topics := make([]TopicType, 0) + topics = append(topics, sharedTopic) + f := Filter{KeySym: sharedKey} + f.Topics = [][]byte{topics[0][:]} + node.filerId, err = node.shh.Subscribe(&f) + if err != nil { + t.Fatalf("failed to install the filter: %s.", err) + } + node.id, err = crypto.HexToECDSA(keys[i]) + if err != nil { + t.Fatalf("failed convert the key: %s.", keys[i]) + } + port := port0 + i + addr := fmt.Sprintf(":%d", port) // e.g. ":30303" + name := common.MakeName("whisper-go", "2.0") + var peers []*discover.Node + if i > 0 { + peerNodeId := nodes[i-1].id + peerPort := uint16(port - 1) + peerNode := discover.PubkeyID(&peerNodeId.PublicKey) + peer := discover.NewNode(peerNode, ip, peerPort, peerPort) + peers = append(peers, peer) + } + + node.server = &p2p.Server{ + Config: p2p.Config{ + PrivateKey: node.id, + MaxPeers: NumNodes/2 + 1, + Name: name, + Protocols: node.shh.Protocols(), + ListenAddr: addr, + NAT: nat.Any(), + BootstrapNodes: peers, + StaticNodes: peers, + TrustedNodes: peers, + }, + } + + err = node.server.Start() + if err != nil { + t.Fatalf("failed to start server %d.", i) + } + + nodes[i] = &node + } +} + +func stopServers() { + for i := 0; i < NumNodes; i++ { + n := nodes[i] + if n != nil { + n.shh.Unsubscribe(n.filerId) + n.shh.Stop() + n.server.Stop() + } + } +} + +func checkPropagation(t *testing.T) { + if t.Failed() { + return + } + + const cycle = 100 + const iterations = 100 + + for j := 0; j < iterations; j++ { + time.Sleep(cycle * time.Millisecond) + + for i := 0; i < NumNodes; i++ { + f := nodes[i].shh.GetFilter(nodes[i].filerId) + if f == nil { + t.Fatalf("failed to get filterId %s from node %d.", nodes[i].filerId, i) + } + + mail := f.Retrieve() + if !validateMail(t, i, mail) { + return + } + + if isTestComplete() { + return + } + } + } + + t.Fatalf("Test was not complete: timeout %d seconds.", iterations*cycle/1000) +} + +func validateMail(t *testing.T, index int, mail []*ReceivedMessage) bool { + var cnt int + for _, m := range mail { + if bytes.Equal(m.Payload, expectedMessage) { + cnt++ + } + } + + if cnt == 0 { + // no messages received yet: nothing is wrong + return true + } + if cnt > 1 { + t.Fatalf("node %d received %d.", index, cnt) + return false + } + + if cnt > 0 { + result.mutex.Lock() + defer result.mutex.Unlock() + result.counter[index] += cnt + if result.counter[index] > 1 { + t.Fatalf("node %d accumulated %d.", index, result.counter[index]) + } + } + return true +} + +func isTestComplete() bool { + result.mutex.RLock() + defer result.mutex.RUnlock() + + for i := 0; i < NumNodes; i++ { + if result.counter[i] < 1 { + return false + } + } + + for i := 0; i < NumNodes; i++ { + envelopes := nodes[i].shh.Envelopes() + if len(envelopes) < 2 { + return false + } + } + + return true +} + +func sendMsg(t *testing.T, expected bool, id int) { + if t.Failed() { + return + } + + opt := MessageParams{KeySym: sharedKey, Topic: sharedTopic, Payload: expectedMessage, PoW: 0.00000001, WorkTime: 1} + if !expected { + opt.KeySym[0]++ + opt.Topic[0]++ + opt.Payload = opt.Payload[1:] + } + + msg, err := NewSentMessage(&opt) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + envelope, err := msg.Wrap(&opt) + if err != nil { + t.Fatalf("failed to seal message: %s", err) + } + + err = nodes[id].shh.Send(envelope) + if err != nil { + t.Fatalf("failed to send message: %s", err) + } +} + +func TestPeerBasic(t *testing.T) { + InitSingleTest() + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d.", seed) + } + + params.PoW = 0.001 + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d.", seed) + } + + p := newPeer(nil, nil, nil) + p.mark(env) + if !p.marked(env) { + t.Fatalf("failed mark with seed %d.", seed) + } +} diff --git a/whisper/whisperv6/topic.go b/whisper/whisperv6/topic.go new file mode 100644 index 000000000..5ef7f6939 --- /dev/null +++ b/whisper/whisperv6/topic.go @@ -0,0 +1,55 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Contains the Whisper protocol Topic element. + +package whisperv6 + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" +) + +// Topic represents a cryptographically secure, probabilistic partial +// classifications of a message, determined as the first (left) 4 bytes of the +// SHA3 hash of some arbitrary data given by the original author of the message. +type TopicType [TopicLength]byte + +func BytesToTopic(b []byte) (t TopicType) { + sz := TopicLength + if x := len(b); x < TopicLength { + sz = x + } + for i := 0; i < sz; i++ { + t[i] = b[i] + } + return t +} + +// String converts a topic byte array to a string representation. +func (topic *TopicType) String() string { + return string(common.ToHex(topic[:])) +} + +// MarshalText returns the hex representation of t. +func (t TopicType) MarshalText() ([]byte, error) { + return hexutil.Bytes(t[:]).MarshalText() +} + +// UnmarshalText parses a hex representation to a topic. +func (t *TopicType) UnmarshalText(input []byte) error { + return hexutil.UnmarshalFixedText("Topic", input, t[:]) +} diff --git a/whisper/whisperv6/topic_test.go b/whisper/whisperv6/topic_test.go new file mode 100644 index 000000000..454afe0de --- /dev/null +++ b/whisper/whisperv6/topic_test.go @@ -0,0 +1,134 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "encoding/json" + "testing" +) + +var topicStringTests = []struct { + topic TopicType + str string +}{ + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, str: "0x00000000"}, + {topic: TopicType{0x00, 0x7f, 0x80, 0xff}, str: "0x007f80ff"}, + {topic: TopicType{0xff, 0x80, 0x7f, 0x00}, str: "0xff807f00"}, + {topic: TopicType{0xf2, 0x6e, 0x77, 0x79}, str: "0xf26e7779"}, +} + +func TestTopicString(t *testing.T) { + for i, tst := range topicStringTests { + s := tst.topic.String() + if s != tst.str { + t.Fatalf("failed test %d: have %s, want %s.", i, s, tst.str) + } + } +} + +var bytesToTopicTests = []struct { + data []byte + topic TopicType +}{ + {topic: TopicType{0x8f, 0x9a, 0x2b, 0x7d}, data: []byte{0x8f, 0x9a, 0x2b, 0x7d}}, + {topic: TopicType{0x00, 0x7f, 0x80, 0xff}, data: []byte{0x00, 0x7f, 0x80, 0xff}}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte{0x00, 0x00, 0x00, 0x00}}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte{0x00, 0x00, 0x00}}, + {topic: TopicType{0x01, 0x00, 0x00, 0x00}, data: []byte{0x01}}, + {topic: TopicType{0x00, 0xfe, 0x00, 0x00}, data: []byte{0x00, 0xfe}}, + {topic: TopicType{0xea, 0x1d, 0x43, 0x00}, data: []byte{0xea, 0x1d, 0x43}}, + {topic: TopicType{0x6f, 0x3c, 0xb0, 0xdd}, data: []byte{0x6f, 0x3c, 0xb0, 0xdd, 0x0f, 0x00, 0x90}}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte{}}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: nil}, +} + +var unmarshalTestsGood = []struct { + topic TopicType + data []byte +}{ + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"0x00000000"`)}, + {topic: TopicType{0x00, 0x7f, 0x80, 0xff}, data: []byte(`"0x007f80ff"`)}, + {topic: TopicType{0xff, 0x80, 0x7f, 0x00}, data: []byte(`"0xff807f00"`)}, + {topic: TopicType{0xf2, 0x6e, 0x77, 0x79}, data: []byte(`"0xf26e7779"`)}, +} + +var unmarshalTestsBad = []struct { + topic TopicType + data []byte +}{ + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"0x000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"0x0000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"0x000000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"0x0000000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"0000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"000000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"0000000000"`)}, + {topic: TopicType{0x00, 0x00, 0x00, 0x00}, data: []byte(`"abcdefg0"`)}, +} + +var unmarshalTestsUgly = []struct { + topic TopicType + data []byte +}{ + {topic: TopicType{0x01, 0x00, 0x00, 0x00}, data: []byte(`"0x00000001"`)}, +} + +func TestBytesToTopic(t *testing.T) { + for i, tst := range bytesToTopicTests { + top := BytesToTopic(tst.data) + if top != tst.topic { + t.Fatalf("failed test %d: have %v, want %v.", i, t, tst.topic) + } + } +} + +func TestUnmarshalTestsGood(t *testing.T) { + for i, tst := range unmarshalTestsGood { + var top TopicType + err := json.Unmarshal(tst.data, &top) + if err != nil { + t.Errorf("failed test %d. input: %v. err: %v", i, tst.data, err) + } else if top != tst.topic { + t.Errorf("failed test %d: have %v, want %v.", i, t, tst.topic) + } + } +} + +func TestUnmarshalTestsBad(t *testing.T) { + // in this test UnmarshalJSON() is supposed to fail + for i, tst := range unmarshalTestsBad { + var top TopicType + err := json.Unmarshal(tst.data, &top) + if err == nil { + t.Fatalf("failed test %d. input: %v.", i, tst.data) + } + } +} + +func TestUnmarshalTestsUgly(t *testing.T) { + // in this test UnmarshalJSON() is NOT supposed to fail, but result should be wrong + for i, tst := range unmarshalTestsUgly { + var top TopicType + err := json.Unmarshal(tst.data, &top) + if err != nil { + t.Errorf("failed test %d. input: %v.", i, tst.data) + } else if top == tst.topic { + t.Errorf("failed test %d: have %v, want %v.", i, top, tst.topic) + } + } +} diff --git a/whisper/whisperv6/whisper.go b/whisper/whisperv6/whisper.go new file mode 100644 index 000000000..553ac3f00 --- /dev/null +++ b/whisper/whisperv6/whisper.go @@ -0,0 +1,858 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "bytes" + "crypto/ecdsa" + crand "crypto/rand" + "crypto/sha256" + "fmt" + "runtime" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rpc" + "github.com/syndtr/goleveldb/leveldb/errors" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/sync/syncmap" + set "gopkg.in/fatih/set.v0" +) + +type Statistics struct { + messagesCleared int + memoryCleared int + memoryUsed int + cycles int + totalMessagesCleared int +} + +const ( + minPowIdx = iota // Minimal PoW required by the whisper node + maxMsgSizeIdx = iota // Maximal message length allowed by the whisper node + overflowIdx = iota // Indicator of message queue overflow +) + +// Whisper represents a dark communication interface through the Ethereum +// network, using its very own P2P communication layer. +type Whisper struct { + protocol p2p.Protocol // Protocol description and parameters + filters *Filters // Message filters installed with Subscribe function + + privateKeys map[string]*ecdsa.PrivateKey // Private key storage + symKeys map[string][]byte // Symmetric key storage + keyMu sync.RWMutex // Mutex associated with key storages + + poolMu sync.RWMutex // Mutex to sync the message and expiration pools + envelopes map[common.Hash]*Envelope // Pool of envelopes currently tracked by this node + expirations map[uint32]*set.SetNonTS // Message expiration pool + + peerMu sync.RWMutex // Mutex to sync the active peer set + peers map[*Peer]struct{} // Set of currently active peers + + messageQueue chan *Envelope // Message queue for normal whisper messages + p2pMsgQueue chan *Envelope // Message queue for peer-to-peer messages (not to be forwarded any further) + quit chan struct{} // Channel used for graceful exit + + settings syncmap.Map // holds configuration settings that can be dynamically changed + + statsMu sync.Mutex // guard stats + stats Statistics // Statistics of whisper node + + mailServer MailServer // MailServer interface +} + +// New creates a Whisper client ready to communicate through the Ethereum P2P network. +func New(cfg *Config) *Whisper { + if cfg == nil { + cfg = &DefaultConfig + } + + whisper := &Whisper{ + privateKeys: make(map[string]*ecdsa.PrivateKey), + symKeys: make(map[string][]byte), + envelopes: make(map[common.Hash]*Envelope), + expirations: make(map[uint32]*set.SetNonTS), + peers: make(map[*Peer]struct{}), + messageQueue: make(chan *Envelope, messageQueueLimit), + p2pMsgQueue: make(chan *Envelope, messageQueueLimit), + quit: make(chan struct{}), + } + + whisper.filters = NewFilters(whisper) + + whisper.settings.Store(minPowIdx, cfg.MinimumAcceptedPOW) + whisper.settings.Store(maxMsgSizeIdx, cfg.MaxMessageSize) + whisper.settings.Store(overflowIdx, false) + + // p2p whisper sub protocol handler + whisper.protocol = p2p.Protocol{ + Name: ProtocolName, + Version: uint(ProtocolVersion), + Length: NumberOfMessageCodes, + Run: whisper.HandlePeer, + NodeInfo: func() interface{} { + return map[string]interface{}{ + "version": ProtocolVersionStr, + "maxMessageSize": whisper.MaxMessageSize(), + "minimumPoW": whisper.MinPow(), + } + }, + } + + return whisper +} + +func (w *Whisper) MinPow() float64 { + val, _ := w.settings.Load(minPowIdx) + return val.(float64) +} + +// MaxMessageSize returns the maximum accepted message size. +func (w *Whisper) MaxMessageSize() uint32 { + val, _ := w.settings.Load(maxMsgSizeIdx) + return val.(uint32) +} + +// Overflow returns an indication if the message queue is full. +func (w *Whisper) Overflow() bool { + val, _ := w.settings.Load(overflowIdx) + return val.(bool) +} + +// APIs returns the RPC descriptors the Whisper implementation offers +func (w *Whisper) APIs() []rpc.API { + return []rpc.API{ + { + Namespace: ProtocolName, + Version: ProtocolVersionStr, + Service: NewPublicWhisperAPI(w), + Public: true, + }, + } +} + +// RegisterServer registers MailServer interface. +// MailServer will process all the incoming messages with p2pRequestCode. +func (w *Whisper) RegisterServer(server MailServer) { + w.mailServer = server +} + +// Protocols returns the whisper sub-protocols ran by this particular client. +func (w *Whisper) Protocols() []p2p.Protocol { + return []p2p.Protocol{w.protocol} +} + +// Version returns the whisper sub-protocols version number. +func (w *Whisper) Version() uint { + return w.protocol.Version +} + +// SetMaxMessageSize sets the maximal message size allowed by this node +func (w *Whisper) SetMaxMessageSize(size uint32) error { + if size > MaxMessageSize { + return fmt.Errorf("message size too large [%d>%d]", size, MaxMessageSize) + } + w.settings.Store(maxMsgSizeIdx, uint32(size)) + return nil +} + +// SetMinimumPoW sets the minimal PoW required by this node +func (w *Whisper) SetMinimumPoW(val float64) error { + if val <= 0.0 { + return fmt.Errorf("invalid PoW: %f", val) + } + w.settings.Store(minPowIdx, val) + return nil +} + +// getPeer retrieves peer by ID +func (w *Whisper) getPeer(peerID []byte) (*Peer, error) { + w.peerMu.Lock() + defer w.peerMu.Unlock() + for p := range w.peers { + id := p.peer.ID() + if bytes.Equal(peerID, id[:]) { + return p, nil + } + } + return nil, fmt.Errorf("Could not find peer with ID: %x", peerID) +} + +// AllowP2PMessagesFromPeer marks specific peer trusted, +// which will allow it to send historic (expired) messages. +func (w *Whisper) AllowP2PMessagesFromPeer(peerID []byte) error { + p, err := w.getPeer(peerID) + if err != nil { + return err + } + p.trusted = true + return nil +} + +// RequestHistoricMessages sends a message with p2pRequestCode to a specific peer, +// which is known to implement MailServer interface, and is supposed to process this +// request and respond with a number of peer-to-peer messages (possibly expired), +// which are not supposed to be forwarded any further. +// The whisper protocol is agnostic of the format and contents of envelope. +func (w *Whisper) RequestHistoricMessages(peerID []byte, envelope *Envelope) error { + p, err := w.getPeer(peerID) + if err != nil { + return err + } + p.trusted = true + return p2p.Send(p.ws, p2pRequestCode, envelope) +} + +// SendP2PMessage sends a peer-to-peer message to a specific peer. +func (w *Whisper) SendP2PMessage(peerID []byte, envelope *Envelope) error { + p, err := w.getPeer(peerID) + if err != nil { + return err + } + return w.SendP2PDirect(p, envelope) +} + +// SendP2PDirect sends a peer-to-peer message to a specific peer. +func (w *Whisper) SendP2PDirect(peer *Peer, envelope *Envelope) error { + return p2p.Send(peer.ws, p2pCode, envelope) +} + +// NewKeyPair generates a new cryptographic identity for the client, and injects +// it into the known identities for message decryption. Returns ID of the new key pair. +func (w *Whisper) NewKeyPair() (string, error) { + key, err := crypto.GenerateKey() + if err != nil || !validatePrivateKey(key) { + key, err = crypto.GenerateKey() // retry once + } + if err != nil { + return "", err + } + if !validatePrivateKey(key) { + return "", fmt.Errorf("failed to generate valid key") + } + + id, err := GenerateRandomID() + if err != nil { + return "", fmt.Errorf("failed to generate ID: %s", err) + } + + w.keyMu.Lock() + defer w.keyMu.Unlock() + + if w.privateKeys[id] != nil { + return "", fmt.Errorf("failed to generate unique ID") + } + w.privateKeys[id] = key + return id, nil +} + +// DeleteKeyPair deletes the specified key if it exists. +func (w *Whisper) DeleteKeyPair(key string) bool { + w.keyMu.Lock() + defer w.keyMu.Unlock() + + if w.privateKeys[key] != nil { + delete(w.privateKeys, key) + return true + } + return false +} + +// AddKeyPair imports a asymmetric private key and returns it identifier. +func (w *Whisper) AddKeyPair(key *ecdsa.PrivateKey) (string, error) { + id, err := GenerateRandomID() + if err != nil { + return "", fmt.Errorf("failed to generate ID: %s", err) + } + + w.keyMu.Lock() + w.privateKeys[id] = key + w.keyMu.Unlock() + + return id, nil +} + +// HasKeyPair checks if the the whisper node is configured with the private key +// of the specified public pair. +func (w *Whisper) HasKeyPair(id string) bool { + w.keyMu.RLock() + defer w.keyMu.RUnlock() + return w.privateKeys[id] != nil +} + +// GetPrivateKey retrieves the private key of the specified identity. +func (w *Whisper) GetPrivateKey(id string) (*ecdsa.PrivateKey, error) { + w.keyMu.RLock() + defer w.keyMu.RUnlock() + key := w.privateKeys[id] + if key == nil { + return nil, fmt.Errorf("invalid id") + } + return key, nil +} + +// GenerateSymKey generates a random symmetric key and stores it under id, +// which is then returned. Will be used in the future for session key exchange. +func (w *Whisper) GenerateSymKey() (string, error) { + key := make([]byte, aesKeyLength) + _, err := crand.Read(key) + if err != nil { + return "", err + } else if !validateSymmetricKey(key) { + return "", fmt.Errorf("error in GenerateSymKey: crypto/rand failed to generate random data") + } + + id, err := GenerateRandomID() + if err != nil { + return "", fmt.Errorf("failed to generate ID: %s", err) + } + + w.keyMu.Lock() + defer w.keyMu.Unlock() + + if w.symKeys[id] != nil { + return "", fmt.Errorf("failed to generate unique ID") + } + w.symKeys[id] = key + return id, nil +} + +// AddSymKeyDirect stores the key, and returns its id. +func (w *Whisper) AddSymKeyDirect(key []byte) (string, error) { + if len(key) != aesKeyLength { + return "", fmt.Errorf("wrong key size: %d", len(key)) + } + + id, err := GenerateRandomID() + if err != nil { + return "", fmt.Errorf("failed to generate ID: %s", err) + } + + w.keyMu.Lock() + defer w.keyMu.Unlock() + + if w.symKeys[id] != nil { + return "", fmt.Errorf("failed to generate unique ID") + } + w.symKeys[id] = key + return id, nil +} + +// AddSymKeyFromPassword generates the key from password, stores it, and returns its id. +func (w *Whisper) AddSymKeyFromPassword(password string) (string, error) { + id, err := GenerateRandomID() + if err != nil { + return "", fmt.Errorf("failed to generate ID: %s", err) + } + if w.HasSymKey(id) { + return "", fmt.Errorf("failed to generate unique ID") + } + + derived, err := deriveKeyMaterial([]byte(password), EnvelopeVersion) + if err != nil { + return "", err + } + + w.keyMu.Lock() + defer w.keyMu.Unlock() + + // double check is necessary, because deriveKeyMaterial() is very slow + if w.symKeys[id] != nil { + return "", fmt.Errorf("critical error: failed to generate unique ID") + } + w.symKeys[id] = derived + return id, nil +} + +// HasSymKey returns true if there is a key associated with the given id. +// Otherwise returns false. +func (w *Whisper) HasSymKey(id string) bool { + w.keyMu.RLock() + defer w.keyMu.RUnlock() + return w.symKeys[id] != nil +} + +// DeleteSymKey deletes the key associated with the name string if it exists. +func (w *Whisper) DeleteSymKey(id string) bool { + w.keyMu.Lock() + defer w.keyMu.Unlock() + if w.symKeys[id] != nil { + delete(w.symKeys, id) + return true + } + return false +} + +// GetSymKey returns the symmetric key associated with the given id. +func (w *Whisper) GetSymKey(id string) ([]byte, error) { + w.keyMu.RLock() + defer w.keyMu.RUnlock() + if w.symKeys[id] != nil { + return w.symKeys[id], nil + } + return nil, fmt.Errorf("non-existent key ID") +} + +// Subscribe installs a new message handler used for filtering, decrypting +// and subsequent storing of incoming messages. +func (w *Whisper) Subscribe(f *Filter) (string, error) { + return w.filters.Install(f) +} + +// GetFilter returns the filter by id. +func (w *Whisper) GetFilter(id string) *Filter { + return w.filters.Get(id) +} + +// Unsubscribe removes an installed message handler. +func (w *Whisper) Unsubscribe(id string) error { + ok := w.filters.Uninstall(id) + if !ok { + return fmt.Errorf("Unsubscribe: Invalid ID") + } + return nil +} + +// Send injects a message into the whisper send queue, to be distributed in the +// network in the coming cycles. +func (w *Whisper) Send(envelope *Envelope) error { + ok, err := w.add(envelope) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("failed to add envelope") + } + return err +} + +// Start implements node.Service, starting the background data propagation thread +// of the Whisper protocol. +func (w *Whisper) Start(*p2p.Server) error { + log.Info("started whisper v." + ProtocolVersionStr) + go w.update() + + numCPU := runtime.NumCPU() + for i := 0; i < numCPU; i++ { + go w.processQueue() + } + + return nil +} + +// Stop implements node.Service, stopping the background data propagation thread +// of the Whisper protocol. +func (w *Whisper) Stop() error { + close(w.quit) + log.Info("whisper stopped") + return nil +} + +// HandlePeer is called by the underlying P2P layer when the whisper sub-protocol +// connection is negotiated. +func (wh *Whisper) HandlePeer(peer *p2p.Peer, rw p2p.MsgReadWriter) error { + // Create the new peer and start tracking it + whisperPeer := newPeer(wh, peer, rw) + + wh.peerMu.Lock() + wh.peers[whisperPeer] = struct{}{} + wh.peerMu.Unlock() + + defer func() { + wh.peerMu.Lock() + delete(wh.peers, whisperPeer) + wh.peerMu.Unlock() + }() + + // Run the peer handshake and state updates + if err := whisperPeer.handshake(); err != nil { + return err + } + whisperPeer.start() + defer whisperPeer.stop() + + return wh.runMessageLoop(whisperPeer, rw) +} + +// runMessageLoop reads and processes inbound messages directly to merge into client-global state. +func (wh *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error { + for { + // fetch the next packet + packet, err := rw.ReadMsg() + if err != nil { + log.Warn("message loop", "peer", p.peer.ID(), "err", err) + return err + } + if packet.Size > wh.MaxMessageSize() { + log.Warn("oversized message received", "peer", p.peer.ID()) + return errors.New("oversized message received") + } + + switch packet.Code { + case statusCode: + // this should not happen, but no need to panic; just ignore this message. + log.Warn("unxepected status message received", "peer", p.peer.ID()) + case messagesCode: + // decode the contained envelopes + var envelope Envelope + if err := packet.Decode(&envelope); err != nil { + log.Warn("failed to decode envelope, peer will be disconnected", "peer", p.peer.ID(), "err", err) + return errors.New("invalid envelope") + } + cached, err := wh.add(&envelope) + if err != nil { + log.Warn("bad envelope received, peer will be disconnected", "peer", p.peer.ID(), "err", err) + return errors.New("invalid envelope") + } + if cached { + p.mark(&envelope) + } + case p2pCode: + // peer-to-peer message, sent directly to peer bypassing PoW checks, etc. + // this message is not supposed to be forwarded to other peers, and + // therefore might not satisfy the PoW, expiry and other requirements. + // these messages are only accepted from the trusted peer. + if p.trusted { + var envelope Envelope + if err := packet.Decode(&envelope); err != nil { + log.Warn("failed to decode direct message, peer will be disconnected", "peer", p.peer.ID(), "err", err) + return errors.New("invalid direct message") + } + wh.postEvent(&envelope, true) + } + case p2pRequestCode: + // Must be processed if mail server is implemented. Otherwise ignore. + if wh.mailServer != nil { + var request Envelope + if err := packet.Decode(&request); err != nil { + log.Warn("failed to decode p2p request message, peer will be disconnected", "peer", p.peer.ID(), "err", err) + return errors.New("invalid p2p request") + } + wh.mailServer.DeliverMail(p, &request) + } + default: + // New message types might be implemented in the future versions of Whisper. + // For forward compatibility, just ignore. + } + + packet.Discard() + } +} + +// add inserts a new envelope into the message pool to be distributed within the +// whisper network. It also inserts the envelope into the expiration pool at the +// appropriate time-stamp. In case of error, connection should be dropped. +func (wh *Whisper) add(envelope *Envelope) (bool, error) { + now := uint32(time.Now().Unix()) + sent := envelope.Expiry - envelope.TTL + + if sent > now { + if sent-SynchAllowance > now { + return false, fmt.Errorf("envelope created in the future [%x]", envelope.Hash()) + } else { + // recalculate PoW, adjusted for the time difference, plus one second for latency + envelope.calculatePoW(sent - now + 1) + } + } + + if envelope.Expiry < now { + if envelope.Expiry+SynchAllowance*2 < now { + return false, fmt.Errorf("very old message") + } else { + log.Debug("expired envelope dropped", "hash", envelope.Hash().Hex()) + return false, nil // drop envelope without error + } + } + + if uint32(envelope.size()) > wh.MaxMessageSize() { + return false, fmt.Errorf("huge messages are not allowed [%x]", envelope.Hash()) + } + + if len(envelope.Version) > 4 { + return false, fmt.Errorf("oversized version [%x]", envelope.Hash()) + } + + aesNonceSize := len(envelope.AESNonce) + if aesNonceSize != 0 && aesNonceSize != AESNonceLength { + // the standard AES GCM nonce size is 12 bytes, + // but constant gcmStandardNonceSize cannot be accessed (not exported) + return false, fmt.Errorf("wrong size of AESNonce: %d bytes [env: %x]", aesNonceSize, envelope.Hash()) + } + + if envelope.PoW() < wh.MinPow() { + log.Debug("envelope with low PoW dropped", "PoW", envelope.PoW(), "hash", envelope.Hash().Hex()) + return false, nil // drop envelope without error + } + + hash := envelope.Hash() + + wh.poolMu.Lock() + _, alreadyCached := wh.envelopes[hash] + if !alreadyCached { + wh.envelopes[hash] = envelope + if wh.expirations[envelope.Expiry] == nil { + wh.expirations[envelope.Expiry] = set.NewNonTS() + } + if !wh.expirations[envelope.Expiry].Has(hash) { + wh.expirations[envelope.Expiry].Add(hash) + } + } + wh.poolMu.Unlock() + + if alreadyCached { + log.Trace("whisper envelope already cached", "hash", envelope.Hash().Hex()) + } else { + log.Trace("cached whisper envelope", "hash", envelope.Hash().Hex()) + wh.statsMu.Lock() + wh.stats.memoryUsed += envelope.size() + wh.statsMu.Unlock() + wh.postEvent(envelope, false) // notify the local node about the new message + if wh.mailServer != nil { + wh.mailServer.Archive(envelope) + } + } + return true, nil +} + +// postEvent queues the message for further processing. +func (w *Whisper) postEvent(envelope *Envelope, isP2P bool) { + // if the version of incoming message is higher than + // currently supported version, we can not decrypt it, + // and therefore just ignore this message + if envelope.Ver() <= EnvelopeVersion { + if isP2P { + w.p2pMsgQueue <- envelope + } else { + w.checkOverflow() + w.messageQueue <- envelope + } + } +} + +// checkOverflow checks if message queue overflow occurs and reports it if necessary. +func (w *Whisper) checkOverflow() { + queueSize := len(w.messageQueue) + + if queueSize == messageQueueLimit { + if !w.Overflow() { + w.settings.Store(overflowIdx, true) + log.Warn("message queue overflow") + } + } else if queueSize <= messageQueueLimit/2 { + if w.Overflow() { + w.settings.Store(overflowIdx, false) + log.Warn("message queue overflow fixed (back to normal)") + } + } +} + +// processQueue delivers the messages to the watchers during the lifetime of the whisper node. +func (w *Whisper) processQueue() { + var e *Envelope + for { + select { + case <-w.quit: + return + + case e = <-w.messageQueue: + w.filters.NotifyWatchers(e, false) + + case e = <-w.p2pMsgQueue: + w.filters.NotifyWatchers(e, true) + } + } +} + +// update loops until the lifetime of the whisper node, updating its internal +// state by expiring stale messages from the pool. +func (w *Whisper) update() { + // Start a ticker to check for expirations + expire := time.NewTicker(expirationCycle) + + // Repeat updates until termination is requested + for { + select { + case <-expire.C: + w.expire() + + case <-w.quit: + return + } + } +} + +// expire iterates over all the expiration timestamps, removing all stale +// messages from the pools. +func (w *Whisper) expire() { + w.poolMu.Lock() + defer w.poolMu.Unlock() + + w.statsMu.Lock() + defer w.statsMu.Unlock() + w.stats.reset() + now := uint32(time.Now().Unix()) + for expiry, hashSet := range w.expirations { + if expiry < now { + // Dump all expired messages and remove timestamp + hashSet.Each(func(v interface{}) bool { + sz := w.envelopes[v.(common.Hash)].size() + delete(w.envelopes, v.(common.Hash)) + w.stats.messagesCleared++ + w.stats.memoryCleared += sz + w.stats.memoryUsed -= sz + return true + }) + w.expirations[expiry].Clear() + delete(w.expirations, expiry) + } + } +} + +// Stats returns the whisper node statistics. +func (w *Whisper) Stats() Statistics { + w.statsMu.Lock() + defer w.statsMu.Unlock() + + return w.stats +} + +// Envelopes retrieves all the messages currently pooled by the node. +func (w *Whisper) Envelopes() []*Envelope { + w.poolMu.RLock() + defer w.poolMu.RUnlock() + + all := make([]*Envelope, 0, len(w.envelopes)) + for _, envelope := range w.envelopes { + all = append(all, envelope) + } + return all +} + +// Messages iterates through all currently floating envelopes +// and retrieves all the messages, that this filter could decrypt. +func (w *Whisper) Messages(id string) []*ReceivedMessage { + result := make([]*ReceivedMessage, 0) + w.poolMu.RLock() + defer w.poolMu.RUnlock() + + if filter := w.filters.Get(id); filter != nil { + for _, env := range w.envelopes { + msg := filter.processEnvelope(env) + if msg != nil { + result = append(result, msg) + } + } + } + return result +} + +// isEnvelopeCached checks if envelope with specific hash has already been received and cached. +func (w *Whisper) isEnvelopeCached(hash common.Hash) bool { + w.poolMu.Lock() + defer w.poolMu.Unlock() + + _, exist := w.envelopes[hash] + return exist +} + +// reset resets the node's statistics after each expiry cycle. +func (s *Statistics) reset() { + s.cycles++ + s.totalMessagesCleared += s.messagesCleared + + s.memoryCleared = 0 + s.messagesCleared = 0 +} + +// ValidatePublicKey checks the format of the given public key. +func ValidatePublicKey(k *ecdsa.PublicKey) bool { + return k != nil && k.X != nil && k.Y != nil && k.X.Sign() != 0 && k.Y.Sign() != 0 +} + +// validatePrivateKey checks the format of the given private key. +func validatePrivateKey(k *ecdsa.PrivateKey) bool { + if k == nil || k.D == nil || k.D.Sign() == 0 { + return false + } + return ValidatePublicKey(&k.PublicKey) +} + +// validateSymmetricKey returns false if the key contains all zeros +func validateSymmetricKey(k []byte) bool { + return len(k) > 0 && !containsOnlyZeros(k) +} + +// containsOnlyZeros checks if the data contain only zeros. +func containsOnlyZeros(data []byte) bool { + for _, b := range data { + if b != 0 { + return false + } + } + return true +} + +// bytesToUintLittleEndian converts the slice to 64-bit unsigned integer. +func bytesToUintLittleEndian(b []byte) (res uint64) { + mul := uint64(1) + for i := 0; i < len(b); i++ { + res += uint64(b[i]) * mul + mul *= 256 + } + return res +} + +// BytesToUintBigEndian converts the slice to 64-bit unsigned integer. +func BytesToUintBigEndian(b []byte) (res uint64) { + for i := 0; i < len(b); i++ { + res *= 256 + res += uint64(b[i]) + } + return res +} + +// deriveKeyMaterial derives symmetric key material from the key or password. +// pbkdf2 is used for security, in case people use password instead of randomly generated keys. +func deriveKeyMaterial(key []byte, version uint64) (derivedKey []byte, err error) { + if version == 0 { + // kdf should run no less than 0.1 seconds on average compute, + // because it's a once in a session experience + derivedKey := pbkdf2.Key(key, nil, 65356, aesKeyLength, sha256.New) + return derivedKey, nil + } else { + return nil, unknownVersionError(version) + } +} + +// GenerateRandomID generates a random string, which is then returned to be used as a key id +func GenerateRandomID() (id string, err error) { + buf := make([]byte, keyIdSize) + _, err = crand.Read(buf) + if err != nil { + return "", err + } + if !validateSymmetricKey(buf) { + return "", fmt.Errorf("error in generateRandomID: crypto/rand failed to generate random data") + } + id = common.Bytes2Hex(buf) + return id, err +} diff --git a/whisper/whisperv6/whisper_test.go b/whisper/whisperv6/whisper_test.go new file mode 100644 index 000000000..c7cea4014 --- /dev/null +++ b/whisper/whisperv6/whisper_test.go @@ -0,0 +1,851 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package whisperv6 + +import ( + "bytes" + "crypto/ecdsa" + mrand "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" +) + +func TestWhisperBasic(t *testing.T) { + w := New(&DefaultConfig) + p := w.Protocols() + shh := p[0] + if shh.Name != ProtocolName { + t.Fatalf("failed Protocol Name: %v.", shh.Name) + } + if uint64(shh.Version) != ProtocolVersion { + t.Fatalf("failed Protocol Version: %v.", shh.Version) + } + if shh.Length != NumberOfMessageCodes { + t.Fatalf("failed Protocol Length: %v.", shh.Length) + } + if shh.Run == nil { + t.Fatalf("failed shh.Run.") + } + if uint64(w.Version()) != ProtocolVersion { + t.Fatalf("failed whisper Version: %v.", shh.Version) + } + if w.GetFilter("non-existent") != nil { + t.Fatalf("failed GetFilter.") + } + + peerID := make([]byte, 64) + mrand.Read(peerID) + peer, _ := w.getPeer(peerID) + if peer != nil { + t.Fatal("found peer for random key.") + } + if err := w.AllowP2PMessagesFromPeer(peerID); err == nil { + t.Fatalf("failed MarkPeerTrusted.") + } + exist := w.HasSymKey("non-existing") + if exist { + t.Fatalf("failed HasSymKey.") + } + key, err := w.GetSymKey("non-existing") + if err == nil { + t.Fatalf("failed GetSymKey(non-existing): false positive.") + } + if key != nil { + t.Fatalf("failed GetSymKey: false positive.") + } + mail := w.Envelopes() + if len(mail) != 0 { + t.Fatalf("failed w.Envelopes().") + } + m := w.Messages("non-existent") + if len(m) != 0 { + t.Fatalf("failed w.Messages.") + } + + var derived []byte + ver := uint64(0xDEADBEEF) + if _, err := deriveKeyMaterial(peerID, ver); err != unknownVersionError(ver) { + t.Fatalf("failed deriveKeyMaterial with param = %v: %s.", peerID, err) + } + derived, err = deriveKeyMaterial(peerID, 0) + if err != nil { + t.Fatalf("failed second deriveKeyMaterial with param = %v: %s.", peerID, err) + } + if !validateSymmetricKey(derived) { + t.Fatalf("failed validateSymmetricKey with param = %v.", derived) + } + if containsOnlyZeros(derived) { + t.Fatalf("failed containsOnlyZeros with param = %v.", derived) + } + + buf := []byte{0xFF, 0xE5, 0x80, 0x2, 0} + le := bytesToUintLittleEndian(buf) + be := BytesToUintBigEndian(buf) + if le != uint64(0x280e5ff) { + t.Fatalf("failed bytesToIntLittleEndian: %d.", le) + } + if be != uint64(0xffe5800200) { + t.Fatalf("failed BytesToIntBigEndian: %d.", be) + } + + id, err := w.NewKeyPair() + if err != nil { + t.Fatalf("failed to generate new key pair: %s.", err) + } + pk, err := w.GetPrivateKey(id) + if err != nil { + t.Fatalf("failed to retrieve new key pair: %s.", err) + } + if !validatePrivateKey(pk) { + t.Fatalf("failed validatePrivateKey: %v.", pk) + } + if !ValidatePublicKey(&pk.PublicKey) { + t.Fatalf("failed ValidatePublicKey: %v.", pk) + } +} + +func TestWhisperAsymmetricKeyImport(t *testing.T) { + var ( + w = New(&DefaultConfig) + privateKeys []*ecdsa.PrivateKey + ) + + for i := 0; i < 50; i++ { + id, err := w.NewKeyPair() + if err != nil { + t.Fatalf("could not generate key: %v", err) + } + + pk, err := w.GetPrivateKey(id) + if err != nil { + t.Fatalf("could not export private key: %v", err) + } + + privateKeys = append(privateKeys, pk) + + if !w.DeleteKeyPair(id) { + t.Fatalf("could not delete private key") + } + } + + for _, pk := range privateKeys { + if _, err := w.AddKeyPair(pk); err != nil { + t.Fatalf("could not import private key: %v", err) + } + } +} + +func TestWhisperIdentityManagement(t *testing.T) { + w := New(&DefaultConfig) + id1, err := w.NewKeyPair() + if err != nil { + t.Fatalf("failed to generate new key pair: %s.", err) + } + id2, err := w.NewKeyPair() + if err != nil { + t.Fatalf("failed to generate new key pair: %s.", err) + } + pk1, err := w.GetPrivateKey(id1) + if err != nil { + t.Fatalf("failed to retrieve the key pair: %s.", err) + } + pk2, err := w.GetPrivateKey(id2) + if err != nil { + t.Fatalf("failed to retrieve the key pair: %s.", err) + } + + if !w.HasKeyPair(id1) { + t.Fatalf("failed HasIdentity(pk1).") + } + if !w.HasKeyPair(id2) { + t.Fatalf("failed HasIdentity(pk2).") + } + if pk1 == nil { + t.Fatalf("failed GetIdentity(pk1).") + } + if pk2 == nil { + t.Fatalf("failed GetIdentity(pk2).") + } + + if !validatePrivateKey(pk1) { + t.Fatalf("pk1 is invalid.") + } + if !validatePrivateKey(pk2) { + t.Fatalf("pk2 is invalid.") + } + + // Delete one identity + done := w.DeleteKeyPair(id1) + if !done { + t.Fatalf("failed to delete id1.") + } + pk1, err = w.GetPrivateKey(id1) + if err == nil { + t.Fatalf("retrieve the key pair: false positive.") + } + pk2, err = w.GetPrivateKey(id2) + if err != nil { + t.Fatalf("failed to retrieve the key pair: %s.", err) + } + if w.HasKeyPair(id1) { + t.Fatalf("failed DeleteIdentity(pub1): still exist.") + } + if !w.HasKeyPair(id2) { + t.Fatalf("failed DeleteIdentity(pub1): pub2 does not exist.") + } + if pk1 != nil { + t.Fatalf("failed DeleteIdentity(pub1): first key still exist.") + } + if pk2 == nil { + t.Fatalf("failed DeleteIdentity(pub1): second key does not exist.") + } + + // Delete again non-existing identity + done = w.DeleteKeyPair(id1) + if done { + t.Fatalf("delete id1: false positive.") + } + pk1, err = w.GetPrivateKey(id1) + if err == nil { + t.Fatalf("retrieve the key pair: false positive.") + } + pk2, err = w.GetPrivateKey(id2) + if err != nil { + t.Fatalf("failed to retrieve the key pair: %s.", err) + } + if w.HasKeyPair(id1) { + t.Fatalf("failed delete non-existing identity: exist.") + } + if !w.HasKeyPair(id2) { + t.Fatalf("failed delete non-existing identity: pub2 does not exist.") + } + if pk1 != nil { + t.Fatalf("failed delete non-existing identity: first key exist.") + } + if pk2 == nil { + t.Fatalf("failed delete non-existing identity: second key does not exist.") + } + + // Delete second identity + done = w.DeleteKeyPair(id2) + if !done { + t.Fatalf("failed to delete id2.") + } + pk1, err = w.GetPrivateKey(id1) + if err == nil { + t.Fatalf("retrieve the key pair: false positive.") + } + pk2, err = w.GetPrivateKey(id2) + if err == nil { + t.Fatalf("retrieve the key pair: false positive.") + } + if w.HasKeyPair(id1) { + t.Fatalf("failed delete second identity: first identity exist.") + } + if w.HasKeyPair(id2) { + t.Fatalf("failed delete second identity: still exist.") + } + if pk1 != nil { + t.Fatalf("failed delete second identity: first key exist.") + } + if pk2 != nil { + t.Fatalf("failed delete second identity: second key exist.") + } +} + +func TestWhisperSymKeyManagement(t *testing.T) { + InitSingleTest() + + var err error + var k1, k2 []byte + w := New(&DefaultConfig) + id1 := string("arbitrary-string-1") + id2 := string("arbitrary-string-2") + + id1, err = w.GenerateSymKey() + if err != nil { + t.Fatalf("failed GenerateSymKey with seed %d: %s.", seed, err) + } + + k1, err = w.GetSymKey(id1) + if err != nil { + t.Fatalf("failed GetSymKey(id1).") + } + k2, err = w.GetSymKey(id2) + if err == nil { + t.Fatalf("failed GetSymKey(id2): false positive.") + } + if !w.HasSymKey(id1) { + t.Fatalf("failed HasSymKey(id1).") + } + if w.HasSymKey(id2) { + t.Fatalf("failed HasSymKey(id2): false positive.") + } + if k1 == nil { + t.Fatalf("first key does not exist.") + } + if k2 != nil { + t.Fatalf("second key still exist.") + } + + // add existing id, nothing should change + randomKey := make([]byte, aesKeyLength) + mrand.Read(randomKey) + id1, err = w.AddSymKeyDirect(randomKey) + if err != nil { + t.Fatalf("failed AddSymKey with seed %d: %s.", seed, err) + } + + k1, err = w.GetSymKey(id1) + if err != nil { + t.Fatalf("failed w.GetSymKey(id1).") + } + k2, err = w.GetSymKey(id2) + if err == nil { + t.Fatalf("failed w.GetSymKey(id2): false positive.") + } + if !w.HasSymKey(id1) { + t.Fatalf("failed w.HasSymKey(id1).") + } + if w.HasSymKey(id2) { + t.Fatalf("failed w.HasSymKey(id2): false positive.") + } + if k1 == nil { + t.Fatalf("first key does not exist.") + } + if !bytes.Equal(k1, randomKey) { + t.Fatalf("k1 != randomKey.") + } + if k2 != nil { + t.Fatalf("second key already exist.") + } + + id2, err = w.AddSymKeyDirect(randomKey) + if err != nil { + t.Fatalf("failed AddSymKey(id2) with seed %d: %s.", seed, err) + } + k1, err = w.GetSymKey(id1) + if err != nil { + t.Fatalf("failed w.GetSymKey(id1).") + } + k2, err = w.GetSymKey(id2) + if err != nil { + t.Fatalf("failed w.GetSymKey(id2).") + } + if !w.HasSymKey(id1) { + t.Fatalf("HasSymKey(id1) failed.") + } + if !w.HasSymKey(id2) { + t.Fatalf("HasSymKey(id2) failed.") + } + if k1 == nil { + t.Fatalf("k1 does not exist.") + } + if k2 == nil { + t.Fatalf("k2 does not exist.") + } + if !bytes.Equal(k1, k2) { + t.Fatalf("k1 != k2.") + } + if !bytes.Equal(k1, randomKey) { + t.Fatalf("k1 != randomKey.") + } + if len(k1) != aesKeyLength { + t.Fatalf("wrong length of k1.") + } + if len(k2) != aesKeyLength { + t.Fatalf("wrong length of k2.") + } + + w.DeleteSymKey(id1) + k1, err = w.GetSymKey(id1) + if err == nil { + t.Fatalf("failed w.GetSymKey(id1): false positive.") + } + if k1 != nil { + t.Fatalf("failed GetSymKey(id1): false positive.") + } + k2, err = w.GetSymKey(id2) + if err != nil { + t.Fatalf("failed w.GetSymKey(id2).") + } + if w.HasSymKey(id1) { + t.Fatalf("failed to delete first key: still exist.") + } + if !w.HasSymKey(id2) { + t.Fatalf("failed to delete first key: second key does not exist.") + } + if k1 != nil { + t.Fatalf("failed to delete first key.") + } + if k2 == nil { + t.Fatalf("failed to delete first key: second key is nil.") + } + + w.DeleteSymKey(id1) + w.DeleteSymKey(id2) + k1, err = w.GetSymKey(id1) + if err == nil { + t.Fatalf("failed w.GetSymKey(id1): false positive.") + } + k2, err = w.GetSymKey(id2) + if err == nil { + t.Fatalf("failed w.GetSymKey(id2): false positive.") + } + if k1 != nil || k2 != nil { + t.Fatalf("k1 or k2 is not nil") + } + if w.HasSymKey(id1) { + t.Fatalf("failed to delete second key: first key exist.") + } + if w.HasSymKey(id2) { + t.Fatalf("failed to delete second key: still exist.") + } + if k1 != nil { + t.Fatalf("failed to delete second key: first key is not nil.") + } + if k2 != nil { + t.Fatalf("failed to delete second key: second key is not nil.") + } + + randomKey = make([]byte, aesKeyLength+1) + mrand.Read(randomKey) + _, err = w.AddSymKeyDirect(randomKey) + if err == nil { + t.Fatalf("added the key with wrong size, seed %d.", seed) + } + + const password = "arbitrary data here" + id1, err = w.AddSymKeyFromPassword(password) + if err != nil { + t.Fatalf("failed AddSymKeyFromPassword(id1) with seed %d: %s.", seed, err) + } + id2, err = w.AddSymKeyFromPassword(password) + if err != nil { + t.Fatalf("failed AddSymKeyFromPassword(id2) with seed %d: %s.", seed, err) + } + k1, err = w.GetSymKey(id1) + if err != nil { + t.Fatalf("failed w.GetSymKey(id1).") + } + k2, err = w.GetSymKey(id2) + if err != nil { + t.Fatalf("failed w.GetSymKey(id2).") + } + if !w.HasSymKey(id1) { + t.Fatalf("HasSymKey(id1) failed.") + } + if !w.HasSymKey(id2) { + t.Fatalf("HasSymKey(id2) failed.") + } + if k1 == nil { + t.Fatalf("k1 does not exist.") + } + if k2 == nil { + t.Fatalf("k2 does not exist.") + } + if !bytes.Equal(k1, k2) { + t.Fatalf("k1 != k2.") + } + if len(k1) != aesKeyLength { + t.Fatalf("wrong length of k1.") + } + if len(k2) != aesKeyLength { + t.Fatalf("wrong length of k2.") + } + if !validateSymmetricKey(k2) { + t.Fatalf("key validation failed.") + } +} + +func TestExpiry(t *testing.T) { + InitSingleTest() + + w := New(&DefaultConfig) + w.SetMinimumPoW(0.0000001) + defer w.SetMinimumPoW(DefaultMinimumPoW) + w.Start(nil) + defer w.Stop() + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + params.TTL = 1 + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + err = w.Send(env) + if err != nil { + t.Fatalf("failed to send envelope with seed %d: %s.", seed, err) + } + + // wait till received or timeout + var received, expired bool + for j := 0; j < 20; j++ { + time.Sleep(100 * time.Millisecond) + if len(w.Envelopes()) > 0 { + received = true + break + } + } + + if !received { + t.Fatalf("did not receive the sent envelope, seed: %d.", seed) + } + + // wait till expired or timeout + for j := 0; j < 20; j++ { + time.Sleep(100 * time.Millisecond) + if len(w.Envelopes()) == 0 { + expired = true + break + } + } + + if !expired { + t.Fatalf("expire failed, seed: %d.", seed) + } +} + +func TestCustomization(t *testing.T) { + InitSingleTest() + + w := New(&DefaultConfig) + defer w.SetMinimumPoW(DefaultMinimumPoW) + defer w.SetMaxMessageSize(DefaultMaxMessageSize) + w.Start(nil) + defer w.Stop() + + const smallPoW = 0.00001 + + f, err := generateFilter(t, true) + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + params.KeySym = f.KeySym + params.Topic = BytesToTopic(f.Topics[2]) + params.PoW = smallPoW + params.TTL = 3600 * 24 // one day + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + err = w.Send(env) + if err == nil { + t.Fatalf("successfully sent envelope with PoW %.06f, false positive (seed %d).", env.PoW(), seed) + } + + w.SetMinimumPoW(smallPoW / 2) + err = w.Send(env) + if err != nil { + t.Fatalf("failed to send envelope with seed %d: %s.", seed, err) + } + + params.TTL++ + msg, err = NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err = msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + w.SetMaxMessageSize(uint32(env.size() - 1)) + err = w.Send(env) + if err == nil { + t.Fatalf("successfully sent oversized envelope (seed %d): false positive.", seed) + } + + w.SetMaxMessageSize(DefaultMaxMessageSize) + err = w.Send(env) + if err != nil { + t.Fatalf("failed to send second envelope with seed %d: %s.", seed, err) + } + + // wait till received or timeout + var received bool + for j := 0; j < 20; j++ { + time.Sleep(100 * time.Millisecond) + if len(w.Envelopes()) > 1 { + received = true + break + } + } + + if !received { + t.Fatalf("did not receive the sent envelope, seed: %d.", seed) + } + + // check w.messages() + id, err := w.Subscribe(f) + if err != nil { + t.Fatalf("failed subscribe with seed %d: %s.", seed, err) + } + time.Sleep(5 * time.Millisecond) + mail := f.Retrieve() + if len(mail) > 0 { + t.Fatalf("received premature mail") + } + + mail = w.Messages(id) + if len(mail) != 2 { + t.Fatalf("failed to get whisper messages") + } +} + +func TestSymmetricSendCycle(t *testing.T) { + InitSingleTest() + + w := New(&DefaultConfig) + defer w.SetMinimumPoW(DefaultMinimumPoW) + defer w.SetMaxMessageSize(DefaultMaxMessageSize) + w.Start(nil) + defer w.Stop() + + filter1, err := generateFilter(t, true) + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + filter1.PoW = DefaultMinimumPoW + + // Copy the first filter since some of its fields + // are randomly gnerated. + filter2 := &Filter{ + KeySym: filter1.KeySym, + Topics: filter1.Topics, + PoW: filter1.PoW, + AllowP2P: filter1.AllowP2P, + Messages: make(map[common.Hash]*ReceivedMessage), + } + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + filter1.Src = ¶ms.Src.PublicKey + filter2.Src = ¶ms.Src.PublicKey + + params.KeySym = filter1.KeySym + params.Topic = BytesToTopic(filter1.Topics[2]) + params.PoW = filter1.PoW + params.WorkTime = 10 + params.TTL = 50 + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + _, err = w.Subscribe(filter1) + if err != nil { + t.Fatalf("failed subscribe 1 with seed %d: %s.", seed, err) + } + + _, err = w.Subscribe(filter2) + if err != nil { + t.Fatalf("failed subscribe 2 with seed %d: %s.", seed, err) + } + + err = w.Send(env) + if err != nil { + t.Fatalf("Failed sending envelope with PoW %.06f (seed %d): %s", env.PoW(), seed, err) + } + + // wait till received or timeout + var received bool + for j := 0; j < 200; j++ { + time.Sleep(10 * time.Millisecond) + if len(w.Envelopes()) > 0 { + received = true + break + } + } + + if !received { + t.Fatalf("did not receive the sent envelope, seed: %d.", seed) + } + + // check w.messages() + time.Sleep(5 * time.Millisecond) + mail1 := filter1.Retrieve() + mail2 := filter2.Retrieve() + if len(mail2) == 0 { + t.Fatalf("did not receive any email for filter 2") + } + if len(mail1) == 0 { + t.Fatalf("did not receive any email for filter 1") + } + +} + +func TestSymmetricSendWithoutAKey(t *testing.T) { + InitSingleTest() + + w := New(&DefaultConfig) + defer w.SetMinimumPoW(DefaultMinimumPoW) + defer w.SetMaxMessageSize(DefaultMaxMessageSize) + w.Start(nil) + defer w.Stop() + + filter, err := generateFilter(t, true) + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + filter.PoW = DefaultMinimumPoW + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + filter.Src = nil + + params.KeySym = filter.KeySym + params.Topic = BytesToTopic(filter.Topics[2]) + params.PoW = filter.PoW + params.WorkTime = 10 + params.TTL = 50 + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + _, err = w.Subscribe(filter) + if err != nil { + t.Fatalf("failed subscribe 1 with seed %d: %s.", seed, err) + } + + err = w.Send(env) + if err != nil { + t.Fatalf("Failed sending envelope with PoW %.06f (seed %d): %s", env.PoW(), seed, err) + } + + // wait till received or timeout + var received bool + for j := 0; j < 200; j++ { + time.Sleep(10 * time.Millisecond) + if len(w.Envelopes()) > 0 { + received = true + break + } + } + + if !received { + t.Fatalf("did not receive the sent envelope, seed: %d.", seed) + } + + // check w.messages() + time.Sleep(5 * time.Millisecond) + mail := filter.Retrieve() + if len(mail) == 0 { + t.Fatalf("did not receive message in spite of not setting a public key") + } +} + +func TestSymmetricSendKeyMismatch(t *testing.T) { + InitSingleTest() + + w := New(&DefaultConfig) + defer w.SetMinimumPoW(DefaultMinimumPoW) + defer w.SetMaxMessageSize(DefaultMaxMessageSize) + w.Start(nil) + defer w.Stop() + + filter, err := generateFilter(t, true) + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + filter.PoW = DefaultMinimumPoW + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + + params.KeySym = filter.KeySym + params.Topic = BytesToTopic(filter.Topics[2]) + params.PoW = filter.PoW + params.WorkTime = 10 + params.TTL = 50 + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + _, err = w.Subscribe(filter) + if err != nil { + t.Fatalf("failed subscribe 1 with seed %d: %s.", seed, err) + } + + err = w.Send(env) + if err != nil { + t.Fatalf("Failed sending envelope with PoW %.06f (seed %d): %s", env.PoW(), seed, err) + } + + // wait till received or timeout + var received bool + for j := 0; j < 200; j++ { + time.Sleep(10 * time.Millisecond) + if len(w.Envelopes()) > 0 { + received = true + break + } + } + + if !received { + t.Fatalf("did not receive the sent envelope, seed: %d.", seed) + } + + // check w.messages() + time.Sleep(5 * time.Millisecond) + mail := filter.Retrieve() + if len(mail) > 0 { + t.Fatalf("received a message when keys weren't matching") + } +}