Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 333 additions & 4 deletions router-tests/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -56,6 +58,333 @@ func TestAuthentication(t *testing.T) {
})
})

t.Run("unknown kid refresh blocks when burst exceeded", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 10 * time.Second,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: true,
Interval: 1 * time.Second,
Burst: 1,
},
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)

header := http.Header{"Authorization": []string{"Bearer " + token}}

res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res1.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res1.StatusCode)
_, err = io.ReadAll(res1.Body)
require.NoError(t, err)

start := time.Now()
res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res2.Body.Close() }()
elapsed := time.Since(start)

require.True(t, elapsed >= 700*time.Millisecond)
require.Equal(t, http.StatusUnauthorized, res2.StatusCode)
data, err := io.ReadAll(res2.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
})
})

t.Run("unknown kid refresh does not block when burst not exceeded", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 10 * time.Second,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: true,
Interval: 1 * time.Second,
Burst: 1,
},
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)
header := http.Header{"Authorization": []string{"Bearer " + token}}

res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
_, err = io.ReadAll(res.Body)
require.NoError(t, err)

// Wait for interval so next refresh is within burst budget
time.Sleep(1200 * time.Millisecond)

start := time.Now()
res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res2.Body.Close() }()
elapsed := time.Since(start)
require.True(t, elapsed < 100*time.Millisecond)
require.Equal(t, http.StatusUnauthorized, res2.StatusCode)
data, err := io.ReadAll(res2.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
})
})

// Since the rate limiter knows that the limit will definitely be exceeded it exits
// immediately without waiting
t.Run("unknown kid refresh interval exceeding max wait returns immediately", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 10 * time.Second,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: true,
Interval: 1 * time.Second, // next token available in ~1s
Burst: 1,
MaxWait: 700 * time.Millisecond, // cap wait well below interval
},
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)

header := http.Header{"Authorization": []string{"Bearer " + token}}

res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res1.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res1.StatusCode)
_, err = io.ReadAll(res1.Body)
require.NoError(t, err)

// Next call should exceed max wait so should return immediately
start := time.Now()
res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res2.Body.Close() }()
elapsed := time.Since(start)
require.True(t, elapsed < 100*time.Millisecond)
require.Equal(t, http.StatusUnauthorized, res2.StatusCode)
data, err := io.ReadAll(res2.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
})
})

t.Run("unknown kid refresh exceeding burst waits until interval when max wait larger", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 10 * time.Second,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: true,
Interval: 1 * time.Second,
Burst: 1,
MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token
},
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)

header := http.Header{"Authorization": []string{"Bearer " + token}}

res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res1.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res1.StatusCode)
_, err = io.ReadAll(res1.Body)
require.NoError(t, err)

start := time.Now()
res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res2.Body.Close() }()
elapsed := time.Since(start)

require.True(t, elapsed >= 600*time.Millisecond)
require.Equal(t, http.StatusUnauthorized, res2.StatusCode)
data, err := io.ReadAll(res2.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
})
})

// After consuming the single burst token, launch multiple requests in parallel.
// Each should block if the max limit has not been accumulated
t.Run("unknown kid refresh parallel exceeding burst waits up to max wait", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

const waitEntries = 4

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 10 * time.Second,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: true,
Interval: 1 * time.Second,
Burst: 1,
MaxWait: waitEntries * time.Second,
},
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)

header := http.Header{"Authorization": []string{"Bearer " + token}}

// Send initial request to use up the burst token
res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res1.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res1.StatusCode)
_, err = io.ReadAll(res1.Body)
require.NoError(t, err)

var elapsedFastCounter atomic.Int64
var wg sync.WaitGroup

for range waitEntries + 1 {
wg.Add(1)

go func() {
defer wg.Done()

start := time.Now()
res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res2.Body.Close() }()

elapsed := time.Since(start)

if elapsed < 100*time.Millisecond {
elapsedFastCounter.Add(1)
}

require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond)
require.Equal(t, http.StatusUnauthorized, res2.StatusCode)
data, err := io.ReadAll(res2.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
}()
}

wg.Wait()

// We only exit early on the 5th request as by the 5th request we have accumulated
// enough tokens to exceed the max wait duration
require.Equal(t, 1, int(elapsedFastCounter.Load()))
})
})

t.Run("authentication should not block with unknown kid when refresh is disabled", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 100 * time.Millisecond,
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Create a token signed with a valid key but with an unknown kid header
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)

maxDuration := 4 * time.Second
testenv.AwaitFunc(t, maxDuration, func() {
for range 5 {
func() {
header := http.Header{
"Authorization": []string{"Bearer " + token},
}
res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader))
data, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
}()
}
})
})
})

t.Run("invalid token", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1182,7 +1511,7 @@ func TestAlgorithmMismatch(t *testing.T) {

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil)
token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
Expand Down Expand Up @@ -1307,7 +1636,7 @@ func TestOidcDiscovery(t *testing.T) {
tokens := make(map[string]string)

for _, c := range crypto {
token, err := authServer.TokenForKID(c.KID(), nil)
token, err := authServer.TokenForKID(c.KID(), nil, false)
require.NoError(t, err)

tokens[c.KID()] = token
Expand Down Expand Up @@ -1421,7 +1750,7 @@ func TestMultipleKeys(t *testing.T) {
tokens := make(map[string]string)

for _, c := range crypto {
token, err := authServer.TokenForKID(c.KID(), nil)
token, err := authServer.TokenForKID(c.KID(), nil, false)
require.NoError(t, err)

tokens[c.KID()] = token
Expand Down Expand Up @@ -1604,7 +1933,7 @@ func TestSupportedAlgorithms(t *testing.T) {

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil)
token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
Expand Down
2 changes: 1 addition & 1 deletion router-tests/header_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func TestHeaderSetWithExpression(t *testing.T) {
authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions)
require.NoError(t, err)

token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"})
token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}, false)
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
Expand Down
Loading
Loading