diff --git a/openapi3/schema.go b/openapi3/schema.go index 45350eced..214f4f635 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -1358,6 +1358,19 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return schema.expectedType(settings, TypeObject) } + if settings.asreq || settings.asrep { + for propName, propSchema := range schema.Properties { + if value[propName] == nil { + if dlft := propSchema.Value.Default; dlft != nil { + value[propName] = dlft + if f := settings.defaultsSet; f != nil { + settings.onceSettingDefaults.Do(f) + } + } + } + } + } + var me MultiError // "properties" diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go index 71db5f237..cb4c142a4 100644 --- a/openapi3/schema_validation_settings.go +++ b/openapi3/schema_validation_settings.go @@ -1,5 +1,9 @@ package openapi3 +import ( + "sync" +) + // SchemaValidationOption describes options a user has when validating request / response bodies. type SchemaValidationOption func(*schemaValidationSettings) @@ -7,6 +11,9 @@ type schemaValidationSettings struct { failfast bool multiError bool asreq, asrep bool // exclusive (XOR) fields + + onceSettingDefaults sync.Once + defaultsSet func() } // FailFast returns schema validation errors quicker. @@ -25,6 +32,11 @@ func VisitAsResponse() SchemaValidationOption { return func(s *schemaValidationSettings) { s.asreq, s.asrep = false, true } } +// DefaultsSet executes the given callback (once) IFF schema validation set default values. +func DefaultsSet(f func()) SchemaValidationOption { + return func(s *schemaValidationSettings) { s.defaultsSet = f } +} + func newSchemaValidationSettings(opts ...SchemaValidationOption) *schemaValidationSettings { settings := &schemaValidationSettings{} for _, opt := range opts { diff --git a/openapi3filter/req_resp_decoder.go b/openapi3filter/req_resp_decoder.go index 0408d8da3..79706437c 100644 --- a/openapi3filter/req_resp_decoder.go +++ b/openapi3filter/req_resp_decoder.go @@ -868,7 +868,11 @@ const prefixUnsupportedCT = "unsupported content type" // decodeBody returns a decoded body. // The function returns ParseError when a body is invalid. -func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) { +func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) ( + string, + interface{}, + error, +) { contentType := header.Get(headerCT) if contentType == "" { if _, ok := body.(*multipart.Part); ok { @@ -878,16 +882,16 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, mediaType := parseMediaType(contentType) decoder, ok := bodyDecoders[mediaType] if !ok { - return nil, &ParseError{ + return "", nil, &ParseError{ Kind: KindUnsupportedFormat, Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType), } } value, err := decoder(body, header, schema, encFn) if err != nil { - return nil, err + return "", nil, err } - return value, nil + return mediaType, value, nil } func init() { @@ -1036,7 +1040,7 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S } var value interface{} - if value, err = decodeBody(part, http.Header(part.Header), valueSchema, subEncFn); err != nil { + if _, value, err = decodeBody(part, http.Header(part.Header), valueSchema, subEncFn); err != nil { if v, ok := err.(*ParseError); ok { return nil, &ParseError{path: []interface{}{name}, Cause: v} } diff --git a/openapi3filter/req_resp_decoder_test.go b/openapi3filter/req_resp_decoder_test.go index de93547b5..c733bd028 100644 --- a/openapi3filter/req_resp_decoder_test.go +++ b/openapi3filter/req_resp_decoder_test.go @@ -1280,7 +1280,7 @@ func TestDecodeBody(t *testing.T) { } return tc.encoding[name] } - got, err := decodeBody(tc.body, h, schemaRef, encFn) + _, got, err := decodeBody(tc.body, h, schemaRef, encFn) if tc.wantErr != nil { require.Error(t, err) @@ -1350,7 +1350,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { body := strings.NewReader("foo,bar") schema := openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema()).NewRef() encFn := func(string) *openapi3.Encoding { return nil } - got, err := decodeBody(body, h, schema, encFn) + _, got, err := decodeBody(body, h, schema, encFn) require.NoError(t, err) require.Equal(t, []string{"foo", "bar"}, got) @@ -1360,7 +1360,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { originalDecoder = RegisteredBodyDecoder(contentType) require.Nil(t, originalDecoder) - _, err = decodeBody(body, h, schema, encFn) + _, _, err = decodeBody(body, h, schema, encFn) require.Equal(t, &ParseError{ Kind: KindUnsupportedFormat, Reason: prefixUnsupportedCT + ` "text/csv"`, diff --git a/openapi3filter/req_resp_encoder.go b/openapi3filter/req_resp_encoder.go new file mode 100644 index 000000000..b6429d6d8 --- /dev/null +++ b/openapi3filter/req_resp_encoder.go @@ -0,0 +1,27 @@ +package openapi3filter + +import ( + "encoding/json" + "fmt" +) + +func encodeBody(body interface{}, mediaType string) ([]byte, error) { + encoder, ok := bodyEncoders[mediaType] + if !ok { + return nil, &ParseError{ + Kind: KindUnsupportedFormat, + Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType), + } + } + return encoder(body) +} + +type bodyEncoder func(body interface{}) ([]byte, error) + +var bodyEncoders = map[string]bodyEncoder{ + "application/json": jsonBodyEncoder, +} + +func jsonBodyEncoder(body interface{}) ([]byte, error) { + return json.Marshal(body) +} diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index fae6b09f9..db845c0be 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -142,6 +142,30 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param } schema = parameter.Schema.Value } + + // Set default value if needed + if value == nil && schema != nil && schema.Default != nil { + value = schema.Default + req := input.Request + switch parameter.In { + case openapi3.ParameterInPath: + // TODO: no idea how to handle this + case openapi3.ParameterInQuery: + q := req.URL.Query() + q.Add(parameter.Name, fmt.Sprintf("%v", value)) + req.URL.RawQuery = q.Encode() + case openapi3.ParameterInHeader: + req.Header.Add(parameter.Name, fmt.Sprintf("%v", value)) + case openapi3.ParameterInCookie: + req.AddCookie(&http.Cookie{ + Name: parameter.Name, + Value: fmt.Sprintf("%v", value), + }) + default: + return fmt.Errorf("unsupported parameter's 'in': %s", parameter.In) + } + } + // Validate a parameter's value and presence. if parameter.Required && !found { return &RequestError{Input: input, Parameter: parameter, Reason: ErrInvalidRequired.Error(), Err: ErrInvalidRequired} @@ -230,7 +254,7 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req } encFn := func(name string) *openapi3.Encoding { return contentType.Encoding[name] } - value, err := decodeBody(bytes.NewReader(data), req.Header, contentType.Schema, encFn) + mediaType, value, err := decodeBody(bytes.NewReader(data), req.Header, contentType.Schema, encFn) if err != nil { return &RequestError{ Input: input, @@ -240,8 +264,10 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req } } - opts := make([]openapi3.SchemaValidationOption, 0, 2) // 2 potential opts here + defaultsSet := false + opts := make([]openapi3.SchemaValidationOption, 0, 3) // 3 potential opts here opts = append(opts, openapi3.VisitAsRequest()) + opts = append(opts, openapi3.DefaultsSet(func() { defaultsSet = true })) if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } @@ -255,6 +281,21 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req Err: err, } } + + if defaultsSet { + var err error + if data, err = encodeBody(value, mediaType); err != nil { + return &RequestError{ + Input: input, + RequestBody: requestBody, + Reason: "rewriting failed", + Err: err, + } + } + // Put the data back into the input + req.Body = ioutil.NopCloser(bytes.NewReader(data)) + } + return nil } diff --git a/openapi3filter/validate_response.go b/openapi3filter/validate_response.go index 7cb713ace..f19123e53 100644 --- a/openapi3filter/validate_response.go +++ b/openapi3filter/validate_response.go @@ -111,7 +111,7 @@ func ValidateResponse(ctx context.Context, input *ResponseValidationInput) error input.SetBodyBytes(data) encFn := func(name string) *openapi3.Encoding { return contentType.Encoding[name] } - value, err := decodeBody(bytes.NewBuffer(data), input.Header, contentType.Schema, encFn) + _, value, err := decodeBody(bytes.NewBuffer(data), input.Header, contentType.Schema, encFn) if err != nil { return &ResponseError{ Input: input, diff --git a/openapi3filter/validate_set_default_test.go b/openapi3filter/validate_set_default_test.go new file mode 100644 index 000000000..bacffe529 --- /dev/null +++ b/openapi3filter/validate_set_default_test.go @@ -0,0 +1,561 @@ +package openapi3filter + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "net/url" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + legacyrouter "github.com/getkin/kin-openapi/routers/legacy" + "github.com/stretchr/testify/require" +) + +func TestValidatingRequestParameterAndSetDefault(t *testing.T) { + const spec = `{ + "openapi": "3.0.3", + "info": { + "version": "1.0.0", + "title": "title", + "description": "desc", + "contact": { + "email": "email" + } + }, + "paths": { + "/accounts": { + "get": { + "description": "Create a new account", + "parameters": [ + { + "in": "query", + "name": "q1", + "schema": { + "type": "string", + "default": "Q" + } + }, + { + "in": "query", + "name": "q2", + "schema": { + "type": "string", + "default": "Q" + } + }, + { + "in": "query", + "name": "q3", + "schema": { + "type": "string" + } + }, + { + "in": "header", + "name": "h1", + "schema": { + "type": "boolean", + "default": true + } + }, + { + "in": "header", + "name": "h2", + "schema": { + "type": "boolean", + "default": true + } + }, + { + "in": "header", + "name": "h3", + "schema": { + "type": "boolean" + } + }, + { + "in": "cookie", + "name": "c1", + "schema": { + "type": "integer", + "default": 128 + } + }, + { + "in": "cookie", + "name": "c2", + "schema": { + "type": "integer", + "default": 128 + } + }, + { + "in": "cookie", + "name": "c3", + "schema": { + "type": "integer" + } + } + ], + "responses": { + "201": { + "description": "Successfully created a new account" + }, + "400": { + "description": "The server could not understand the request due to invalid syntax", + } + } + } + } + } +} +` + + sl := openapi3.NewLoader() + doc, err := sl.LoadFromData([]byte(spec)) + require.NoError(t, err) + err = doc.Validate(sl.Context) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + httpReq, err := http.NewRequest(http.MethodGet, "/accounts", nil) + require.NoError(t, err) + + params := &url.Values{ + "q2": []string{"from_request"}, + } + httpReq.URL.RawQuery = params.Encode() + httpReq.Header.Set("h2", "false") + httpReq.AddCookie(&http.Cookie{Name: "c2", Value: "1024"}) + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + }) + require.NoError(t, err) + + // Unset default values in URL were set + require.Equal(t, "Q", httpReq.URL.Query().Get("q1")) + // Unset default values in headers were set + require.Equal(t, "true", httpReq.Header.Get("h1")) + // Unset default values in cookies were set + cookie, err := httpReq.Cookie("c1") + require.NoError(t, err) + require.Equal(t, "128", cookie.Value) + + // All values from request were retained + require.Equal(t, "from_request", httpReq.URL.Query().Get("q2")) + require.Equal(t, "false", httpReq.Header.Get("h2")) + cookie, err = httpReq.Cookie("c2") + require.NoError(t, err) + require.Equal(t, "1024", cookie.Value) + + // Not set value to parameters without default value + require.Equal(t, "", httpReq.URL.Query().Get("q3")) + require.Equal(t, "", httpReq.Header.Get("h3")) + _, err = httpReq.Cookie("c3") + require.Equal(t, http.ErrNoCookie, err) +} + +func TestValidateRequestBodyAndSetDefault(t *testing.T) { + const spec = `{ + "openapi": "3.0.3", + "info": { + "version": "1.0.0", + "title": "title", + "description": "desc", + "contact": { + "email": "email" + } + }, + "paths": { + "/accounts": { + "post": { + "description": "Create a new account", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["id"], + "properties": { + "id": { + "type": "string", + "pattern": "[0-9a-v]+$", + "minLength": 20, + "maxLength": 20 + }, + "name": { + "type": "string", + "default": "default" + }, + "code": { + "type": "integer", + "default": 123 + }, + "all": { + "type": "boolean", + "default": false + }, + "page": { + "type": "object", + "properties": { + "num": { + "type": "integer", + "default": 1 + }, + "size": { + "type": "integer", + "default": 10 + }, + "order": { + "type": "string", + "enum": ["asc", "desc"], + "default": "desc" + } + } + }, + "filters": { + "type": "array", + "nullable": true, + "items": { + "type": "object", + "properties": { + "field": { + "type": "string", + "default": "name" + }, + "op": { + "type": "string", + "enum": ["eq", "ne"], + "default": "eq" + }, + "value": { + "type": "integer", + "default": 123 + } + } + } + } + } + } + } + } + }, + "responses": { + "201": { + "description": "Successfully created a new account" + }, + "400": { + "description": "The server could not understand the request due to invalid syntax", + } + } + } + } + } +}` + sl := openapi3.NewLoader() + doc, err := sl.LoadFromData([]byte(spec)) + require.NoError(t, err) + err = doc.Validate(sl.Context) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + type page struct { + Num int `json:"num,omitempty"` + Size int `json:"size,omitempty"` + Order string `json:"order,omitempty"` + } + type filter struct { + Field string `json:"field,omitempty"` + OP string `json:"op,omitempty"` + Value int `json:"value,omitempty"` + } + type body struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Code int `json:"code,omitempty"` + All bool `json:"all,omitempty"` + Page *page `json:"page,omitempty"` + Filters []filter `json:"filters,omitempty"` + } + + testCases := []struct { + name string + body body + bodyAssertion func(t *testing.T, body string) + }{ + { + name: "only id", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, `{"id":"bt6kdc3d0cvp6u8u3ft0", "name": "default", "code": 123, "all": false}`, body) + }, + }, + { + name: "id & name", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Name: "non-default", + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, `{"id":"bt6kdc3d0cvp6u8u3ft0", "name": "non-default", "code": 123, "all": false}`, body) + }, + }, + { + name: "id & name & code", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Name: "non-default", + Code: 456, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, `{"id":"bt6kdc3d0cvp6u8u3ft0", "name": "non-default", "code": 456, "all": false}`, body) + }, + }, + { + name: "id & name & code & all", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Name: "non-default", + Code: 456, + All: true, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, `{"id":"bt6kdc3d0cvp6u8u3ft0", "name": "non-default", "code": 456, "all": true}`, body) + }, + }, + { + name: "id & page(num)", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Page: &page{ + Num: 10, + }, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, ` +{ + "id": "bt6kdc3d0cvp6u8u3ft0", + "name": "default", + "code": 123, + "all": false, + "page": { + "num": 10, + "size": 10, + "order": "desc" + } +} + `, body) + }, + }, + { + name: "id & page(num & order)", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Page: &page{ + Num: 10, + Order: "asc", + }, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, ` +{ + "id": "bt6kdc3d0cvp6u8u3ft0", + "name": "default", + "code": 123, + "all": false, + "page": { + "num": 10, + "size": 10, + "order": "asc" + } +} + `, body) + }, + }, + { + name: "id & page & filters(one element and contains field)", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Page: &page{ + Num: 10, + Order: "asc", + }, + Filters: []filter{ + { + Field: "code", + }, + }, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, ` +{ + "id": "bt6kdc3d0cvp6u8u3ft0", + "name": "default", + "code": 123, + "all": false, + "page": { + "num": 10, + "size": 10, + "order": "asc" + }, + "filters": [ + { + "field": "code", + "op": "eq", + "value": 123 + } + ] +} + `, body) + }, + }, + { + name: "id & page & filters(one element and contains field & op & value)", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Page: &page{ + Num: 10, + Order: "asc", + }, + Filters: []filter{ + { + Field: "code", + OP: "ne", + Value: 456, + }, + }, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, ` +{ + "id": "bt6kdc3d0cvp6u8u3ft0", + "name": "default", + "code": 123, + "all": false, + "page": { + "num": 10, + "size": 10, + "order": "asc" + }, + "filters": [ + { + "field": "code", + "op": "ne", + "value": 456 + } + ] +} + `, body) + }, + }, + { + name: "id & page & filters(multiple elements)", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + Page: &page{ + Num: 10, + Order: "asc", + }, + Filters: []filter{ + { + Value: 456, + }, + { + OP: "ne", + }, + { + Field: "code", + Value: 456, + }, + { + OP: "ne", + Value: 789, + }, + { + Field: "code", + OP: "ne", + Value: 456, + }, + }, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, ` +{ + "id": "bt6kdc3d0cvp6u8u3ft0", + "name": "default", + "code": 123, + "all": false, + "page": { + "num": 10, + "size": 10, + "order": "asc" + }, + "filters": [ + { + "field": "name", + "op": "eq", + "value": 456 + }, + { + "field": "name", + "op": "ne", + "value": 123 + }, + { + "field": "code", + "op": "eq", + "value": 456 + }, + { + "field": "name", + "op": "ne", + "value": 789 + }, + { + "field": "code", + "op": "ne", + "value": 456 + } + ] +} + `, body) + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b, err := json.Marshal(tc.body) + require.NoError(t, err) + httpReq, err := http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add(headerCT, "application/json") + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + }) + require.NoError(t, err) + + validatedReqBody, err := ioutil.ReadAll(httpReq.Body) + require.NoError(t, err) + tc.bodyAssertion(t, string(validatedReqBody)) + }) + } +}