From c0f9d6af5f288180046d8d8d5b6168d9b14fbd65 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Sat, 20 May 2023 21:16:45 +0100 Subject: [PATCH 1/2] Improve Kubernetes access test coverage This PR improves the Kubernetes Access test coverage from 61% to >73%. There are some other code paths that require special attention such as watching with filtering but those require more work. --- lib/cloud/clients.go | 2 +- lib/kube/grpc/grpc_test.go | 2 +- lib/kube/proxy/cluster_details.go | 27 ++- lib/kube/proxy/exec_test.go | 2 +- lib/kube/proxy/kube_creds.go | 53 ++++- lib/kube/proxy/kube_creds_test.go | 274 ++++++++++++++++++++++ lib/kube/proxy/moderated_sessions_test.go | 4 +- lib/kube/proxy/utils_testing.go | 153 ++++++++++-- 8 files changed, 472 insertions(+), 45 deletions(-) create mode 100644 lib/kube/proxy/kube_creds_test.go diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index beccd6d28f614..f099d075c38c7 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -964,7 +964,7 @@ func (c *TestCloudClients) GetAzurePostgresClient(subscription string) (azure.DB // GetAzureKubernetesClient returns an AKS client for the specified subscription func (c *TestCloudClients) GetAzureKubernetesClient(subscription string) (azure.AKSClient, error) { - if len(c.AzurePostgresPerSub) != 0 { + if len(c.AzureAKSClientPerSub) != 0 { return c.AzureAKSClientPerSub[subscription], nil } return c.AzureAKSClient, nil diff --git a/lib/kube/grpc/grpc_test.go b/lib/kube/grpc/grpc_test.go index 86ac74b9b8299..fc0a22319a6b1 100644 --- a/lib/kube/grpc/grpc_test.go +++ b/lib/kube/grpc/grpc_test.go @@ -419,7 +419,7 @@ func initGRPCServer(t *testing.T, testCtx *kubeproxy.TestContext, listener net.L Signer: proxyAuthClient, AccessPoint: proxyAuthClient, Emitter: testCtx.Emitter, - KubeProxyAddr: testCtx.KubeServiceAddress(), + KubeProxyAddr: testCtx.KubeProxyAddress(), Authz: testCtx.Authz, }, ) diff --git a/lib/kube/proxy/cluster_details.go b/lib/kube/proxy/cluster_details.go index 14bca0250cfa0..1a41313dc83aa 100644 --- a/lib/kube/proxy/cluster_details.go +++ b/lib/kube/proxy/cluster_details.go @@ -89,22 +89,25 @@ func getKubeClusterCredentials(ctx context.Context, cloudClients cloud.Clients, case cluster.IsKubeconfig(): return getStaticCredentialsFromKubeconfig(ctx, cluster, log, checker) case cluster.IsAzure(): - return getAzureCredentials(ctx, cloudClients, cluster, log, checker) + return getAzureCredentials(ctx, cloudClients, dynamicCredsConfig{kubeCluster: cluster, log: log, checker: checker}) case cluster.IsAWS(): - return getAWSCredentials(ctx, cloudClients, cluster, log, checker) + return getAWSCredentials(ctx, cloudClients, dynamicCredsConfig{kubeCluster: cluster, log: log, checker: checker}) case cluster.IsGCP(): - return getGCPCredentials(ctx, cloudClients, cluster, log, checker) + return getGCPCredentials(ctx, cloudClients, dynamicCredsConfig{kubeCluster: cluster, log: log, checker: checker}) default: return nil, trace.BadParameter("authentication method provided for cluster %q not supported", cluster.GetName()) } } // getAzureCredentials creates a dynamicCreds that generates and updates the access credentials to a AKS Kubernetes cluster. -func getAzureCredentials(ctx context.Context, cloudClients cloud.Clients, cluster types.KubeCluster, log *logrus.Entry, checker servicecfg.ImpersonationPermissionsChecker) (*dynamicKubeCreds, error) { +func getAzureCredentials(ctx context.Context, cloudClients cloud.Clients, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) { // create a client that returns the credentials for kubeCluster - client := azureRestConfigClient(cloudClients) + cfg.client = azureRestConfigClient(cloudClients) - creds, err := newDynamicKubeCreds(ctx, cluster, log, client, checker) + creds, err := newDynamicKubeCreds( + ctx, + cfg, + ) return creds, trace.Wrap(err) } @@ -126,10 +129,10 @@ func azureRestConfigClient(cloudClients cloud.Clients) dynamicCredsClient { } // getAWSCredentials creates a dynamicKubeCreds that generates and updates the access credentials to a EKS kubernetes cluster. -func getAWSCredentials(ctx context.Context, cloudClients cloud.Clients, cluster types.KubeCluster, log *logrus.Entry, checker servicecfg.ImpersonationPermissionsChecker) (*dynamicKubeCreds, error) { +func getAWSCredentials(ctx context.Context, cloudClients cloud.Clients, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) { // create a client that returns the credentials for kubeCluster - client := getAWSClientRestConfig(cloudClients) - creds, err := newDynamicKubeCreds(ctx, cluster, log, client, checker) + cfg.client = getAWSClientRestConfig(cloudClients) + creds, err := newDynamicKubeCreds(ctx, cfg) return creds, trace.Wrap(err) } @@ -237,10 +240,10 @@ func getStaticCredentialsFromKubeconfig(ctx context.Context, cluster types.KubeC } // getGCPCredentials creates a dynamicKubeCreds that generates and updates the access credentials to a GKE kubernetes cluster. -func getGCPCredentials(ctx context.Context, cloudClients cloud.Clients, cluster types.KubeCluster, log *logrus.Entry, checker servicecfg.ImpersonationPermissionsChecker) (*dynamicKubeCreds, error) { +func getGCPCredentials(ctx context.Context, cloudClients cloud.Clients, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) { // create a client that returns the credentials for kubeCluster - client := gcpRestConfigClient(cloudClients) - creds, err := newDynamicKubeCreds(ctx, cluster, log, client, checker) + cfg.client = gcpRestConfigClient(cloudClients) + creds, err := newDynamicKubeCreds(ctx, cfg) return creds, trace.Wrap(err) } diff --git a/lib/kube/proxy/exec_test.go b/lib/kube/proxy/exec_test.go index 844a2d9746f6b..60742a6b3e62c 100644 --- a/lib/kube/proxy/exec_test.go +++ b/lib/kube/proxy/exec_test.go @@ -185,7 +185,7 @@ func TestExecKubeService(t *testing.T) { req, err := generateExecRequest( generateExecRequestConfig{ - addr: testCtx.KubeServiceAddress(), + addr: testCtx.KubeProxyAddress(), podName: podName, podNamespace: podNamespace, containerName: podContainerName, diff --git a/lib/kube/proxy/kube_creds.go b/lib/kube/proxy/kube_creds.go index 8e67cb3b9771e..dd2bd54ba8be1 100644 --- a/lib/kube/proxy/kube_creds.go +++ b/lib/kube/proxy/kube_creds.go @@ -24,7 +24,9 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/transport" @@ -151,7 +153,7 @@ type dynamicCredsClient func(ctx context.Context, cluster types.KubeCluster) (cf // function and renews them whenever they are about to expire. type dynamicKubeCreds struct { ctx context.Context - renewTicker *time.Ticker + renewTicker clockwork.Ticker staticCreds *staticKubeCreds log logrus.FieldLogger closeC chan struct{} @@ -160,19 +162,50 @@ type dynamicKubeCreds struct { sync.RWMutex } +// dynamicCredsConfig contains configuration for dynamicKubeCreds. +type dynamicCredsConfig struct { + kubeCluster types.KubeCluster + log logrus.FieldLogger + client dynamicCredsClient + checker servicecfg.ImpersonationPermissionsChecker + clock clockwork.Clock +} + +func (d *dynamicCredsConfig) checkAndSetDefaults() error { + if d.kubeCluster == nil { + return trace.BadParameter("missing kubeCluster") + } + if d.log == nil { + return trace.BadParameter("missing log") + } + if d.client == nil { + return trace.BadParameter("missing client") + } + if d.checker == nil { + return trace.BadParameter("missing checker") + } + if d.clock == nil { + d.clock = clockwork.NewRealClock() + } + return nil +} + // newDynamicKubeCreds creates a new dynamicKubeCreds refresher and starts the // credentials refresher mechanism to renew them once they are about to expire. -func newDynamicKubeCreds(ctx context.Context, kubeCluster types.KubeCluster, log logrus.FieldLogger, client dynamicCredsClient, checker servicecfg.ImpersonationPermissionsChecker) (*dynamicKubeCreds, error) { +func newDynamicKubeCreds(ctx context.Context, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) { + if err := cfg.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } dyn := &dynamicKubeCreds{ ctx: ctx, - log: log, + log: cfg.log, closeC: make(chan struct{}), - client: client, - renewTicker: time.NewTicker(time.Hour), - checker: checker, + client: cfg.client, + renewTicker: cfg.clock.NewTicker(time.Hour), + checker: cfg.checker, } - if err := dyn.renewClientset(kubeCluster); err != nil { + if err := dyn.renewClientset(cfg.kubeCluster); err != nil { return nil, trace.Wrap(err) } @@ -181,9 +214,9 @@ func newDynamicKubeCreds(ctx context.Context, kubeCluster types.KubeCluster, log select { case <-dyn.closeC: return - case <-dyn.renewTicker.C: - if err := dyn.renewClientset(kubeCluster); err != nil { - log.WithError(err).Warnf("Unable to renew cluster %q credentials.", kubeCluster.GetName()) + case <-dyn.renewTicker.Chan(): + if err := dyn.renewClientset(cfg.kubeCluster); err != nil { + log.WithError(err).Warnf("Unable to renew cluster %q credentials.", cfg.kubeCluster.GetName()) } } } diff --git a/lib/kube/proxy/kube_creds_test.go b/lib/kube/proxy/kube_creds_test.go new file mode 100644 index 0000000000000..f39aa931cb20a --- /dev/null +++ b/lib/kube/proxy/kube_creds_test.go @@ -0,0 +1,274 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "context" + "encoding/base64" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go/service/eks/eksiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + authztypes "k8s.io/client-go/kubernetes/typed/authorization/v1" + "k8s.io/client-go/rest" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/cloud/gcp" + "github.com/gravitational/teleport/lib/fixtures" +) + +func Test_DynamicKubeCreds(t *testing.T) { + t.Parallel() + log := logrus.New() + log.SetOutput(io.Discard) + awsKube, err := types.NewKubernetesClusterV3( + types.Metadata{ + Name: "aws", + }, + types.KubernetesClusterSpecV3{ + AWS: types.KubeAWS{ + Region: "us-west-2", + AccountID: "1234567890", + Name: "eks", + }, + }, + ) + require.NoError(t, err) + gkeKube, err := types.NewKubernetesClusterV3( + types.Metadata{ + Name: "gke", + }, + types.KubernetesClusterSpecV3{ + GCP: types.KubeGCP{ + Location: "us-west-2", + ProjectID: "1234567890", + Name: "eks", + }, + }, + ) + require.NoError(t, err) + aksKube, err := types.NewKubernetesClusterV3( + types.Metadata{ + Name: "aks", + }, + types.KubernetesClusterSpecV3{ + Azure: types.KubeAzure{ + TenantID: "id", + ResourceGroup: "1234567890", + ResourceName: "eks", + SubscriptionID: "12345", + }, + }, + ) + require.NoError(t, err) + + // mock sts client + stsMock := &stsMockClient{ + // u is used to presign the request + // here we just verify the pre-signed request includes this url. + u: &url.URL{ + Scheme: "https", + Host: "sts.amazonaws.com", + Path: "/?Action=GetCallerIdentity&Version=2011-06-15", + }, + } + + // mock clients + cloudclients := &cloud.TestCloudClients{ + STS: stsMock, + EKS: &eksMockClient{ + cluster: awsKube, + t: t, + }, + GCPGKE: &gkeMockCLient{kube: gkeKube, t: t}, + AzureAKSClientPerSub: map[string]azure.AKSClient{ + "12345": &azureMockCLient{ + kube: aksKube, + t: t, + }, + }, + } + + type args struct { + cluster types.KubeCluster + client dynamicCredsClient + validateBearerToken func(string) error + } + tests := []struct { + name string + args args + wantAddr string + }{ + { + name: "aws eks cluster", + args: args{ + cluster: awsKube, + client: getAWSClientRestConfig(cloudclients), + validateBearerToken: func(token string) error { + if token == "" { + return trace.BadParameter("missing bearer token") + } + tokens := strings.Split(token, ".") + if len(tokens) != 2 { + return trace.BadParameter("invalid bearer token") + } + if tokens[0] != "k8s-aws-v1" { + return trace.BadParameter("token must start with k8s-aws-v1") + } + dec, err := base64.RawStdEncoding.DecodeString(tokens[1]) + if err != nil { + return trace.Wrap(err) + } + if string(dec) != stsMock.u.String() { + return trace.BadParameter("invalid token payload") + } + return nil + }, + }, + wantAddr: "api.eks.us-west-2.amazonaws.com:443", + }, + { + name: "gcp gke cluster", + args: args{ + cluster: gkeKube, + client: gcpRestConfigClient(cloudclients), + validateBearerToken: func(_ string) error { return nil }, + }, + wantAddr: "api.gke.google.com:443", + }, + { + name: "azure aks cluster", + args: args{ + cluster: aksKube, + client: azureRestConfigClient(cloudclients), + validateBearerToken: func(_ string) error { return nil }, + }, + wantAddr: "api.aks.microsoft.com:443", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeClock := clockwork.NewFakeClock() + got, err := newDynamicKubeCreds( + context.Background(), + dynamicCredsConfig{ + clock: fakeClock, + checker: func(_ context.Context, _ string, + _ authztypes.SelfSubjectAccessReviewInterface, + ) error { + return nil + }, + log: log, + kubeCluster: tt.args.cluster, + client: tt.args.client, + }, + ) + require.NoError(t, err) + require.Equal(t, got.getKubeRestConfig().CAData, []byte(fixtures.TLSCACertPEM)) + require.NoError(t, tt.args.validateBearerToken(got.getKubeRestConfig().BearerToken)) + require.Equal(t, got.getTargetAddr(), tt.wantAddr) + require.NoError(t, got.close()) + }) + } +} + +type eksMockClient struct { + eksiface.EKSAPI + cluster types.KubeCluster + t *testing.T +} + +func (e *eksMockClient) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) { + require.Equal(e.t, e.cluster.GetAWSConfig().Name, *req.Name) + return &eks.DescribeClusterOutput{ + Cluster: &eks.Cluster{ + Endpoint: aws.String("https://api.eks.us-west-2.amazonaws.com"), + Name: req.Name, + CertificateAuthority: &eks.Certificate{ + Data: aws.String(base64.RawStdEncoding.EncodeToString([]byte(fixtures.TLSCACertPEM))), + }, + }, + }, nil +} + +type stsMockClient struct { + stsiface.STSAPI + u *url.URL +} + +func (s *stsMockClient) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) { + return &request.Request{ + HTTPRequest: &http.Request{ + Header: http.Header{}, + URL: s.u, + }, + Operation: &request.Operation{}, + Handlers: request.Handlers{}, + }, nil +} + +type gkeMockCLient struct { + gcp.GKEClient + kube types.KubeCluster + t *testing.T +} + +func (g *gkeMockCLient) GetClusterRestConfig(ctx context.Context, cfg gcp.ClusterDetails) (*rest.Config, time.Time, error) { + require.Equal(g.t, g.kube.GetGCPConfig().Name, cfg.Name) + require.Equal(g.t, g.kube.GetGCPConfig().ProjectID, cfg.ProjectID) + require.Equal(g.t, g.kube.GetGCPConfig().Location, cfg.Location) + return &rest.Config{ + Host: "https://api.gke.google.com", + TLSClientConfig: rest.TLSClientConfig{ + CAData: []byte(fixtures.TLSCACertPEM), + }, + }, time.Now(), nil +} + +type azureMockCLient struct { + azure.AKSClient + kube types.KubeCluster + t *testing.T +} + +func (a *azureMockCLient) ClusterCredentials(ctx context.Context, cfg azure.ClusterCredentialsConfig) (*rest.Config, time.Time, error) { + require.Equal(a.t, a.kube.GetAzureConfig().ResourceName, cfg.ResourceName) + require.Equal(a.t, a.kube.GetAzureConfig().ResourceGroup, cfg.ResourceGroup) + require.Equal(a.t, a.kube.GetAzureConfig().TenantID, cfg.TenantID) + + return &rest.Config{ + Host: "https://api.aks.microsoft.com", + TLSClientConfig: rest.TLSClientConfig{ + CAData: []byte(fixtures.TLSCACertPEM), + }, + }, time.Now(), nil +} diff --git a/lib/kube/proxy/moderated_sessions_test.go b/lib/kube/proxy/moderated_sessions_test.go index 4e515c3d48deb..ca75fef1534b1 100644 --- a/lib/kube/proxy/moderated_sessions_test.go +++ b/lib/kube/proxy/moderated_sessions_test.go @@ -268,7 +268,7 @@ func TestModeratedSessions(t *testing.T) { } req, err := generateExecRequest( generateExecRequestConfig{ - addr: testCtx.KubeServiceAddress(), + addr: testCtx.KubeProxyAddress(), podName: podName, podNamespace: podNamespace, containerName: podContainerName, @@ -653,7 +653,7 @@ func TestInteractiveSessionsNoAuth(t *testing.T) { } req, err := generateExecRequest( generateExecRequestConfig{ - addr: testCtx.KubeServiceAddress(), + addr: testCtx.KubeProxyAddress(), podName: podName, podNamespace: podNamespace, containerName: podContainerName, diff --git a/lib/kube/proxy/utils_testing.go b/lib/kube/proxy/utils_testing.go index 723f4be51b06b..d103b662f178e 100644 --- a/lib/kube/proxy/utils_testing.go +++ b/lib/kube/proxy/utils_testing.go @@ -23,12 +23,14 @@ import ( "net/http" "net/url" "path/filepath" + "strings" "testing" "time" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/client-go/kubernetes" @@ -50,6 +52,8 @@ import ( "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/kube/proxy/streamproto" "github.com/gravitational/teleport/lib/limiter" + "github.com/gravitational/teleport/lib/multiplexer" + "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" sessPkg "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/tlsca" @@ -63,9 +67,11 @@ type TestContext struct { AuthClient *auth.Client Authz authz.Authorizer KubeServer *TLSServer + KubeProxy *TLSServer Emitter *eventstest.ChannelEmitter Context context.Context - listener net.Listener + kubeServerListener net.Listener + kubeProxyListener net.Listener cancel context.CancelFunc heartbeatCtx context.Context heartbeatCancel context.CancelFunc @@ -161,7 +167,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // TLS config for kube proxy and Kube service. serverIdentity, err := auth.NewServerIdentity(authServer.AuthServer, testCtx.HostID, types.RoleKube) require.NoError(t, err) - tlsConfig, err := serverIdentity.TLSConfig(nil) + kubeServiceTLSConfig, err := serverIdentity.TLSConfig(nil) require.NoError(t, err) // Create test audit events emitter. @@ -189,6 +195,13 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo features = cfg.ClusterFeatures } + testCtx.kubeServerListener, err = net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + testCtx.kubeProxyListener, err = net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + log := logrus.New() + log.SetLevel(logrus.DebugLevel) + // Create kubernetes service server. testCtx.KubeServer, err = NewTLSServer(TLSServerConfig{ ForwarderConfig: ForwarderConfig{ @@ -222,7 +235,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo ClusterFeatures: features, }, DynamicLabels: nil, - TLS: tlsConfig, + TLS: kubeServiceTLSConfig.Clone(), AccessPoint: client, LimiterConfig: limiter.Config{ MaxConnections: 1000, @@ -252,28 +265,111 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo GetRotation: func(role types.SystemRole) (*types.Rotation, error) { return &types.Rotation{}, nil }, ResourceMatchers: cfg.ResourceMatchers, OnReconcile: cfg.OnReconcile, + Log: log, + }) + require.NoError(t, err) + + // Create kubernetes proxy server. + kubeServersWatcher, err := services.NewKubeServerWatcher( + testCtx.Context, + services.KubeServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentKube, + Client: client, + }, + }, + ) + require.NoError(t, err) + t.Cleanup(kubeServersWatcher.Close) + + // TLS config for kube proxy and Kube service. + proxyServerIdentity, err := auth.NewServerIdentity(authServer.AuthServer, testCtx.HostID, types.RoleProxy) + require.NoError(t, err) + proxyTLSConfig, err := proxyServerIdentity.TLSConfig(nil) + require.NoError(t, err) + // Create kubernetes service server. + testCtx.KubeProxy, err = NewTLSServer(TLSServerConfig{ + ForwarderConfig: ForwarderConfig{ + ReverseTunnelSrv: &reversetunnel.FakeServer{ + Sites: []reversetunnel.RemoteSite{ + &fakeRemoteSite{ + FakeRemoteSite: reversetunnel.NewFakeRemoteSite(testCtx.ClusterName, client), + idToAddr: map[string]string{ + testCtx.HostID: testCtx.kubeServerListener.Addr().String(), + }, + }, + }, + }, + Namespace: apidefaults.Namespace, + Keygen: keyGen, + ClusterName: testCtx.ClusterName, + Authz: testCtx.Authz, + // fileStreamer continues to write events after the server is shutdown and + // races against os.RemoveAll leading the test to fail. + // Using "node-sync" mode to write the events and session recordings + // directly to AuthClient solves the issue. + // We wrap the AuthClient with an events.TeeStreamer to send non-disk + // events like session.end to testCtx.emitter as well. + AuthClient: &fakeClient{ClientI: client, closeC: testCtx.closeSessionTrackers}, + // StreamEmitter is required although not used because we are using + // "node-sync" as session recording mode. + StreamEmitter: testCtx.Emitter, + DataDir: t.TempDir(), + CachingAuthClient: client, + HostID: testCtx.HostID, + Context: testCtx.Context, + KubeServiceType: ProxyService, + Component: teleport.ComponentKube, + LockWatcher: testCtx.lockWatcher, + Clock: clockwork.NewRealClock(), + ClusterFeatures: features, + ConnTLSConfig: proxyTLSConfig.Clone(), + PROXYSigner: &multiplexer.PROXYSigner{}, + }, + TLS: proxyTLSConfig.Clone(), + AccessPoint: client, + KubernetesServersWatcher: kubeServersWatcher, + LimiterConfig: limiter.Config{ + MaxConnections: 1000, + MaxNumberOfUsers: 1000, + }, + Log: log, }) require.NoError(t, err) // Waits for len(clusters) heartbeats to start waitForHeartbeats := len(cfg.Clusters) - testCtx.startKubeService(t) - + testCtx.startKubeServices(t) + // Wait for all clusters to be registered. for i := 0; i < waitForHeartbeats; i++ { <-heartbeatsWaitChannel } + // Wait for kube servers to be initialized. + kubeServersWatcher.WaitInitialization() + // Ensure watcher has the correct list of clusters. + require.Eventually(t, func() bool { + kubeServers, err := kubeServersWatcher.GetKubernetesServers(context.Background()) + return err == nil && len(kubeServers) == len(cfg.Clusters) + }, 3*time.Second, time.Millisecond*100) + return testCtx } -// startKubeService starts kube service to handle connections. -func (c *TestContext) startKubeService(t *testing.T) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - c.listener = listener +// startKubeServices starts kube service and kube proxy to handle connections. +func (c *TestContext) startKubeServices(t *testing.T) { go func() { - err := c.KubeServer.Serve(listener) + err := c.KubeServer.Serve(c.kubeServerListener) + // ignore server closed error returned when .Close is called. + if errors.Is(err, http.ErrServerClosed) { + return + } + assert.NoError(t, err) + }() + + go func() { + err := c.KubeProxy.Serve(c.kubeProxyListener) // ignore server closed error returned when .Close is called. if errors.Is(err, http.ErrServerClosed) { return @@ -288,16 +384,17 @@ func (c *TestContext) Close() error { // errors when deprovisioning. c.heartbeatCancel() // kubeServer closes the listener - err := c.KubeServer.Close() + errKubeServer := c.KubeServer.Close() + errKubeProxy := c.KubeProxy.Close() authCErr := c.AuthClient.Close() authSErr := c.AuthServer.Close() c.cancel() - return trace.NewAggregate(err, authCErr, authSErr) + return trace.NewAggregate(errKubeServer, errKubeProxy, authCErr, authSErr) } -// KubeServiceAddress returns the address of the kube service -func (c *TestContext) KubeServiceAddress() string { - return c.listener.Addr().String() +// KubeProxyAddress returns the address of the kube proxy. +func (c *TestContext) KubeProxyAddress() string { + return c.kubeProxyListener.Addr().String() } // RoleSpec defiens the role name and kube details to be created. @@ -423,7 +520,7 @@ func (c *TestContext) GenTestKubeClientTLSCert(t *testing.T, userName, kubeClust ServerName: "teleport.cluster.local", } restConfig := &rest.Config{ - Host: "https://" + c.KubeServiceAddress(), + Host: "https://" + c.KubeProxyAddress(), TLSClientConfig: tlsClientConfig, } @@ -437,7 +534,7 @@ func (c *TestContext) GenTestKubeClientTLSCert(t *testing.T, userName, kubeClust func (c *TestContext) NewJoiningSession(cfg *rest.Config, sessionID string, mode types.SessionParticipantMode) (*streamproto.SessionStream, error) { ws, err := newWebSocketClient(cfg, http.MethodPost, &url.URL{ Scheme: "wss", - Host: c.KubeServiceAddress(), + Host: c.KubeProxyAddress(), Path: "/api/v1/teleport/join/" + sessionID, }) if err != nil { @@ -486,3 +583,23 @@ func (f *fakeClient) CreateSessionTracker(ctx context.Context, st types.SessionT return f.ClientI.CreateSessionTracker(ctx, st) } } + +// fakeRemoteSite is a fake remote site that uses a map to map server IDs to +// addresses to simulate reverse tunneling. +type fakeRemoteSite struct { + *reversetunnel.FakeRemoteSite + idToAddr map[string]string +} + +func (f *fakeRemoteSite) DialTCP(p reversetunnel.DialParams) (conn net.Conn, err error) { + // The server ID is the first part of the address. + addr, ok := f.idToAddr[strings.Split(p.ServerID, ".")[0]] + if !ok { + return nil, trace.NotFound("server %q not found", p.ServerID) + } + conn, err = net.Dial("tcp", addr) + if err != nil { + panic(err) + } + return conn, nil +} From ebe00c641c37855b459b91233ca0a7aa7ed51f34 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Tue, 23 May 2023 10:21:03 +0100 Subject: [PATCH 2/2] move mocks to mocks package --- lib/cloud/mocks/aws.go | 69 +++++++++- lib/cloud/mocks/azure.go | 57 ++++++++ lib/cloud/mocks/gcp.go | 31 +++++ lib/kube/proxy/cluster_details.go | 11 +- lib/kube/proxy/kube_creds.go | 30 +++-- lib/kube/proxy/kube_creds_test.go | 210 ++++++++++++++---------------- 6 files changed, 283 insertions(+), 125 deletions(-) create mode 100644 lib/cloud/mocks/azure.go diff --git a/lib/cloud/mocks/aws.go b/lib/cloud/mocks/aws.go index 90cd162937e66..a91ed26b127ae 100644 --- a/lib/cloud/mocks/aws.go +++ b/lib/cloud/mocks/aws.go @@ -18,11 +18,15 @@ package mocks import ( "context" + "net/http" + "net/url" "sync" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/iam" @@ -44,6 +48,7 @@ import ( type STSMock struct { stsiface.STSAPI ARN string + URL *url.URL assumedRoleARNs []string assumedRoleExternalIDs []string mu sync.Mutex @@ -96,6 +101,21 @@ func (m *STSMock) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput }, nil } +func (m *STSMock) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) { + return &request.Request{ + HTTPRequest: &http.Request{ + Header: http.Header{}, + URL: m.URL, + }, + Operation: &request.Operation{ + Name: "GetCallerIdentity", + HTTPMethod: "POST", + HTTPPath: "/", + }, + Handlers: request.Handlers{}, + }, nil +} + // RDSMock mocks AWS RDS API. type RDSMock struct { rdsiface.RDSAPI @@ -208,6 +228,7 @@ func (m *RDSMock) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyD } return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) } + func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { if aws.StringValue(input.DBProxyName) == "" { return &rds.DescribeDBProxiesOutput{ @@ -223,6 +244,7 @@ func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.Descr } return nil, trace.NotFound("proxy %v not found", aws.StringValue(input.DBProxyName)) } + func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { inputProxyName := aws.StringValue(input.DBProxyName) inputProxyEndpointName := aws.StringValue(input.DBProxyEndpointName) @@ -252,6 +274,7 @@ func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rd } return &rds.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil } + func (m *RDSMock) DescribeDBProxyTargetsWithContext(ctx aws.Context, input *rds.DescribeDBProxyTargetsInput, options ...request.Option) (*rds.DescribeDBProxyTargetsOutput, error) { // only mocking to return a port here return &rds.DescribeDBProxyTargetsOutput{ @@ -260,18 +283,21 @@ func (m *RDSMock) DescribeDBProxyTargetsWithContext(ctx aws.Context, input *rds. }}, }, nil } + func (m *RDSMock) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { fn(&rds.DescribeDBProxiesOutput{ DBProxies: m.DBProxies, }, true) return nil } + func (m *RDSMock) DescribeDBProxyEndpointsPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, fn func(*rds.DescribeDBProxyEndpointsOutput, bool) bool, options ...request.Option) error { fn(&rds.DescribeDBProxyEndpointsOutput{ DBProxyEndpoints: m.DBProxyEndpoints, }, true) return nil } + func (m *RDSMock) ListTagsForResourceWithContext(ctx aws.Context, input *rds.ListTagsForResourceInput, options ...request.Option) (*rds.ListTagsForResourceOutput, error) { return &rds.ListTagsForResourceOutput{}, nil } @@ -379,6 +405,7 @@ func (m *RedshiftMock) GetClusterCredentialsWithContext(aws.Context, *redshift.G } return m.GetClusterCredentialsOutput, nil } + func (m *RedshiftMock) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) { if aws.StringValue(input.ClusterIdentifier) == "" { return &redshift.DescribeClustersOutput{ @@ -430,12 +457,15 @@ func (m *RDSMockUnauth) DescribeDBInstancesPagesWithContext(ctx aws.Context, inp func (m *RDSMockUnauth) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { return trace.AccessDenied("unauthorized") } + func (m *RDSMockUnauth) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { return nil, trace.AccessDenied("unauthorized") } + func (m *RDSMockUnauth) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { return nil, trace.AccessDenied("unauthorized") } + func (m *RDSMockUnauth) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { return trace.AccessDenied("unauthorized") } @@ -451,9 +481,11 @@ type RDSMockByDBType struct { func (m *RDSMockByDBType) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { return m.DBInstances.DescribeDBInstancesWithContext(ctx, input, options...) } + func (m *RDSMockByDBType) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { return m.DBInstances.ModifyDBInstanceWithContext(ctx, input, options...) } + func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { return m.DBInstances.DescribeDBInstancesPagesWithContext(ctx, input, fn, options...) } @@ -461,18 +493,23 @@ func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, i func (m *RDSMockByDBType) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { return m.DBClusters.DescribeDBClustersWithContext(ctx, input, options...) } + func (m *RDSMockByDBType) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { return m.DBClusters.ModifyDBClusterWithContext(ctx, input, options...) } + func (m *RDSMockByDBType) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { return m.DBClusters.DescribeDBClustersPagesWithContext(aws, input, fn, options...) } + func (m *RDSMockByDBType) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { return m.DBProxies.DescribeDBProxiesWithContext(ctx, input, options...) } + func (m *RDSMockByDBType) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { return m.DBProxies.DescribeDBProxyEndpointsWithContext(ctx, input, options...) } + func (m *RDSMockByDBType) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { return m.DBProxies.DescribeDBProxiesPagesWithContext(ctx, input, fn, options...) } @@ -535,6 +572,7 @@ func (m *ElastiCacheMock) AddMockUser(user *elasticache.User, tagsMap map[string m.Users = append(m.Users, user) m.addTags(aws.StringValue(user.ARN), tagsMap) } + func (m *ElastiCacheMock) addTags(arn string, tagsMap map[string]string) { if m.TagsByARN == nil { m.TagsByARN = make(map[string][]*elasticache.Tag) @@ -572,12 +610,14 @@ func (m *ElastiCacheMock) DescribeReplicationGroupsWithContext(_ aws.Context, in } return nil, trace.NotFound("ElastiCache %v not found", aws.StringValue(input.ReplicationGroupId)) } + func (m *ElastiCacheMock) DescribeReplicationGroupsPagesWithContext(_ aws.Context, _ *elasticache.DescribeReplicationGroupsInput, fn func(*elasticache.DescribeReplicationGroupsOutput, bool) bool, _ ...request.Option) error { fn(&elasticache.DescribeReplicationGroupsOutput{ ReplicationGroups: m.ReplicationGroups, }, true) return nil } + func (m *ElastiCacheMock) DescribeUsersPagesWithContext(_ aws.Context, _ *elasticache.DescribeUsersInput, fn func(*elasticache.DescribeUsersOutput, bool) bool, _ ...request.Option) error { fn(&elasticache.DescribeUsersOutput{ Users: m.Users, @@ -588,9 +628,11 @@ func (m *ElastiCacheMock) DescribeUsersPagesWithContext(_ aws.Context, _ *elasti func (m *ElastiCacheMock) DescribeCacheClustersPagesWithContext(aws.Context, *elasticache.DescribeCacheClustersInput, func(*elasticache.DescribeCacheClustersOutput, bool) bool, ...request.Option) error { return trace.AccessDenied("unauthorized") } + func (m *ElastiCacheMock) DescribeCacheSubnetGroupsPagesWithContext(aws.Context, *elasticache.DescribeCacheSubnetGroupsInput, func(*elasticache.DescribeCacheSubnetGroupsOutput, bool) bool, ...request.Option) error { return trace.AccessDenied("unauthorized") } + func (m *ElastiCacheMock) ListTagsForResourceWithContext(_ aws.Context, input *elasticache.ListTagsForResourceInput, _ ...request.Option) (*elasticache.TagListMessage, error) { if m.TagsByARN == nil { return nil, trace.NotFound("no tags") @@ -605,6 +647,7 @@ func (m *ElastiCacheMock) ListTagsForResourceWithContext(_ aws.Context, input *e TagList: tags, }, nil } + func (m *ElastiCacheMock) ModifyUserWithContext(_ aws.Context, input *elasticache.ModifyUserInput, opts ...request.Option) (*elasticache.ModifyUserOutput, error) { for _, user := range m.Users { if aws.StringValue(user.UserId) == aws.StringValue(input.UserId) { @@ -627,6 +670,7 @@ func (m *MemoryDBMock) AddMockUser(user *memorydb.User, tagsMap map[string]strin m.Users = append(m.Users, user) m.addTags(aws.StringValue(user.ARN), tagsMap) } + func (m *MemoryDBMock) addTags(arn string, tagsMap map[string]string) { if m.TagsByARN == nil { m.TagsByARN = make(map[string][]*memorydb.Tag) @@ -641,11 +685,12 @@ func (m *MemoryDBMock) addTags(arn string, tagsMap map[string]string) { } m.TagsByARN[arn] = tags } + func (m *MemoryDBMock) DescribeSubnetGroupsWithContext(aws.Context, *memorydb.DescribeSubnetGroupsInput, ...request.Option) (*memorydb.DescribeSubnetGroupsOutput, error) { return nil, trace.AccessDenied("unauthorized") } -func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) { +func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) { if aws.StringValue(input.ClusterName) == "" { return &memorydb.DescribeClustersOutput{ Clusters: m.Clusters, @@ -661,6 +706,7 @@ func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memoryd } return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.ClusterName)) } + func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTagsInput, _ ...request.Option) (*memorydb.ListTagsOutput, error) { if m.TagsByARN == nil { return nil, trace.NotFound("no tags") @@ -675,11 +721,13 @@ func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTa TagList: tags, }, nil } + func (m *MemoryDBMock) DescribeUsersWithContext(aws.Context, *memorydb.DescribeUsersInput, ...request.Option) (*memorydb.DescribeUsersOutput, error) { return &memorydb.DescribeUsersOutput{ Users: m.Users, }, nil } + func (m *MemoryDBMock) UpdateUserWithContext(_ aws.Context, input *memorydb.UpdateUserInput, opts ...request.Option) (*memorydb.UpdateUserOutput, error) { for _, user := range m.Users { if aws.StringValue(user.Name) == aws.StringValue(input.UserName) { @@ -778,3 +826,22 @@ func RedshiftGetClusterCredentialsOutput(user, password string, clock clockwork. Expiration: aws.Time(clock.Now().Add(15 * time.Minute)), } } + +// EKSMock is a mock EKS client. +type EKSMock struct { + eksiface.EKSAPI + Clusters []*eks.Cluster + Notify chan struct{} +} + +func (e *EKSMock) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) { + defer func() { + e.Notify <- struct{}{} + }() + for _, cluster := range e.Clusters { + if aws.StringValue(req.Name) == aws.StringValue(cluster.Name) { + return &eks.DescribeClusterOutput{Cluster: cluster}, nil + } + } + return nil, trace.NotFound("cluster %v not found", aws.StringValue(req.Name)) +} diff --git a/lib/cloud/mocks/azure.go b/lib/cloud/mocks/azure.go new file mode 100644 index 0000000000000..37385572dd439 --- /dev/null +++ b/lib/cloud/mocks/azure.go @@ -0,0 +1,57 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mocks + +import ( + "context" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "k8s.io/client-go/rest" + + "github.com/gravitational/teleport/lib/cloud/azure" +) + +// AKSClusterEntry is an entry in the AKSMock.Clusters list. +type AKSClusterEntry struct { + azure.ClusterCredentialsConfig + Config *rest.Config + TTL time.Duration +} + +// AKSMock implements the azure.AKSClient interface for tests. +type AKSMock struct { + azure.AKSClient + Clusters []AKSClusterEntry + Notify chan struct{} + Clock clockwork.Clock +} + +func (a *AKSMock) ClusterCredentials(ctx context.Context, cfg azure.ClusterCredentialsConfig) (*rest.Config, time.Time, error) { + defer func() { + a.Notify <- struct{}{} + }() + for _, cluster := range a.Clusters { + if cluster.ClusterCredentialsConfig.ResourceGroup == cfg.ResourceGroup && + cluster.ClusterCredentialsConfig.ResourceName == cfg.ResourceName && + cluster.ClusterCredentialsConfig.TenantID == cfg.TenantID { + return cluster.Config, a.Clock.Now().Add(cluster.TTL), nil + } + } + return nil, time.Now(), trace.NotFound("cluster not found") +} diff --git a/lib/cloud/mocks/gcp.go b/lib/cloud/mocks/gcp.go index 58aaedb55c6a9..3dc076aca4fa6 100644 --- a/lib/cloud/mocks/gcp.go +++ b/lib/cloud/mocks/gcp.go @@ -19,8 +19,12 @@ package mocks import ( "context" "crypto/tls" + "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" sqladmin "google.golang.org/api/sqladmin/v1beta4" + "k8s.io/client-go/rest" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/gcp" @@ -48,3 +52,30 @@ func (g *GCPSQLAdminClientMock) GetDatabaseInstance(ctx context.Context, db type func (g *GCPSQLAdminClientMock) GenerateEphemeralCert(ctx context.Context, db types.Database, identity tlsca.Identity) (*tls.Certificate, error) { return g.EphemeralCert, nil } + +// GKEClusterEntry is an entry in the GKEMock.Clusters list. +type GKEClusterEntry struct { + gcp.ClusterDetails + Config *rest.Config + TTL time.Duration +} + +// GKEMock implements the gcp.GKEClient interface for tests. +type GKEMock struct { + gcp.GKEClient + Clusters []GKEClusterEntry + Notify chan struct{} + Clock clockwork.Clock +} + +func (g *GKEMock) GetClusterRestConfig(ctx context.Context, cfg gcp.ClusterDetails) (*rest.Config, time.Time, error) { + defer func() { + g.Notify <- struct{}{} + }() + for _, cluster := range g.Clusters { + if cluster.ClusterDetails == cfg { + return cluster.Config, g.Clock.Now().Add(cluster.TTL), nil + } + } + return nil, time.Now(), trace.NotFound("cluster not found") +} diff --git a/lib/kube/proxy/cluster_details.go b/lib/kube/proxy/cluster_details.go index 1a41313dc83aa..988ec5969c514 100644 --- a/lib/kube/proxy/cluster_details.go +++ b/lib/kube/proxy/cluster_details.go @@ -24,6 +24,7 @@ import ( "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" @@ -131,13 +132,13 @@ func azureRestConfigClient(cloudClients cloud.Clients) dynamicCredsClient { // getAWSCredentials creates a dynamicKubeCreds that generates and updates the access credentials to a EKS kubernetes cluster. func getAWSCredentials(ctx context.Context, cloudClients cloud.Clients, cfg dynamicCredsConfig) (*dynamicKubeCreds, error) { // create a client that returns the credentials for kubeCluster - cfg.client = getAWSClientRestConfig(cloudClients) + cfg.client = getAWSClientRestConfig(cloudClients, cfg.clock) creds, err := newDynamicKubeCreds(ctx, cfg) return creds, trace.Wrap(err) } // getAWSClientRestConfig creates a dynamicCredsClient that generates returns credentials to EKS clusters. -func getAWSClientRestConfig(cloudClients cloud.Clients) dynamicCredsClient { +func getAWSClientRestConfig(cloudClients cloud.Clients, clock clockwork.Clock) dynamicCredsClient { return func(ctx context.Context, cluster types.KubeCluster) (*rest.Config, time.Time, error) { // TODO(gavin): support assume_role_arn for AWS EKS. region := cluster.GetAWSConfig().Region @@ -168,7 +169,7 @@ func getAWSClientRestConfig(cloudClients cloud.Clients) dynamicCredsClient { return nil, time.Time{}, trace.Wrap(err) } - token, exp, err := genAWSToken(stsClient, cluster.GetAWSConfig().Name) + token, exp, err := genAWSToken(stsClient, cluster.GetAWSConfig().Name, clock) if err != nil { return nil, time.Time{}, trace.Wrap(err) } @@ -185,7 +186,7 @@ func getAWSClientRestConfig(cloudClients cloud.Clients) dynamicCredsClient { // genAWSToken creates an AWS token to access EKS clusters. // Logic from https://github.com/aws/aws-cli/blob/6c0d168f0b44136fc6175c57c090d4b115437ad1/awscli/customizations/eks/get_token.py#L211-L229 -func genAWSToken(stsClient stsiface.STSAPI, clusterID string) (string, time.Time, error) { +func genAWSToken(stsClient stsiface.STSAPI, clusterID string, clock clockwork.Clock) (string, time.Time, error) { const ( // The sts GetCallerIdentity request is valid for 15 minutes regardless of this parameters value after it has been // signed. @@ -212,7 +213,7 @@ func genAWSToken(stsClient stsiface.STSAPI, clusterID string) (string, time.Time } // Set token expiration to 1 minute before the presigned URL expires for some cushion - tokenExpiration := time.Now().Local().Add(presignedURLExpiration - 1*time.Minute) + tokenExpiration := clock.Now().Add(presignedURLExpiration - 1*time.Minute) return v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration, nil } diff --git a/lib/kube/proxy/kube_creds.go b/lib/kube/proxy/kube_creds.go index dd2bd54ba8be1..00503305c7c34 100644 --- a/lib/kube/proxy/kube_creds.go +++ b/lib/kube/proxy/kube_creds.go @@ -26,7 +26,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" - log "github.com/sirupsen/logrus" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/transport" @@ -159,16 +158,19 @@ type dynamicKubeCreds struct { closeC chan struct{} client dynamicCredsClient checker servicecfg.ImpersonationPermissionsChecker + clock clockwork.Clock sync.RWMutex + wg sync.WaitGroup } // dynamicCredsConfig contains configuration for dynamicKubeCreds. type dynamicCredsConfig struct { - kubeCluster types.KubeCluster - log logrus.FieldLogger - client dynamicCredsClient - checker servicecfg.ImpersonationPermissionsChecker - clock clockwork.Clock + kubeCluster types.KubeCluster + log logrus.FieldLogger + client dynamicCredsClient + checker servicecfg.ImpersonationPermissionsChecker + clock clockwork.Clock + initialRenewInterval time.Duration } func (d *dynamicCredsConfig) checkAndSetDefaults() error { @@ -187,6 +189,9 @@ func (d *dynamicCredsConfig) checkAndSetDefaults() error { if d.clock == nil { d.clock = clockwork.NewRealClock() } + if d.initialRenewInterval == 0 { + d.initialRenewInterval = time.Hour + } return nil } @@ -201,22 +206,24 @@ func newDynamicKubeCreds(ctx context.Context, cfg dynamicCredsConfig) (*dynamicK log: cfg.log, closeC: make(chan struct{}), client: cfg.client, - renewTicker: cfg.clock.NewTicker(time.Hour), + renewTicker: cfg.clock.NewTicker(cfg.initialRenewInterval), checker: cfg.checker, + clock: cfg.clock, } if err := dyn.renewClientset(cfg.kubeCluster); err != nil { return nil, trace.Wrap(err) } - + dyn.wg.Add(1) go func() { + defer dyn.wg.Done() for { select { case <-dyn.closeC: return case <-dyn.renewTicker.Chan(): if err := dyn.renewClientset(cfg.kubeCluster); err != nil { - log.WithError(err).Warnf("Unable to renew cluster %q credentials.", cfg.kubeCluster.GetName()) + logrus.WithError(err).Warnf("Unable to renew cluster %q credentials.", cfg.kubeCluster.GetName()) } } } @@ -263,6 +270,8 @@ func (d *dynamicKubeCreds) wrapTransport(rt http.RoundTripper) (http.RoundTrippe func (d *dynamicKubeCreds) close() error { close(d.closeC) + d.wg.Wait() + d.renewTicker.Stop() return nil } @@ -289,7 +298,8 @@ func (d *dynamicKubeCreds) renewClientset(cluster types.KubeCluster) error { d.staticCreds = creds // prepares the next renew cycle if !exp.IsZero() { - d.renewTicker.Reset(time.Until(exp) / 2) + reset := exp.Sub(d.clock.Now()) / 2 + d.renewTicker.Reset(reset) } return nil } diff --git a/lib/kube/proxy/kube_creds_test.go b/lib/kube/proxy/kube_creds_test.go index f39aa931cb20a..90fc0c014ef6d 100644 --- a/lib/kube/proxy/kube_creds_test.go +++ b/lib/kube/proxy/kube_creds_test.go @@ -1,5 +1,5 @@ /* -Copyright 2022 Gravitational, Inc. +Copyright 2023 Gravitational, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,19 +19,13 @@ package proxy import ( "context" "encoding/base64" - "io" - "net/http" "net/url" "strings" "testing" "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/eks/eksiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" @@ -43,13 +37,24 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" + "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/fixtures" ) +// Test_DynamicKubeCreds tests the dynamic kube credrentials generator for +// AWS, GCP, and Azure clusters accessed using their respective IAM credentials. +// This test mocks the cloud provider clients and the STS client to generate +// rest.Config objects for each cluster. It also tests the renewal of the +// credentials when they expire. func Test_DynamicKubeCreds(t *testing.T) { t.Parallel() - log := logrus.New() - log.SetOutput(io.Discard) + var ( + fakeClock = clockwork.NewFakeClock() + log = logrus.New() + notify = make(chan struct{}, 1) + ttl = 14 * time.Minute + ) + awsKube, err := types.NewKubernetesClusterV3( types.Metadata{ Name: "aws", @@ -71,7 +76,7 @@ func Test_DynamicKubeCreds(t *testing.T) { GCP: types.KubeGCP{ Location: "us-west-2", ProjectID: "1234567890", - Name: "eks", + Name: "gke", }, }, ) @@ -84,7 +89,7 @@ func Test_DynamicKubeCreds(t *testing.T) { Azure: types.KubeAzure{ TenantID: "id", ResourceGroup: "1234567890", - ResourceName: "eks", + ResourceName: "aks-name", SubscriptionID: "12345", }, }, @@ -92,28 +97,70 @@ func Test_DynamicKubeCreds(t *testing.T) { require.NoError(t, err) // mock sts client - stsMock := &stsMockClient{ - // u is used to presign the request - // here we just verify the pre-signed request includes this url. - u: &url.URL{ - Scheme: "https", - Host: "sts.amazonaws.com", - Path: "/?Action=GetCallerIdentity&Version=2011-06-15", - }, + u := &url.URL{ + Scheme: "https", + Host: "sts.amazonaws.com", + Path: "/?Action=GetCallerIdentity&Version=2011-06-15", } - // mock clients cloudclients := &cloud.TestCloudClients{ - STS: stsMock, - EKS: &eksMockClient{ - cluster: awsKube, - t: t, + STS: &mocks.STSMock{ + // u is used to presign the request + // here we just verify the pre-signed request includes this url. + URL: u, + }, + EKS: &mocks.EKSMock{ + Notify: notify, + Clusters: []*eks.Cluster{ + { + Endpoint: aws.String("https://api.eks.us-west-2.amazonaws.com"), + Name: aws.String(awsKube.GetAWSConfig().Name), + CertificateAuthority: &eks.Certificate{ + Data: aws.String(base64.RawStdEncoding.EncodeToString([]byte(fixtures.TLSCACertPEM))), + }, + }, + }, + }, + GCPGKE: &mocks.GKEMock{ + Notify: notify, + Clock: fakeClock, + Clusters: []mocks.GKEClusterEntry{ + { + Config: &rest.Config{ + Host: "https://api.gke.google.com", + TLSClientConfig: rest.TLSClientConfig{ + CAData: []byte(fixtures.TLSCACertPEM), + }, + }, + ClusterDetails: gcp.ClusterDetails{ + Name: gkeKube.GetGCPConfig().Name, + ProjectID: gkeKube.GetGCPConfig().ProjectID, + Location: gkeKube.GetGCPConfig().Location, + }, + TTL: ttl, + }, + }, }, - GCPGKE: &gkeMockCLient{kube: gkeKube, t: t}, AzureAKSClientPerSub: map[string]azure.AKSClient{ - "12345": &azureMockCLient{ - kube: aksKube, - t: t, + "12345": &mocks.AKSMock{ + Notify: notify, + Clock: fakeClock, + Clusters: []mocks.AKSClusterEntry{ + { + Config: &rest.Config{ + Host: "https://api.aks.microsoft.com", + TLSClientConfig: rest.TLSClientConfig{ + CAData: []byte(fixtures.TLSCACertPEM), + }, + }, + TTL: ttl, + ClusterCredentialsConfig: azure.ClusterCredentialsConfig{ + ResourceName: aksKube.GetAzureConfig().ResourceName, + ResourceGroup: aksKube.GetAzureConfig().ResourceGroup, + TenantID: aksKube.GetAzureConfig().TenantID, + }, + }, + }, }, }, } @@ -132,7 +179,7 @@ func Test_DynamicKubeCreds(t *testing.T) { name: "aws eks cluster", args: args{ cluster: awsKube, - client: getAWSClientRestConfig(cloudclients), + client: getAWSClientRestConfig(cloudclients, fakeClock), validateBearerToken: func(token string) error { if token == "" { return trace.BadParameter("missing bearer token") @@ -148,7 +195,7 @@ func Test_DynamicKubeCreds(t *testing.T) { if err != nil { return trace.Wrap(err) } - if string(dec) != stsMock.u.String() { + if string(dec) != u.String() { return trace.BadParameter("invalid token payload") } return nil @@ -177,7 +224,6 @@ func Test_DynamicKubeCreds(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeClock := clockwork.NewFakeClock() got, err := newDynamicKubeCreds( context.Background(), dynamicCredsConfig{ @@ -187,88 +233,34 @@ func Test_DynamicKubeCreds(t *testing.T) { ) error { return nil }, - log: log, - kubeCluster: tt.args.cluster, - client: tt.args.client, + log: log, + kubeCluster: tt.args.cluster, + client: tt.args.client, + initialRenewInterval: ttl / 2, }, ) require.NoError(t, err) - require.Equal(t, got.getKubeRestConfig().CAData, []byte(fixtures.TLSCACertPEM)) - require.NoError(t, tt.args.validateBearerToken(got.getKubeRestConfig().BearerToken)) - require.Equal(t, got.getTargetAddr(), tt.wantAddr) + select { + case <-notify: + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for cluster to be ready") + } + for i := 0; i < 10; i++ { + require.Equal(t, got.getKubeRestConfig().CAData, []byte(fixtures.TLSCACertPEM)) + require.NoError(t, tt.args.validateBearerToken(got.getKubeRestConfig().BearerToken)) + require.Equal(t, got.getTargetAddr(), tt.wantAddr) + fakeClock.BlockUntil(1) + fakeClock.Advance(ttl / 2) + // notify receives a signal when the cloud client is invoked. + // this is used to test that the credentials are refreshed each time + // they are about to expire. + select { + case <-notify: + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for cluster to be ready, i=%d", i) + } + } require.NoError(t, got.close()) }) } } - -type eksMockClient struct { - eksiface.EKSAPI - cluster types.KubeCluster - t *testing.T -} - -func (e *eksMockClient) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) { - require.Equal(e.t, e.cluster.GetAWSConfig().Name, *req.Name) - return &eks.DescribeClusterOutput{ - Cluster: &eks.Cluster{ - Endpoint: aws.String("https://api.eks.us-west-2.amazonaws.com"), - Name: req.Name, - CertificateAuthority: &eks.Certificate{ - Data: aws.String(base64.RawStdEncoding.EncodeToString([]byte(fixtures.TLSCACertPEM))), - }, - }, - }, nil -} - -type stsMockClient struct { - stsiface.STSAPI - u *url.URL -} - -func (s *stsMockClient) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) { - return &request.Request{ - HTTPRequest: &http.Request{ - Header: http.Header{}, - URL: s.u, - }, - Operation: &request.Operation{}, - Handlers: request.Handlers{}, - }, nil -} - -type gkeMockCLient struct { - gcp.GKEClient - kube types.KubeCluster - t *testing.T -} - -func (g *gkeMockCLient) GetClusterRestConfig(ctx context.Context, cfg gcp.ClusterDetails) (*rest.Config, time.Time, error) { - require.Equal(g.t, g.kube.GetGCPConfig().Name, cfg.Name) - require.Equal(g.t, g.kube.GetGCPConfig().ProjectID, cfg.ProjectID) - require.Equal(g.t, g.kube.GetGCPConfig().Location, cfg.Location) - return &rest.Config{ - Host: "https://api.gke.google.com", - TLSClientConfig: rest.TLSClientConfig{ - CAData: []byte(fixtures.TLSCACertPEM), - }, - }, time.Now(), nil -} - -type azureMockCLient struct { - azure.AKSClient - kube types.KubeCluster - t *testing.T -} - -func (a *azureMockCLient) ClusterCredentials(ctx context.Context, cfg azure.ClusterCredentialsConfig) (*rest.Config, time.Time, error) { - require.Equal(a.t, a.kube.GetAzureConfig().ResourceName, cfg.ResourceName) - require.Equal(a.t, a.kube.GetAzureConfig().ResourceGroup, cfg.ResourceGroup) - require.Equal(a.t, a.kube.GetAzureConfig().TenantID, cfg.TenantID) - - return &rest.Config{ - Host: "https://api.aks.microsoft.com", - TLSClientConfig: rest.TLSClientConfig{ - CAData: []byte(fixtures.TLSCACertPEM), - }, - }, time.Now(), nil -}