Skip to content

Commit

Permalink
Support positional argument in wrapped handlers (#59)
Browse files Browse the repository at this point in the history
Add a new top-level function handler.Positional. Like Check, it accepts a
function to be wrapped as a jrpc2.Handler. Unlike Check, this function allows
positional arguments, which it implements by constructing a wrapper that takes
a synthetic struct type as its argument, and redirects the fields of the struct
to the positional arguments of the original function when the wrapper is called.
  • Loading branch information
creachadair authored Oct 24, 2021
1 parent 884e7d1 commit 98d5ca4
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 6 deletions.
24 changes: 22 additions & 2 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package handler

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -92,6 +93,8 @@ var (
errType = reflect.TypeOf((*error)(nil)).Elem() // type error
reqType = reflect.TypeOf((*jrpc2.Request)(nil)) // type *jrpc2.Request

strictType = reflect.TypeOf((*interface{ DisallowUnknownFields() })(nil)).Elem()

errNoParameters = &jrpc2.Error{Code: code.InvalidParams, Message: "no parameters accepted"}
)

Expand All @@ -102,6 +105,7 @@ type FuncInfo struct {
IsVariadic bool // true if the function is variadic on its argument
Result reflect.Type // the non-error result type, or nil
ReportsError bool // true if the function reports an error
strictFields bool // enforce strict field checking

fn interface{} // the original function value
}
Expand Down Expand Up @@ -136,6 +140,12 @@ func (fi *FuncInfo) Wrap() Func {
return Func(f)
}

// If strict field checking is desired, ensure arguments are wrapped.
wrapArg := func(v reflect.Value) interface{} { return v.Interface() }
if fi.strictFields && !fi.Argument.Implements(strictType) {
wrapArg = func(v reflect.Value) interface{} { return &strict{v.Interface()} }
}

// Construct a function to unpack the parameters from the request message,
// based on the signature of the user's callback.
var newInput func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error)
Expand All @@ -160,7 +170,7 @@ func (fi *FuncInfo) Wrap() Func {
// Case 3a: The function wants a pointer to its argument value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
in := reflect.New(fi.Argument.Elem())
if err := req.UnmarshalParams(in.Interface()); err != nil {
if err := req.UnmarshalParams(wrapArg(in)); err != nil {
return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err)
}
return []reflect.Value{ctx, in}, nil
Expand All @@ -169,7 +179,7 @@ func (fi *FuncInfo) Wrap() Func {
// Case 3b: The function wants a bare argument value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
in := reflect.New(fi.Argument) // we still need a pointer to unmarshal
if err := req.UnmarshalParams(in.Interface()); err != nil {
if err := req.UnmarshalParams(wrapArg(in)); err != nil {
return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err)
}
// Indirect the pointer back off for the callee.
Expand Down Expand Up @@ -370,3 +380,13 @@ func filterJSONError(tag, want string, err error) error {
}
return err
}

// strict is a wrapper for an arbitrary value that enforces strict field
// checking when unmarshaling from JSON.
type strict struct{ v interface{} }

func (s *strict) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields()
return dec.Decode(s.v)
}
123 changes: 119 additions & 4 deletions handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ type argStruct struct {
B int `json:"bravo"`
}

// Verify that the New function correctly handles the various type signatures
// Verify that the CHeck function correctly handles the various type signatures
// it's advertised to support, and not others.
func TestNew(t *testing.T) {
func TestCheck(t *testing.T) {
tests := []struct {
v interface{}
bad bool
Expand Down Expand Up @@ -66,9 +66,52 @@ func TestNew(t *testing.T) {
for _, test := range tests {
got, err := handler.Check(test.v)
if !test.bad && err != nil {
t.Errorf("newHandler(%T): unexpected error: %v", test.v, err)
t.Errorf("Check(%T): unexpected error: %v", test.v, err)
} else if test.bad && err == nil {
t.Errorf("newHandler(%T): got %+v, want error", test.v, got)
t.Errorf("Check(%T): got %+v, want error", test.v, got)
}
}
}

// Verify that the Positional function correctly handles its cases.
func TestPositional(t *testing.T) {
tests := []struct {
v interface{}
n []string
bad bool
}{
{v: nil, bad: true}, // nil value
{v: "not a function", bad: true}, // not a function

// Things that should work.
{v: func(context.Context) error { return nil }},
{v: func(context.Context) int { return 1 }},
{v: func(context.Context, bool) bool { return false },
n: []string{"isTrue"}},
{v: func(context.Context, int, int) int { return 0 },
n: []string{"a", "b"}},
{v: func(context.Context, string, int, []float64) int { return 0 },
n: []string{"a", "b", "c"}},

// Things that should not work.
{v: func() error { return nil }, bad: true}, // no parameters
{v: func(int) int { return 0 }, bad: true}, // first argument not context
{v: func(context.Context, string) error { return nil },
n: nil, bad: true}, // not enough names
{v: func(context.Context, string, string, string) error { return nil },
n: []string{"x", "y"}, bad: true}, // too many names
{v: func(context.Context, string, ...float64) int { return 0 },
n: []string{"goHome", "youAreDrunk"}, bad: true}, // variadic

// N.B. Other cases are covered by TestCheck. The cases here are only
// those that Positional checks for explicitly.
}
for _, test := range tests {
got, err := handler.Positional(test.v, test.n...)
if !test.bad && err != nil {
t.Errorf("Positional(%T, %q): unexpected error: %v", test.v, test.n, err)
} else if test.bad && err == nil {
t.Errorf("Positional(%T, %q): got %+v, want error", test.v, test.n, got)
}
}
}
Expand Down Expand Up @@ -102,6 +145,48 @@ func TestNew_pointerRegression(t *testing.T) {
}
}

// Verify that positional arguments are decoded properly.
func TestPositional_decode(t *testing.T) {
fi, err := handler.Positional(func(ctx context.Context, a, b int) int {
return a + b
}, "first", "second")
if err != nil {
t.Fatalf("Positional: unexpected error: %v", err)
}
call := fi.Wrap()
tests := []struct {
input string
want int
bad bool
}{
{`{"jsonrpc":"2.0","id":1,"method":"add","params":{"first":5,"second":3}}`, 8, false},
{`{"jsonrpc":"2.0","id":2,"method":"add","params":{"first":5}}`, 5, false},
{`{"jsonrpc":"2.0","id":3,"method":"add","params":{"second":3}}`, 3, false},
{`{"jsonrpc":"2.0","id":4,"method":"add","params":{}}`, 0, false},
{`{"jsonrpc":"2.0","id":5,"method":"add","params":null}`, 0, false},
{`{"jsonrpc":"2.0","id":6,"method":"add"}`, 0, false},

{`{"jsonrpc":"2.0","id":6,"method":"add","params":["wrong", "type"]}`, 0, true},
{`{"jsonrpc":"2.0","id":6,"method":"add","params":{"unknown":"field"}}`, 0, true},
}
for _, test := range tests {
req, err := jrpc2.ParseRequests([]byte(test.input))
if err != nil {
t.Fatalf("ParseRequests %#q: unexpected error: %v", test.input, err)
}
got, err := call(context.Background(), req[0])
if !test.bad {
if err != nil {
t.Errorf("Call %#q: unexpected error: %v", test.input, err)
} else if z := got.(int); z != test.want {
t.Errorf("Call %#q: got %d, want %d", test.input, z, test.want)
}
} else if test.bad && err == nil {
t.Errorf("Call %#q: got %v, want error", test.input, got)
}
}
}

func ExampleCheck() {
fi, err := handler.Check(func(_ context.Context, ss []string) int { return len(ss) })
if err != nil {
Expand Down Expand Up @@ -339,3 +424,33 @@ func ExampleObj_unmarshal() {
// Output:
// uid=501, name="P. T. Barnum"
}

func ExamplePositional() {
fn := func(ctx context.Context, name string, age int, accurate bool) error {
fmt.Printf("%s is %d years old (fact check: %v)\n", name, age, accurate)
return nil
}
fi, err := handler.Positional(fn, "name", "age", "accurate")
if err != nil {
log.Fatalf("Positional: %v", err)
}
req, err := jrpc2.ParseRequests([]byte(`{
"jsonrpc": "2.0",
"id": 1,
"method": "foo",
"params": {
"name": "Dennis",
"age": 37,
"accurate": true
}
}`))
if err != nil {
log.Fatalf("Parse: %v", err)
}
call := fi.Wrap()
if _, err := call(context.Background(), req[0]); err != nil {
log.Fatalf("Call: %v", err)
}
// Output:
// Dennis is 37 years old (fact check: true)
}
135 changes: 135 additions & 0 deletions handler/positional.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package handler

import (
"errors"
"fmt"
"reflect"
)

// Positional checks whether fn can serve as a jrpc2.Handler. The concrete
// value of fn must be a function with one of the following type signature
// schemes:
//
// func(context.Context, X1, x2, ..., Xn) (Y, error)
// func(context.Context, X1, x2, ..., Xn) Y
// func(context.Context, X1, x2, ..., Xn) error
//
// For JSON-marshalable types X_i and Y. If fn does not have one of these
// forms, Positional reports an error. The given names must match the number of
// non-context arguments exactly. Variadic functions are not supported.
//
// In contrast to Check, this function allows any number of arguments, but the
// caller must provide names for them. Positional creates an anonymous struct
// type whose fields correspond to the non-context arguments of fn. The names
// are used as the JSON field keys for the corresponding parameters.
//
// When converted into a handler.Func, the wrapped function accepts a JSON
// object with the field keys named. For example, given:
//
// func add(ctx context.Context, x, y int) int { return x + y }
//
// fi, err := handler.Positional(add, "first", "second")
// // ...
// call := fi.Wrap()
//
// the resulting JSON-RPC handler accepts a parameter object like:
//
// {"first": 17, "second": 23}
//
// where "first" is mapped to argument x and "second" to argument y. Unknown
// field keys generate an error.
func Positional(fn interface{}, names ...string) (*FuncInfo, error) {
if fn == nil {
return nil, errors.New("nil function")
}

fv := reflect.ValueOf(fn)
if fv.Kind() != reflect.Func {
return nil, errors.New("not a function")
}
ft := fv.Type()
if np := ft.NumIn(); np == 0 {
return nil, errors.New("wrong number of parameters")
} else if ft.In(0) != ctxType {
return nil, errors.New("first parameter is not context.Context")
} else if np == 1 {
// If the context is the only argument, there is nothing to do.
return Check(fn)
} else if ft.IsVariadic() {
return nil, errors.New("variadic functions are not supported")
}

// Reaching here, we have at least one non-context argument.
atype, err := makeArgType(ft, names)
if err != nil {
return nil, err
}
fi, err := Check(makeCaller(ft, fv, atype))
if err == nil {
fi.strictFields = true
}
return fi, err
}

// makeArgType creates a struct type whose fields match the parameters of t,
// with JSON struct tags corresponding to the given names.
//
// Preconditions: t is a function with len(names)+1 arguments.
func makeArgType(t reflect.Type, names []string) (reflect.Type, error) {
if t.NumIn()-1 != len(names) {
return nil, fmt.Errorf("got %d names for %d inputs", len(names), t.NumIn()-1)
}

// TODO(creachadair): I wanted to implement the strictFielder interface on
// the generated struct instead of having extra magic in the wrapper.
// However, it is not now possible to add methods to a type constructed by
// reflection.
//
// Embedding an anonymous field that exposes the method doesn't work for
// JSON unmarshaling: The base struct will have the method, but its pointer
// will not, probably related to https://github.com/golang/go/issues/15924.
// JSON unmarshaling requires a pointer to its argument.
//
// For now, I worked around this by adding a hook into the wrapper compiler.

var fields []reflect.StructField
for i, name := range names {
tag := `json:"-"`
if name != "" && name != "-" {
tag = fmt.Sprintf(`json:"%s,omitempty"`, name)
}
fields = append(fields, reflect.StructField{
Name: fmt.Sprintf("P_%d", i+1),
Type: t.In(i + 1),
Tag: reflect.StructTag(tag),
})
}
return reflect.StructOf(fields), nil
}

// makeCaller creates a wrapper function that takes a context and an atype as
// arguments, and calls fv with the context and the struct fields unpacked into
// positional arguments.
//
// Preconditions: fv is a function and atype is its argument struct.
func makeCaller(ft reflect.Type, fv reflect.Value, atype reflect.Type) interface{} {
atypes := []reflect.Type{ctxType, atype}

otypes := make([]reflect.Type, ft.NumOut())
for i := 0; i < ft.NumOut(); i++ {
otypes[i] = ft.Out(i)
}

wtype := reflect.FuncOf(atypes, otypes, false)
wrap := reflect.MakeFunc(wtype, func(args []reflect.Value) []reflect.Value {
cargs := []reflect.Value{args[0]} // ctx

// Unpack the struct fields into positional arguments.
st := args[1]
for i := 0; i < st.NumField(); i++ {
cargs = append(cargs, st.Field(i))
}
return fv.Call(cargs)
})
return wrap.Interface()
}

0 comments on commit 98d5ca4

Please sign in to comment.