diff --git a/rpc/client.go b/rpc/client.go index 6cae87be4c51..bc0375c5219a 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -32,6 +32,7 @@ import ( ) var ( + ErrBadResult = errors.New("bad result in JSON-RPC response") ErrClientQuit = errors.New("client is closed") ErrNoResult = errors.New("no result in JSON-RPC response") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") @@ -316,7 +317,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str case len(resp.Result) == 0: return ErrNoResult default: - return json.Unmarshal(resp.Result, &result) + if result == nil { + return nil + } + return json.Unmarshal(resp.Result, result) } } diff --git a/rpc/client_test.go b/rpc/client_test.go index 56bc74b5d083..ea8e8b70a470 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -18,6 +18,8 @@ package rpc import ( "context" + "encoding/json" + "errors" "fmt" "math/rand" "net" @@ -35,6 +37,73 @@ import ( "github.com/davecgh/go-spew/spew" ) +// This test checks calling a method that returns 'null'. +func TestClientNullResponse(t *testing.T) { + server := newTestServer() + defer server.Stop() + + client := DialInProc(server) + defer client.Close() + + var result json.RawMessage + if err := client.Call(&result, "test_null"); err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("Expected non-nil result") + } + if !reflect.DeepEqual(result, json.RawMessage("null")) { + t.Errorf("Expected null, got %s", result) + } +} + +func TestClientBatchRequest_len(t *testing.T) { + b, err := json.Marshal([]jsonrpcMessage{ + {Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)}, + {Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)}, + }) + if err != nil { + t.Fatal("failed to encode jsonrpc message:", err) + } + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, err := rw.Write(b) + if err != nil { + t.Error("failed to write response:", err) + } + })) + t.Cleanup(s.Close) + + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + + t.Run("too-few", func(t *testing.T) { + batch := []BatchElem{ + {Method: "foo"}, + {Method: "bar"}, + {Method: "baz"}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { + t.Errorf("expected %q but got: %v", ErrBadResult, err) + } + }) + + t.Run("too-many", func(t *testing.T) { + batch := []BatchElem{ + {Method: "foo"}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { + t.Errorf("expected %q but got: %v", ErrBadResult, err) + } + }) +} + func TestClientRequest(t *testing.T) { server := newTestServer() defer server.Stop() diff --git a/rpc/http.go b/rpc/http.go index 68dd4be6c9b5..e43597a992ac 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -170,6 +170,9 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { return err } + if len(respmsgs) != len(msgs) { + return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult) + } for i := 0; i < len(respmsgs); i++ { op.resp <- &respmsgs[i] } diff --git a/rpc/server_test.go b/rpc/server_test.go index 2958312c8a15..1af285bc84da 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) { t.Fatalf("Expected service calc to be registered") } - wantCallbacks := 9 + wantCallbacks := 10 if len(svc.callbacks) != wantCallbacks { t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) } diff --git a/rpc/testservice_test.go b/rpc/testservice_test.go index 62afc1df44f4..22bd64d2f639 100644 --- a/rpc/testservice_test.go +++ b/rpc/testservice_test.go @@ -84,6 +84,10 @@ func (s *testService) Sleep(ctx context.Context, duration time.Duration) { time.Sleep(duration) } +func (s *testService) Null() any { + return nil +} + func (s *testService) Block(ctx context.Context) error { <-ctx.Done() return errors.New("context canceled in testservice_block")