Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(srv.Close)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
return srv
}

Expand Down Expand Up @@ -705,8 +706,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {

tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)

proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
Expand Down Expand Up @@ -1136,8 +1138,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {

tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)

proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
Expand Down
12 changes: 11 additions & 1 deletion management/internals/server/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {

func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
Expand Down Expand Up @@ -203,6 +203,16 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
})
}

func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
return Create(s, func() *nbgrpc.PKCEVerifierStore {
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create PKCE verifier store: %v", err)
}
return pkceStore
})
}

func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
Expand Down
1 change: 0 additions & 1 deletion management/internals/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
Expand Down
61 changes: 61 additions & 0 deletions management/internals/shared/grpc/pkce_verifier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package grpc

import (
"context"
"fmt"
"time"

"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"

nbcache "github.com/netbirdio/netbird/management/server/cache"
)

// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type PKCEVerifierStore struct {
cache *cache.Cache[string]
ctx context.Context
}

// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}

return &PKCEVerifierStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}

// Store saves a PKCE verifier associated with an OAuth state parameter.
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
return fmt.Errorf("failed to store PKCE verifier: %w", err)
}

log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
return nil
}

// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
// Returns the verifier and true if found, or empty string and false if not found.
// This enforces single-use semantics for PKCE verifiers.
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
verifier, err := s.cache.Get(s.ctx, state)
if err != nil {
log.Debugf("PKCE verifier not found for state")
return "", false
}

if err := s.cache.Delete(s.ctx, state); err != nil {
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
}

return verifier, true
Comment thread
pascal-fischer marked this conversation as resolved.
}
59 changes: 10 additions & 49 deletions management/internals/shared/grpc/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,12 @@ type ProxyServiceServer struct {
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig

// TODO: use database to store these instead?
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
pkceVerifiers sync.Map
pkceCleanupCancel context.CancelFunc
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
Comment thread
pascal-fischer marked this conversation as resolved.
}

const pkceVerifierTTL = 10 * time.Minute

type pkceEntry struct {
verifier string
createdAt time.Time
}

// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
Expand All @@ -108,42 +100,21 @@ type proxyConnection struct {
}

// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx := context.Background()
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
pkceCleanupCancel: cancel,
}
go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s
}

// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
ticker := time.NewTicker(pkceVerifierTTL)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
s.pkceVerifiers.Range(func(key, value any) bool {
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
s.pkceVerifiers.Delete(key)
}
return true
})
}
}
}

// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
Expand All @@ -160,11 +131,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
}
}

// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}

func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager
}
Expand Down Expand Up @@ -795,7 +761,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)

codeVerifier := oauth2.GenerateVerifier()
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
}

return &proto.GetOIDCURLResponse{
Url: (&oauth2.Config{
Expand Down Expand Up @@ -832,18 +801,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
v, ok := s.pkceVerifiers.LoadAndDelete(state)
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
entry, ok := v.(pkceEntry)
if !ok {
return "", "", errors.New("invalid verifier for state")
}
if time.Since(entry.createdAt) > pkceVerifierTTL {
return "", "", errors.New("PKCE verifier expired")
}
verifier = entry.verifier

// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|")
Expand Down
55 changes: 43 additions & 12 deletions management/internals/shared/grpc/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"

"sync"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -94,11 +93,16 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
}

func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)

pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

s := &ProxyServiceServer{
tokenStore: tokenStore,
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
}
s.SetProxyController(newTestProxyController())

Expand Down Expand Up @@ -151,11 +155,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
}

func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)

pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

s := &ProxyServiceServer{
tokenStore: tokenStore,
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
}
s.SetProxyController(newTestProxyController())

Expand Down Expand Up @@ -185,11 +194,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
}

func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)

pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

s := &ProxyServiceServer{
tokenStore: tokenStore,
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
}
s.SetProxyController(newTestProxyController())

Expand Down Expand Up @@ -241,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
}

func TestOAuthState_NeverTheSame(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}

redirectURL := "https://app.example.com/callback"
Expand All @@ -265,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
}

func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}

// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
require.NoError(t, err)

_, _, err := s.ValidateState("base64url|hmac")
_, _, err = s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}

func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}

// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
require.NoError(t, err)

_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}
5 changes: 4 additions & 1 deletion management/internals/shared/grpc/validate_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)

proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)

createTestProxies(t, ctx, testStore)
Expand Down
2 changes: 1 addition & 1 deletion management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3133,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}

proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
Expand Down
Loading
Loading