diff --git a/rpc/server.go b/rpc/server.go index 33564249820c..6511edd6ffd4 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -55,6 +55,7 @@ type Server struct { batchItemLimit int batchResponseLimit int httpBodyLimit int + readLimit int64 } // NewServer creates a new server instance with no registered handlers. @@ -62,8 +63,9 @@ func NewServer() *Server { server := &Server{ idgen: randomIDGenerator(), codecs: make(map[ServerCodec]struct{}), - httpBodyLimit: defaultBodyLimit, denyList: make(map[string]struct{}), + httpBodyLimit: defaultBodyLimit, + readLimit: wsDefaultReadLimit, } server.run.Store(true) // Register the default service providing meta information about the RPC service such @@ -91,6 +93,13 @@ func (s *Server) SetHTTPBodyLimit(limit int) { s.httpBodyLimit = limit } +// SetReadLimits sets the limit for max message size for Websocket requests. +// +// This method should be called before processing any requests via Websocket server. +func (s *Server) SetReadLimits(limit int64) { + s.readLimit = limit +} + // RegisterName creates a service for the given receiver type under the given name. When no // methods on the given receiver match the criteria to be either an RPC method or a // subscription an error is returned. Otherwise a new service is created and added to the diff --git a/rpc/server_test.go b/rpc/server_test.go index 9ee545d81ade..83c36a02f2a0 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -19,8 +19,10 @@ package rpc import ( "bufio" "bytes" + "context" "io" "net" + "net/http/httptest" "os" "path/filepath" "strings" @@ -202,3 +204,115 @@ func TestServerBatchResponseSizeLimit(t *testing.T) { } } } + +func TestServerSetReadLimits(t *testing.T) { + t.Parallel() + + // Test different read limits + testCases := []struct { + name string + readLimit int64 + testSize int + shouldFail bool + }{ + { + name: "small limit with small request - should succeed", + readLimit: 2048, + testSize: 500, // Small request data + shouldFail: false, + }, + { + name: "small limit with large request - should fail", + readLimit: 2048, + testSize: 5000, // Large request data that should exceed limit + shouldFail: true, + }, + { + name: "medium limit with medium request - should succeed", + readLimit: 10240, + testSize: 5000, // Medium request data + shouldFail: false, + }, + { + name: "medium limit with large request - should fail", + readLimit: 10240, + testSize: 20000, // Large request data + shouldFail: true, + }, + { + name: "large limit with large request - should succeed", + readLimit: 50000, + testSize: 20000, // Large request data that should fit + shouldFail: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server and set read limits + srv := newTestServer() + srv.SetReadLimits(tc.readLimit) + defer srv.Stop() + + // Start HTTP server with WebSocket handler + httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + defer httpsrv.Close() + + wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + + // Connect WebSocket client + client, err := DialOptions(context.Background(), wsURL) + if err != nil { + t.Fatalf("can't dial: %v", err) + } + defer client.Close() + + // Create large request data - this is what will be limited + largeString := strings.Repeat("A", tc.testSize) + + // Send the large string as a parameter in the request + var result echoResult + err = client.Call(&result, "test_echo", largeString, 42, &echoArgs{S: "test"}) + + if tc.shouldFail { + // Expecting an error due to read limit exceeded + if err == nil { + t.Fatalf("expected error for request size %d with limit %d, but got none", tc.testSize, tc.readLimit) + } + // Check if it's the expected message size limit error + if !strings.Contains(err.Error(), "message too big") { + t.Fatalf("expected 'message too big' error, got: %v", err) + } + } else { + // Expecting success + if err != nil { + t.Fatalf("unexpected error for request size %d with limit %d: %v", tc.testSize, tc.readLimit, err) + } + // Verify the response is correct - the echo should return our string + if result.String != largeString { + t.Fatalf("expected echo result to match input") + } + } + }) + } +} + +// Test that SetReadLimits properly updates the server's readerLimit field +func TestServerSetReadLimitsField(t *testing.T) { + server := NewServer() + + // Test initial default value + if server.readLimit != wsDefaultReadLimit { + t.Errorf("expected initial readerLimit to be %d, got %d", wsDefaultReadLimit, server.readLimit) + } + + // Test setting different values + testValues := []int64{1024, 10240, 102400, 1048576} + + for _, expectedLimit := range testValues { + server.SetReadLimits(expectedLimit) + if server.readLimit != expectedLimit { + t.Errorf("expected readerLimit to be %d after SetReadLimits, got %d", expectedLimit, server.readLimit) + } + } +} diff --git a/rpc/websocket.go b/rpc/websocket.go index 9f67caf859f1..42c45a591112 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { log.Debug("WebSocket upgrade failed", "err", err) return } - codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) + codec := newWebsocketCodec(conn, r.Host, r.Header, s.readLimit) s.ServeCodec(codec, 0) }) }