Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(middleware/csrf): TrustedOrigins using https://*.example.com style subdomains #2925

Merged
merged 13 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions docs/api/middleware/csrf.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (h *Handler) DeleteToken(c fiber.Ctx) error
| Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` |
| Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` |
| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://.example.com" to allow any subdomain of example.com to submit requests. | `[]` |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `[]` |

### Default Config

Expand Down Expand Up @@ -154,6 +154,36 @@ var ConfigDefault = Config{
}
```

### Trusted Origins

The `TrustedOrigins` option is used to specify a list of trusted origins for unsafe requests. This is useful when you want to allow requests from other origins. This supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`.

To ensure that the provided `TrustedOrigins` origins are correctly formatted, this middleware validates and normalizes them. It checks for valid schemes, i.e., HTTP or HTTPS, and it will automatically remove trailing slashes. If the provided origin is invalid, the middleware will panic.

#### Example with Explicit Origins

In the following example, the CSRF middleware will allow requests from `trusted.example.com`, in addition to the current host.

```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://trusted.example.com"},
}))
```

#### Example with Subdomain Matching

In the following example, the CSRF middleware will allow requests from any subdomain of `example.com`, in addition to the current host.

```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://*.example.com"},
}))
```

::caution
When using `TrustedOrigins` with subdomain matching, make sure you control and trust all the subdomains, including all subdomain levels. If not, an attacker could create a subdomain under a trusted origin and use it to send harmful requests.
:::

## Constants

```go
Expand Down Expand Up @@ -273,7 +303,6 @@ When HTTPS requests are protected by CSRF, referer checking is always carried ou
The Referer header is automatically included in requests by all modern browsers, including those made using the JS Fetch API. However, if you're making use of this middleware with a custom client, it's important to ensure that the client sends a valid Referer header.
:::


### Token Lifecycle

Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 1 hour, and each subsequent request extends the expiration by 1 hour. The token only expires if the user doesn't make a request for the duration of the expiration time.
Expand Down
127 changes: 69 additions & 58 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

// Handler for CSRF middleware
type Handler struct {
config *Config
config Config
sessionManager *sessionManager
storageManager *storageManager
}
Expand Down Expand Up @@ -56,6 +56,36 @@
storageManager = newStorageManager(cfg.Storage)
}

// Pre-parse trusted origins
trustedOrigins := []string{}
trustedSubOrigins := []subdomain{}

for _, origin := range cfg.TrustedOrigins {
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + origin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
trustedSubOrigins = append(trustedSubOrigins, sd)
} else {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + origin)
}
trustedOrigins = append(trustedOrigins, normalizedOrigin)
}
}
sixcolors marked this conversation as resolved.
Show resolved Hide resolved

// Create the handler outside of the returned function
handler := &Handler{
config: cfg,
sessionManager: sessionManager,
storageManager: storageManager,
}

// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
Expand All @@ -64,11 +94,7 @@
}

// Store the CSRF handler in the context
c.Locals(handlerKey, &Handler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
c.Locals(handlerKey, handler)

var token string

Expand All @@ -88,12 +114,12 @@
// Assume that anything not defined as 'safe' by RFC7231 needs protection

// Enforce an origin check for unsafe requests.
err := originMatchesHost(c, cfg.TrustedOrigins)
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)

// If there's no origin, enforce a referer check for HTTPS connections.
if errors.Is(err, errOriginNotFound) {
if c.Scheme() == "https" {
err = refererMatchesHost(c, cfg.TrustedOrigins)
err = refererMatchesHost(c, trustedOrigins, trustedSubOrigins)
} else {
// If it's not HTTPS, clear the error to allow the request to proceed.
err = nil
Expand Down Expand Up @@ -237,20 +263,15 @@
// DeleteToken removes the token found in the context from the storage
// and expires the CSRF cookie
func (handler *Handler) DeleteToken(c fiber.Ctx) error {
// Get the config from the context
config := handler.config
if config == nil {
panic("CSRF Handler config not found in context")
}
// Extract token from the client request cookie
cookieToken := c.Cookies(config.CookieName)
cookieToken := c.Cookies(handler.config.CookieName)
if cookieToken == "" {
return config.ErrorHandler(c, ErrTokenNotFound)
return handler.config.ErrorHandler(c, ErrTokenNotFound)

Check warning on line 269 in middleware/csrf/csrf.go

View check run for this annotation

Codecov / codecov/patch

middleware/csrf/csrf.go#L269

Added line #L269 was not covered by tests
}
// Remove the token from storage
deleteTokenFromStorage(c, cookieToken, *config, handler.sessionManager, handler.storageManager)
deleteTokenFromStorage(c, cookieToken, handler.config, handler.sessionManager, handler.storageManager)
// Expire the cookie
expireCSRFCookie(c, *config)
expireCSRFCookie(c, handler.config)
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

Expand All @@ -262,8 +283,8 @@
// originMatchesHost checks that the origin header matches the host header
// returns an error if the origin header is not present or is invalid
// returns nil if the origin header is valid
func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
origin := c.Get(fiber.HeaderOrigin)
func originMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
origin := strings.ToLower(c.Get(fiber.HeaderOrigin))
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
return errOriginNotFound
}
Expand All @@ -273,23 +294,31 @@
return ErrOriginInvalid
}

if originURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, origin) {
return nil
}
if originURL.Scheme == c.Scheme() && originURL.Host == c.Host() {
return nil
}

for _, trustedOrigin := range trustedOrigins {
if origin == trustedOrigin {
return nil
}
return ErrOriginNoMatch
}

return nil
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(origin) {
return nil
}
}

return ErrOriginNoMatch
}

// refererMatchesHost checks that the referer header matches the host header
// returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
referer := c.Get(fiber.HeaderReferer)
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
referer := strings.ToLower(c.Get(fiber.HeaderReferer))
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved

if referer == "" {
return ErrRefererNotFound
}
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -299,41 +328,23 @@
return ErrRefererInvalid
}

if refererURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, referer) {
return nil
}
}
return ErrRefererNoMatch
}

return nil
}

// isTrustedSchemeAndDomain checks if the trustedProtoDomain is the same as the protoDomain
// or if the protoDomain is a subdomain of the trustedProtoDomain where trustedProtoDomain
// is prefixed with "https://." or "http://."
func isTrustedSchemeAndDomain(trustedProtoDomain, protoDomain string) bool {
if trustedProtoDomain == protoDomain {
return true
if refererURL.Scheme == c.Scheme() && refererURL.Host == c.Host() {
return nil
}

// Use constant prefixes for better readability and avoid magic numbers.
const httpsPrefix = "https://."
const httpPrefix = "http://."
referer = refererURL.String()

if strings.HasPrefix(trustedProtoDomain, httpsPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpsPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "https://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
for _, trustedOrigin := range trustedOrigins {
if referer == trustedOrigin {
return nil
}
}

if strings.HasPrefix(trustedProtoDomain, httpPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "http://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(referer) {
return nil
}
}

return false
return ErrRefererNoMatch
}
Loading
Loading