Skip to content

Commit

Permalink
Added ErrorHandler and ErrorHandlerWithContext in CSRF middleware (#2257
Browse files Browse the repository at this point in the history
)

* feat: add error handler to csrf middleware

Co-authored-by: Mojtaba Arezoomand <[email protected]>
  • Loading branch information
mojixcoder and mojixcoder authored Sep 1, 2022
1 parent 534bbb8 commit d77e8c0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
18 changes: 16 additions & 2 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ type (
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`

// ErrorHandler defines a function which is executed for returning custom errors.
ErrorHandler CSRFErrorHandler
}

// CSRFErrorHandler is a function which is executed for creating custom errors.
CSRFErrorHandler func(err error, c echo.Context) error
)

// ErrCSRFInvalid is returned when CSRF check fails
Expand Down Expand Up @@ -154,8 +160,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
lastTokenErr = ErrCSRFInvalid
}
}
var finalErr error
if lastTokenErr != nil {
return lastTokenErr
finalErr = lastTokenErr
} else if lastExtractorErr != nil {
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
Expand All @@ -167,7 +174,14 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
} else {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
}
return lastExtractorErr
finalErr = lastExtractorErr
}

if finalErr != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(finalErr, c)
}
return finalErr
}
}

Expand Down
22 changes: 22 additions & 0 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,25 @@ func TestCSRFConfig_skipper(t *testing.T) {
})
}
}

func TestCSRFErrorHandling(t *testing.T) {
cfg := CSRFConfig{
ErrorHandler: func(err error, c echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
},
}

e := echo.New()
e.POST("/", func(c echo.Context) error {
return c.String(http.StatusNotImplemented, "should not end up here")
})

e.Use(CSRFWithConfig(cfg))

req := httptest.NewRequest(http.MethodPost, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)

assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
}

0 comments on commit d77e8c0

Please sign in to comment.