Skip to content

Commit

Permalink
feat(middleware/csrf): TrustedOrigins using https://*.example.com sty…
Browse files Browse the repository at this point in the history
…le subdomains
  • Loading branch information
sixcolors committed Mar 18, 2024
1 parent 43dc60f commit 2c80db9
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 75 deletions.
31 changes: 29 additions & 2 deletions docs/api/middleware/csrf.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (h *Handler) DeleteToken(c fiber.Ctx) error
| Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` |
| Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` |
| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://.example.com" to allow any subdomain of example.com to submit requests. | `[]` |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `[]` |

### Default Config

Expand Down Expand Up @@ -154,6 +154,34 @@ var ConfigDefault = Config{
}
```

### Trusted Origins

The `TrustedOrigins` option is used to specify a list of trusted origins for unsafe requests. This is useful when you want to allow requests from other origins. This supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`.

#### Example with Explicit Origins

In the following example, the CSRF middleware will allow requests from `trusted.example.com`, in addition to the current host.

```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://trusted.example.com"},
}))
```

#### Example with Subdomain Matching

In the following example, the CSRF middleware will allow requests from any subdomain of `example.com`, in addition to the current host.

```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://*.example.com"},
}))
```

::caution
When using `TrustedOrigins` with subdomain matching, make sure you control and trust all the subdomains, including all subdomain levels. If not, an attacker could create a subdomain under a trusted origin and use it to send harmful requests.
:::

## Constants

```go
Expand Down Expand Up @@ -273,7 +301,6 @@ When HTTPS requests are protected by CSRF, referer checking is always carried ou
The Referer header is automatically included in requests by all modern browsers, including those made using the JS Fetch API. However, if you're making use of this middleware with a custom client, it's important to ensure that the client sends a valid Referer header.
:::


### Token Lifecycle

Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 1 hour, and each subsequent request extends the expiration by 1 hour. The token only expires if the user doesn't make a request for the duration of the expiration time.
Expand Down
128 changes: 70 additions & 58 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package csrf

import (
"errors"
"log"

Check failure on line 5 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

import 'log' is not allowed from list 'all': logging is provided by `pkg/log` (depguard)
"net/url"
"reflect"
"strings"
Expand All @@ -24,7 +25,7 @@ var (

// Handler for CSRF middleware
type Handler struct {
config *Config
config Config
sessionManager *sessionManager
storageManager *storageManager
}
Expand Down Expand Up @@ -56,6 +57,36 @@ func New(config ...Config) fiber.Handler {
storageManager = newStorageManager(cfg.Storage)
}

// Pre-parse trusted origins
trustedOrigins := []string{}
trustedSubOrigins := []subdomain{}

for _, origin := range cfg.TrustedOrigins {
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
log.Panicf("[CSRF] Invalid origin format in configuration: %s", origin)

Check warning on line 69 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

deep-exit: calls to log.Panicf only in main() or init() functions (revive)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
trustedSubOrigins = append(trustedSubOrigins, sd)
} else {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
log.Panicf("[CSRF] Invalid origin format in configuration: %s", origin)

Check warning on line 77 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

deep-exit: calls to log.Panicf only in main() or init() functions (revive)
}
trustedOrigins = append(trustedOrigins, normalizedOrigin)
}
}

// Create the handler outside of the returned function
handler := &Handler{
config: cfg,
sessionManager: sessionManager,
storageManager: storageManager,
}

// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
Expand All @@ -64,11 +95,7 @@ func New(config ...Config) fiber.Handler {
}

// Store the CSRF handler in the context
c.Locals(handlerKey, &Handler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
c.Locals(handlerKey, handler)

var token string

Expand All @@ -88,12 +115,12 @@ func New(config ...Config) fiber.Handler {
// Assume that anything not defined as 'safe' by RFC7231 needs protection

// Enforce an origin check for unsafe requests.
err := originMatchesHost(c, cfg.TrustedOrigins)
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)

// If there's no origin, enforce a referer check for HTTPS connections.
if errors.Is(err, errOriginNotFound) {
if c.Scheme() == "https" {
err = refererMatchesHost(c, cfg.TrustedOrigins)
err = refererMatchesHost(c, trustedOrigins, trustedSubOrigins)
} else {
// If it's not HTTPS, clear the error to allow the request to proceed.
err = nil
Expand Down Expand Up @@ -237,20 +264,15 @@ func setCSRFCookie(c fiber.Ctx, cfg Config, token string, expiry time.Duration)
// DeleteToken removes the token found in the context from the storage
// and expires the CSRF cookie
func (handler *Handler) DeleteToken(c fiber.Ctx) error {
// Get the config from the context
config := handler.config
if config == nil {
panic("CSRF Handler config not found in context")
}
// Extract token from the client request cookie
cookieToken := c.Cookies(config.CookieName)
cookieToken := c.Cookies(handler.config.CookieName)
if cookieToken == "" {
return config.ErrorHandler(c, ErrTokenNotFound)
return handler.config.ErrorHandler(c, ErrTokenNotFound)

Check warning on line 270 in middleware/csrf/csrf.go

View check run for this annotation

Codecov / codecov/patch

middleware/csrf/csrf.go#L270

Added line #L270 was not covered by tests
}
// Remove the token from storage
deleteTokenFromStorage(c, cookieToken, *config, handler.sessionManager, handler.storageManager)
deleteTokenFromStorage(c, cookieToken, handler.config, handler.sessionManager, handler.storageManager)
// Expire the cookie
expireCSRFCookie(c, *config)
expireCSRFCookie(c, handler.config)
return nil
}

Expand All @@ -262,8 +284,8 @@ func isFromCookie(extractor any) bool {
// originMatchesHost checks that the origin header matches the host header
// returns an error if the origin header is not present or is invalid
// returns nil if the origin header is valid
func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
origin := c.Get(fiber.HeaderOrigin)
func originMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
origin := strings.ToLower(c.Get(fiber.HeaderOrigin))
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
return errOriginNotFound
}
Expand All @@ -273,23 +295,31 @@ func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return ErrOriginInvalid
}

if originURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, origin) {
return nil
}
if originURL.Scheme == c.Scheme() && originURL.Host == c.Host() {
return nil
}

for _, trustedOrigin := range trustedOrigins {
if origin == trustedOrigin {
return nil
}
return ErrOriginNoMatch
}

return nil
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(origin) {
return nil
}
}

return ErrOriginNoMatch
}

// refererMatchesHost checks that the referer header matches the host header
// returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
referer := c.Get(fiber.HeaderReferer)
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
referer := strings.ToLower(c.Get(fiber.HeaderReferer))

if referer == "" {
return ErrRefererNotFound
}
Expand All @@ -299,41 +329,23 @@ func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return ErrRefererInvalid
}

if refererURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, referer) {
return nil
}
}
return ErrRefererNoMatch
}

return nil
}

// isTrustedSchemeAndDomain checks if the trustedProtoDomain is the same as the protoDomain
// or if the protoDomain is a subdomain of the trustedProtoDomain where trustedProtoDomain
// is prefixed with "https://." or "http://."
func isTrustedSchemeAndDomain(trustedProtoDomain, protoDomain string) bool {
if trustedProtoDomain == protoDomain {
return true
if refererURL.Scheme == c.Scheme() && refererURL.Host == c.Host() {
return nil
}

// Use constant prefixes for better readability and avoid magic numbers.
const httpsPrefix = "https://."
const httpPrefix = "http://."
referer = refererURL.String()

if strings.HasPrefix(trustedProtoDomain, httpsPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpsPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "https://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
for _, trustedOrigin := range trustedOrigins {
if referer == trustedOrigin {
return nil
}
}

if strings.HasPrefix(trustedProtoDomain, httpPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "http://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(referer) {
return nil
}
}

return false
return ErrRefererNoMatch
}
59 changes: 44 additions & 15 deletions middleware/csrf/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,18 +733,6 @@ func Test_CSRF_Origin(t *testing.T) {
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())

// Test Correct Origin with path
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com/action/items?gogogo=true")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())

// Test Wrong Origin
ctx.Request.Reset()
ctx.Response.Reset()
Expand All @@ -767,8 +755,8 @@ func Test_CSRF_TrustedOrigins(t *testing.T) {
TrustedOrigins: []string{
"http://safe.example.com",
"https://safe.example.com",
"http://.domain-1.com",
"https://.domain-1.com",
"http://*.domain-1.com",
"https://*.domain-1.com",
},
}))

Expand Down Expand Up @@ -872,6 +860,35 @@ func Test_CSRF_TrustedOrigins(t *testing.T) {
require.Equal(t, 403, ctx.Response.StatusCode())
}

func Test_CSRF_TrustedOrigins_InvalidOrigins(t *testing.T) {

Check failure on line 863 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

Test_CSRF_TrustedOrigins_InvalidOrigins's subtests should call t.Parallel (tparallel)
t.Parallel()

tests := []struct {
name string
origin string
}{
{"No Scheme", "localhost"},
{"Wildcard", "https://*"},
{"Wildcard domain", "https://*example.com"},
{"File Scheme", "file://example.com"},
{"FTP Scheme", "ftp://example.com"},
{"Port Wildcard", "http://example.com:*"},
{"Multiple Wildcards", "https://*.*.com"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Panics(t, func() {
app := fiber.New()
app.Use(New(Config{
CookieSecure: true,
TrustedOrigins: []string{tt.origin},
}))
}, "Expected panic")
})
}
}

func Test_CSRF_Referer(t *testing.T) {
t.Parallel()
app := fiber.New()
Expand Down Expand Up @@ -979,6 +996,18 @@ func Test_CSRF_DeleteToken(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// DeleteToken after token generation and remove the cookie
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set(HeaderName, "")
handler := HandlerFromContext(app.AcquireCtx(ctx))
if handler != nil {
ctx.Request.Header.DelAllCookies()
err := handler.DeleteToken(app.AcquireCtx(ctx))
require.ErrorIs(t, err, ErrTokenNotFound)
}
h(ctx)

// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
Expand All @@ -991,7 +1020,7 @@ func Test_CSRF_DeleteToken(t *testing.T) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
handler := HandlerFromContext(app.AcquireCtx(ctx))
handler = HandlerFromContext(app.AcquireCtx(ctx))
if handler != nil {
if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil {
t.Fatal(err)
Expand Down
Loading

0 comments on commit 2c80db9

Please sign in to comment.