Skip to content

Commit

Permalink
openapi3filter: add context to Validator Middleware's ErrFunc and Log…
Browse files Browse the repository at this point in the history
…Func functions (#953)

* add context to Validator Middleware's ErrFunc and LogFunc functions

* update existing ErrFunc and LogFunc instead of creating new ones and updated docs

---------

Co-authored-by: ap7u5 <[email protected]>
  • Loading branch information
crissi98 and ap7u5 authored Jun 2, 2024
1 parent f170f8c commit 45b4399
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
4 changes: 2 additions & 2 deletions .github/docs/openapi3filter.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ type ErrCode int
occur during validation. These may be used to write an appropriate response
in ErrFunc.

type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error)
type ErrFunc func(ctx context.Context, w http.ResponseWriter, status int, code ErrCode, err error)
ErrFunc handles errors that may occur during validation.

type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter)
Expand All @@ -198,7 +198,7 @@ type Headerer interface {
Headerer, the provided headers will be applied to the response writer,
after the Content-Type is set.

type LogFunc func(message string, err error)
type LogFunc func(ctx context.Context, message string, err error)
LogFunc handles log messages that may occur during validation.

type Options struct {
Expand Down
28 changes: 15 additions & 13 deletions openapi3filter/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openapi3filter

import (
"bytes"
"context"
"io"
"log"
"net/http"
Expand All @@ -19,10 +20,10 @@ type Validator struct {
}

// ErrFunc handles errors that may occur during validation.
type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error)
type ErrFunc func(ctx context.Context, w http.ResponseWriter, status int, code ErrCode, err error)

// LogFunc handles log messages that may occur during validation.
type LogFunc func(message string, err error)
type LogFunc func(ctx context.Context, message string, err error)

// ErrCode is used for classification of different types of errors that may
// occur during validation. These may be used to write an appropriate response
Expand Down Expand Up @@ -61,10 +62,10 @@ func (e ErrCode) responseText() string {
func NewValidator(router routers.Router, options ...ValidatorOption) *Validator {
v := &Validator{
router: router,
errFunc: func(w http.ResponseWriter, status int, code ErrCode, _ error) {
errFunc: func(_ context.Context, w http.ResponseWriter, status int, code ErrCode, _ error) {
http.Error(w, code.responseText(), status)
},
logFunc: func(message string, err error) {
logFunc: func(_ context.Context, message string, err error) {
log.Printf("%s: %v", message, err)
},
}
Expand Down Expand Up @@ -117,10 +118,11 @@ func ValidationOptions(options Options) ValidatorOption {
// request and response validation.
func (v *Validator) Middleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
route, pathParams, err := v.router.FindRoute(r)
if err != nil {
v.logFunc("validation error: failed to find route for "+r.URL.String(), err)
v.errFunc(w, http.StatusNotFound, ErrCodeCannotFindRoute, err)
v.logFunc(ctx, "validation error: failed to find route for "+r.URL.String(), err)
v.errFunc(ctx, w, http.StatusNotFound, ErrCodeCannotFindRoute, err)
return
}
requestValidationInput := &RequestValidationInput{
Expand All @@ -129,9 +131,9 @@ func (v *Validator) Middleware(h http.Handler) http.Handler {
Route: route,
Options: &v.options,
}
if err = ValidateRequest(r.Context(), requestValidationInput); err != nil {
v.logFunc("invalid request", err)
v.errFunc(w, http.StatusBadRequest, ErrCodeRequestInvalid, err)
if err = ValidateRequest(ctx, requestValidationInput); err != nil {
v.logFunc(ctx, "invalid request", err)
v.errFunc(ctx, w, http.StatusBadRequest, ErrCodeRequestInvalid, err)
return
}

Expand All @@ -144,22 +146,22 @@ func (v *Validator) Middleware(h http.Handler) http.Handler {

h.ServeHTTP(wr, r)

if err = ValidateResponse(r.Context(), &ResponseValidationInput{
if err = ValidateResponse(ctx, &ResponseValidationInput{
RequestValidationInput: requestValidationInput,
Status: wr.statusCode(),
Header: wr.Header(),
Body: io.NopCloser(bytes.NewBuffer(wr.bodyContents())),
Options: &v.options,
}); err != nil {
v.logFunc("invalid response", err)
v.logFunc(ctx, "invalid response", err)
if v.strict {
v.errFunc(w, http.StatusInternalServerError, ErrCodeResponseInvalid, err)
v.errFunc(ctx, w, http.StatusInternalServerError, ErrCodeResponseInvalid, err)
}
return
}

if err = wr.flushBodyContents(); err != nil {
v.logFunc("failed to write response", err)
v.logFunc(ctx, "failed to write response", err)
}
})
}
Expand Down
3 changes: 2 additions & 1 deletion openapi3filter/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openapi3filter_test

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -489,7 +490,7 @@ paths:
// testing a service against its spec in development and CI. In production,
// availability may be more important than strictness.
v := openapi3filter.NewValidator(router, openapi3filter.Strict(true),
openapi3filter.OnErr(func(w http.ResponseWriter, status int, code openapi3filter.ErrCode, err error) {
openapi3filter.OnErr(func(_ context.Context, w http.ResponseWriter, status int, code openapi3filter.ErrCode, err error) {
// Customize validation error responses to use JSON
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
Expand Down

0 comments on commit 45b4399

Please sign in to comment.