diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 2c4c1372e21..ab899e0bf84 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -31,6 +31,7 @@ type store interface { type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool @@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d var ret []*domain.Domain // Add connected proxy clusters as free domains. - // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // For BYOP accounts, only their own cluster is returned; otherwise shared clusters. + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) return nil, err @@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName return nil, status.NewPermissionDeniedError() } - // Verify the target cluster is in the available clusters - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // Verify the target cluster is in the available clusters for this account + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string { // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster. func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -298,6 +299,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) } +func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) { + byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get BYOP cluster addresses: %w", err) + } + if len(byopAddresses) > 0 { + return byopAddresses, nil + } + return m.proxyManager.GetActiveClusterAddresses(ctx) +} + func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { bestCluster := "" bestLen := -1 diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go new file mode 100644 index 00000000000..fdeb0765ff4 --- /dev/null +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -0,0 +1,110 @@ +package manager + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockProxyManager struct { + getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error) +} + +func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveClusterAddressesFunc != nil { + return m.getActiveClusterAddressesFunc(ctx) + } + return nil, nil +} + +func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveClusterAddressesForAccountFunc != nil { + return m.getActiveClusterAddressesForAccountFunc(ctx, accountID) + } + return nil, nil +} + +func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + +func TestGetClusterAllowList_BYOPProxy(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist") + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byop.example.com"}, result) +} + +func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, errors.New("db error") + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "BYOP cluster addresses") +} + +func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) +} + diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index aa7cd8630e5..e491e2bbcb4 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,15 +11,19 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error + Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error Disconnect(ctx context.Context, proxyID string) error Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) - GetActiveClusters(ctx context.Context) ([]Cluster, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStale(ctx context.Context, inactivityDuration time.Duration) error + GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) + CountAccountProxies(ctx context.Context, accountID string) (int64, error) + IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error } // OIDCValidationConfig contains the OIDC configuration needed for token validation. diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index d13334e83b0..58e612f6e3f 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -13,13 +13,19 @@ import ( // store defines the interface for proxy persistence operations type store interface { SaveProxy(ctx context.Context, p *proxy.Proxy) error + DisconnectProxy(ctx context.Context, proxyID string) error UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error } // Manager handles all proxy operations @@ -43,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { // Connect registers a new proxy connection in the database. // capabilities may be nil for old proxies that do not report them. -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { +func (m *Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) error { now := time.Now() var caps proxy.Capabilities if capabilities != nil { @@ -53,9 +59,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress ID: proxyID, ClusterAddress: clusterAddress, IPAddress: ipAddress, + AccountID: accountID, LastSeen: now, ConnectedAt: &now, - Status: "connected", + Status: proxy.StatusConnected, Capabilities: caps, } @@ -74,16 +81,8 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress } // Disconnect marks a proxy as disconnected in the database -func (m Manager) Disconnect(ctx context.Context, proxyID string) error { - now := time.Now() - p := &proxy.Proxy{ - ID: proxyID, - Status: "disconnected", - DisconnectedAt: &now, - LastSeen: now, - } - - if err := m.store.SaveProxy(ctx, p); err != nil { +func (m *Manager) Disconnect(ctx context.Context, proxyID string) error { + if err := m.store.DisconnectProxy(ctx, proxyID); err != nil { log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) return err } @@ -96,7 +95,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error { } // Heartbeat updates the proxy's last seen timestamp -func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) return err @@ -108,7 +107,7 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddre } // GetActiveClusterAddresses returns all unique cluster addresses for active proxies -func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { +func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { addresses, err := m.store.GetActiveProxyClusterAddresses(ctx) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) @@ -117,16 +116,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error return addresses, nil } -// GetActiveClusters returns all active proxy clusters with their connected proxy count. -func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) { - clusters, err := m.store.GetActiveProxyClusters(ctx) - if err != nil { - log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err) - return nil, err - } - return clusters, nil -} - // ClusterSupportsCustomPorts returns whether any active proxy in the cluster // supports custom ports. Returns nil when no proxy has reported capabilities. func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { @@ -146,10 +135,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string } // CleanupStale removes proxies that haven't sent heartbeat in the specified duration -func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { +func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err) return err } return nil } + +func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err) + return nil, err + } + return addresses, nil +} + +func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) { + return m.store.GetProxyByAccountID(ctx, accountID) +} + +func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + return m.store.CountProxiesByAccountID(ctx, accountID) +} + +func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID) + if err != nil { + return false, err + } + return !conflicting, nil +} + +func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil { + log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err) + return err + } + return nil +} + diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go new file mode 100644 index 00000000000..8bbb275ff85 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go @@ -0,0 +1,336 @@ +package manager + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" +) + +type mockStore struct { + saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error + disconnectProxyFunc func(ctx context.Context, proxyID string) error + updateProxyHeartbeatFunc func(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error) + cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error + getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error) + countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error) + isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error) + deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error +} + +func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { + if m.saveProxyFunc != nil { + return m.saveProxyFunc(ctx, p) + } + return nil +} +func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error { + if m.disconnectProxyFunc != nil { + return m.disconnectProxyFunc(ctx, proxyID) + } + return nil +} +func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + if m.updateProxyHeartbeatFunc != nil { + return m.updateProxyHeartbeatFunc(ctx, proxyID, clusterAddress, ipAddress) + } + return nil +} +func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveProxyClusterAddressesFunc != nil { + return m.getActiveProxyClusterAddressesFunc(ctx) + } + return nil, nil +} +func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveProxyClusterAddressesForAccFunc != nil { + return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID) + } + return nil, nil +} +func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) { + return nil, nil +} +func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error { + if m.cleanupStaleProxiesFunc != nil { + return m.cleanupStaleProxiesFunc(ctx, d) + } + return nil +} +func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + if m.getProxyByAccountIDFunc != nil { + return m.getProxyByAccountIDFunc(ctx, accountID) + } + return nil, fmt.Errorf("proxy not found for account %s", accountID) +} +func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + if m.countProxiesByAccountIDFunc != nil { + return m.countProxiesByAccountIDFunc(ctx, accountID) + } + return 0, nil +} +func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + if m.isClusterAddressConflictingFunc != nil { + return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID) + } + return false, nil +} +func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + if m.deleteAccountClusterFunc != nil { + return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID) + } + return nil +} +func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} +func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} +func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + +func newTestManager(s store) *Manager { + meter := noop.NewMeterProvider().Meter("test") + m, err := NewManager(s, meter) + if err != nil { + panic(err) + } + return m +} + +func TestConnect_WithAccountID(t *testing.T) { + accountID := "acc-123" + + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", &accountID, nil) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Equal(t, "proxy-1", savedProxy.ID) + assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress) + assert.Equal(t, "10.0.0.1", savedProxy.IPAddress) + assert.Equal(t, &accountID, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) + assert.NotNil(t, savedProxy.ConnectedAt) +} + +func TestConnect_WithoutAccountID(t *testing.T) { + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.Connect(context.Background(), "proxy-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Nil(t, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) +} + +func TestConnect_StoreError(t *testing.T) { + s := &mockStore{ + saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", nil, nil) + assert.Error(t, err) +} + +func TestIsClusterAddressAvailable(t *testing.T) { + tests := []struct { + name string + conflicting bool + storeErr error + wantResult bool + wantErr bool + }{ + { + name: "available - no conflict", + conflicting: false, + wantResult: true, + }, + { + name: "not available - conflict exists", + conflicting: true, + wantResult: false, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) { + return tt.conflicting, tt.storeErr + }, + } + + mgr := newTestManager(s) + result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantResult, result) + }) + } +} + +func TestCountAccountProxies(t *testing.T) { + tests := []struct { + name string + count int64 + storeErr error + wantCount int64 + wantErr bool + }{ + { + name: "no proxies", + count: 0, + wantCount: 0, + }, + { + name: "one proxy", + count: 1, + wantCount: 1, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) { + return tt.count, tt.storeErr + }, + } + + mgr := newTestManager(s) + count, err := mgr.CountAccountProxies(context.Background(), "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + }) + } +} + +func TestGetAccountProxy(t *testing.T) { + accountID := "acc-123" + + t.Run("found", func(t *testing.T) { + expected := &proxy.Proxy{ + ID: "proxy-1", + ClusterAddress: "byop.example.com", + AccountID: &accountID, + Status: proxy.StatusConnected, + } + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) { + assert.Equal(t, accountID, accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + p, err := mgr.GetAccountProxy(context.Background(), accountID) + require.NoError(t, err) + assert.Equal(t, expected, p) + }) + + t.Run("not found", func(t *testing.T) { + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) { + return nil, errors.New("not found") + }, + } + + mgr := newTestManager(s) + _, err := mgr.GetAccountProxy(context.Background(), accountID) + assert.Error(t, err) + }) +} + +func TestDeleteAccountCluster(t *testing.T) { + t.Run("success", func(t *testing.T) { + var deletedCluster, deletedAccount string + s := &mockStore{ + deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error { + deletedCluster = clusterAddress + deletedAccount = accountID + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123") + require.NoError(t, err) + assert.Equal(t, "cluster.example.com", deletedCluster) + assert.Equal(t, "acc-123", deletedAccount) + }) + + t.Run("store error", func(t *testing.T) { + s := &mockStore{ + deleteAccountClusterFunc: func(_ context.Context, _, _ string) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123") + assert.Error(t, err) + }) +} + +func TestGetActiveClusterAddressesForAccount(t *testing.T) { + expected := []string{"byop.example.com"} + s := &mockStore{ + getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, expected, result) +} diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index 282ca0ba5a0..5d43fae7aec 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -93,17 +93,17 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte } // Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { +func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities) ret0, _ := ret[0].(error) return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities) } // Disconnect mocks base method. @@ -135,19 +135,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) } -// GetActiveClusters mocks base method. -func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { +func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveClusters", ctx) - ret0, _ := ret[0].([]Cluster) + ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetActiveClusters indicates an expected call of GetActiveClusters. -func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID) } // Heartbeat mocks base method. @@ -164,6 +162,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAdd return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) } +// GetAccountProxy mocks base method. +func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID) + ret0, _ := ret[0].(*Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountProxy indicates an expected call of GetAccountProxy. +func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID) +} + +// CountAccountProxies mocks base method. +func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAccountProxies indicates an expected call of CountAccountProxies. +func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID) +} + +// IsClusterAddressAvailable mocks base method. +func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable. +func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID) +} + +// DeleteAccountCluster mocks base method. +func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) +} + // MockController is a mock of Controller interface. type MockController struct { ctrl *gomock.Controller diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 339c82446ca..eaff09aa6be 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -1,6 +1,13 @@ package proxy -import "time" +import ( + "time" +) + +const ( + StatusConnected = "connected" + StatusDisconnected = "disconnected" +) // Capabilities describes what a proxy can handle, as reported via gRPC. // Nil fields mean the proxy never reported this capability. @@ -20,6 +27,7 @@ type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` IPAddress string `gorm:"type:varchar(45)"` + AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time @@ -35,6 +43,8 @@ func (Proxy) TableName() string { // Cluster represents a group of proxy nodes serving the same address. type Cluster struct { + ID string Address string ConnectedProxies int + SelfHosted bool } diff --git a/management/internals/modules/reverseproxy/proxytoken/handler.go b/management/internals/modules/reverseproxy/proxytoken/handler.go new file mode 100644 index 00000000000..728cdf723ce --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler.go @@ -0,0 +1,195 @@ +package proxytoken + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +type handler struct { + store store.Store + permissionsManager permissions.Manager +} + +func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) { + h := &handler{store: s, permissionsManager: permissionsManager} + router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS") +} + +func (h *handler) createToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + var req api.ProxyTokenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if req.Name == "" || len(req.Name) > 255 { + util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w) + return + } + + var expiresIn time.Duration + if req.ExpiresIn != nil { + if *req.ExpiresIn < 0 { + util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w) + return + } + if *req.ExpiresIn > 0 { + expiresIn = time.Duration(*req.ExpiresIn) * time.Second + } + } + + accountID := userAuth.AccountId + generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId) + if err != nil { + util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w) + return + } + + if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil { + util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w) + return + } + + resp := toProxyTokenCreatedResponse(generated) + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId) + if err != nil { + util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w) + return + } + + resp := make([]api.ProxyToken, 0, len(tokens)) + for _, token := range tokens { + resp = append(resp, toProxyTokenResponse(token)) + } + + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokenID := mux.Vars(r)["tokenId"] + if tokenID == "" { + util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w) + return + } + + token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID) + if err != nil { + if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + } else { + util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w) + } + return + } + + if token.AccountID == nil || *token.AccountID != userAuth.AccountId { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + return + } + + if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil { + util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken { + resp := api.ProxyToken{ + Id: token.ID, + Name: token.Name, + Revoked: token.Revoked, + } + if !token.CreatedAt.IsZero() { + resp.CreatedAt = token.CreatedAt + } + if token.ExpiresAt != nil { + resp.ExpiresAt = token.ExpiresAt + } + if token.LastUsed != nil { + resp.LastUsed = token.LastUsed + } + return resp +} + +func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated { + base := toProxyTokenResponse(&generated.ProxyAccessToken) + plainToken := string(generated.PlainToken) + return api.ProxyTokenCreated{ + Id: base.Id, + Name: base.Name, + CreatedAt: base.CreatedAt, + ExpiresAt: base.ExpiresAt, + LastUsed: base.LastUsed, + Revoked: base.Revoked, + PlainToken: plainToken, + } +} diff --git a/management/internals/modules/reverseproxy/proxytoken/handler_test.go b/management/internals/modules/reverseproxy/proxytoken/handler_test.go new file mode 100644 index 00000000000..a2875290981 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler_test.go @@ -0,0 +1,275 @@ +package proxytoken + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func authContext(accountID, userID string) context.Context { + return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{ + AccountId: accountID, + UserId: userID, + }) +} + +func TestCreateToken_AccountScoped(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "my-token"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp api.ProxyTokenCreated + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + assert.NotEmpty(t, resp.PlainToken) + assert.Equal(t, "my-token", resp.Name) + assert.False(t, resp.Revoked) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.AccountID) + assert.Equal(t, accountID, *savedToken.AccountID) + assert.Equal(t, "user-1", savedToken.CreatedBy) +} + +func TestCreateToken_WithExpiration(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "expiring-token", "expires_in": 3600}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.ExpiresAt) + assert.True(t, savedToken.ExpiresAt.After(time.Now())) +} + +func TestCreateToken_EmptyName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": ""}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestCreateToken_PermissionDenied(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": "test"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestListTokens(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + now := time.Now() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{ + {ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false}, + {ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true}, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.listTokens(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp []api.ProxyToken + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Len(t, resp, 2) + assert.Equal(t, "tok-1", resp[0].Id) + assert.False(t, resp[0].Revoked) + assert.Equal(t, "tok-2", resp[1].Id) + assert.True(t, resp[1].Revoked) +} + +func TestRevokeToken_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + Name: "test-token", + AccountID: &accountID, + }, nil) + mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext(accountID, "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestRevokeToken_WrongAccount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + otherAccount := "acc-other" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: &otherAccount, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestRevokeToken_ManagementWideToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: nil, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index a49cbea3572..6a94aa32bb1 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -10,6 +10,7 @@ import ( type Manager interface { GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) + DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) @@ -28,4 +29,5 @@ type Manager interface { RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StartExposeReaper(ctx context.Context) + GetServiceByDomain(ctx context.Context, domain string) (*Service, error) } diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index cc5ccbb8e81..83b2162ed7b 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) } +// DeleteAccountCluster mocks base method. +func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress) +} + // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() @@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID) } +// GetServiceByDomain mocks base method. +func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) + ret0, _ := ret[0].(*Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceByDomain indicates an expected call of GetServiceByDomain. +func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain) +} + // GetGlobalServices mocks base method. func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index cd81efa88dd..08272077c1a 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma accesslogsmanager.RegisterEndpoints(router, accessLogsManager) router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") @@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { apiClusters := make([]api.ProxyCluster, 0, len(clusters)) for _, c := range clusters { apiClusters = append(apiClusters, api.ProxyCluster{ + Id: c.ID, Address: c.Address, ConnectedProxies: c.ConnectedProxies, + SelfHosted: c.SelfHosted, }) } util.WriteJSONObject(r.Context(), w, apiClusters) } + +func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + clusterAddress := mux.Vars(r)["clusterAddress"] + if clusterAddress == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w) + return + } + + if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index ed9d4201be2..1cec7a8c2b9 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -120,7 +120,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin return nil, status.NewPermissionDeniedError() } - return m.store.GetActiveProxyClusters(ctx) + return m.store.GetActiveProxyClusters(ctx, accountID) +} + +// DeleteAccountCluster removes all proxy registrations for the given cluster address +// owned by the account. +func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID) } func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { @@ -984,6 +998,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]* return services, nil } +func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 54ac8ab182e..be572aa692a 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -433,7 +433,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { t.Helper() tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) - srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) return srv } @@ -712,7 +712,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) @@ -1135,7 +1135,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) { tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 2b40c0aad9c..44b7fdc32a0 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store()) s.AfterInit(func(s *BaseServer) { proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetProxyController(s.ServiceProxyController()) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index a5e352e7508..82e49ee7878 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net" "net/http" "net/url" "strings" @@ -47,6 +48,11 @@ type ProxyOIDCConfig struct { KeysLocation string } +// ProxyTokenChecker checks whether a proxy access token is still valid. +type ProxyTokenChecker interface { + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) +} + // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -75,6 +81,9 @@ type ProxyServiceServer struct { // Store for one-time authentication tokens tokenStore *OneTimeTokenStore + // Checker for proxy access token validity + tokenChecker ProxyTokenChecker + // OIDC configuration for proxy authentication oidcConfig ProxyOIDCConfig @@ -90,6 +99,8 @@ const pkceVerifierTTL = 10 * time.Minute type proxyConnection struct { proxyID string address string + accountID *string + tokenID string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer sendChan chan *proto.GetMappingUpdateResponse @@ -97,8 +108,19 @@ type proxyConnection struct { cancel context.CancelFunc } +func enforceAccountScope(ctx context.Context, requestAccountID string) error { + token := GetProxyTokenFromContext(ctx) + if token == nil || token.AccountID == nil { + return nil + } + if requestAccountID == "" || *token.AccountID != requestAccountID { + return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID) + } + return nil +} + // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer { ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ accessLogManager: accessLogMgr, @@ -108,6 +130,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + tokenChecker: tokenChecker, cancel: cancel, } go s.cleanupStaleProxies(ctx) @@ -166,10 +189,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy address is invalid") } + var accountID *string + token := GetProxyTokenFromContext(ctx) + if token != nil && token.AccountID != nil { + accountID = token.AccountID + + available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID) + if err != nil { + return status.Errorf(codes.Internal, "check cluster address: %v", err) + } + if !available { + return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress) + } + } + + var tokenID string + if token != nil { + tokenID = token.ID + } + connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ proxyID: proxyID, address: proxyAddress, + accountID: accountID, + tokenID: tokenID, capabilities: req.GetCapabilities(), stream: stream, sendChan: make(chan *proto.GetMappingUpdateResponse, 100), @@ -177,12 +221,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest cancel: cancel, } - s.connectedProxies.Store(proxyID, conn) - if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) - } - - // Register proxy in database with capabilities var caps *proxy.Capabilities if c := req.GetCapabilities(); c != nil { caps = &proxy.Capabilities{ @@ -191,19 +229,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest SupportsCrowdsec: c.SupportsCrowdsec, } } - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { - log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - s.connectedProxies.Delete(proxyID) - if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { - log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) + if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID, caps); err != nil { + if accountID != nil { + cancel() + return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) } - return status.Errorf(codes.Internal, "register proxy in database: %v", err) + log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err) + } + + s.connectedProxies.Store(proxyID, conn) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) } + log.WithFields(log.Fields{ "proxy_id": proxyID, "address": proxyAddress, "cluster_addr": proxyAddress, + "account_id": accountID, "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { @@ -228,7 +272,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest go s.sender(conn, errChan) // Start heartbeat goroutine - go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo) + go s.heartbeat(connCtx, conn, peerInfo) select { case err := <-errChan: @@ -238,16 +282,28 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } -// heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) { +func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, ipAddress string) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { - log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) + if err := s.proxyManager.Heartbeat(ctx, conn.proxyID, conn.address, ipAddress); err != nil { + log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", conn.proxyID, err) + } + + if conn.tokenID != "" && s.tokenChecker != nil { + valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID) + if err != nil { + log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err) + continue + } + if !valid { + log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID) + conn.cancel() + return + } } case <-ctx.Done(): return @@ -255,8 +311,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddr } } -// sendSnapshot sends the initial snapshot of services to the connecting proxy. -// Only entries matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") @@ -289,7 +343,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec } func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { - services, err := s.serviceManager.GetGlobalServices(ctx) + var services []*rpservice.Service + var err error + if conn.accountID != nil { + services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID) + } else { + services, err = s.serviceManager.GetGlobalServices(ctx) + } if err != nil { return nil, fmt.Errorf("get services from store: %w", err) } @@ -318,8 +378,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * return mappings, nil } -// isProxyAddressValid validates a proxy address +// isProxyAddressValid validates a proxy address (domain name or IP address) func isProxyAddressValid(addr string) bool { + if addr == "" { + return false + } + if net.ParseIP(addr) != nil { + return true + } _, err := domain.ValidateDomains([]string{addr}) return err == nil } @@ -343,6 +409,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { accessLog := req.GetLog() + if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil { + return nil, err + } + fields := log.Fields{ "service_id": accessLog.GetServiceId(), "account_id": accessLog.GetAccountId(), @@ -380,11 +450,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // Management should call this when services are created/updated/removed. // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. +// BYOP proxies only receive updates for their own account's services. func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") + updateAccountIDs := make(map[string]struct{}) + for _, m := range update.Mapping { + if m.AccountId != "" { + updateAccountIDs[m.AccountId] = struct{}{} + } + } s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - resp := s.perProxyMessage(update, conn.proxyID) + connUpdate := update + if conn.accountID != nil && len(updateAccountIDs) > 0 { + if _, ok := updateAccountIDs[*conn.accountID]; !ok { + return true + } + filtered := filterMappingsForAccount(update.Mapping, *conn.accountID) + if len(filtered) == 0 { + return true + } + connUpdate = &proto.GetMappingUpdateResponse{ + Mapping: filtered, + InitialSyncComplete: update.InitialSyncComplete, + } + } + resp := s.perProxyMessage(connUpdate, conn.proxyID) if resp == nil { return true } @@ -398,6 +489,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes }) } +// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect. +func (s *ProxyServiceServer) ForceDisconnect(proxyID string) { + if connVal, ok := s.connectedProxies.Load(proxyID); ok { + conn := connVal.(*proxyConnection) + conn.cancel() + s.connectedProxies.Delete(proxyID) + log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy") + } +} + +func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping { + var filtered []*proto.ProxyMapping + for _, m := range mappings { + if m.AccountId == accountID { + filtered = append(filtered, m) + } + } + return filtered +} + // GetConnectedProxies returns a list of connected proxy IDs func (s *ProxyServiceServer) GetConnectedProxies() []string { var proxies []string @@ -466,6 +577,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd continue } conn := connVal.(*proxyConnection) + if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId { + continue + } if !proxyAcceptsMapping(conn, update) { log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) continue @@ -549,6 +663,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { log.WithContext(ctx).Debugf("failed to get service from store: %v", err) @@ -668,6 +786,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic // SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + accountID := req.GetAccountId() serviceID := req.GetServiceId() protoStatus := req.GetStatus() @@ -738,6 +860,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { // CreateProxyPeer handles proxy peer creation with one-time token authentication func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + serviceID := req.GetServiceId() accountID := req.GetAccountId() token := req.GetToken() @@ -792,6 +918,10 @@ func strPtr(s string) *string { } func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + redirectURL, err := url.Parse(req.GetRedirectUrl()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) @@ -920,21 +1050,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL // GenerateSessionToken creates a signed session JWT for the given domain and user. func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { - // Find the service by domain to get its signing key - services, err := s.serviceManager.GetGlobalServices(ctx) + service, err := s.getServiceByDomain(ctx, domain) if err != nil { - return "", fmt.Errorf("get services: %w", err) - } - - var service *rpservice.Service - for _, svc := range services { - if svc.Domain == domain { - service = svc - break - } - } - if service == nil { - return "", fmt.Errorf("service not found for domain: %s", domain) + return "", fmt.Errorf("service not found for domain %s: %w", domain, err) } if service.SessionPrivateKey == "" { @@ -1032,6 +1150,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } + if err := enforceAccountScope(ctx, service.AccountID); err != nil { + return nil, err + } + pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey) if err != nil { log.WithFields(log.Fields{ @@ -1115,18 +1237,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val } func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { - services, err := s.serviceManager.GetGlobalServices(ctx) - if err != nil { - return nil, fmt.Errorf("get services: %w", err) - } - - for _, service := range services { - if service.Domain == domain { - return service, nil - } - } - - return nil, fmt.Errorf("service not found for domain: %s", domain) + return s.serviceManager.GetServiceByDomain(ctx, domain) } func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { diff --git a/management/internals/shared/grpc/proxy_address_test.go b/management/internals/shared/grpc/proxy_address_test.go new file mode 100644 index 00000000000..824a5722602 --- /dev/null +++ b/management/internals/shared/grpc/proxy_address_test.go @@ -0,0 +1,29 @@ +package grpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsProxyAddressValid(t *testing.T) { + tests := []struct { + name string + addr string + valid bool + }{ + {name: "valid domain", addr: "eu.proxy.netbird.io", valid: true}, + {name: "valid subdomain", addr: "byop.proxy.example.com", valid: true}, + {name: "valid IPv4", addr: "10.0.0.1", valid: true}, + {name: "valid IPv4 public", addr: "203.0.113.10", valid: true}, + {name: "valid IPv6", addr: "::1", valid: true}, + {name: "valid IPv6 full", addr: "2001:db8::1", valid: true}, + {name: "empty string", addr: "", valid: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr)) + }) + } +} diff --git a/management/internals/shared/grpc/proxy_auth.go b/management/internals/shared/grpc/proxy_auth.go index dd593dfa079..9888e8eee18 100644 --- a/management/internals/shared/grpc/proxy_auth.go +++ b/management/internals/shared/grpc/proxy_auth.go @@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types return nil, status.Errorf(codes.Unauthenticated, "invalid token") } - // TODO: Enforce AccountID scope for "bring your own proxy" feature. - // Currently tokens are management-wide; AccountID field is reserved for future use. - if !token.IsValid() { return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked") } diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 0fa9a0dc1d4..46dad5b5608 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, return nil } +func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error { return nil } @@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} +func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) { + if m.err != nil { + return nil, m.err + } + for _, services := range m.proxiesByAccount { + for _, svc := range services { + if svc.Domain == domain { + return svc, nil + } + } + } + return nil, errors.New("service not found for domain: " + domain) +} + func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index de4e96d9375..d4755f7d5ac 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -12,9 +12,12 @@ import ( cachestore "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -313,6 +316,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { assert.Contains(t, err.Error(), "invalid state format") } +func scopedCtx(accountID string) context.Context { + token := &types.ProxyAccessToken{ + ID: "token-1", + AccountID: &accountID, + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func globalCtx() context.Context { + token := &types.ProxyAccessToken{ + ID: "token-global", + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-1") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-2") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) { + err := enforceAccountScope(globalCtx(), "acc-1") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "acc-2") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) { + err := enforceAccountScope(context.Background(), "acc-1") + assert.NoError(t, err) +} + func TestValidateState_RejectsInvalidHMAC(t *testing.T) { ctx := context.Background() pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index d1d7fc8b7fd..6cd95f988e7 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil) proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) @@ -318,13 +318,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error { return nil } @@ -340,6 +344,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co return nil, nil } +func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) { return nil, nil } @@ -348,6 +356,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time return nil } +func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error { + return nil +} + func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { return nil } diff --git a/management/server/account_test.go b/management/server/account_test.go index 756c4242168..435a1b949fe 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3074,7 +3074,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) if err != nil { return nil, nil, err diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 56b2d820354..b7a9db1d727 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -146,6 +147,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks if serviceManager != nil && reverseProxyDomainManager != nil { reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } + + proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router) + // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index c99acab63a8..30d8aa0e794 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -216,6 +216,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { nil, usersManager, nil, + nil, ) proxyService.SetServiceManager(&testServiceManager{store: testStore}) @@ -389,6 +390,10 @@ func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) er return nil } +func (m *testServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { return nil } @@ -435,6 +440,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri func (m *testServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 1a8b83c7eed..3c4ea98d02a 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { @@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0a716d08d07..64916455489 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4495,6 +4495,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e return nil } +func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var tokens []*types.ProxyAccessToken + result := tx.Where("account_id = ?", accountID).Find(&tokens) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error) + } + + return tokens, nil +} + +func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID) + if err != nil { + return false, err + } + return token.IsValid(), nil +} + +func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var token types.ProxyAccessToken + result := tx.Take(&token, idQueryCondition, tokenID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy access token not found") + } + return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error) + } + + return &token, nil +} + // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { result := s.db.Model(&types.ProxyAccessToken{}). @@ -5437,13 +5478,29 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { return nil } -// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist +func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error { + now := time.Now() + result := s.db. + Model(&proxy.Proxy{}). + Where("id = ?", proxyID). + Updates(map[string]interface{}{ + "status": proxy.StatusDisconnected, + "disconnected_at": now, + "last_seen": now, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error) + return status.Errorf(status.Internal, "failed to disconnect proxy") + } + return nil +} + func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { now := time.Now() result := s.db. Model(&proxy.Proxy{}). - Where("id = ? AND status = ?", proxyID, "connected"). + Where("id = ? AND status = ?", proxyID, proxy.StatusConnected). Update("last_seen", now) if result.Error != nil { @@ -5469,13 +5526,15 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd return nil } -// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies +// GetActiveProxyClusterAddresses returns the unique cluster addresses of active +// shared proxies (those without an account scope). BYOP cluster addresses are +// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them. func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { var addresses []string result := s.db. Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5487,13 +5546,75 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string return addresses, nil } -// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count. -func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + var addresses []string + + result := s.db. + Model(&proxy.Proxy{}). + Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). + Distinct("cluster_address"). + Pluck("cluster_address", &addresses) + + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account") + } + + return addresses, nil +} + +func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + var p proxy.Proxy + result := s.db.Where("account_id = ?", accountID).Take(&p) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy not found for account") + } + return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error) + } + return &p, nil +} + +func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + var count int64 + result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error) + } + return count, nil +} + +func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + var count int64 + result := s.db. + Model(&proxy.Proxy{}). + Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID). + Count(&count) + if result.Error != nil { + return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error) + } + return count > 0, nil +} + +func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + result := s.db. + Where("cluster_address = ? AND account_id = ?", clusterAddress, accountID). + Delete(&proxy.Proxy{}) + if result.Error != nil { + return status.Errorf(status.Internal, "delete account cluster: %v", result.Error) + } + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "cluster not found") + } + return nil +} + +func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { var clusters []proxy.Cluster result := s.db.Model(&proxy.Proxy{}). - Select("cluster_address as address, COUNT(*) as connected_proxies"). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted"). + Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)", + proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID). Group("cluster_address"). Scan(&clusters) diff --git a/management/server/store/store.go b/management/server/store/store.go index 0d8b0678a99..8bb84c2bbf7 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -114,6 +114,9 @@ type Store interface { GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error RevokeProxyAccessToken(ctx context.Context, tokenID string) error MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error @@ -284,13 +287,19 @@ type Store interface { DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error + DisconnectProxy(ctx context.Context, proxyID string) error UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) @@ -494,6 +503,9 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key") }, + func(db *gorm.DB) error { + return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique") + }, } } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index beee13d9631..d199a1210ca 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -165,19 +165,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } -// GetClusterSupportsCrowdSec mocks base method. -func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. -func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) -} // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -236,6 +223,21 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID) } +// CountProxiesByAccountID mocks base method. +func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID. +func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID) +} + // CreateAccessLog mocks base method. func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error { m.ctrl.T.Helper() @@ -574,6 +576,20 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID) } +// DeleteAccountCluster mocks base method. +func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) +} + // DeleteRoute mocks base method. func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error { m.ctrl.T.Helper() @@ -714,6 +730,20 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID) } +// DisconnectProxy mocks base method. +func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DisconnectProxy indicates an expected call of DisconnectProxy. +func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID) +} + // EphemeralServiceExists mocks base method. func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) { m.ctrl.T.Helper() @@ -1300,19 +1330,34 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) } +// GetActiveProxyClusterAddressesForAccount mocks base method. +func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount. +func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID) +} + // GetActiveProxyClusters mocks base method. -func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx) + ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID) ret0, _ := ret[0].([]proxy.Cluster) ret1, _ := ret[1].(error) return ret0, ret1 } // GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. -func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID) } // GetAllAccounts mocks base method. @@ -1388,6 +1433,20 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) } +// GetClusterSupportsCrowdSec mocks base method. +func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. +func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) +} + // GetClusterSupportsCustomPorts mocks base method. func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { m.ctrl.T.Helper() @@ -1957,6 +2016,51 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken) } +// GetProxyAccessTokenByID mocks base method. +func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID) + ret0, _ := ret[0].(*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID) +} + +// GetProxyAccessTokensByAccountID mocks base method. +func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID) +} + +// GetProxyByAccountID mocks base method. +func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID) + ret0, _ := ret[0].(*proxy.Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyByAccountID indicates an expected call of GetProxyByAccountID. +func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID) +} + // GetResourceGroups mocks base method. func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) { m.ctrl.T.Helper() @@ -2389,6 +2493,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) } +// IsClusterAddressConflicting mocks base method. +func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting. +func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID) +} + // IsPrimaryAccount mocks base method. func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { m.ctrl.T.Helper() @@ -2405,6 +2524,21 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID) } +// IsProxyAccessTokenValid mocks base method. +func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid. +func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID) +} + // ListCustomDomains mocks base method. func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) { m.ctrl.T.Helper() diff --git a/proxy/management_byop_integration_test.go b/proxy/management_byop_integration_test.go new file mode 100644 index 00000000000..c0fbe682ab2 --- /dev/null +++ b/proxy/management_byop_integration_test.go @@ -0,0 +1,409 @@ +package proxy + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + grpcstatus "google.golang.org/grpc/status" + + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type byopTestSetup struct { + store store.Store + proxyService *nbgrpc.ProxyServiceServer + grpcServer *grpc.Server + grpcAddr string + cleanup func() + + accountA string + accountB string + accountAToken types.PlainProxyToken + accountBToken types.PlainProxyToken + accountACluster string + accountBCluster string +} + +func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup { + t.Helper() + ctx := context.Background() + + testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + + accountAID := "byop-account-a" + accountBID := "byop-account-b" + + for _, acc := range []*types.Account{ + {Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + {Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + } { + require.NoError(t, testStore.SaveAccount(ctx, acc)) + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pubKey := base64.StdEncoding.EncodeToString(pub) + privKey := base64.StdEncoding.EncodeToString(priv) + + clusterA := "byop-a.proxy.test" + clusterB := "byop-b.proxy.test" + + services := []*service.Service{ + { + ID: "svc-a1", AccountID: accountAID, Name: "App A1", + Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-a2", AccountID: accountAID, Name: "App A2", + Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-b1", AccountID: accountBID, Name: "App B1", + Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}}, + }, + } + for _, svc := range services { + require.NoError(t, testStore.CreateService(ctx, svc)) + } + + tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken)) + + tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken)) + + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) + + meter := noop.NewMeterProvider().Meter("test") + realProxyManager, err := proxymanager.NewManager(testStore, meter) + require.NoError(t, err) + + oidcConfig := nbgrpc.ProxyOIDCConfig{ + Issuer: "https://fake-issuer.example.com", + ClientID: "test-client", + HMACKey: []byte("test-hmac-key"), + } + + usersManager := users.NewManager(testStore) + + proxyService := nbgrpc.NewProxyServiceServer( + &testAccessLogManager{}, + tokenStore, + pkceStore, + oidcConfig, + nil, + usersManager, + realProxyManager, + nil, + ) + + svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} + proxyService.SetServiceManager(svcMgr) + + proxyController := &testProxyController{} + proxyService.SetProxyController(proxyController) + + _, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor)) + proto.RegisterProxyServiceServer(grpcServer, proxyService) + + go func() { + if err := grpcServer.Serve(lis); err != nil { + t.Logf("gRPC server error: %v", err) + } + }() + + return &byopTestSetup{ + store: testStore, + proxyService: proxyService, + grpcServer: grpcServer, + grpcAddr: lis.Addr().String(), + cleanup: func() { + grpcServer.GracefulStop() + authClose() + storeCleanup() + }, + accountA: accountAID, + accountB: accountBID, + accountAToken: tokenA.PlainToken, + accountBToken: tokenB.PlainToken, + accountACluster: clusterA, + accountBCluster: clusterB, + } +} + +func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context { + md := metadata.Pairs("authorization", "Bearer "+string(token)) + return metadata.NewOutgoingContext(ctx, md) +} + +func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { + t.Helper() + var mappings []*proto.ProxyMapping + for { + msg, err := stream.Recv() + require.NoError(t, err) + mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } + } + return mappings +} + +func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings := receiveBYOPMappings(t, stream) + + assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services") + for _, m := range mappings { + assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A") + t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId()) + } + + ids := map[string]bool{} + for _, m := range mappings { + ids[m.GetId()] = true + } + assert.True(t, ids["svc-a1"], "should contain svc-a1") + assert.True(t, ids["svc-a2"], "should contain svc-a2") + assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1") +} + +func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-b", + Version: "test-v1", + Address: setup.accountBCluster, + }) + require.NoError(t, err) + + mappings := receiveBYOPMappings(t, stream) + + assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service") + assert.Equal(t, "svc-b1", mappings[0].GetId()) + assert.Equal(t, setup.accountB, mappings[0].GetAccountId()) +} + +func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-first", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings1 := receiveBYOPMappings(t, stream1) + assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services") + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-second", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings2 := receiveBYOPMappings(t, stream2) + assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services") + for _, m := range mappings2 { + assert.Equal(t, setup.accountA, m.GetAccountId()) + } +} + +func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-cluster", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _ = receiveBYOPMappings(t, stream1) + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-b-conflict", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _, err = stream2.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists") + t.Logf("expected rejection: %s", st.Message()) +} + +func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + proxyID := "byop-proxy-reconnect" + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + firstMappings := receiveBYOPMappings(t, stream1) + cancel1() + + time.Sleep(200 * time.Millisecond) + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + secondMappings := receiveBYOPMappings(t, stream2) + + assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings") + + firstIDs := map[string]bool{} + for _, m := range firstMappings { + firstIDs[m.GetId()] = true + } + for _, m := range secondMappings { + assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId()) + } +} + +func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "no-auth-proxy", + Version: "test-v1", + Address: "some.cluster.io", + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unauthenticated, st.Code()) +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 4b1ecf922f0..9374a17c74e 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/base64" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -140,6 +141,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { nil, usersManager, proxyManager, + nil, ) // Use store-backed service manager @@ -201,7 +203,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { +func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *nbproxy.Capabilities) error { return nil } @@ -217,6 +219,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin return nil, nil } +func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) { return nil, nil } @@ -237,6 +243,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro return nil } +func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) { + return nil, fmt.Errorf("proxy not found for account %s", accountID) +} + +func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error { + return nil +} + // testProxyController is a mock implementation of rpservice.ProxyController for testing. type testProxyController struct{} @@ -290,6 +312,10 @@ func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID return nil } +func (m *storeBackedServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { return nil } @@ -336,6 +362,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} +func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 0b855db676b..63e149e4924 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3323,10 +3323,64 @@ components: example: false required: - enabled + ProxyTokenRequest: + type: object + properties: + name: + type: string + description: Human-readable token name + example: "my-proxy-token" + expires_in: + type: integer + minimum: 0 + description: Token expiration in seconds (0 = never expires) + example: 0 + required: + - name + ProxyToken: + type: object + properties: + id: + type: string + name: + type: string + expires_at: + type: string + format: date-time + created_at: + type: string + format: date-time + last_used: + type: string + format: date-time + revoked: + type: boolean + required: + - id + - name + - created_at + - revoked + ProxyTokenCreated: + type: object + description: Returned on creation — plain_token is shown only once + allOf: + - $ref: '#/components/schemas/ProxyToken' + - type: object + properties: + plain_token: + type: string + description: The plain text token (shown only once) + example: "nbx_abc123..." + required: + - plain_token ProxyCluster: type: object description: A proxy cluster represents a group of proxy nodes serving the same address properties: + id: + type: string + description: Unique identifier of a proxy in this cluster + example: "chlfq4q5r8kc73b0qjpg" address: type: string description: Cluster address used for CNAME targets @@ -3335,9 +3389,15 @@ components: type: integer description: Number of proxy nodes connected in this cluster example: 3 + self_hosted: + type: boolean + description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner + example: false required: + - id - address - connected_proxies + - self_hosted ReverseProxyDomainType: type: string description: Type of Reverse Proxy Domain @@ -11317,6 +11377,111 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/clusters/{clusterAddress}: + delete: + summary: Delete a self-hosted proxy cluster + description: Removes all self-hosted (BYOP) proxy registrations for the given cluster address owned by the account. + tags: [ Services ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: clusterAddress + required: true + schema: + type: string + description: The address of the proxy cluster + responses: + '200': + description: Proxy cluster deleted successfully + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/proxy-tokens: + get: + summary: List Proxy Tokens + description: Returns all proxy access tokens for the account + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy tokens + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyToken' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Proxy Token + description: Generate an account-scoped proxy access token for self-hosted proxy registration + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenRequest' + responses: + '200': + description: Proxy token created (plain token shown once) + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenCreated' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/proxy-tokens/{tokenId}: + delete: + summary: Revoke a Proxy Token + description: Revoke an account-scoped proxy access token + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: tokenId + required: true + schema: + type: string + description: The unique identifier of the proxy token + responses: + '200': + description: Token revoked + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 0317b8183cb..d9e136e0836 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -3761,11 +3761,49 @@ type ProxyAccessLogsResponse struct { // ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address type ProxyCluster struct { + // Id Unique identifier of a proxy in this cluster + Id string `json:"id"` + // Address Cluster address used for CNAME targets Address string `json:"address"` // ConnectedProxies Number of proxy nodes connected in this cluster ConnectedProxies int `json:"connected_proxies"` + + // SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner + SelfHosted bool `json:"self_hosted"` +} + +// ProxyToken defines model for ProxyToken. +type ProxyToken struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenCreated defines model for ProxyTokenCreated. +type ProxyTokenCreated struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + + // PlainToken The plain text token (shown only once) + PlainToken string `json:"plain_token"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenRequest defines model for ProxyTokenRequest. +type ProxyTokenRequest struct { + // ExpiresIn Token expiration in seconds (0 = never expires) + ExpiresIn *int `json:"expires_in,omitempty"` + + // Name Human-readable token name + Name string `json:"name"` } // Resource defines model for Resource. @@ -5127,6 +5165,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate // PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType. type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest +// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType. +type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest + // PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType. type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest