diff --git a/router/pkg/cors/config.go b/router/pkg/cors/config.go index 136953c145..1a72366f29 100644 --- a/router/pkg/cors/config.go +++ b/router/pkg/cors/config.go @@ -1,8 +1,9 @@ package cors import ( + "maps" "net/http" - "strings" + "slices" ) type cors struct { @@ -12,13 +13,13 @@ type cors struct { allowOrigins []string normalHeaders http.Header preflightHeaders http.Header - wildcardOrigins [][]string + wildcardOrigins []*WildcardPattern handler http.Handler } var ( - maxRecursionDepth = 10 // Safeguard against deep recursion - DefaultSchemas = []string{ + maxWildcardOriginLength = 4096 // Maximum length of an origin string for it to be eligible for wildcard matching + DefaultSchemas = []string{ "http://", "https://", } @@ -55,7 +56,7 @@ func newCors(handler http.Handler, config Config) *cors { allowOrigins: normalize(config.AllowOrigins), normalHeaders: generateNormalHeaders(config), preflightHeaders: generatePreflightHeaders(config), - wildcardOrigins: config.parseWildcardRules(), + wildcardOrigins: config.parseNewWildcardRules(), handler: handler, } } @@ -102,10 +103,8 @@ func (cors *cors) validateOrigin(origin string) bool { if cors.allowAllOrigins { return true } - for _, value := range cors.allowOrigins { - if value == origin { - return true - } + if slices.Contains(cors.allowOrigins, origin) { + return true } if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) { return true @@ -117,67 +116,25 @@ func (cors *cors) validateOrigin(origin string) bool { } func (cors *cors) validateWildcardOrigin(origin string) bool { - for _, w := range cors.wildcardOrigins { - if matchOriginWithRule(origin, w, 0, map[string]bool{}) { - return true - } - } - return false -} - -// Recursive helper function with depth limit and memoization -func matchOriginWithRule(origin string, rule []string, depth int, memo map[string]bool) bool { - if depth > maxRecursionDepth { - return false // Exceeded recursion depth - } - - // Memoization key - key := origin + "|" + strings.Join(rule, "|") - if val, exists := memo[key]; exists { - return val - } - - if len(rule) == 0 { - // Successfully matched if origin is also fully consumed - return origin == "" - } - - part := rule[0] - - if part == "*" { - // Try to match the remaining rule by advancing in origin - for i := 0; i <= len(origin); i++ { - if matchOriginWithRule(origin[i:], rule[1:], depth+1, memo) { - memo[key] = true - return true - } - } - memo[key] = false + // Origin is >4KB, avoid matching it for performance + if len(origin) > maxWildcardOriginLength { return false } - // Check if the origin starts with the current part - if strings.HasPrefix(origin, part) { - // Recursively check the rest of the origin and rule - result := matchOriginWithRule(origin[len(part):], rule[1:], depth+1, memo) - memo[key] = result - return result + for _, w := range cors.wildcardOrigins { + if w.Match(origin) { + return true + } } - - memo[key] = false return false } func (cors *cors) handlePreflight(w http.ResponseWriter) { header := w.Header() - for key, value := range cors.preflightHeaders { - header[key] = value - } + maps.Copy(header, cors.preflightHeaders) } func (cors *cors) handleNormal(w http.ResponseWriter) { header := w.Header() - for key, value := range cors.normalHeaders { - header[key] = value - } + maps.Copy(header, cors.normalHeaders) } diff --git a/router/pkg/cors/cors.go b/router/pkg/cors/cors.go index f81a16288b..fd37c89aec 100644 --- a/router/pkg/cors/cors.go +++ b/router/pkg/cors/cors.go @@ -108,35 +108,16 @@ func (c *Config) Validate() error { return nil } -func (c *Config) parseWildcardRules() [][]string { - var wRules [][]string +func (c *Config) parseNewWildcardRules() []*WildcardPattern { + var wRules []*WildcardPattern for _, o := range c.AllowOrigins { if !strings.Contains(o, "*") { continue } - // Split origin by wildcard (*) - parts := strings.Split(o, "*") - - // If there’s no wildcard, skip this origin - if len(parts) == 1 { - continue - } - - // Generate rules for origins with multiple wildcard segments - var rule []string - for i, part := range parts { - if i > 0 { - rule = append(rule, "*") // Add wildcard indicator between segments - } - if part != "" { - rule = append(rule, part) - } - } - - // Add parsed rule to wRules - wRules = append(wRules, rule) + wp := Compile(o) + wRules = append(wRules, wp) } return wRules diff --git a/router/pkg/cors/cors_test.go b/router/pkg/cors/cors_test.go index bb964be43e..0a939dd744 100644 --- a/router/pkg/cors/cors_test.go +++ b/router/pkg/cors/cors_test.go @@ -2,6 +2,7 @@ package cors import ( "context" + "fmt" "net/http" "net/http/httptest" "strings" @@ -220,6 +221,29 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) { assert.Len(t, header, 2) } +func TestExtremeLengthOriginKillswitch(t *testing.T) { + cors := newCors(nil, Config{ + Enabled: true, + AllowOrigins: []string{"https://*.google.com"}, + }) + + shortSubdomain := strings.Repeat("a", 10) + longSubdomain := strings.Repeat("a", 500) + tooLongSubdomain := strings.Repeat("a", 4096) + + assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", shortSubdomain))) + assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", longSubdomain))) + assert.False(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", tooLongSubdomain))) + + // Should not affect strict origins + cors = newCors(nil, Config{ + Enabled: true, + AllowOrigins: []string{fmt.Sprintf("https://%s.google.com", tooLongSubdomain)}, + }) + + assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", tooLongSubdomain))) +} + func TestValidateOrigin(t *testing.T) { cors := newCors(nil, Config{ Enabled: true, @@ -519,29 +543,10 @@ func TestComplexWildcards(t *testing.T) { } for _, tc := range testCasesList { w := performRequest(router, "GET", tc.origin) - assert.Equal(t, tc.expectedCode, w.Code) + assert.Equalf(t, tc.expectedCode, w.Code, "expected %d for %s, got %d", tc.expectedCode, tc.origin, w.Code) } } -func TestMaxRecursionDepth(t *testing.T) { - router := newTestRouter(Config{ - Enabled: true, - AllowOrigins: []string{ - "https://*.example.*.*.com", // multiple sequential wildcards - "https://*.*.*.*.com", - }, - AllowMethods: []string{"GET"}, - }) - - maxRecursionDepth = 2 - w := performRequest(router, "GET", "https://subdomain.example.subdomain.example.com") - assert.Equal(t, 403, w.Code) - - maxRecursionDepth = 10 - w = performRequest(router, "GET", "https://subdomain.example.subdomain.example.com") - assert.Equal(t, 200, w.Code) -} - func TestDisabled(t *testing.T) { config := Config{ Enabled: true, @@ -561,35 +566,26 @@ func TestDisabled(t *testing.T) { assert.Equal(t, 200, w.Code) } -func BenchmarkCorsWithoutWildcards(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - - b.Run("without wildcards", func(b *testing.B) { +func BenchmarkCorsWithWildcards(b *testing.B) { + b.Run("with wildcards", func(b *testing.B) { router := newTestRouter(Config{ Enabled: true, AllowOrigins: []string{ - "https://*.wgexample.com", - "https://wgexample.com", - "https://*.wgexample.io:*", - "https://*.wgexample.org", - "https://*.d2grknavcceso7.amplifyapp.com", "https://*.example.*.*.com", // multiple sequential wildcards "https://*.*.*.*.com", }, AllowMethods: []string{"GET"}, }) - w := performRequest(router, "GET", "https://wgexample.com") - assert.Equal(b, 200, w.Code) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com") + assert.Equal(b, 200, w.Code) + } }) -} - -func BenchmarkCorsWithWildcards(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - b.Run("with wildcards", func(b *testing.B) { + b.Run("with massive wildcards", func(b *testing.B) { router := newTestRouter(Config{ Enabled: true, AllowOrigins: []string{ @@ -599,7 +595,30 @@ func BenchmarkCorsWithWildcards(b *testing.B) { AllowMethods: []string{"GET"}, }) - w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com") - assert.Equal(b, 200, w.Code) + longString := strings.Repeat("a", 50000) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + w := performRequest(router, "GET", fmt.Sprintf("https://%[1]s.%[1]s.%[1]s.%[1]s.com", longString)) + assert.Equal(b, 200, w.Code) + } + }) + + b.Run("without wildcards", func(b *testing.B) { + router := newTestRouter(Config{ + Enabled: true, + AllowOrigins: []string{ + "https://wgexample.com", + }, + AllowMethods: []string{"GET"}, + }) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + w := performRequest(router, "GET", "https://wgexample.com") + assert.Equal(b, 200, w.Code) + } }) } diff --git a/router/pkg/cors/wildcard.go b/router/pkg/cors/wildcard.go new file mode 100644 index 0000000000..32774d0b02 --- /dev/null +++ b/router/pkg/cors/wildcard.go @@ -0,0 +1,129 @@ +package cors + +import ( + "bytes" + "regexp" + "strings" +) + +// WildcardPattern represents a wildcard pattern that can match strings containing '*' wildcards +type WildcardPattern struct { + pattern string + cards []card +} + +// card represents a literal string segment between wildcards +type card struct { + offset int + size int +} + +var repeatedWildcards = regexp.MustCompile(`\*+`) + +// NewWildcardPattern creates a new wildcard pattern from the given text +func Compile(pattern string) *WildcardPattern { + if pattern == "" { + return &WildcardPattern{ + pattern: pattern, + cards: make([]card, 0), + } + } + + pattern = repeatedWildcards.ReplaceAllString(pattern, "*") + + wp := &WildcardPattern{ + pattern: pattern, + cards: make([]card, 0), + } + + pos := strings.Index(pattern, "*") + if pos == -1 { + // No wildcards, just one card with the entire string + wp.cards = append(wp.cards, card{offset: 0, size: len(pattern)}) + return wp + } + + // Add first card (prefix before first '*') + wp.cards = append(wp.cards, card{offset: 0, size: pos}) + pos++ + + // Process middle segments between '*' characters + for { + pos2 := strings.Index(pattern[pos:], "*") + if pos2 == -1 { + break + } + pos2 += pos // Convert back to absolute position + if pos2 != pos { + // Non-empty segment between wildcards + wp.cards = append(wp.cards, card{offset: pos, size: pos2 - pos}) + } + pos = pos2 + 1 + } + + // Add last card (suffix after last '*') + wp.cards = append(wp.cards, card{offset: pos, size: len(pattern) - pos}) + + return wp +} + +// Match checks if the given string matches the wildcard pattern +func (wp *WildcardPattern) Match(s string) bool { + matched := wp.MatchBytes([]byte(s)) + return matched +} + +// MatchBytes checks if the given byte slice matches the wildcard pattern +func (wp *WildcardPattern) MatchBytes(data []byte) bool { + begin := 0 + end := len(data) + + numCards := len(wp.cards) + + // Handle empty pattern + if numCards == 0 { + return len(data) == 0 + } + + // Check anchored prefix card + firstCard := wp.cards[0] + if end-begin < firstCard.size { + return false + } + + if !bytes.Equal(data[begin:begin+firstCard.size], []byte(wp.pattern[firstCard.offset:firstCard.offset+firstCard.size])) { + return false + } + + begin += firstCard.size + + if numCards == 1 { + return begin == end + } + + // Check anchored suffix card + lastCard := wp.cards[numCards-1] + if end-begin < lastCard.size { + return false + } + + suffixPattern := wp.pattern[lastCard.offset : lastCard.offset+lastCard.size] + if string(data[end-lastCard.size:end]) != suffixPattern { + return false + } + end -= lastCard.size + + // Check unanchored infix cards + for i := 1; i < numCards-1; i++ { + card := wp.cards[i] + + // Find the pattern in the remaining data + idx := bytes.Index(data[begin:end], []byte(wp.pattern[card.offset:card.offset+card.size])) + if idx == -1 { + return false + } + begin += idx + card.size + } + + return true +} diff --git a/router/pkg/cors/wildcard_test.go b/router/pkg/cors/wildcard_test.go new file mode 100644 index 0000000000..ee147c5bd4 --- /dev/null +++ b/router/pkg/cors/wildcard_test.go @@ -0,0 +1,113 @@ +package cors + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWildcardCompile(t *testing.T) { + t.Run("should normalize repeated wildcard patterns", func(t *testing.T) { + compiled := Compile("**.example.com") + require.Equal(t, "*.example.com", compiled.pattern) + }) + + t.Run("should handle empty pattern", func(t *testing.T) { + compiled := Compile("") + require.Equal(t, "", compiled.pattern) + require.Equal(t, 0, len(compiled.cards)) + }) + + t.Run("should handle pattern without wildcards", func(t *testing.T) { + compiled := Compile("example.com") + require.Equal(t, "example.com", compiled.pattern) + require.Equal(t, 1, len(compiled.cards)) + }) + + t.Run("should handle single wildcard", func(t *testing.T) { + compiled := Compile("*") + require.Equal(t, "*", compiled.pattern) + require.Equal(t, 2, len(compiled.cards)) + }) +} + +func TestWildcardMatch(t *testing.T) { + t.Run("exact match without wildcards", func(t *testing.T) { + pattern := Compile("example.com") + require.True(t, pattern.Match("example.com")) + require.False(t, pattern.Match("test.com")) + require.False(t, pattern.Match("example.org")) + }) + + t.Run("single wildcard at start", func(t *testing.T) { + pattern := Compile("*.com") + require.True(t, pattern.Match("example.com")) + require.True(t, pattern.Match("test.com")) + require.True(t, pattern.Match(".com")) + require.False(t, pattern.Match("example.org")) + require.False(t, pattern.Match("com")) + }) + + t.Run("single wildcard at end", func(t *testing.T) { + pattern := Compile("example.*") + require.True(t, pattern.Match("example.com")) + require.True(t, pattern.Match("example.org")) + require.True(t, pattern.Match("example.")) + require.False(t, pattern.Match("test.com")) + require.False(t, pattern.Match("example")) + }) + + t.Run("single wildcard in middle", func(t *testing.T) { + pattern := Compile("api.*.com") + require.True(t, pattern.Match("api.v1.com")) + require.True(t, pattern.Match("api.test.com")) + require.True(t, pattern.Match("api..com")) + require.False(t, pattern.Match("api.com")) + require.False(t, pattern.Match("web.v1.com")) + }) + + t.Run("multiple wildcards", func(t *testing.T) { + pattern := Compile("*.api.*.com") + require.True(t, pattern.Match("sub.api.v1.com")) + require.True(t, pattern.Match("test.api.prod.com")) + require.False(t, pattern.Match("api.v1.com")) + require.False(t, pattern.Match("sub.api.com")) + }) + + t.Run("only wildcards", func(t *testing.T) { + pattern := Compile("*") + require.True(t, pattern.Match("anything")) + require.True(t, pattern.Match("")) + require.True(t, pattern.Match("a")) + }) + + t.Run("empty string patterns", func(t *testing.T) { + pattern := Compile("") + require.True(t, pattern.Match("")) + require.False(t, pattern.Match("anything")) + }) + + t.Run("normalized consecutive wildcards", func(t *testing.T) { + pattern := Compile("**.example.com") + require.True(t, pattern.Match("sub.example.com")) + require.True(t, pattern.Match("deep.sub.example.com")) + require.False(t, pattern.Match("example.org")) + }) +} + +func TestWildcardMatchBytes(t *testing.T) { + t.Run("should match byte slice same as string", func(t *testing.T) { + pattern := Compile("*.example.com") + testStr := "sub.example.com" + + require.Equal(t, pattern.Match(testStr), pattern.MatchBytes([]byte(testStr))) + }) + + t.Run("should handle empty byte slice", func(t *testing.T) { + pattern := Compile("*") + require.True(t, pattern.MatchBytes([]byte{})) + + pattern2 := Compile("test") + require.False(t, pattern2.MatchBytes([]byte{})) + }) +}