Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,17 @@ type Server struct {
batchItemLimit int
batchResponseLimit int
httpBodyLimit int
readLimit int64
}

// NewServer creates a new server instance with no registered handlers.
func NewServer() *Server {
server := &Server{
idgen: randomIDGenerator(),
codecs: make(map[ServerCodec]struct{}),
httpBodyLimit: defaultBodyLimit,
denyList: make(map[string]struct{}),
httpBodyLimit: defaultBodyLimit,
readLimit: wsDefaultReadLimit,
}
server.run.Store(true)
// Register the default service providing meta information about the RPC service such
Expand Down Expand Up @@ -91,6 +93,13 @@ func (s *Server) SetHTTPBodyLimit(limit int) {
s.httpBodyLimit = limit
}

// SetReadLimits sets the limit for max message size for Websocket requests.
//
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: empty white line

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was just following the other conventions

// This method should be called before processing any requests via Websocket server.
func (s *Server) SetReadLimits(limit int64) {
s.readLimit = limit
}

// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either an RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
Expand Down
114 changes: 114 additions & 0 deletions rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package rpc
import (
"bufio"
"bytes"
"context"
"io"
"net"
"net/http/httptest"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -202,3 +204,115 @@ func TestServerBatchResponseSizeLimit(t *testing.T) {
}
}
}

func TestServerSetReadLimits(t *testing.T) {
t.Parallel()

// Test different read limits
testCases := []struct {
name string
readLimit int64
testSize int
shouldFail bool
}{
{
name: "small limit with small request - should succeed",
readLimit: 2048,
testSize: 500, // Small request data
shouldFail: false,
},
{
name: "small limit with large request - should fail",
readLimit: 2048,
testSize: 5000, // Large request data that should exceed limit
shouldFail: true,
},
{
name: "medium limit with medium request - should succeed",
readLimit: 10240,
testSize: 5000, // Medium request data
shouldFail: false,
},
{
name: "medium limit with large request - should fail",
readLimit: 10240,
testSize: 20000, // Large request data
shouldFail: true,
},
{
name: "large limit with large request - should succeed",
readLimit: 50000,
testSize: 20000, // Large request data that should fit
shouldFail: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create server and set read limits
srv := newTestServer()
srv.SetReadLimits(tc.readLimit)
defer srv.Stop()

// Start HTTP server with WebSocket handler
httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
defer httpsrv.Close()

wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")

// Connect WebSocket client
client, err := DialOptions(context.Background(), wsURL)
if err != nil {
t.Fatalf("can't dial: %v", err)
}
defer client.Close()

// Create large request data - this is what will be limited
largeString := strings.Repeat("A", tc.testSize)

// Send the large string as a parameter in the request
var result echoResult
err = client.Call(&result, "test_echo", largeString, 42, &echoArgs{S: "test"})

if tc.shouldFail {
// Expecting an error due to read limit exceeded
if err == nil {
t.Fatalf("expected error for request size %d with limit %d, but got none", tc.testSize, tc.readLimit)
}
// Check if it's the expected message size limit error
if !strings.Contains(err.Error(), "message too big") {
t.Fatalf("expected 'message too big' error, got: %v", err)
}
} else {
// Expecting success
if err != nil {
t.Fatalf("unexpected error for request size %d with limit %d: %v", tc.testSize, tc.readLimit, err)
}
// Verify the response is correct - the echo should return our string
if result.String != largeString {
t.Fatalf("expected echo result to match input")
}
}
})
}
}

// Test that SetReadLimits properly updates the server's readerLimit field
func TestServerSetReadLimitsField(t *testing.T) {
server := NewServer()

// Test initial default value
if server.readLimit != wsDefaultReadLimit {
t.Errorf("expected initial readerLimit to be %d, got %d", wsDefaultReadLimit, server.readLimit)
}

// Test setting different values
testValues := []int64{1024, 10240, 102400, 1048576}

for _, expectedLimit := range testValues {
server.SetReadLimits(expectedLimit)
if server.readLimit != expectedLimit {
t.Errorf("expected readerLimit to be %d after SetReadLimits, got %d", expectedLimit, server.readLimit)
}
}
}
2 changes: 1 addition & 1 deletion rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err)
return
}
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
codec := newWebsocketCodec(conn, r.Host, r.Header, s.readLimit)
s.ServeCodec(codec, 0)
})
}
Expand Down