Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 16 additions & 59 deletions router/pkg/cors/config.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package cors

import (
"maps"
"net/http"
"strings"
"slices"
)

type cors struct {
Expand All @@ -12,13 +13,13 @@ type cors struct {
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
wildcardOrigins []*WildcardPattern
handler http.Handler
}

var (
maxRecursionDepth = 10 // Safeguard against deep recursion
DefaultSchemas = []string{
maxWildcardOriginLength = 4096 // Maximum length of an origin string for it to be eligible for wildcard matching
DefaultSchemas = []string{
"http://",
"https://",
}
Expand Down Expand Up @@ -55,7 +56,7 @@ func newCors(handler http.Handler, config Config) *cors {
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
wildcardOrigins: config.parseNewWildcardRules(),
handler: handler,
}
}
Expand Down Expand Up @@ -102,10 +103,8 @@ func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
}
for _, value := range cors.allowOrigins {
if value == origin {
return true
}
if slices.Contains(cors.allowOrigins, origin) {
return true
}
if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
return true
Expand All @@ -117,67 +116,25 @@ func (cors *cors) validateOrigin(origin string) bool {
}

func (cors *cors) validateWildcardOrigin(origin string) bool {
for _, w := range cors.wildcardOrigins {
if matchOriginWithRule(origin, w, 0, map[string]bool{}) {
return true
}
}
return false
}

// Recursive helper function with depth limit and memoization
func matchOriginWithRule(origin string, rule []string, depth int, memo map[string]bool) bool {
if depth > maxRecursionDepth {
return false // Exceeded recursion depth
}

// Memoization key
key := origin + "|" + strings.Join(rule, "|")
if val, exists := memo[key]; exists {
return val
}

if len(rule) == 0 {
// Successfully matched if origin is also fully consumed
return origin == ""
}

part := rule[0]

if part == "*" {
// Try to match the remaining rule by advancing in origin
for i := 0; i <= len(origin); i++ {
if matchOriginWithRule(origin[i:], rule[1:], depth+1, memo) {
memo[key] = true
return true
}
}
memo[key] = false
// Origin is >4KB, avoid matching it for performance
if len(origin) > maxWildcardOriginLength {
return false
}

// Check if the origin starts with the current part
if strings.HasPrefix(origin, part) {
// Recursively check the rest of the origin and rule
result := matchOriginWithRule(origin[len(part):], rule[1:], depth+1, memo)
memo[key] = result
return result
for _, w := range cors.wildcardOrigins {
if w.Match(origin) {
return true
}
}

memo[key] = false
return false
}

func (cors *cors) handlePreflight(w http.ResponseWriter) {
header := w.Header()
for key, value := range cors.preflightHeaders {
header[key] = value
}
maps.Copy(header, cors.preflightHeaders)
}

func (cors *cors) handleNormal(w http.ResponseWriter) {
header := w.Header()
for key, value := range cors.normalHeaders {
header[key] = value
}
maps.Copy(header, cors.normalHeaders)
}
27 changes: 4 additions & 23 deletions router/pkg/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,35 +108,16 @@ func (c *Config) Validate() error {
return nil
}

func (c *Config) parseWildcardRules() [][]string {
var wRules [][]string
func (c *Config) parseNewWildcardRules() []*WildcardPattern {
var wRules []*WildcardPattern

for _, o := range c.AllowOrigins {
if !strings.Contains(o, "*") {
continue
}

// Split origin by wildcard (*)
parts := strings.Split(o, "*")

// If there’s no wildcard, skip this origin
if len(parts) == 1 {
continue
}

// Generate rules for origins with multiple wildcard segments
var rule []string
for i, part := range parts {
if i > 0 {
rule = append(rule, "*") // Add wildcard indicator between segments
}
if part != "" {
rule = append(rule, part)
}
}

// Add parsed rule to wRules
wRules = append(wRules, rule)
wp := Compile(o)
wRules = append(wRules, wp)
}

return wRules
Expand Down
99 changes: 59 additions & 40 deletions router/pkg/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cors

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -220,6 +221,29 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
assert.Len(t, header, 2)
}

func TestExtremeLengthOriginKillswitch(t *testing.T) {
cors := newCors(nil, Config{
Enabled: true,
AllowOrigins: []string{"https://*.google.com"},
})

shortSubdomain := strings.Repeat("a", 10)
longSubdomain := strings.Repeat("a", 500)
tooLongSubdomain := strings.Repeat("a", 4096)

assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", shortSubdomain)))
assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", longSubdomain)))
assert.False(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", tooLongSubdomain)))

// Should not affect strict origins
cors = newCors(nil, Config{
Enabled: true,
AllowOrigins: []string{fmt.Sprintf("https://%s.google.com", tooLongSubdomain)},
})

assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", tooLongSubdomain)))
}

func TestValidateOrigin(t *testing.T) {
cors := newCors(nil, Config{
Enabled: true,
Expand Down Expand Up @@ -519,29 +543,10 @@ func TestComplexWildcards(t *testing.T) {
}
for _, tc := range testCasesList {
w := performRequest(router, "GET", tc.origin)
assert.Equal(t, tc.expectedCode, w.Code)
assert.Equalf(t, tc.expectedCode, w.Code, "expected %d for %s, got %d", tc.expectedCode, tc.origin, w.Code)
}
}

func TestMaxRecursionDepth(t *testing.T) {
router := newTestRouter(Config{
Enabled: true,
AllowOrigins: []string{
"https://*.example.*.*.com", // multiple sequential wildcards
"https://*.*.*.*.com",
},
AllowMethods: []string{"GET"},
})

maxRecursionDepth = 2
w := performRequest(router, "GET", "https://subdomain.example.subdomain.example.com")
assert.Equal(t, 403, w.Code)

maxRecursionDepth = 10
w = performRequest(router, "GET", "https://subdomain.example.subdomain.example.com")
assert.Equal(t, 200, w.Code)
}

func TestDisabled(t *testing.T) {
config := Config{
Enabled: true,
Expand All @@ -561,35 +566,26 @@ func TestDisabled(t *testing.T) {
assert.Equal(t, 200, w.Code)
}

func BenchmarkCorsWithoutWildcards(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()

b.Run("without wildcards", func(b *testing.B) {
func BenchmarkCorsWithWildcards(b *testing.B) {
b.Run("with wildcards", func(b *testing.B) {
router := newTestRouter(Config{
Enabled: true,
AllowOrigins: []string{
"https://*.wgexample.com",
"https://wgexample.com",
"https://*.wgexample.io:*",
"https://*.wgexample.org",
"https://*.d2grknavcceso7.amplifyapp.com",
"https://*.example.*.*.com", // multiple sequential wildcards
"https://*.*.*.*.com",
},
AllowMethods: []string{"GET"},
})

w := performRequest(router, "GET", "https://wgexample.com")
assert.Equal(b, 200, w.Code)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com")
assert.Equal(b, 200, w.Code)
}
})
}

func BenchmarkCorsWithWildcards(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()

b.Run("with wildcards", func(b *testing.B) {
b.Run("with massive wildcards", func(b *testing.B) {
router := newTestRouter(Config{
Enabled: true,
AllowOrigins: []string{
Expand All @@ -599,7 +595,30 @@ func BenchmarkCorsWithWildcards(b *testing.B) {
AllowMethods: []string{"GET"},
})

w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com")
assert.Equal(b, 200, w.Code)
longString := strings.Repeat("a", 50000)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := performRequest(router, "GET", fmt.Sprintf("https://%[1]s.%[1]s.%[1]s.%[1]s.com", longString))
assert.Equal(b, 200, w.Code)
}
})

b.Run("without wildcards", func(b *testing.B) {
router := newTestRouter(Config{
Enabled: true,
AllowOrigins: []string{
"https://wgexample.com",
},
AllowMethods: []string{"GET"},
})

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := performRequest(router, "GET", "https://wgexample.com")
assert.Equal(b, 200, w.Code)
}
})
}
Loading
Loading