diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index a1e3d16f0c27c..147b6de599ca4 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -40,8 +40,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + "golang.org/x/exp/maps" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" @@ -55,20 +58,25 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/kube/kubeconfig" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/alpnproxy" + "github.com/gravitational/teleport/lib/srv/alpnproxy/common" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/postgres" + "github.com/gravitational/teleport/lib/teleterm/gateway" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) type Suite struct { - root *helpers.TeleInstance - leaf *helpers.TeleInstance + root *helpers.TeleInstance + leaf *helpers.TeleInstance + username string } type suiteOptions struct { @@ -125,12 +133,14 @@ func newSuite(t *testing.T, opts ...proxySuiteOptionsFunc) *Suite { } lCfg.Listeners = options.leafClusterListeners(t, &lCfg.Fds) lc := helpers.NewInstance(t, lCfg) + user := helpers.MustGetCurrentUser(t) + suite := &Suite{ - root: rc, - leaf: lc, + root: rc, + leaf: lc, + username: user.Username, } - user := helpers.MustGetCurrentUser(t) for _, role := range options.rootClusterRoles { rc.AddUserWithRole(user.Username, role) } @@ -711,3 +721,67 @@ func mustFindKubePod(t *testing.T, tc *client.TeleportClient) { require.Equal(t, types.KindKubePod, response.Resources[0].Kind) require.Equal(t, kubePodName, response.Resources[0].GetName()) } + +func mustConnectDatabaseGateway(t *testing.T, gw gateway.Gateway) { + t.Helper() + + dbGateway, err := gateway.AsDatabase(gw) + require.NoError(t, err) + + // Open a new connection. + client, err := mysql.MakeTestClientWithoutTLS( + net.JoinHostPort(gw.LocalAddress(), gw.LocalPort()), + dbGateway.RouteToDatabase()) + require.NoError(t, err) + + // Execute a query. + result, err := client.Execute("select 1") + require.NoError(t, err) + require.Equal(t, mysql.TestQueryResponse, result) + + // Disconnect. + require.NoError(t, client.Close()) +} + +func kubeClientForLocalProxy(t *testing.T, kubeconfigPath, teleportCluster, kubeCluster string) *kubernetes.Clientset { + t.Helper() + + config, err := kubeconfig.Load(kubeconfigPath) + require.NoError(t, err) + + contextName := kubeconfig.ContextName(teleportCluster, kubeCluster) + require.Contains(t, maps.Keys(config.Clusters), contextName) + proxyURL, err := url.Parse(config.Clusters[contextName].ProxyURL) + require.NoError(t, err) + + tlsClientConfig := rest.TLSClientConfig{ + CAData: config.Clusters[contextName].CertificateAuthorityData, + CertData: config.AuthInfos[contextName].ClientCertificateData, + KeyData: config.AuthInfos[contextName].ClientKeyData, + ServerName: common.KubeLocalProxySNI(teleportCluster, kubeCluster), + } + client, err := kubernetes.NewForConfig(&rest.Config{ + Host: "https://" + teleportCluster, + TLSClientConfig: tlsClientConfig, + Proxy: http.ProxyURL(proxyURL), + }) + require.NoError(t, err) + return client +} + +func mustGetKubePod(t *testing.T, client *kubernetes.Clientset, wantPodName string) { + t.Helper() + + resp, err := client.CoreV1().Pods("default").List(context.Background(), metav1.ListOptions{}) + require.NoError(t, err) + require.Equal(t, len(resp.Items), 1) + require.Equal(t, wantPodName, resp.Items[0].GetName()) +} + +func mustGetProfileName(t *testing.T, webProxyAddr string) string { + t.Helper() + + profileName, _, err := net.SplitHostPort(webProxyAddr) + require.NoError(t, err) + return profileName +} diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 75df089885a94..8f6bf12e47e2f 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -434,9 +434,7 @@ func TestALPNSNIProxyKube(t *testing.T) { }, }) require.NoError(t, err) - resp, err := k8Client.CoreV1().Pods("default").List(context.Background(), metav1.ListOptions{}) - require.NoError(t, err) - require.Equal(t, 1, len(resp.Items), "pods item length mismatch") + mustGetKubePod(t, k8Client, kubePodName) }) } @@ -507,9 +505,7 @@ func TestALPNSNIProxyKubeV2Leaf(t *testing.T) { }) require.NoError(t, err) - resp, err := k8Client.CoreV1().Pods("default").List(context.Background(), metav1.ListOptions{}) - require.NoError(t, err) - require.Equal(t, 1, len(resp.Items), "pods item length mismatch") + mustGetKubePod(t, k8Client, kubePodName) } func TestKubeIPPinning(t *testing.T) { diff --git a/integration/proxy/teleterm_test.go b/integration/proxy/teleterm_test.go index 92d1d32b6625a..ac71dc1e4f633 100644 --- a/integration/proxy/teleterm_test.go +++ b/integration/proxy/teleterm_test.go @@ -17,6 +17,7 @@ package proxy import ( "context" "net" + "sync" "testing" "time" @@ -26,69 +27,93 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "k8s.io/client-go/kubernetes" + "github.com/gravitational/teleport/api/types" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" dbhelpers "github.com/gravitational/teleport/integration/db" "github.com/gravitational/teleport/integration/helpers" + "github.com/gravitational/teleport/integration/kube" + "github.com/gravitational/teleport/lib" libclient "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/srv/db/mysql" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" - "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/teleterm/gateway" + "github.com/gravitational/teleport/lib/utils" ) // testTeletermGatewaysCertRenewal is run from within TestALPNSNIProxyDatabaseAccess to amortize the // cost of setting up clusters in tests. func testTeletermGatewaysCertRenewal(t *testing.T, pack *dbhelpers.DatabasePack) { - rootClusterName, _, err := net.SplitHostPort(pack.Root.Cluster.Web) - require.NoError(t, err) - - creds, err := helpers.GenerateUserCreds(helpers.UserCredsRequest{ - Process: pack.Root.Cluster.Process, - Username: pack.Root.User.GetName(), - }) - require.NoError(t, err) - t.Run("root cluster", func(t *testing.T) { - databaseURI := uri.NewClusterURI(rootClusterName). + profileName := mustGetProfileName(t, pack.Root.Cluster.Web) + databaseURI := uri.NewClusterURI(profileName). AppendDB(pack.Root.MysqlService.Name) - testGatewayCertRenewal(t, pack, "", creds, databaseURI) + testDBGatewayCertRenewal(t, pack, "", databaseURI) }) t.Run("leaf cluster", func(t *testing.T) { + profileName := mustGetProfileName(t, pack.Root.Cluster.Web) leafClusterName := pack.Leaf.Cluster.Secrets.SiteName - databaseURI := uri.NewClusterURI(rootClusterName). + databaseURI := uri.NewClusterURI(profileName). AppendLeafCluster(leafClusterName). AppendDB(pack.Leaf.MysqlService.Name) - testGatewayCertRenewal(t, pack, "", creds, databaseURI) + testDBGatewayCertRenewal(t, pack, "", databaseURI) }) t.Run("ALPN connection upgrade", func(t *testing.T) { // Make a mock ALB which points to the Teleport Proxy Service. Then // ALPN local proxies will point to this ALB instead. albProxy := helpers.MustStartMockALBProxy(t, pack.Root.Cluster.Web) - databaseURI := uri.NewClusterURI(rootClusterName). + // Note that profile name is taken from tc.WebProxyAddr. Use + // albProxy.Addr() as profile name in case it's different from + // pack.Root.Cluster.Web (e.g. 127.0.0.1 vs localhost). + profileName := mustGetProfileName(t, albProxy.Addr().String()) + databaseURI := uri.NewClusterURI(profileName). AppendDB(pack.Root.MysqlService.Name) - testGatewayCertRenewal(t, pack, albProxy.Addr().String(), creds, databaseURI) + testDBGatewayCertRenewal(t, pack, albProxy.Addr().String(), databaseURI) }) } -func testGatewayCertRenewal(t *testing.T, pack *dbhelpers.DatabasePack, albAddr string, creds *helpers.UserCreds, databaseURI uri.ResourceURI) { - tc, err := pack.Root.Cluster.NewClientWithCreds(helpers.ClientConfig{ - Login: pack.Root.User.GetName(), - Cluster: pack.Root.Cluster.Secrets.SiteName, +func testDBGatewayCertRenewal(t *testing.T, pack *dbhelpers.DatabasePack, albAddr string, databaseURI uri.ResourceURI) { + t.Helper() + + testGatewayCertRenewal( + t, + pack.Root.Cluster, + pack.Root.User.GetName(), + albAddr, + daemon.CreateGatewayParams{ + TargetURI: databaseURI.String(), + TargetUser: pack.Root.User.GetName(), + }, + mustConnectDatabaseGateway, + ) +} + +type testGatewayConnectionFunc func(*testing.T, gateway.Gateway) + +func testGatewayCertRenewal(t *testing.T, inst *helpers.TeleInstance, username, albAddr string, params daemon.CreateGatewayParams, testConnection testGatewayConnectionFunc) { + t.Helper() + + tc, err := inst.NewClient(helpers.ClientConfig{ + Login: username, + Cluster: inst.Secrets.SiteName, ALBAddr: albAddr, - }, *creds) + }) require.NoError(t, err) + // Save the profile yaml file to disk as NewClientWithCreds doesn't do that by itself. err = tc.SaveProfile(false /* makeCurrent */) require.NoError(t, err) fakeClock := clockwork.NewFakeClockAt(time.Now()) - storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, InsecureSkipVerify: tc.InsecureSkipVerify, @@ -111,67 +136,35 @@ func testGatewayCertRenewal(t *testing.T, pack *dbhelpers.DatabasePack, albAddr // Create a mock tshd events service server and have the daemon connect to it, // like it would during normal initialization of the app. - - tshdEventsService, tshEventsServerAddr := newMockTSHDEventsServiceServer(t, tc, pack) + tshdEventsService, tshEventsServerAddr := newMockTSHDEventsServiceServer(t, tc, inst, username) err = daemonService.UpdateAndDialTshdEventsServerAddress(tshEventsServerAddr) require.NoError(t, err) // Here the test setup ends and actual test code starts. - - gateway, err := daemonService.CreateGateway(context.Background(), daemon.CreateGatewayParams{ - TargetURI: databaseURI.String(), - TargetUser: "root", - }) + gateway, err := daemonService.CreateGateway(context.Background(), params) require.NoError(t, err, trace.DebugReport(err)) - // Open a new connection. - route := tlsca.RouteToDatabase{ - ServiceName: pack.Root.MysqlService.Name, - Protocol: pack.Root.MysqlService.Protocol, - Username: "root", - } - client, err := mysql.MakeTestClientWithoutTLS( - net.JoinHostPort(gateway.LocalAddress(), gateway.LocalPort()), - route) - require.NoError(t, err) - - // Execute a query. - result, err := client.Execute("select 1") - require.NoError(t, err) - require.Equal(t, mysql.TestQueryResponse, result) - - // Disconnect. - require.NoError(t, client.Close()) + testConnection(t, gateway) // Advance the fake clock to simulate the db cert expiry inside the middleware. fakeClock.Advance(time.Hour * 48) + // Overwrite user certs with expired ones to simulate the user cert expiry. expiredCreds, err := helpers.GenerateUserCreds(helpers.UserCredsRequest{ - Process: pack.Root.Cluster.Process, - Username: pack.Root.User.GetName(), + Process: inst.Process, + Username: username, TTL: -time.Hour, }) require.NoError(t, err) - err = helpers.SetupUserCreds(tc, pack.Root.Cluster.Config.Proxy.SSHAddr.Addr, *expiredCreds) + err = helpers.SetupUserCreds(tc, inst.Config.Proxy.SSHAddr.Addr, *expiredCreds) require.NoError(t, err) // Open a new connection. - // This should trigger the relogin flow. The middleware will notice that the db cert has expired - // and then it will attempt to reissue the db cert using an expired user cert. + // This should trigger the relogin flow. The middleware will notice that the cert has expired + // and then it will attempt to reissue the user cert using an expired user cert. // The mocked tshdEventsClient will issue a valid user cert, save it to disk, and the middleware // will let the connection through. - client, err = mysql.MakeTestClientWithoutTLS( - net.JoinHostPort(gateway.LocalAddress(), gateway.LocalPort()), - route) - require.NoError(t, err) - - // Execute a query. - result, err = client.Execute("select 1") - require.NoError(t, err) - require.Equal(t, mysql.TestQueryResponse, result) - - // Disconnect. - require.NoError(t, client.Close()) + testConnection(t, gateway) require.Equal(t, 1, tshdEventsService.callCounts["Relogin"], "Unexpected number of calls to TSHDEventsClient.Relogin") @@ -183,16 +176,18 @@ type mockTSHDEventsService struct { *api.UnimplementedTshdEventsServiceServer tc *libclient.TeleportClient - pack *dbhelpers.DatabasePack + inst *helpers.TeleInstance + username string callCounts map[string]int } -func newMockTSHDEventsServiceServer(t *testing.T, tc *libclient.TeleportClient, pack *dbhelpers.DatabasePack) (service *mockTSHDEventsService, addr string) { +func newMockTSHDEventsServiceServer(t *testing.T, tc *libclient.TeleportClient, inst *helpers.TeleInstance, username string) (service *mockTSHDEventsService, addr string) { t.Helper() tshdEventsService := &mockTSHDEventsService{ tc: tc, - pack: pack, + inst: inst, + username: username, callCounts: make(map[string]int), } @@ -216,13 +211,13 @@ func newMockTSHDEventsServiceServer(t *testing.T, tc *libclient.TeleportClient, func (c *mockTSHDEventsService) Relogin(context.Context, *api.ReloginRequest) (*api.ReloginResponse, error) { c.callCounts["Relogin"]++ creds, err := helpers.GenerateUserCreds(helpers.UserCredsRequest{ - Process: c.pack.Root.Cluster.Process, - Username: c.pack.Root.User.GetName(), + Process: c.inst.Process, + Username: c.username, }) if err != nil { return nil, trace.Wrap(err) } - err = helpers.SetupUserCreds(c.tc, c.pack.Root.Cluster.Config.Proxy.SSHAddr.Addr, *creds) + err = helpers.SetupUserCreds(c.tc, c.inst.Config.Proxy.SSHAddr.Addr, *creds) if err != nil { return nil, trace.Wrap(err) } @@ -234,3 +229,124 @@ func (c *mockTSHDEventsService) SendNotification(context.Context, *api.SendNotif c.callCounts["SendNotification"]++ return &api.SendNotificationResponse{}, nil } + +// TestTeletermKubeGateway tests making kube API calls against Teleterm kube +// gateway and reissuing certs. +// +// Note that this test does NOT reuse existing kube test setups as IP Pinning +// is enabled in those tests. User certs with pinned IPs are injected during +// those tests, which is not feasible for Teleterm daemon flow. +func TestTeletermKubeGateway(t *testing.T) { + lib.SetInsecureDevMode(true) + defer lib.SetInsecureDevMode(false) + + const ( + localK8SNI = "kube.teleport.cluster.local" + k8User = "alice@example.com" + k8RoleName = "kubemaster" + ) + + kubeAPIMockSvr := startKubeAPIMock(t) + kubeConfigPath := mustCreateKubeConfigFile(t, k8ClientConfig(kubeAPIMockSvr.URL, localK8SNI)) + + username := helpers.MustGetCurrentUser(t).Username + kubeRoleSpec := types.RoleSpecV6{ + Allow: types.RoleConditions{ + Logins: []string{username}, + KubernetesLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + KubeGroups: []string{kube.TestImpersonationGroup}, + KubeUsers: []string{k8User}, + KubernetesResources: []types.KubernetesResource{ + { + Kind: types.KindKubePod, Name: types.Wildcard, Namespace: types.Wildcard, Verbs: []string{types.Wildcard}, + }, + }, + }, + } + kubeRole, err := types.NewRole(k8RoleName, kubeRoleSpec) + require.NoError(t, err) + suite := newSuite(t, + withRootClusterConfig(rootClusterStandardConfig(t), func(config *servicecfg.Config) { + config.Version = defaults.TeleportConfigVersionV2 + config.Proxy.Kube.Enabled = true + config.Kube.Enabled = true + config.Kube.KubeconfigPath = kubeConfigPath + config.Kube.ListenAddr = utils.MustParseAddr( + helpers.NewListener(t, service.ListenerKube, &config.FileDescriptors)) + }), + withLeafClusterConfig(leafClusterStandardConfig(t), func(config *servicecfg.Config) { + config.Version = defaults.TeleportConfigVersionV2 + config.Proxy.Kube.Enabled = true + config.Kube.Enabled = true + config.Kube.KubeconfigPath = kubeConfigPath + config.Kube.ListenAddr = utils.MustParseAddr( + helpers.NewListener(t, service.ListenerKube, &config.FileDescriptors)) + }), + withRootClusterRoles(kubeRole), + withLeafClusterRoles(kubeRole), + withRootAndLeafTrustedClusterReset(), + withTrustedCluster(), + ) + + t.Run("root", func(t *testing.T) { + profileName := mustGetProfileName(t, suite.root.Web) + kubeURI := uri.NewClusterURI(profileName).AppendKube(kubeClusterName) + testKubeGatewayCertRenewal(t, suite, "", kubeURI) + }) + t.Run("leaf", func(t *testing.T) { + profileName := mustGetProfileName(t, suite.root.Web) + kubeURI := uri.NewClusterURI(profileName).AppendLeafCluster(suite.leaf.Secrets.SiteName).AppendKube(kubeClusterName) + testKubeGatewayCertRenewal(t, suite, "", kubeURI) + }) + t.Run("ALPN connection upgrade", func(t *testing.T) { + // Make a mock ALB which points to the Teleport Proxy Service. Then + // ALPN local proxies will point to this ALB instead. + albProxy := helpers.MustStartMockALBProxy(t, suite.root.Web) + + // Note that profile name is taken from tc.WebProxyAddr. Use + // albProxy.Addr() as profile name in case it's different from + // suite.root.Web (e.g. 127.0.0.1 vs localhost). + profileName := mustGetProfileName(t, albProxy.Addr().String()) + + kubeURI := uri.NewClusterURI(profileName).AppendKube(kubeClusterName) + testKubeGatewayCertRenewal(t, suite, albProxy.Addr().String(), kubeURI) + }) +} + +func testKubeGatewayCertRenewal(t *testing.T, suite *Suite, albAddr string, kubeURI uri.ResourceURI) { + t.Helper() + + var client *kubernetes.Clientset + var clientOnce sync.Once + + kubeCluster := kubeURI.GetKubeName() + teleportCluster := suite.root.Secrets.SiteName + if kubeURI.GetLeafClusterName() != "" { + teleportCluster = kubeURI.GetLeafClusterName() + } + + testKubeConnection := func(t *testing.T, gw gateway.Gateway) { + t.Helper() + + clientOnce.Do(func() { + kubeGateway, err := gateway.AsKube(gw) + require.NoError(t, err) + + client = kubeClientForLocalProxy(t, kubeGateway.KubeconfigPath(), teleportCluster, kubeCluster) + }) + + mustGetKubePod(t, client, kubePodName) + } + + testGatewayCertRenewal( + t, + suite.root, + suite.username, + albAddr, + daemon.CreateGatewayParams{ + TargetURI: kubeURI.String(), + }, + testKubeConnection, + ) + +} diff --git a/lib/client/profile.go b/lib/client/profile.go index ea8460137a153..991daa9a06126 100644 --- a/lib/client/profile.go +++ b/lib/client/profile.go @@ -511,6 +511,20 @@ func (p *ProfileStatus) KubeConfigPath(name string) string { return keypaths.KubeConfigPath(p.Dir, p.Name, p.Username, p.Cluster, name) } +// KubeCertPathForCluster returns path to the specified kube access certificate +// for this profile, for the specified cluster name. +// +// It's kept in /keys//-kube//-x509.pem +func (p *ProfileStatus) KubeCertPathForCluster(teleportCluster, kubeCluster string) string { + if teleportCluster == "" { + teleportCluster = p.Cluster + } + if path, ok := p.virtualPathFromEnv(VirtualPathKubernetes, VirtualPathKubernetesParams(kubeCluster)); ok { + return path + } + return keypaths.KubeCertPath(p.Dir, p.Name, p.Username, teleportCluster, kubeCluster) +} + // DatabaseServices returns a list of database service names for this profile. func (p *ProfileStatus) DatabaseServices() (result []string) { for _, db := range p.Databases { diff --git a/lib/teleterm/clusters/cluster_databases.go b/lib/teleterm/clusters/cluster_databases.go index 9843446c2feb3..55ac925c7012d 100644 --- a/lib/teleterm/clusters/cluster_databases.go +++ b/lib/teleterm/clusters/cluster_databases.go @@ -151,8 +151,8 @@ func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest) return response, nil } -// ReissueDBCerts issues new certificates for specific DB access and saves them to disk. -func (c *Cluster) ReissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error { +// reissueDBCerts issues new certificates for specific DB access and saves them to disk. +func (c *Cluster) reissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error { // When generating certificate for MongoDB access, database username must // be encoded into it. This is required to be able to tell which database // user to authenticate the connection as. diff --git a/lib/teleterm/clusters/cluster_gateways.go b/lib/teleterm/clusters/cluster_gateways.go index fca6f1d66b3c0..557e77aa0d14f 100644 --- a/lib/teleterm/clusters/cluster_gateways.go +++ b/lib/teleterm/clusters/cluster_gateways.go @@ -43,6 +43,21 @@ type CreateGatewayParams struct { // CreateGateway creates a gateway func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) (gateway.Gateway, error) { + switch { + case params.TargetURI.IsDB(): + gateway, err := c.createDBGateway(ctx, params) + return gateway, trace.Wrap(err) + + case params.TargetURI.IsKube(): + gateway, err := c.createKubeGateway(ctx, params) + return gateway, trace.Wrap(err) + + default: + return nil, trace.NotImplemented("gateway not supported for %v", params.TargetURI) + } +} + +func (c *Cluster) createDBGateway(ctx context.Context, params CreateGatewayParams) (gateway.Gateway, error) { db, err := c.GetDatabase(ctx, params.TargetURI) if err != nil { return nil, trace.Wrap(err) @@ -54,7 +69,7 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) Username: params.TargetUser, } - if err := c.ReissueDBCerts(ctx, routeToDatabase); err != nil { + if err := c.reissueDBCerts(ctx, routeToDatabase); err != nil { return nil, trace.Wrap(err) } @@ -83,3 +98,49 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) return gw, nil } + +func (c *Cluster) createKubeGateway(ctx context.Context, params CreateGatewayParams) (gateway.Gateway, error) { + kube := params.TargetURI.GetKubeName() + + if err := c.reissueKubeCert(ctx, kube); err != nil { + return nil, trace.Wrap(err) + } + + // TODO support TargetUser (--as), TargetGroups (--as-groups), TargetSubresourceName (--kube-namespace). + gw, err := gateway.New(gateway.Config{ + LocalPort: params.LocalPort, + TargetURI: params.TargetURI, + TargetName: kube, + KeyPath: c.status.KeyPath(), + CertPath: c.status.KubeCertPathForCluster(c.clusterClient.SiteName, kube), + Insecure: c.clusterClient.InsecureSkipVerify, + WebProxyAddr: c.clusterClient.WebProxyAddr, + Log: c.Log, + CLICommandProvider: params.CLICommandProvider, + TCPPortAllocator: params.TCPPortAllocator, + OnExpiredCert: params.OnExpiredCert, + Clock: c.clock, + TLSRoutingConnUpgradeRequired: c.clusterClient.TLSRoutingConnUpgradeRequired, + RootClusterCACertPoolFunc: c.clusterClient.RootClusterCACertPool, + ClusterName: c.Name, + Username: c.status.Username, + ProfileDir: c.status.Dir, + }) + return gw, trace.Wrap(err) +} + +// ReissueGatewayCerts reissues certificate for provided gateway. +func (c *Cluster) ReissueGatewayCerts(ctx context.Context, g gateway.Gateway) error { + switch { + case g.TargetURI().IsDB(): + db, err := gateway.AsDatabase(g) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(c.reissueDBCerts(ctx, db.RouteToDatabase())) + case g.TargetURI().IsKube(): + return trace.Wrap(c.reissueKubeCert(ctx, g.TargetName())) + default: + return nil + } +} diff --git a/lib/teleterm/clusters/cluster_kubes.go b/lib/teleterm/clusters/cluster_kubes.go index b56e92c2e4e8c..fc9fa91fea726 100644 --- a/lib/teleterm/clusters/cluster_kubes.go +++ b/lib/teleterm/clusters/cluster_kubes.go @@ -105,3 +105,24 @@ type GetKubesResponse struct { // // TotalCount is the total number of resources available as a whole. TotalCount int } + +// reissueKubeCert issue new certificates for kube cluster and saves them to disk. +func (c *Cluster) reissueKubeCert(ctx context.Context, kubeCluster string) error { + return trace.Wrap(AddMetadataToRetryableError(ctx, func() error { + // Refresh the certs to account for clusterClient.SiteName pointing at a leaf cluster. + err := c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ + RouteToCluster: c.clusterClient.SiteName, + AccessRequests: c.status.ActiveRequests.AccessRequests, + }) + if err != nil { + return trace.Wrap(err) + } + + // Fetch the certs for the kube cluster. + return trace.Wrap(c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ + RouteToCluster: c.clusterClient.SiteName, + KubernetesCluster: kubeCluster, + AccessRequests: c.status.ActiveRequests.AccessRequests, + })) + })) +} diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index 7efa6e5ac33eb..e18fa0236730b 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -38,7 +38,6 @@ type Config struct { PrehogAddr string GatewayCreator GatewayCreator - TCPPortAllocator gateway.TCPPortAllocator DBCLICommandProvider gateway.CLICommandProvider KubeCLICommandProvider gateway.CLICommandProvider // CreateTshdEventsClientCredsFunc lazily creates creds for the tshd events server ran by the @@ -60,10 +59,6 @@ func (c *Config) CheckAndSetDefaults() error { c.GatewayCreator = clusters.NewGatewayCreator(c.Storage) } - if c.TCPPortAllocator == nil { - c.TCPPortAllocator = gateway.NetTCPPortAllocator{} - } - if c.Log == nil { c.Log = logrus.NewEntry(logrus.StandardLogger()).WithField(trace.Component, "daemon") } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 16e5113c79466..1a528abcae5b6 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -253,6 +253,10 @@ func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) return nil, trace.Wrap(err) } + if gateway, ok := s.shouldReuseGateway(targetURI); ok { + return gateway, nil + } + cliCommandProvider, err := s.getGatewayCLICommandProvider(targetURI) if err != nil { return nil, trace.Wrap(err) @@ -264,7 +268,6 @@ func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) TargetSubresourceName: params.TargetSubresourceName, LocalPort: params.LocalPort, CLICommandProvider: cliCommandProvider, - TCPPortAllocator: s.cfg.TCPPortAllocator, OnExpiredCert: s.reissueGatewayCerts, } @@ -313,12 +316,7 @@ func (s *Service) reissueGatewayCerts(ctx context.Context, g gateway.Gateway) er return trace.Wrap(err) } - // TODO(greedy52) move cluster.ReissueDBCerts to cluster.ReissueGatewayCerts - db, err := gateway.AsDatabase(g) - if err != nil { - return trace.Wrap(err) - } - if err := cluster.ReissueDBCerts(ctx, db.RouteToDatabase()); err != nil { + if err := cluster.ReissueGatewayCerts(ctx, g); err != nil { return trace.Wrap(err) } @@ -732,6 +730,24 @@ func (s *Service) CreateConnectMyComputerRole(ctx context.Context, req *api.Crea return response, trace.Wrap(err) } +func (s *Service) shouldReuseGateway(targetURI uri.ResourceURI) (gateway.Gateway, bool) { + // A single gateway can be shared for all terminals of the same kube + // cluster. + if targetURI.IsKube() { + return s.findGatewayByTargetURI(targetURI) + } + return nil, false +} + +func (s *Service) findGatewayByTargetURI(targetURI uri.ResourceURI) (gateway.Gateway, bool) { + for _, gateway := range s.gateways { + if gateway.TargetURI() == targetURI { + return gateway, true + } + } + return nil, false +} + // Service is the daemon service type Service struct { cfg *Config diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index 6cef04d96e8e3..7b57e643c63a6 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -38,8 +38,9 @@ import ( ) type mockGatewayCreator struct { - t *testing.T - callCount int + t *testing.T + callCount int + tcpPortAllocator gateway.TCPPortAllocator } func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters.CreateGatewayParams) (gateway.Gateway, error) { @@ -51,20 +52,21 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters. }) keyPairPaths := gatewaytest.MustGenAndSaveCert(m.t, tlsca.Identity{ - Username: params.TargetUser, + Username: "user", Groups: []string{"test-group"}, RouteToDatabase: tlsca.RouteToDatabase{ ServiceName: params.TargetURI.GetDbName(), Protocol: defaults.ProtocolPostgres, Username: params.TargetUser, }, + KubernetesCluster: params.TargetURI.GetKubeName(), }) gateway, err := gateway.New(gateway.Config{ LocalPort: params.LocalPort, TargetURI: params.TargetURI, TargetUser: params.TargetUser, - TargetName: params.TargetURI.GetDbName(), + TargetName: params.TargetURI.GetDbName() + params.TargetURI.GetKubeName(), TargetSubresourceName: params.TargetSubresourceName, Protocol: defaults.ProtocolPostgres, CertPath: keyPairPaths.CertPath, @@ -72,7 +74,8 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters. Insecure: true, WebProxyAddr: hs.Listener.Addr().String(), CLICommandProvider: params.CLICommandProvider, - TCPPortAllocator: params.TCPPortAllocator, + TCPPortAllocator: m.tcpPortAllocator, + ProfileDir: m.t.TempDir(), }) if err != nil { return nil, trace.Wrap(err) @@ -95,16 +98,18 @@ type gatewayCRUDTestContext struct { func TestGatewayCRUD(t *testing.T) { t.Parallel() tests := []struct { - name string - gatewayNamesToCreate []string + name string + gatewayNamesToCreate []string + appendGatewayTargetURI func(name string) uri.ResourceURI // tcpPortAllocator is an optional field which lets us provide a custom // gatewaytest.MockTCPPortAllocator with some ports already in use. tcpPortAllocator *gatewaytest.MockTCPPortAllocator testFunc func(*testing.T, *gatewayCRUDTestContext, *Service) }{ { - name: "create then find", - gatewayNamesToCreate: []string{"gateway"}, + name: "create then find", + gatewayNamesToCreate: []string{"gateway"}, + appendGatewayTargetURI: uri.NewClusterURI("foo").AppendDB, testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) { createdGateway := c.nameToGateway["gateway"] foundGateway, err := daemon.findGateway(createdGateway.URI().String()) @@ -113,8 +118,9 @@ func TestGatewayCRUD(t *testing.T) { }, }, { - name: "ListGateways", - gatewayNamesToCreate: []string{"gateway1", "gateway2"}, + name: "ListGateways", + gatewayNamesToCreate: []string{"gateway1", "gateway2"}, + appendGatewayTargetURI: uri.NewClusterURI("foo").AppendDB, testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) { gateways := daemon.ListGateways() gatewayURIs := map[uri.ResourceURI]struct{}{} @@ -129,8 +135,9 @@ func TestGatewayCRUD(t *testing.T) { }, }, { - name: "RemoveGateway", - gatewayNamesToCreate: []string{"gatewayToRemove", "gatewayToKeep"}, + name: "RemoveGateway", + gatewayNamesToCreate: []string{"gatewayToRemove", "gatewayToKeep"}, + appendGatewayTargetURI: uri.NewClusterURI("foo").AppendDB, testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) { gatewayToRemove := c.nameToGateway["gatewayToRemove"] gatewayToKeep := c.nameToGateway["gatewayToKeep"] @@ -145,8 +152,9 @@ func TestGatewayCRUD(t *testing.T) { }, }, { - name: "SetGatewayLocalPort closes previous gateway if new port is free", - gatewayNamesToCreate: []string{"gateway"}, + name: "SetGatewayLocalPort closes previous gateway if new port is free", + gatewayNamesToCreate: []string{"gateway"}, + appendGatewayTargetURI: uri.NewClusterURI("foo").AppendDB, testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) { oldGateway := c.nameToGateway["gateway"] oldListener := c.mockTCPPortAllocator.RecentListener() @@ -171,9 +179,10 @@ func TestGatewayCRUD(t *testing.T) { }, }, { - name: "SetGatewayLocalPort doesn't close or modify previous gateway if new port is occupied", - gatewayNamesToCreate: []string{"gateway"}, - tcpPortAllocator: &gatewaytest.MockTCPPortAllocator{PortsInUse: []string{"12345"}}, + name: "SetGatewayLocalPort doesn't close or modify previous gateway if new port is occupied", + gatewayNamesToCreate: []string{"gateway"}, + appendGatewayTargetURI: uri.NewClusterURI("foo").AppendDB, + tcpPortAllocator: &gatewaytest.MockTCPPortAllocator{PortsInUse: []string{"12345"}}, testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) { gateway := c.nameToGateway["gateway"] gatewayAddress := net.JoinHostPort(gateway.LocalAddress(), gateway.LocalPort()) @@ -190,8 +199,9 @@ func TestGatewayCRUD(t *testing.T) { }, }, { - name: "SetGatewayLocalPort is a noop if new port is equal to old port", - gatewayNamesToCreate: []string{"gateway"}, + name: "SetGatewayLocalPort is a noop if new port is equal to old port", + gatewayNamesToCreate: []string{"gateway"}, + appendGatewayTargetURI: uri.NewClusterURI("foo").AppendDB, testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) { gateway := c.nameToGateway["gateway"] localPort := gateway.LocalPort() @@ -203,6 +213,19 @@ func TestGatewayCRUD(t *testing.T) { require.Equal(t, 1, c.mockTCPPortAllocator.CallCount) }, }, + { + name: "CreateGateway returns existing kube gateway if targetURI is the same", + gatewayNamesToCreate: []string{"kube-gateway"}, + appendGatewayTargetURI: uri.NewClusterURI("foo").AppendKube, + testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) { + wantGateway := c.nameToGateway["kube-gateway"] + actualGateway, err := daemon.CreateGateway(context.Background(), CreateGatewayParams{ + TargetURI: wantGateway.TargetURI().String(), + }) + require.NoError(t, err) + require.Equal(t, wantGateway, actualGateway) + }, + }, } for _, tt := range tests { @@ -215,7 +238,10 @@ func TestGatewayCRUD(t *testing.T) { } homeDir := t.TempDir() - mockGatewayCreator := &mockGatewayCreator{t: t} + mockGatewayCreator := &mockGatewayCreator{ + t: t, + tcpPortAllocator: tt.tcpPortAllocator, + } storage, err := clusters.NewStorage(clusters.Config{ Dir: homeDir, @@ -224,9 +250,8 @@ func TestGatewayCRUD(t *testing.T) { require.NoError(t, err) daemon, err := New(Config{ - Storage: storage, - GatewayCreator: mockGatewayCreator, - TCPPortAllocator: tt.tcpPortAllocator, + Storage: storage, + GatewayCreator: mockGatewayCreator, }) require.NoError(t, err) @@ -235,7 +260,7 @@ func TestGatewayCRUD(t *testing.T) { for _, gatewayName := range tt.gatewayNamesToCreate { gatewayName := gatewayName gateway, err := daemon.CreateGateway(context.Background(), CreateGatewayParams{ - TargetURI: uri.NewClusterURI("foo").AppendDB(gatewayName).String(), + TargetURI: tt.appendGatewayTargetURI(gatewayName).String(), TargetUser: "alice", TargetSubresourceName: "", LocalPort: "",