diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 8d4e6228d..5c97cef4c 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -193,6 +193,7 @@ var ( utils.RPCGlobalGasCapFlag, utils.RPCGlobalEVMTimeoutFlag, utils.RPCGlobalTxFeeCapFlag, + utils.RPCGlobalLogQueryLimit, utils.AllowUnprotectedTxs, utils.MaxBlockRangeFlag, utils.BatchRequestLimit, diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index 0345b59fa..c43187f71 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -160,6 +160,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.RPCGlobalGasCapFlag, utils.RPCGlobalEVMTimeoutFlag, utils.RPCGlobalTxFeeCapFlag, + utils.RPCGlobalLogQueryLimit, utils.AllowUnprotectedTxs, utils.JSpathFlag, utils.ExecFlag, diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 297e672f9..98bfae808 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -560,6 +560,11 @@ var ( Usage: "Sets a cap on transaction fee (in ether) that can be sent via the RPC APIs (0 = no cap)", Value: ethconfig.Defaults.RPCTxFeeCap, } + RPCGlobalLogQueryLimit = &cli.IntFlag{ + Name: "rpc.logquerylimit", + Usage: "Maximum number of alternative addresses or topics allowed per search position in eth_getLogs filter criteria (0 = no cap)", + Value: ethconfig.Defaults.LogQueryLimit, + } // Logging and debug settings EthStatsURLFlag = cli.StringFlag{ Name: "ethstats", @@ -1707,6 +1712,9 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) { if ctx.IsSet(CacheLogSizeFlag.Name) { cfg.FilterLogCacheSize = ctx.Int(CacheLogSizeFlag.Name) } + if ctx.IsSet(RPCGlobalLogQueryLimit.Name) { + cfg.LogQueryLimit = ctx.Int(RPCGlobalLogQueryLimit.Name) + } if !ctx.Bool(SnapshotFlag.Name) { // If snap-sync is requested, this flag is also required if cfg.SyncMode == downloader.SnapSync { @@ -1982,7 +1990,8 @@ func RegisterGraphQLService(stack *node.Node, backend ethapi.Backend, filterSyst func RegisterFilterAPI(stack *node.Node, backend ethapi.Backend, ethcfg *ethconfig.Config) *filters.FilterSystem { isLightClient := ethcfg.SyncMode == downloader.LightSync filterSystem := filters.NewFilterSystem(backend, filters.Config{ - LogCacheSize: ethcfg.FilterLogCacheSize, + LogCacheSize: ethcfg.FilterLogCacheSize, + LogQueryLimit: ethcfg.LogQueryLimit, }) stack.RegisterAPIs([]rpc.API{{ Namespace: "eth", diff --git a/core/state/statedb_hooked.go b/core/state/statedb_hooked.go index 5a2407c6c..75d29e8ae 100644 --- a/core/state/statedb_hooked.go +++ b/core/state/statedb_hooked.go @@ -174,7 +174,10 @@ func (s *hookedStateDB) SetCode(address common.Address, code []byte) []byte { if len(prev) != 0 { prevHash = crypto.Keccak256Hash(prev) } - s.hooks.OnCodeChange(address, prevHash, prev, crypto.Keccak256Hash(code), code) + newHash := crypto.Keccak256Hash(code) + if prevHash != newHash { + s.hooks.OnCodeChange(address, prevHash, prev, newHash, code) + } } return prev } diff --git a/core/state/statedb_hooked_test.go b/core/state/statedb_hooked_test.go new file mode 100644 index 000000000..7fd669d8a --- /dev/null +++ b/core/state/statedb_hooked_test.go @@ -0,0 +1,41 @@ +package state + +import ( + "testing" + + "github.com/morph-l2/go-ethereum/common" + "github.com/morph-l2/go-ethereum/core/rawdb" + "github.com/morph-l2/go-ethereum/core/tracing" +) + +func TestSetCodeNoChangeDoesNotHook(t *testing.T) { + statedb, err := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + if err != nil { + t.Fatal(err) + } + addr := common.HexToAddress("0x1111111111111111111111111111111111111111") + var calls int + hooked := NewHookedState(statedb, &tracing.Hooks{ + OnCodeChange: func(common.Address, common.Hash, []byte, common.Hash, []byte) { + calls++ + }, + }) + + code := []byte{1, 2, 3} + hooked.SetCode(addr, code) + if calls != 1 { + t.Fatalf("first SetCode should hook once, got %d", calls) + } + hooked.SetCode(addr, append([]byte(nil), code...)) + if calls != 1 { + t.Fatalf("unchanged SetCode should not hook, got %d", calls) + } + hooked.SetCode(addr, nil) + if calls != 2 { + t.Fatalf("clearing code should hook once, got %d", calls) + } + hooked.SetCode(addr, nil) + if calls != 2 { + t.Fatalf("repeated empty SetCode should not hook, got %d", calls) + } +} diff --git a/core/vm/evm.go b/core/vm/evm.go index 146029d87..05bccf2a5 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -500,7 +500,9 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, if err == nil { createDataGas := uint64(len(ret)) * params.CreateDataGas if contract.UseGas(createDataGas, evm.Config.Tracer, tracing.GasChangeCallCodeStorage) { - evm.StateDB.SetCode(address, ret) + if len(ret) > 0 { + evm.StateDB.SetCode(address, ret) + } } else { err = ErrCodeStoreOutOfGas } diff --git a/core/vm/evm_create_test.go b/core/vm/evm_create_test.go new file mode 100644 index 000000000..b903aa5e2 --- /dev/null +++ b/core/vm/evm_create_test.go @@ -0,0 +1,86 @@ +package vm + +import ( + "math/big" + "testing" + + "github.com/morph-l2/go-ethereum/common" + "github.com/morph-l2/go-ethereum/core/rawdb" + statedb "github.com/morph-l2/go-ethereum/core/state" + "github.com/morph-l2/go-ethereum/params" +) + +type setCodeCountingStateDB struct { + StateDB + setCodeCalls int +} + +func (db *setCodeCountingStateDB) SetCode(addr common.Address, code []byte) []byte { + db.setCodeCalls++ + return db.StateDB.SetCode(addr, code) +} + +func newCreateTestEVM(t *testing.T) (*EVM, *setCodeCountingStateDB, common.Address) { + t.Helper() + + base, err := statedb.New(common.Hash{}, statedb.NewDatabase(rawdb.NewMemoryDatabase()), nil) + if err != nil { + t.Fatal(err) + } + caller := common.HexToAddress("0x1111111111111111111111111111111111111111") + base.CreateAccount(caller) + + wrapped := &setCodeCountingStateDB{StateDB: base} + blockCtx := BlockContext{ + CanTransfer: func(StateDB, common.Address, *big.Int) bool { return true }, + Transfer: func(StateDB, common.Address, common.Address, *big.Int) {}, + GetHash: func(uint64) common.Hash { return common.Hash{} }, + BlockNumber: big.NewInt(0), + Time: big.NewInt(0), + GasLimit: 10_000_000, + BaseFee: big.NewInt(params.InitialBaseFee), + } + txCtx := TxContext{ + Origin: caller, + GasPrice: big.NewInt(params.InitialBaseFee), + } + return NewEVM(blockCtx, txCtx, wrapped, params.TestChainConfig, Config{}), wrapped, caller +} + +func TestCreateEmptyReturnSkipsSetCode(t *testing.T) { + evm, statedb, caller := newCreateTestEVM(t) + + ret, _, _, err := evm.Create(AccountRef(caller), nil, 1_000_000, new(big.Int)) + if err != nil { + t.Fatal(err) + } + if len(ret) != 0 { + t.Fatalf("expected empty create return, got %x", ret) + } + if statedb.setCodeCalls != 0 { + t.Fatalf("empty create called SetCode %d times", statedb.setCodeCalls) + } +} + +func TestCreateNonEmptyReturnSetsCode(t *testing.T) { + evm, statedb, caller := newCreateTestEVM(t) + + code := []byte{ + byte(PUSH1), 0x2a, + byte(PUSH1), 0x00, + byte(MSTORE), + byte(PUSH1), 0x20, + byte(PUSH1), 0x00, + byte(RETURN), + } + ret, _, _, err := evm.Create(AccountRef(caller), code, 1_000_000, new(big.Int)) + if err != nil { + t.Fatal(err) + } + if len(ret) != 32 { + t.Fatalf("expected non-empty create return, got %x", ret) + } + if statedb.setCodeCalls != 1 { + t.Fatalf("non-empty create called SetCode %d times", statedb.setCodeCalls) + } +} diff --git a/eth/ethconfig/config.go b/eth/ethconfig/config.go index 7e1f14549..4a87017dc 100644 --- a/eth/ethconfig/config.go +++ b/eth/ethconfig/config.go @@ -83,6 +83,7 @@ var Defaults = Config{ TrieTimeout: 60 * time.Minute, SnapshotCache: 102, FilterLogCacheSize: 32, + LogQueryLimit: 1000, Miner: miner.DefaultConfig, TxPool: core.DefaultTxPoolConfig, RPCGasCap: 50000000, @@ -169,6 +170,9 @@ type Config struct { // This is the number of blocks for which logs will be cached in the filter system. FilterLogCacheSize int + // This is the maximum number of addresses or topics allowed in filter criteria. + LogQueryLimit int + // Mining options Miner miner.Config diff --git a/eth/ethconfig/gen_config.go b/eth/ethconfig/gen_config.go index 0c0e04157..d131bec5d 100644 --- a/eth/ethconfig/gen_config.go +++ b/eth/ethconfig/gen_config.go @@ -49,6 +49,7 @@ func (c Config) MarshalTOML() (interface{}, error) { SnapshotCache int Preimages bool FilterLogCacheSize int + LogQueryLimit int Miner miner.Config Ethash ethash.Config TxPool core.TxPoolConfig @@ -99,6 +100,7 @@ func (c Config) MarshalTOML() (interface{}, error) { enc.SnapshotCache = c.SnapshotCache enc.Preimages = c.Preimages enc.FilterLogCacheSize = c.FilterLogCacheSize + enc.LogQueryLimit = c.LogQueryLimit enc.Miner = c.Miner enc.Ethash = c.Ethash enc.TxPool = c.TxPool @@ -153,6 +155,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { SnapshotCache *int Preimages *bool FilterLogCacheSize *int + LogQueryLimit *int Miner *miner.Config Ethash *ethash.Config TxPool *core.TxPoolConfig @@ -268,6 +271,9 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.FilterLogCacheSize != nil { c.FilterLogCacheSize = *dec.FilterLogCacheSize } + if dec.LogQueryLimit != nil { + c.LogQueryLimit = *dec.LogQueryLimit + } if dec.Miner != nil { c.Miner = *dec.Miner } diff --git a/eth/fetcher/tx_fetcher.go b/eth/fetcher/tx_fetcher.go index efab9af6c..233666a3f 100644 --- a/eth/fetcher/tx_fetcher.go +++ b/eth/fetcher/tx_fetcher.go @@ -674,6 +674,10 @@ func (f *TxFetcher) loop() { if len(f.announced[hash]) == 0 { delete(f.announced, hash) } + delete(f.alternates[hash], drop.peer) + if len(f.alternates[hash]) == 0 { + delete(f.alternates, hash) + } } delete(f.announces, drop.peer) } diff --git a/eth/fetcher/tx_fetcher_test.go b/eth/fetcher/tx_fetcher_test.go index 4d0ff8afe..bdcf9b08e 100644 --- a/eth/fetcher/tx_fetcher_test.go +++ b/eth/fetcher/tx_fetcher_test.go @@ -1121,6 +1121,44 @@ func TestTransactionFetcherDropRescheduling(t *testing.T) { }) } +func TestTransactionFetcherDropAlternates(t *testing.T) { + testTransactionFetcherParallel(t, txFetcherTest{ + init: func() *TxFetcher { + return NewTxFetcher( + func(common.Hash) bool { return false }, + func(txs []*types.Transaction) []error { + return make([]error, len(txs)) + }, + func(string, []common.Hash) error { return nil }, + ) + }, + steps: []interface{}{ + doTxNotify{peer: "A", hashes: []common.Hash{{0x01}}}, + doWait{time: txArriveTimeout, step: true}, + doTxNotify{peer: "B", hashes: []common.Hash{{0x01}}}, + isScheduled{ + tracking: map[string][]common.Hash{ + "A": {{0x01}}, + "B": {{0x01}}, + }, + fetching: map[string][]common.Hash{ + "A": {{0x01}}, + }, + }, + doDrop("B"), + isWaiting(nil), + isScheduled{ + tracking: map[string][]common.Hash{ + "A": {{0x01}}, + }, + fetching: map[string][]common.Hash{ + "A": {{0x01}}, + }, + }, + }, + }) +} + // This test reproduces a crash caught by the fuzzer. The root cause was a // dangling transaction timing out and clashing on readd with a concurrently // announced one. diff --git a/eth/filters/api.go b/eth/filters/api.go index 8bcc845b0..e14c70d09 100644 --- a/eth/filters/api.go +++ b/eth/filters/api.go @@ -34,6 +34,13 @@ import ( "github.com/morph-l2/go-ethereum/rpc" ) +var ( + errExceedMaxTopics = errors.New("exceed max topics") + errExceedLogQueryLimit = errors.New("exceed max addresses or topics per search position") +) + +const maxTopics = 4 + // filter is a helper struct that holds meta information over the filter type // and associated subscription in the event system. type filter struct { @@ -56,6 +63,7 @@ type FilterAPI struct { filters map[rpc.ID]*filter timeout time.Duration maxBlockRange int64 + logQueryLimit int } // NewFilterAPI returns a new FilterAPI instance. @@ -66,6 +74,7 @@ func NewFilterAPI(system *FilterSystem, lightMode bool, maxBlockRange int64) *Fi filters: make(map[rpc.ID]*filter), timeout: system.cfg.Timeout, maxBlockRange: maxBlockRange, + logQueryLimit: system.cfg.LogQueryLimit, } go api.timeoutLoop(system.cfg.Timeout) @@ -118,6 +127,7 @@ func (api *FilterAPI) NewPendingTransactionFilter(fullTx *bool) rpc.ID { api.filtersMu.Unlock() go func() { + defer pendingTxSub.Unsubscribe() for { select { case pTx := <-pendingTxs: @@ -201,6 +211,7 @@ func (api *FilterAPI) NewBlockFilter() rpc.ID { api.filtersMu.Unlock() go func() { + defer headerSub.Unsubscribe() for { select { case h := <-headers: @@ -318,6 +329,7 @@ func (api *FilterAPI) NewFilter(crit FilterCriteria) (rpc.ID, error) { api.filtersMu.Unlock() go func() { + defer logsSub.Unsubscribe() for { select { case l := <-logs: @@ -342,6 +354,9 @@ func (api *FilterAPI) NewFilter(crit FilterCriteria) (rpc.ID, error) { // // https://eth.wiki/json-rpc/API#eth_getlogs func (api *FilterAPI) GetLogs(ctx context.Context, crit FilterCriteria) ([]*types.Log, error) { + if err := validateLogQuery(ethereum.FilterQuery(crit), api.logQueryLimit); err != nil { + return nil, err + } var filter *Filter if crit.BlockHash != nil { // Block filter requested, construct a single-shot filter @@ -367,6 +382,24 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit FilterCriteria) ([]*type return returnLogs(logs), err } +func validateLogQuery(crit ethereum.FilterQuery, logQueryLimit int) error { + if len(crit.Topics) > maxTopics { + return errExceedMaxTopics + } + if logQueryLimit == 0 { + return nil + } + if len(crit.Addresses) > logQueryLimit { + return errExceedLogQueryLimit + } + for _, topics := range crit.Topics { + if len(topics) > logQueryLimit { + return errExceedLogQueryLimit + } + } + return nil +} + // UninstallFilter removes the filter with the given filter id. // // https://eth.wiki/json-rpc/API#eth_uninstallfilter @@ -560,6 +593,9 @@ func (args *FilterCriteria) UnmarshalJSON(data []byte) error { // topics is an array consisting of strings and/or arrays of strings. // JSON null values are converted to common.Hash{} and ignored by the filter manager. if len(raw.Topics) > 0 { + if len(raw.Topics) > maxTopics { + return errExceedMaxTopics + } args.Topics = make([][]common.Hash, len(raw.Topics)) for i, t := range raw.Topics { switch topic := t.(type) { diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go index 511e973ff..c73ce3be3 100644 --- a/eth/filters/filter_system.go +++ b/eth/filters/filter_system.go @@ -41,8 +41,9 @@ import ( // Config represents the configuration of the filter system. type Config struct { - LogCacheSize int // maximum number of cached blocks (default: 32) - Timeout time.Duration // how long filters stay active (default: 5min) + LogCacheSize int // maximum number of cached blocks (default: 32) + Timeout time.Duration // how long filters stay active (default: 5min) + LogQueryLimit int // maximum number of addresses or topics allowed in filter criteria } func (cfg Config) withDefaults() Config { @@ -267,6 +268,9 @@ func (es *EventSystem) subscribe(sub *subscription) *Subscription { // given criteria to the given logs channel. Default value for the from and to // block is "latest". If the fromBlock > toBlock an error is returned. func (es *EventSystem) SubscribeLogs(crit ethereum.FilterQuery, logs chan []*types.Log) (*Subscription, error) { + if err := validateLogQuery(crit, es.sys.cfg.LogQueryLimit); err != nil { + return nil, err + } var from, to rpc.BlockNumber if crit.FromBlock == nil { from = rpc.LatestBlockNumber diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index 5ff782750..b6a19b63c 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -18,6 +18,7 @@ package filters import ( "context" + "errors" "fmt" "math/big" "math/rand" @@ -355,6 +356,68 @@ func TestPendingTxFilterFullTx(t *testing.T) { } } +func TestPollingFilterErrCleanup(t *testing.T) { + db := rawdb.NewMemoryDatabase() + _, sys := newTestFilterSystem(t, db, Config{}) + api := NewFilterAPI(sys, false, ethconfig.Defaults.MaxBlockRange) + + tests := []struct { + name string + create func(t *testing.T) rpc.ID + }{ + { + name: "pending transactions", + create: func(t *testing.T) rpc.ID { + return api.NewPendingTransactionFilter(nil) + }, + }, + { + name: "blocks", + create: func(t *testing.T) rpc.ID { + return api.NewBlockFilter() + }, + }, + { + name: "logs", + create: func(t *testing.T) rpc.ID { + id, err := api.NewFilter(FilterCriteria{}) + if err != nil { + t.Fatal(err) + } + return id + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id := tt.create(t) + + api.filtersMu.Lock() + f := api.filters[id] + api.filtersMu.Unlock() + if f == nil { + t.Fatalf("filter %s was not installed", id) + } + + f.s.Unsubscribe() + deadline := time.Now().Add(time.Second) + for { + api.filtersMu.Lock() + _, exists := api.filters[id] + api.filtersMu.Unlock() + if !exists { + return + } + if time.Now().After(deadline) { + t.Fatalf("filter %s was not cleaned up", id) + } + time.Sleep(10 * time.Millisecond) + } + }) + } +} + // TestLogFilterCreation test whether a given filter criteria makes sense. // If not it must return an error. func TestLogFilterCreation(t *testing.T) { @@ -445,6 +508,97 @@ func TestInvalidGetLogsRequest(t *testing.T) { } } +func TestExceedLogQueryLimit(t *testing.T) { + db := rawdb.NewMemoryDatabase() + _, sys := newTestFilterSystem(t, db, Config{LogQueryLimit: 1}) + api := NewFilterAPI(sys, false, ethconfig.Defaults.MaxBlockRange) + + addr1 := common.HexToAddress("0x1111111111111111111111111111111111111111") + addr2 := common.HexToAddress("0x2222222222222222222222222222222222222222") + topic1 := common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111") + topic2 := common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222") + + for _, tt := range []struct { + name string + crit FilterCriteria + want error + }{ + { + name: "addresses", + crit: FilterCriteria{Addresses: []common.Address{ + addr1, + addr2, + }}, + want: errExceedLogQueryLimit, + }, + { + name: "topic alternatives", + crit: FilterCriteria{Topics: [][]common.Hash{{ + topic1, + topic2, + }}}, + want: errExceedLogQueryLimit, + }, + { + name: "topic positions", + crit: FilterCriteria{Topics: make([][]common.Hash, maxTopics+1)}, + want: errExceedMaxTopics, + }, + } { + t.Run(tt.name+"/getLogs", func(t *testing.T) { + if _, err := api.GetLogs(context.Background(), tt.crit); !errors.Is(err, tt.want) { + t.Fatalf("error mismatch: got %v, want %v", err, tt.want) + } + }) + t.Run(tt.name+"/newFilter", func(t *testing.T) { + if _, err := api.NewFilter(tt.crit); !errors.Is(err, tt.want) { + t.Fatalf("error mismatch: got %v, want %v", err, tt.want) + } + }) + t.Run(tt.name+"/subscribeLogs", func(t *testing.T) { + logs := make(chan []*types.Log) + if _, err := api.events.SubscribeLogs(ethereum.FilterQuery(tt.crit), logs); !errors.Is(err, tt.want) { + t.Fatalf("error mismatch: got %v, want %v", err, tt.want) + } + }) + } +} + +func TestLogQueryLimitZeroDisablesWidthLimit(t *testing.T) { + db := rawdb.NewMemoryDatabase() + _, sys := newTestFilterSystem(t, db, Config{LogQueryLimit: 0}) + api := NewFilterAPI(sys, false, ethconfig.Defaults.MaxBlockRange) + + crit := FilterCriteria{ + Addresses: []common.Address{ + common.HexToAddress("0x1111111111111111111111111111111111111111"), + common.HexToAddress("0x2222222222222222222222222222222222222222"), + }, + Topics: [][]common.Hash{{ + common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111"), + common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222"), + }}, + } + + id, err := api.NewFilter(crit) + if err != nil { + t.Fatalf("NewFilter rejected unlimited-width criteria: %v", err) + } + api.UninstallFilter(id) + + logs := make(chan []*types.Log) + sub, err := api.events.SubscribeLogs(ethereum.FilterQuery(crit), logs) + if err != nil { + t.Fatalf("SubscribeLogs rejected unlimited-width criteria: %v", err) + } + sub.Unsubscribe() + + tooManyPositions := FilterCriteria{Topics: make([][]common.Hash, maxTopics+1)} + if _, err := api.NewFilter(tooManyPositions); !errors.Is(err, errExceedMaxTopics) { + t.Fatalf("topic position limit mismatch: got %v, want %v", err, errExceedMaxTopics) + } +} + func TestGetLogsRange(t *testing.T) { var ( db = rawdb.NewMemoryDatabase() diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go index 3adb40136..2bcab249e 100644 --- a/eth/protocols/eth/handler_test.go +++ b/eth/protocols/eth/handler_test.go @@ -17,10 +17,12 @@ package eth import ( + "errors" "math" "math/big" "math/rand" "testing" + "time" "github.com/morph-l2/go-ethereum/common" "github.com/morph-l2/go-ethereum/consensus/ethash" @@ -107,6 +109,83 @@ func (b *testBackend) Handle(*Peer, Packet) error { panic("data processing tests should be done in the handler package") } +type duplicateTxBackend struct { + handled bool +} + +func (b *duplicateTxBackend) Chain() *core.BlockChain { panic("not implemented") } +func (b *duplicateTxBackend) TxPool() TxPool { panic("not implemented") } +func (b *duplicateTxBackend) AcceptTxs() bool { return true } +func (b *duplicateTxBackend) RunPeer(*Peer, Handler) error { + panic("not implemented") +} +func (b *duplicateTxBackend) PeerInfo(enode.ID) interface{} { panic("not implemented") } +func (b *duplicateTxBackend) Handle(*Peer, Packet) error { + b.handled = true + return nil +} + +type duplicateTxDecoder struct { + packet interface{} +} + +func (d duplicateTxDecoder) Decode(val interface{}) error { + switch out := val.(type) { + case *TransactionsPacket: + *out = d.packet.(TransactionsPacket) + case *PooledTransactionsPacket66: + *out = d.packet.(PooledTransactionsPacket66) + default: + panic("unexpected decode target") + } + return nil +} + +func (d duplicateTxDecoder) Time() time.Time { return time.Now() } + +func newDuplicateTxTestPeer() *Peer { + return &Peer{ + id: "duplicate-tx-test-peer", + version: ETH66, + knownTxs: newKnownCache(maxKnownTxs), + } +} + +func TestHandleTransactionsDuplicate(t *testing.T) { + tx := types.NewTransaction(0, common.Address{1}, big.NewInt(1), 21000, big.NewInt(1), nil) + backend := new(duplicateTxBackend) + peer := newDuplicateTxTestPeer() + + err := handleTransactions(backend, duplicateTxDecoder{ + packet: TransactionsPacket{tx, tx}, + }, peer) + if !errors.Is(err, errDecode) { + t.Fatalf("error mismatch: got %v, want %v", err, errDecode) + } + if backend.handled { + t.Fatal("duplicate transaction packet was forwarded to backend") + } +} + +func TestHandlePooledTransactions66Duplicate(t *testing.T) { + tx := types.NewTransaction(0, common.Address{1}, big.NewInt(1), 21000, big.NewInt(1), nil) + backend := new(duplicateTxBackend) + peer := newDuplicateTxTestPeer() + + err := handlePooledTransactions66(backend, duplicateTxDecoder{ + packet: PooledTransactionsPacket66{ + RequestId: 1, + PooledTransactionsPacket: PooledTransactionsPacket{tx, tx}, + }, + }, peer) + if !errors.Is(err, errDecode) { + t.Fatalf("error mismatch: got %v, want %v", err, errDecode) + } + if backend.handled { + t.Fatal("duplicate pooled transaction packet was forwarded to backend") + } +} + // Tests that block headers can be retrieved from a remote chain based on user queries. func TestGetBlockHeaders66(t *testing.T) { testGetBlockHeaders(t, ETH66) } diff --git a/eth/protocols/eth/handlers.go b/eth/protocols/eth/handlers.go index 92f2918e3..26b7f0640 100644 --- a/eth/protocols/eth/handlers.go +++ b/eth/protocols/eth/handlers.go @@ -420,12 +420,18 @@ func handleTransactions(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(&txs); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + seen := make(map[common.Hash]struct{}) for i, tx := range txs { // Validate and mark the remote transaction if tx == nil { return fmt.Errorf("%w: transaction %d is nil", errDecode, i) } - peer.markTransaction(tx.Hash()) + hash := tx.Hash() + if _, ok := seen[hash]; ok { + return fmt.Errorf("%w: transaction %d is duplicate", errDecode, i) + } + seen[hash] = struct{}{} + peer.markTransaction(hash) } return backend.Handle(peer, &txs) } @@ -440,12 +446,18 @@ func handlePooledTransactions66(backend Backend, msg Decoder, peer *Peer) error if err := msg.Decode(&txs); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + seen := make(map[common.Hash]struct{}) for i, tx := range txs.PooledTransactionsPacket { // Validate and mark the remote transaction if tx == nil { return fmt.Errorf("%w: transaction %d is nil", errDecode, i) } - peer.markTransaction(tx.Hash()) + hash := tx.Hash() + if _, ok := seen[hash]; ok { + return fmt.Errorf("%w: transaction %d is duplicate", errDecode, i) + } + seen[hash] = struct{}{} + peer.markTransaction(hash) } requestTracker.Fulfil(peer.id, peer.version, PooledTransactionsMsg, txs.RequestId) diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 410b8136f..bbe70ed5a 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -17,6 +17,7 @@ package graphql import ( + "context" "fmt" "io/ioutil" "math/big" @@ -54,11 +55,56 @@ func TestBuildSchema(t *testing.T) { t.Fatalf("could not create new node: %v", err) } // Make sure the schema can be parsed and matched up to the object model. - if err := newHandler(stack, nil, nil, []string{}, []string{}); err != nil { + if _, err := newHandler(stack, nil, nil, []string{}, []string{}); err != nil { t.Errorf("Could not construct GraphQL handler: %v", err) } } +func TestGraphQLMaxDepth(t *testing.T) { + ddir, err := ioutil.TempDir("", "graphql-maxdepth") + if err != nil { + t.Fatalf("failed to create temporary datadir: %v", err) + } + conf := node.DefaultConfig + conf.DataDir = ddir + stack, err := node.New(&conf) + if err != nil { + t.Fatalf("could not create new node: %v", err) + } + h, err := newHandler(stack, nil, nil, []string{}, []string{}) + if err != nil { + t.Fatalf("could not construct GraphQL handler: %v", err) + } + + introspectionQueryWithOfTypes := func(ofTypes int) string { + var b strings.Builder + b.WriteString(`{__type(name:"Block"){fields{type{`) + for i := 0; i < ofTypes; i++ { + b.WriteString("ofType{") + } + b.WriteString("name") + for i := 0; i < ofTypes; i++ { + b.WriteString("}") + } + b.WriteString("}}}}") + return b.String() + } + + // __type -> fields -> type -> ofType... -> name places the leaf field at maxQueryDepth. + res := h.Schema.Exec(context.Background(), introspectionQueryWithOfTypes(maxQueryDepth-4), "", nil) + if len(res.Errors) != 0 { + t.Fatalf("expected query at max depth to succeed, got %v", res.Errors) + } + + res = h.Schema.Exec(context.Background(), introspectionQueryWithOfTypes(maxQueryDepth-3), "", nil) + for _, err := range res.Errors { + if err.Rule == "MaxDepthExceeded" { + return + } + } + t.Fatalf("expected max depth exceeded error, got %v", res.Errors) +} + // Tests that a graphQL request is successfully handled when graphql is enabled on the specified endpoint func TestGraphQLBlockSerialization(t *testing.T) { stack := createNode(t, true, false) diff --git a/graphql/service.go b/graphql/service.go index 9295aee19..6c63f2b90 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -33,6 +33,9 @@ import ( "github.com/morph-l2/go-ethereum/rpc" ) +// maxQueryDepth limits the maximum field nesting depth allowed in GraphQL queries. +const maxQueryDepth = 20 + type handler struct { Schema *graphql.Schema } @@ -106,17 +109,18 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // New constructs a new GraphQL service instance. func New(stack *node.Node, backend ethapi.Backend, filterSystem *filters.FilterSystem, cors, vhosts []string) error { - return newHandler(stack, backend, filterSystem, cors, vhosts) + _, err := newHandler(stack, backend, filterSystem, cors, vhosts) + return err } // newHandler returns a new `http.Handler` that will answer GraphQL queries. // It additionally exports an interactive query browser on the / endpoint. -func newHandler(stack *node.Node, backend ethapi.Backend, filterSystem *filters.FilterSystem, cors, vhosts []string) error { +func newHandler(stack *node.Node, backend ethapi.Backend, filterSystem *filters.FilterSystem, cors, vhosts []string) (*handler, error) { q := Resolver{backend, filterSystem} - s, err := graphql.ParseSchema(schema, &q) + s, err := graphql.ParseSchema(schema, &q, graphql.MaxDepth(maxQueryDepth)) if err != nil { - return err + return nil, err } h := handler{Schema: s} handler := node.NewHTTPHandlerStack(h, cors, vhosts, nil) @@ -125,5 +129,5 @@ func newHandler(stack *node.Node, backend ethapi.Backend, filterSystem *filters. stack.RegisterHandler("GraphQL", "/graphql", handler) stack.RegisterHandler("GraphQL", "/graphql/", handler) - return nil + return &h, nil } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 1fcd600ec..8491f0aba 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -783,6 +783,15 @@ type StorageResult struct { // GetProof returns the Merkle-proof for a given account and optionally some storage keys. func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Address, storageKeys []string, blockNrOrHash rpc.BlockNumberOrHash) (*AccountResult, error) { + keys := make([]common.Hash, len(storageKeys)) + for i, hexKey := range storageKeys { + key, err := decodeHash(hexKey) + if err != nil { + return nil, err + } + keys[i] = key + } + state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err @@ -810,10 +819,7 @@ func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Addre // create the proof for the storageKeys for i, hexKey := range storageKeys { - key, err := decodeHash(hexKey) - if err != nil { - return nil, err - } + key := keys[i] if storageTrie != nil { proof, storageError := state.GetStorageProof(address, key) if storageError != nil { @@ -853,13 +859,13 @@ func decodeHash(s string) (common.Hash, error) { if (len(s) & 1) > 0 { s = "0" + s } + if len(s) > 64 { + return common.Hash{}, fmt.Errorf("hex string too long, want at most 32 bytes") + } b, err := hex.DecodeString(s) if err != nil { return common.Hash{}, fmt.Errorf("hex string invalid") } - if len(b) > 32 { - return common.Hash{}, fmt.Errorf("hex string too long, want at most 32 bytes") - } return common.BytesToHash(b), nil } diff --git a/internal/ethapi/api_morph_test.go b/internal/ethapi/api_morph_test.go index e2c5c0e98..a605e2986 100644 --- a/internal/ethapi/api_morph_test.go +++ b/internal/ethapi/api_morph_test.go @@ -4,11 +4,13 @@ import ( "context" "crypto/ecdsa" "math/big" + "strings" "testing" "github.com/morph-l2/go-ethereum/common" "github.com/morph-l2/go-ethereum/common/hexutil" "github.com/morph-l2/go-ethereum/core/rawdb" + "github.com/morph-l2/go-ethereum/core/state" "github.com/morph-l2/go-ethereum/core/types" "github.com/morph-l2/go-ethereum/crypto" "github.com/morph-l2/go-ethereum/ethdb" @@ -32,6 +34,100 @@ func makeTestRef(b byte) common.Reference { return ref } +func TestDecodeHash(t *testing.T) { + valid64 := strings.Repeat("1", 64) + for _, tt := range []struct { + name string + input string + want common.Hash + wantErr string + }{ + { + name: "64 hex", + input: valid64, + want: common.HexToHash("0x" + valid64), + }, + { + name: "65 hex", + input: strings.Repeat("1", 65), + wantErr: "hex string too long, want at most 32 bytes", + }, + { + name: "66 hex", + input: strings.Repeat("1", 66), + wantErr: "hex string too long, want at most 32 bytes", + }, + { + name: "0x plus 65 hex", + input: "0x" + strings.Repeat("1", 65), + wantErr: "hex string too long, want at most 32 bytes", + }, + { + name: "odd legal input", + input: "abc", + want: common.HexToHash("0x0abc"), + }, + { + name: "short invalid hex", + input: "zz", + wantErr: "hex string invalid", + }, + } { + t.Run(tt.name, func(t *testing.T) { + got, err := decodeHash(tt.input) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error %q", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("error mismatch: got %q, want %q", err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Fatalf("hash mismatch: got %s, want %s", got, tt.want) + } + }) + } +} + +type proofStorageKeyDecodeBackend struct { + Backend + stateCalls int +} + +func (m *proofStorageKeyDecodeBackend) StateAndHeaderByNumberOrHash(context.Context, rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) { + m.stateCalls++ + return nil, nil, nil +} + +func TestGetProofInvalidStorageKeyBeforeStateAccess(t *testing.T) { + backend := new(proofStorageKeyDecodeBackend) + api := &PublicBlockChainAPI{b: backend} + + _, err := api.GetProof( + context.Background(), + common.Address{}, + []string{ + strings.Repeat("0", 64), + strings.Repeat("1", 65), + }, + rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber), + ) + if err == nil { + t.Fatal("expected storage key decode error") + } + if err.Error() != "hex string too long, want at most 32 bytes" { + t.Fatalf("unexpected error: %v", err) + } + if backend.stateCalls != 0 { + t.Fatalf("state accessed before storage key validation: %d calls", backend.stateCalls) + } +} + func uint64Ptr(v uint64) *hexutil.Uint64 { h := hexutil.Uint64(v) return &h @@ -155,7 +251,7 @@ type mockSetDefaultsBackend struct { header *types.Header } -func (m *mockSetDefaultsBackend) CurrentHeader() *types.Header { return m.header } +func (m *mockSetDefaultsBackend) CurrentHeader() *types.Header { return m.header } func (m *mockSetDefaultsBackend) ChainConfig() *params.ChainConfig { return m.chainConfig } func uint16VersionPtr(v uint8) *hexutil.Uint16 { @@ -303,9 +399,9 @@ func TestSetDefaults_MorphTxVersionHeuristic(t *testing.T) { wantVersion: uint16Ref(types.MorphTxVersion1), }, { - name: "jade fork: no MorphTx fields → not MorphTx (version nil)", - headTime: 1000, - modify: func(args *TransactionArgs) {}, + name: "jade fork: no MorphTx fields → not MorphTx (version nil)", + headTime: 1000, + modify: func(args *TransactionArgs) {}, wantVersion: nil, }, @@ -392,16 +488,16 @@ func TestSetDefaults_MorphTxVersionHeuristic(t *testing.T) { wantErr: true, }, { - name: "jade fork: explicit V1 + FeeTokenID=0 + FeeLimit=nil → ok", - headTime: 1000, + name: "jade fork: explicit V1 + FeeTokenID=0 + FeeLimit=nil → ok", + headTime: 1000, modify: func(args *TransactionArgs) { args.Version = uint16VersionPtr(types.MorphTxVersion1) }, wantVersion: uint16Ref(types.MorphTxVersion1), }, { - name: "jade fork: explicit V1 + FeeTokenID=0 + FeeLimit=0 → ok", - headTime: 1000, + name: "jade fork: explicit V1 + FeeTokenID=0 + FeeLimit=0 → ok", + headTime: 1000, modify: func(args *TransactionArgs) { fid := hexutil.Uint16(0) args.FeeTokenID = &fid diff --git a/rpc/client_opt.go b/rpc/client_opt.go index 5bef08cca..3fa045a9b 100644 --- a/rpc/client_opt.go +++ b/rpc/client_opt.go @@ -34,7 +34,8 @@ type clientConfig struct { httpAuth HTTPAuth // WebSocket options - wsDialer *websocket.Dialer + wsDialer *websocket.Dialer + wsMessageSizeLimit *int64 // wsMessageSizeLimit nil = default, 0 = no limit // RPC handler options idgen func() ID @@ -66,6 +67,14 @@ func WithWebsocketDialer(dialer websocket.Dialer) ClientOption { }) } +// WithWebsocketMessageSizeLimit configures the websocket message size limit used by the RPC +// client. Passing a limit of 0 means no limit. +func WithWebsocketMessageSizeLimit(messageSizeLimit int64) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsMessageSizeLimit = &messageSizeLimit + }) +} + // WithHeader configures HTTP headers set by the RPC client. Headers set using this option // will be used for both HTTP and WebSocket connections. func WithHeader(key, value string) ClientOption { diff --git a/rpc/server.go b/rpc/server.go index 02dbec579..5d708bff7 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -55,11 +55,12 @@ type Server struct { compressionLevel int batchItemLimit int batchResponseLimit int + wsReadLimit int64 } // NewServer creates a new server instance with no registered handlers. func NewServer() *Server { - server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1} + server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1, wsReadLimit: wsDefaultReadLimit} // Register the default service providing meta information about the RPC service such // as the services and methods it offers. rpcService := &RPCService{server} @@ -78,6 +79,13 @@ func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) { s.batchResponseLimit = maxResponseSize } +// SetWebsocketReadLimit sets the limit for max message size for Websocket requests. +// +// This method should be called before processing any requests via Websocket server. +func (s *Server) SetWebsocketReadLimit(limit int64) { + s.wsReadLimit = limit +} + // RegisterName creates a service for the given receiver type under the given name. When no // methods on the given receiver match the criteria to be either a RPC method or a // subscription an error is returned. Otherwise a new service is created and added to the diff --git a/rpc/server_test.go b/rpc/server_test.go index 433aaebb0..40741700e 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -19,9 +19,11 @@ package rpc import ( "bufio" "bytes" + "context" "io" "io/ioutil" "net" + "net/http/httptest" "path/filepath" "strings" "testing" @@ -112,6 +114,89 @@ func runTestScript(t *testing.T, file string) { } } +func TestServerWebsocketReadLimit(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + readLimit int64 + testSize int + shouldFail bool + }{ + { + name: "limit with small request", + readLimit: 4096, + testSize: 256, + shouldFail: false, + }, + { + name: "limit with large request", + readLimit: 256, + testSize: 1024, + shouldFail: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + srv := newTestServer() + srv.SetWebsocketReadLimit(tc.readLimit) + defer srv.Stop() + + httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + defer httpsrv.Close() + + wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + client, err := DialOptions(context.Background(), wsURL) + if err != nil { + t.Fatalf("can't dial: %v", err) + } + defer client.Close() + + var result echoResult + err = client.Call(&result, "test_echo", strings.Repeat("A", tc.testSize), 42, &echoArgs{S: "test"}) + if tc.shouldFail && err == nil { + t.Fatalf("expected error for request size %d with limit %d", tc.testSize, tc.readLimit) + } + if !tc.shouldFail && err != nil { + t.Fatalf("unexpected error for request size %d with limit %d: %v", tc.testSize, tc.readLimit, err) + } + }) + } +} + +func TestServerWebsocketDefaultReadLimit(t *testing.T) { + t.Parallel() + + srv := newTestServer() + if err := srv.RegisterName("sink", new(wsReadLimitSinkService)); err != nil { + t.Fatal(err) + } + defer srv.Stop() + + httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + defer httpsrv.Close() + + wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + client, err := DialOptions(context.Background(), wsURL) + if err != nil { + t.Fatalf("can't dial: %v", err) + } + defer client.Close() + + var result bool + err = client.Call(&result, "sink_accept", strings.Repeat("A", wsDefaultReadLimit+1024)) + if err == nil { + t.Fatalf("expected default websocket read limit to reject request larger than %d bytes", wsDefaultReadLimit) + } +} + +type wsReadLimitSinkService struct{} + +func (wsReadLimitSinkService) Accept(_ string) bool { + return true +} + // This test checks that responses are delivered for very short-lived connections that // only carry a single request. func TestServerShortLivedConn(t *testing.T) { diff --git a/rpc/websocket.go b/rpc/websocket.go index 548728754..2c189cc6c 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -40,7 +40,7 @@ const ( wsPingInterval = 60 * time.Second wsPingWriteTimeout = 5 * time.Second wsPongTimeout = 30 * time.Second - wsMessageSizeLimit = 10 * 15 * 1024 * 1024 + wsDefaultReadLimit = 32 * 1024 * 1024 ) var wsBufferPool = new(sync.Pool) @@ -67,7 +67,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { if enableCompression { _ = conn.SetCompressionLevel(s.compressionLevel) } - codec := newWebsocketCodec(conn, r.Host, r.Header) + codec := newWebsocketCodec(conn, r.Host, r.Header, s.wsReadLimit) s.ServeCodec(codec, 0) }) } @@ -253,7 +253,11 @@ func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, er } return nil, hErr } - return newWebsocketCodec(conn, dialURL, header), nil + messageSizeLimit := int64(wsDefaultReadLimit) + if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 { + messageSizeLimit = *cfg.wsMessageSizeLimit + } + return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil } return connect, nil } @@ -284,8 +288,8 @@ type websocketCodec struct { pingReset chan struct{} } -func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { - conn.SetReadLimit(wsMessageSizeLimit) +func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec { + conn.SetReadLimit(readLimit) conn.SetPongHandler(func(appData string) error { conn.SetReadDeadline(time.Time{}) return nil diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index bfd7c7909..ff94a6eb9 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -117,6 +117,60 @@ func TestWebsocketLargeCall(t *testing.T) { } } +// This test checks whether the websocket message size limit option is obeyed. +func TestWebsocketLargeRead(t *testing.T) { + t.Parallel() + + var ( + srv = newTestServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + buffer = 64 + ) + defer srv.Stop() + defer httpsrv.Close() + if err := srv.RegisterName("repeat", new(repeatService)); err != nil { + t.Fatal(err) + } + + for _, tt := range []struct { + size int + limit int + err bool + }{ + {200, 200, false}, + {2048, 1024, true}, + {wsDefaultReadLimit + buffer, 0, false}, + } { + t.Run("", func(t *testing.T) { + limit := tt.limit + if limit != 0 { + limit += buffer + } + client, err := DialOptions(context.Background(), wsURL, WithWebsocketMessageSizeLimit(int64(limit))) + if err != nil { + t.Fatalf("failed to dial test server: %v", err) + } + defer client.Close() + + var res string + err = client.Call(&res, "repeat_repeat", "A", tt.size) + if tt.err && err == nil { + t.Fatalf("expected error, got none") + } + if !tt.err && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +type repeatService struct{} + +func (repeatService) Repeat(str string, count int) string { + return strings.Repeat(str, count) +} + func TestWebsocketPeerInfo(t *testing.T) { var ( s = newTestServer() @@ -210,7 +264,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) { defer srv.Stop() defer httpsrv.Close() - respLength := wsMessageSizeLimit - 50 + respLength := wsDefaultReadLimit - 50 srv.RegisterName("test", largeRespService{respLength}) c, err := DialWebsocket(context.Background(), wsURL, "")