diff --git a/lib/asciitable/table.go b/lib/asciitable/table.go index f5ec9b42e3acd..f7c91009e9e70 100644 --- a/lib/asciitable/table.go +++ b/lib/asciitable/table.go @@ -25,6 +25,7 @@ import ( "strings" "text/tabwriter" + "golang.org/x/exp/slices" "golang.org/x/term" ) @@ -208,6 +209,30 @@ func (t *Table) IsHeadless() bool { return true } +// SortRowsBy sorts the table rows with the given column indices as the sorting +// key, optionally performing a stable sort. Column indices out of range are +// ignored - it is the caller's responsibility to ensure the indices are in +// range. +func (t *Table) SortRowsBy(colIdxKey []int, stable bool) { + lessFn := func(a, b []string) bool { + for _, col := range colIdxKey { + limit := min(len(a), len(b)) + if col >= limit { + continue + } + if a[col] != b[col] { + return a[col] < b[col] + } + } + return false + } + if stable { + slices.SortStableFunc(t.rows, lessFn) + } else { + slices.SortFunc(t.rows, lessFn) + } +} + func min(a, b int) int { if a < b { return a diff --git a/lib/kube/kubeconfig/context_overrride.go b/lib/kube/kubeconfig/context_overwrite.go similarity index 100% rename from lib/kube/kubeconfig/context_overrride.go rename to lib/kube/kubeconfig/context_overwrite.go diff --git a/lib/kube/kubeconfig/context_overrride_test.go b/lib/kube/kubeconfig/context_overwrite_test.go similarity index 100% rename from lib/kube/kubeconfig/context_overrride_test.go rename to lib/kube/kubeconfig/context_overwrite_test.go diff --git a/lib/kube/kubeconfig/kubeconfig.go b/lib/kube/kubeconfig/kubeconfig.go index 9b79935f91dae..c6c81b314d201 100644 --- a/lib/kube/kubeconfig/kubeconfig.go +++ b/lib/kube/kubeconfig/kubeconfig.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/tools/clientcmd" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" @@ -37,6 +38,12 @@ var log = logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentKubeClient, }) +const ( + // teleportKubeClusterNameExtension is the name of the extension that + // contains the Teleport Kube cluster name. + teleportKubeClusterNameExtension = "teleport.kube.name" +) + // Values are Teleport user data needed to generate kubeconfig entries. type Values struct { // TeleportClusterName is used to name kubeconfig sections ("context", "cluster" and @@ -173,7 +180,7 @@ func Update(path string, v Values, storeAllCAs bool) error { } config.AuthInfos[authName] = authInfo - setContext(config.Contexts, contextName, clusterName, authName, v.Namespace) + setContext(config.Contexts, contextName, clusterName, authName, c, v.Namespace) } if v.SelectCluster != "" { contextName := ContextName(v.TeleportClusterName, v.SelectCluster) @@ -199,9 +206,9 @@ func Update(path string, v Values, storeAllCAs bool) error { clusterName := v.TeleportClusterName contextName := clusterName - + var kubeClusterName string if len(v.KubeClusters) == 1 { - kubeClusterName := v.KubeClusters[0] + kubeClusterName = v.KubeClusters[0] contextName = ContextName(clusterName, kubeClusterName) } @@ -222,7 +229,7 @@ func Update(path string, v Values, storeAllCAs bool) error { ClientCertificateData: v.Credentials.TLSCert, ClientKeyData: rsaKeyPEM, } - setContext(config.Contexts, contextName, clusterName, contextName, v.Namespace) + setContext(config.Contexts, contextName, clusterName, contextName, kubeClusterName, v.Namespace) setSelectedExtension(config.Contexts, config.CurrentContext, clusterName) config.CurrentContext = contextName } else if !trace.IsBadParameter(err) { @@ -234,7 +241,7 @@ func Update(path string, v Values, storeAllCAs bool) error { return Save(path, *config) } -func setContext(contexts map[string]*clientcmdapi.Context, name, cluster, auth string, namespace string) { +func setContext(contexts map[string]*clientcmdapi.Context, name, cluster, auth, kubeName, namespace string) { lastContext := contexts[name] newContext := &clientcmdapi.Context{ Cluster: cluster, @@ -245,6 +252,16 @@ func setContext(contexts map[string]*clientcmdapi.Context, name, cluster, auth s newContext.Extensions = lastContext.Extensions } + if newContext.Extensions == nil { + newContext.Extensions = make(map[string]runtime.Object) + } + if kubeName != "" { + newContext.Extensions[teleportKubeClusterNameExtension] = &runtime.Unknown{ + // We need to wrap the kubeName in quotes to make sure it is parsed as a string. + Raw: []byte(fmt.Sprintf("%q", kubeName)), + } + } + // If a user specifies the default namespace we should override it. // Otherwise we should carry the namespace previously defined for the context. if len(namespace) > 0 { @@ -395,13 +412,29 @@ func ContextName(teleportCluster, kubeCluster string) string { // KubeClusterFromContext extracts the kubernetes cluster name from context // name generated by this package. -func KubeClusterFromContext(contextName, teleportCluster string) string { - // If context name doesn't start with teleport cluster name, it was not +func KubeClusterFromContext(contextName string, ctx *clientcmdapi.Context, teleportCluster string) string { + switch { + // If the context name starts with teleport cluster name, it was // generated by tsh. - if !strings.HasPrefix(contextName, teleportCluster+"-") { + case strings.HasPrefix(contextName, teleportCluster+"-"): + return strings.TrimPrefix(contextName, teleportCluster+"-") + // If the context cluster matches teleport cluster, it was generated by + // tsh using --set-context-override flag. + case ctx != nil && ctx.Cluster == teleportCluster: + if v, ok := ctx.Extensions[teleportKubeClusterNameExtension]; ok { + if raw, ok := v.(*runtime.Unknown); ok && trimQuotes(string(raw.Raw)) != "" { + // The value is a JSON string, so we need to trim the quotes. + return trimQuotes(string(raw.Raw)) + } + } + return contextName + default: return "" } - return strings.TrimPrefix(contextName, teleportCluster+"-") +} + +func trimQuotes(s string) string { + return strings.TrimSuffix(strings.TrimPrefix(s, "\""), "\"") } // SelectContext switches the active kubeconfig context to point to the @@ -475,7 +508,10 @@ func SelectedKubeCluster(path, teleportCluster string) (string, error) { return "", trace.Wrap(err) } - if kubeCluster := KubeClusterFromContext(kubeconfig.CurrentContext, teleportCluster); kubeCluster != "" { + if kubeCluster := KubeClusterFromContext( + kubeconfig.CurrentContext, + kubeconfig.Contexts[kubeconfig.CurrentContext], + teleportCluster); kubeCluster != "" { return kubeCluster, nil } return "", trace.NotFound("default context does not belong to Teleport") diff --git a/lib/kube/kubeconfig/kubeconfig_test.go b/lib/kube/kubeconfig/kubeconfig_test.go index 3bedebb45a7c6..96358ce3b4b51 100644 --- a/lib/kube/kubeconfig/kubeconfig_test.go +++ b/lib/kube/kubeconfig/kubeconfig_test.go @@ -319,8 +319,13 @@ func TestUpdateWithExec(t *testing.T) { Cluster: clusterName, AuthInfo: authInfoName, LocationOfOrigin: kubeconfigPath, - Extensions: map[string]runtime.Object{}, - Namespace: tt.namespace, + Extensions: map[string]runtime.Object{ + teleportKubeClusterNameExtension: &runtime.Unknown{ + Raw: []byte(fmt.Sprintf("%q", kubeCluster)), + ContentType: "application/json", + }, + }, + Namespace: tt.namespace, } config, err := Load(kubeconfigPath) require.NoError(t, err) @@ -386,7 +391,12 @@ func TestUpdateWithExecAndProxy(t *testing.T) { Cluster: clusterName, AuthInfo: contextName, LocationOfOrigin: kubeconfigPath, - Extensions: map[string]runtime.Object{}, + Extensions: map[string]runtime.Object{ + teleportKubeClusterNameExtension: &runtime.Unknown{ + Raw: []byte(fmt.Sprintf("%q", kubeCluster)), + ContentType: "application/json", + }, + }, } config, err := Load(kubeconfigPath) @@ -577,3 +587,75 @@ func genUserKey(hostname string) (*client.Key, []byte, error) { }}, }, caCert, nil } + +func TestKubeClusterFromContext(t *testing.T) { + type args struct { + contextName string + ctx *clientcmdapi.Context + teleportCluster string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "context name is cluster name", + args: args{ + contextName: "cluster1", + ctx: &clientcmdapi.Context{Cluster: "cluster1"}, + teleportCluster: "cluster1", + }, + want: "cluster1", + }, + { + name: "context name is {teleport-cluster}-cluster name", + args: args{ + contextName: "telecluster-cluster1", + ctx: &clientcmdapi.Context{Cluster: "cluster1"}, + teleportCluster: "telecluster", + }, + want: "cluster1", + }, + { + name: "context name is {kube-cluster} name", + args: args{ + contextName: "cluster1", + ctx: &clientcmdapi.Context{Cluster: "telecluster"}, + teleportCluster: "telecluster", + }, + want: "cluster1", + }, + { + name: "kube cluster name is set in extension", + args: args{ + contextName: "cluster1", + ctx: &clientcmdapi.Context{ + Cluster: "telecluster", + Extensions: map[string]runtime.Object{ + teleportKubeClusterNameExtension: &runtime.Unknown{ + Raw: []byte("\"another\""), + }, + }, + }, + teleportCluster: "telecluster", + }, + want: "another", + }, + { + name: "context isn't from teleport", + args: args{ + contextName: "cluster1", + ctx: &clientcmdapi.Context{Cluster: "someothercluster"}, + teleportCluster: "telecluster", + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := KubeClusterFromContext(tt.args.contextName, tt.args.ctx, tt.args.teleportCluster) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/lib/kube/kubeconfig/localproxy.go b/lib/kube/kubeconfig/localproxy.go index 4b35f0415796d..04ba30ecad79e 100644 --- a/lib/kube/kubeconfig/localproxy.go +++ b/lib/kube/kubeconfig/localproxy.go @@ -131,8 +131,8 @@ func LocalProxyClustersFromDefaultConfig(defaultConfig *clientcmdapi.Config, clu continue } - for contextName, context := range defaultConfig.Contexts { - if context.Cluster != teleportClusterName { + for contextName, ctx := range defaultConfig.Contexts { + if ctx.Cluster != teleportClusterName { continue } auth, found := defaultConfig.AuthInfos[contextName] @@ -142,8 +142,8 @@ func LocalProxyClustersFromDefaultConfig(defaultConfig *clientcmdapi.Config, clu clusters = append(clusters, LocalProxyCluster{ TeleportCluster: teleportClusterName, - KubeCluster: KubeClusterFromContext(contextName, teleportClusterName), - Namespace: context.Namespace, + KubeCluster: KubeClusterFromContext(contextName, ctx, teleportClusterName), + Namespace: ctx.Namespace, Impersonate: auth.Impersonate, ImpersonateGroups: auth.ImpersonateGroups, }) @@ -178,7 +178,7 @@ func FindTeleportClusterForLocalProxy(defaultConfig *clientcmdapi.Config, cluste return LocalProxyCluster{ TeleportCluster: context.Cluster, - KubeCluster: KubeClusterFromContext(contextName, context.Cluster), + KubeCluster: KubeClusterFromContext(contextName, context, context.Cluster), Namespace: context.Namespace, Impersonate: auth.Impersonate, ImpersonateGroups: auth.ImpersonateGroups, diff --git a/tool/common/common.go b/tool/common/common.go index a4cc39fe4f252..30ae3bccc56a5 100644 --- a/tool/common/common.go +++ b/tool/common/common.go @@ -171,3 +171,19 @@ func FormatLabels(labels map[string]string, verbose bool) string { namespaced = append(namespaced, teleportNamespaced...) return strings.Join(append(result, namespaced...), ",") } + +// FormatResourceName returns the resource's name or its name as originally +// discovered in the cloud by the Teleport Discovery Service. +// In verbose mode, it always returns the resource name. +// In non-verbose mode, if the resource came from discovery and has the +// discovered name label, it returns the discovered name. +func FormatResourceName(r types.ResourceWithLabels, verbose bool) string { + if !verbose { + // return the (shorter) discovered name in non-verbose mode. + discoveredName, ok := r.GetAllLabels()[types.DiscoveredNameLabel] + if ok && discoveredName != "" { + return discoveredName + } + } + return r.GetName() +} diff --git a/tool/tctl/common/collection.go b/tool/tctl/common/collection.go index 932043abc767a..0389717af9c24 100644 --- a/tool/tctl/common/collection.go +++ b/tool/tctl/common/collection.go @@ -688,7 +688,7 @@ func (c *databaseServerCollection) writeText(w io.Writer, verbose bool) error { labels := common.FormatLabels(server.GetDatabase().GetAllLabels(), verbose) rows = append(rows, []string{ server.GetHostname(), - server.GetDatabase().GetName(), + common.FormatResourceName(server.GetDatabase(), verbose), server.GetDatabase().GetProtocol(), server.GetDatabase().GetURI(), labels, @@ -702,6 +702,8 @@ func (c *databaseServerCollection) writeText(w io.Writer, verbose bool) error { } else { t = asciitable.MakeTableWithTruncatedColumn(headers, rows, "Labels") } + // stable sort by hostname then by name. + t.SortRowsBy([]int{0, 1}, true) _, err := t.AsBuffer().WriteTo(w) return trace.Wrap(err) } @@ -730,7 +732,10 @@ func (c *databaseCollection) writeText(w io.Writer, verbose bool) error { for _, database := range c.databases { labels := common.FormatLabels(database.GetAllLabels(), verbose) rows = append(rows, []string{ - database.GetName(), database.GetProtocol(), database.GetURI(), labels, + common.FormatResourceName(database, verbose), + database.GetProtocol(), + database.GetURI(), + labels, }) } headers := []string{"Name", "Protocol", "URI", "Labels"} @@ -740,6 +745,8 @@ func (c *databaseCollection) writeText(w io.Writer, verbose bool) error { } else { t = asciitable.MakeTableWithTruncatedColumn(headers, rows, "Labels") } + // stable sort by name. + t.SortRowsBy([]int{0}, true) _, err := t.AsBuffer().WriteTo(w) return trace.Wrap(err) } @@ -870,7 +877,7 @@ func (c *kubeServerCollection) writeText(w io.Writer, verbose bool) error { } labels := common.FormatLabels(kube.GetAllLabels(), verbose) rows = append(rows, []string{ - kube.GetName(), + common.FormatResourceName(kube, verbose), labels, server.GetTeleportVersion(), }) @@ -883,6 +890,8 @@ func (c *kubeServerCollection) writeText(w io.Writer, verbose bool) error { } else { t = asciitable.MakeTableWithTruncatedColumn(headers, rows, "Labels") } + // stable sort by cluster name. + t.SortRowsBy([]int{0}, true) _, err := t.AsBuffer().WriteTo(w) return trace.Wrap(err) @@ -916,12 +925,12 @@ func (c *kubeClusterCollection) resources() (r []types.Resource) { // cluster4 owner=cluster4,region=southcentralus,resource-group=cluster4,subscription-id=subID // If verbose is disabled, labels column can be truncated to fit into the console. func (c *kubeClusterCollection) writeText(w io.Writer, verbose bool) error { - sort.Sort(types.KubeClusters(c.clusters)) var rows [][]string for _, cluster := range c.clusters { labels := common.FormatLabels(cluster.GetAllLabels(), verbose) rows = append(rows, []string{ - cluster.GetName(), labels, + common.FormatResourceName(cluster, verbose), + labels, }) } headers := []string{"Name", "Labels"} @@ -931,6 +940,8 @@ func (c *kubeClusterCollection) writeText(w io.Writer, verbose bool) error { } else { t = asciitable.MakeTableWithTruncatedColumn(headers, rows, "Labels") } + // stable sort by name. + t.SortRowsBy([]int{0}, true) _, err := t.AsBuffer().WriteTo(w) return trace.Wrap(err) } diff --git a/tool/tctl/common/collection_test.go b/tool/tctl/common/collection_test.go index 88af8c71f8a95..9c54303b8d76f 100644 --- a/tool/tctl/common/collection_test.go +++ b/tool/tctl/common/collection_test.go @@ -21,19 +21,25 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/tool/common" ) var ( - staticLabels = map[string]string{ + staticLabelsFixture = map[string]string{ "label1": "val1", "label2": "val2", "label3": "val3", } + longLabelFixture = map[string]string{ + "ultra_long_label_for_teleport_collection_text_table_formatting": "ultra_long_label_for_teleport_collection_text_table_formatting", + } ) func TestDatabaseResourceMatchersToString(t *testing.T) { @@ -51,80 +57,242 @@ func TestDatabaseResourceMatchersToString(t *testing.T) { require.Equal(t, databaseResourceMatchersToString(resMatch), "(Labels: x=[y])") } -func Test_kubeClusterCollection_writeText(t *testing.T) { - extraLabel := map[string]string{ - "ultra_long_label_for_teleport_kubernetes_list_kube_clusters_method": "ultra_long_label_value_for_teleport_kubernetes_list_kube_clusters_method", +type writeTextTest struct { + collection ResourceCollection + wantVerboseTable func() string + wantNonVerboseTable func() string +} + +func (test *writeTextTest) run(t *testing.T) { + t.Helper() + t.Run("verbose mode", func(t *testing.T) { + t.Helper() + w := &bytes.Buffer{} + err := test.collection.writeText(w, true) + require.NoError(t, err) + diff := cmp.Diff(test.wantVerboseTable(), w.String()) + require.Empty(t, diff) + }) + t.Run("non-verbose mode", func(t *testing.T) { + t.Helper() + w := &bytes.Buffer{} + err := test.collection.writeText(w, false) + require.NoError(t, err) + diff := cmp.Diff(test.wantNonVerboseTable(), w.String()) + require.Empty(t, diff) + }) +} + +func TestResourceCollection_writeText(t *testing.T) { + t.Run("kube clusters", testKubeClusterCollection_writeText) + t.Run("kube servers", testKubeServerCollection_writeText) + t.Run("databases", testDatabaseCollection_writeText) + t.Run("database servers", testDatabaseServerCollection_writeText) +} + +func testKubeClusterCollection_writeText(t *testing.T) { + eksDiscoveredNameLabel := map[string]string{ + types.DiscoveredNameLabel: "cluster3", } kubeClusters := []types.KubeCluster{ mustCreateNewKubeCluster(t, "cluster1", nil), - mustCreateNewKubeCluster(t, "cluster2", extraLabel), - mustCreateNewKubeCluster(t, "afirstCluster", extraLabel), + mustCreateNewKubeCluster(t, "cluster2", longLabelFixture), + mustCreateNewKubeCluster(t, "afirstCluster", longLabelFixture), + mustCreateNewKubeCluster(t, "cluster3-eks-us-west-1-123456789012", eksDiscoveredNameLabel), } - type fields struct { - verbose bool - } - tests := []struct { - name string - fields fields - wantTable func() string - }{ - { - name: "non-verbose mode", - fields: fields{verbose: false}, - wantTable: func() string { - table := asciitable.MakeTableWithTruncatedColumn( - []string{"Name", "Labels"}, - [][]string{ - {"afirstCluster", formatTestLabels(staticLabels, extraLabel, false)}, - {"cluster1", formatTestLabels(staticLabels, nil, false)}, - {"cluster2", formatTestLabels(staticLabels, extraLabel, false)}, - }, - "Labels") - return table.AsBuffer().String() - }, + test := writeTextTest{ + collection: &kubeClusterCollection{clusters: kubeClusters}, + wantNonVerboseTable: func() string { + table := asciitable.MakeTableWithTruncatedColumn( + []string{"Name", "Labels"}, + [][]string{ + {"afirstCluster", formatTestLabels(staticLabelsFixture, longLabelFixture, false)}, + {"cluster1", formatTestLabels(staticLabelsFixture, nil, false)}, + {"cluster2", formatTestLabels(staticLabelsFixture, longLabelFixture, false)}, + {"cluster3", formatTestLabels(staticLabelsFixture, eksDiscoveredNameLabel, false)}, + }, + "Labels") + return table.AsBuffer().String() }, - { - name: "verbose mode", - fields: fields{verbose: true}, - wantTable: func() string { - table := asciitable.MakeTable( - []string{"Name", "Labels"}, - []string{"afirstCluster", formatTestLabels(staticLabels, extraLabel, true)}, - []string{"cluster1", formatTestLabels(staticLabels, nil, true)}, - []string{"cluster2", formatTestLabels(staticLabels, extraLabel, true)}, - ) - return table.AsBuffer().String() - }, + wantVerboseTable: func() string { + table := asciitable.MakeTable( + []string{"Name", "Labels"}, + []string{"afirstCluster", formatTestLabels(staticLabelsFixture, longLabelFixture, true)}, + []string{"cluster1", formatTestLabels(staticLabelsFixture, nil, true)}, + []string{"cluster2", formatTestLabels(staticLabelsFixture, longLabelFixture, true)}, + []string{"cluster3-eks-us-west-1-123456789012", formatTestLabels(staticLabelsFixture, eksDiscoveredNameLabel, true)}, + ) + return table.AsBuffer().String() }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &kubeClusterCollection{ - clusters: kubeClusters, - } - w := &bytes.Buffer{} - err := c.writeText(w, tt.fields.verbose) - require.NoError(t, err) - require.Contains(t, w.String(), tt.wantTable()) - }) - } + test.run(t) } -func mustCreateNewKubeCluster(t *testing.T, name string, extraStaticLabels map[string]string) types.KubeCluster { - labels := make(map[string]string) +func testKubeServerCollection_writeText(t *testing.T) { + eksDiscoveredNameLabel := map[string]string{ + types.DiscoveredNameLabel: "cluster3", + } + kubeServers := []types.KubeServer{ + mustCreateNewKubeServer(t, "cluster1", "_", nil), + mustCreateNewKubeServer(t, "cluster2", "_", longLabelFixture), + mustCreateNewKubeServer(t, "afirstCluster", "_", longLabelFixture), + mustCreateNewKubeServer(t, "cluster3-eks-us-west-1-123456789012", "_", eksDiscoveredNameLabel), + } + test := writeTextTest{ + collection: &kubeServerCollection{servers: kubeServers}, + wantNonVerboseTable: func() string { + table := asciitable.MakeTableWithTruncatedColumn( + []string{"Cluster", "Labels", "Version"}, + [][]string{ + {"afirstCluster", formatTestLabels(staticLabelsFixture, longLabelFixture, false), api.Version}, + {"cluster1", formatTestLabels(staticLabelsFixture, nil, false), api.Version}, + {"cluster2", formatTestLabels(staticLabelsFixture, longLabelFixture, false), api.Version}, + {"cluster3", formatTestLabels(staticLabelsFixture, eksDiscoveredNameLabel, false), api.Version}, + }, + "Labels") + return table.AsBuffer().String() + }, + wantVerboseTable: func() string { + table := asciitable.MakeTable( + []string{"Cluster", "Labels", "Version"}, + []string{"afirstCluster", formatTestLabels(staticLabelsFixture, longLabelFixture, true), api.Version}, + []string{"cluster1", formatTestLabels(staticLabelsFixture, nil, true), api.Version}, + []string{"cluster2", formatTestLabels(staticLabelsFixture, longLabelFixture, true), api.Version}, + []string{"cluster3-eks-us-west-1-123456789012", formatTestLabels(staticLabelsFixture, eksDiscoveredNameLabel, true), api.Version}, + ) + return table.AsBuffer().String() + }, + } + test.run(t) +} - for k, v := range staticLabels { - labels[k] = v +func testDatabaseCollection_writeText(t *testing.T) { + rdsDiscoveredNameLabel := map[string]string{ + types.DiscoveredNameLabel: "database", + } + rdsURI := "database.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432" + databases := []types.Database{ + mustCreateNewDatabase(t, "database-A", "mysql", "localhost:3306", nil), + mustCreateNewDatabase(t, "database-B", "postgres", "localhost:5432", longLabelFixture), + mustCreateNewDatabase(t, "afirstDatabase", "redis", "localhost:6379", longLabelFixture), + mustCreateNewDatabase(t, "database-rds-us-west-1-123456789012", "postgres", + rdsURI, + rdsDiscoveredNameLabel), + } + test := writeTextTest{ + collection: &databaseCollection{databases: databases}, + wantNonVerboseTable: func() string { + table := asciitable.MakeTableWithTruncatedColumn( + []string{"Name", "Protocol", "URI", "Labels"}, + [][]string{ + {"afirstDatabase", "redis", "localhost:6379", formatTestLabels(staticLabelsFixture, longLabelFixture, false)}, + {"database", "postgres", rdsURI, formatTestLabels(staticLabelsFixture, rdsDiscoveredNameLabel, false)}, + {"database-A", "mysql", "localhost:3306", formatTestLabels(staticLabelsFixture, nil, false)}, + {"database-B", "postgres", "localhost:5432", formatTestLabels(staticLabelsFixture, longLabelFixture, false)}, + }, + "Labels") + return table.AsBuffer().String() + }, + wantVerboseTable: func() string { + table := asciitable.MakeTable( + []string{"Name", "Protocol", "URI", "Labels"}, + []string{"afirstDatabase", "redis", "localhost:6379", formatTestLabels(staticLabelsFixture, longLabelFixture, true)}, + []string{"database-A", "mysql", "localhost:3306", formatTestLabels(staticLabelsFixture, nil, true)}, + []string{"database-B", "postgres", "localhost:5432", formatTestLabels(staticLabelsFixture, longLabelFixture, true)}, + []string{"database-rds-us-west-1-123456789012", "postgres", rdsURI, formatTestLabels(staticLabelsFixture, rdsDiscoveredNameLabel, true)}, + ) + return table.AsBuffer().String() + }, } + test.run(t) +} - for k, v := range extraStaticLabels { - labels[k] = v +func testDatabaseServerCollection_writeText(t *testing.T) { + rdsDiscoveredNameLabel := map[string]string{ + types.DiscoveredNameLabel: "database", } + rdsURI := "database.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432" + dbServers := []types.DatabaseServer{ + mustCreateNewDatabaseServer(t, "database-A", "mysql", "localhost:3306", nil), + mustCreateNewDatabaseServer(t, "database-B", "postgres", "localhost:5432", longLabelFixture), + mustCreateNewDatabaseServer(t, "afirstDatabase", "redis", "localhost:6379", longLabelFixture), + mustCreateNewDatabaseServer(t, "database-rds-us-west-1-123456789012", "postgres", + rdsURI, + rdsDiscoveredNameLabel), + } + test := writeTextTest{ + collection: &databaseServerCollection{servers: dbServers}, + wantNonVerboseTable: func() string { + table := asciitable.MakeTableWithTruncatedColumn( + []string{"Host", "Name", "Protocol", "URI", "Labels", "Version"}, + [][]string{ + {"some-host", "afirstDatabase", "redis", "localhost:6379", formatTestLabels(staticLabelsFixture, longLabelFixture, false), api.Version}, + {"some-host", "database", "postgres", rdsURI, formatTestLabels(staticLabelsFixture, rdsDiscoveredNameLabel, false), api.Version}, + {"some-host", "database-A", "mysql", "localhost:3306", formatTestLabels(staticLabelsFixture, nil, false), api.Version}, + {"some-host", "database-B", "postgres", "localhost:5432", formatTestLabels(staticLabelsFixture, longLabelFixture, false), api.Version}, + }, + "Labels") + return table.AsBuffer().String() + }, + wantVerboseTable: func() string { + table := asciitable.MakeTable( + []string{"Host", "Name", "Protocol", "URI", "Labels", "Version"}, + []string{"some-host", "afirstDatabase", "redis", "localhost:6379", formatTestLabels(staticLabelsFixture, longLabelFixture, true), api.Version}, + []string{"some-host", "database-A", "mysql", "localhost:3306", formatTestLabels(staticLabelsFixture, nil, true), api.Version}, + []string{"some-host", "database-B", "postgres", "localhost:5432", formatTestLabels(staticLabelsFixture, longLabelFixture, true), api.Version}, + []string{"some-host", "database-rds-us-west-1-123456789012", "postgres", rdsURI, formatTestLabels(staticLabelsFixture, rdsDiscoveredNameLabel, true), api.Version}, + ) + return table.AsBuffer().String() + }, + } + test.run(t) +} + +func mustCreateNewDatabase(t *testing.T, name, protocol, uri string, extraStaticLabels map[string]string) *types.DatabaseV3 { + t.Helper() + db, err := types.NewDatabaseV3( + types.Metadata{ + Name: name, + Labels: makeTestLabels(extraStaticLabels), + }, + types.DatabaseSpecV3{ + Protocol: protocol, + URI: uri, + DynamicLabels: map[string]types.CommandLabelV2{ + "date": { + Period: types.NewDuration(1 * time.Second), + Command: []string{"date"}, + Result: "Tue 11 Oct 2022 10:21:58 WEST", + }, + }, + }, + ) + require.NoError(t, err) + return db +} +func mustCreateNewDatabaseServer(t *testing.T, name, protocol, uri string, extraStaticLabels map[string]string) types.DatabaseServer { + t.Helper() + dbServer, err := types.NewDatabaseServerV3( + types.Metadata{ + Name: name, + Labels: makeTestLabels(extraStaticLabels), + }, types.DatabaseServerSpecV3{ + HostID: "some-hostid", + Hostname: "some-host", + Database: mustCreateNewDatabase(t, name, protocol, uri, extraStaticLabels), + }) + require.NoError(t, err) + + return dbServer +} + +func mustCreateNewKubeCluster(t *testing.T, name string, extraStaticLabels map[string]string) *types.KubernetesClusterV3 { + t.Helper() cluster, err := types.NewKubernetesClusterV3( types.Metadata{ Name: name, - Labels: labels, + Labels: makeTestLabels(extraStaticLabels), }, types.KubernetesClusterSpecV3{ DynamicLabels: map[string]types.CommandLabelV2{ @@ -140,6 +308,14 @@ func mustCreateNewKubeCluster(t *testing.T, name string, extraStaticLabels map[s return cluster } +func mustCreateNewKubeServer(t *testing.T, name, hostname string, extraStaticLabels map[string]string) *types.KubernetesServerV3 { + t.Helper() + cluster := mustCreateNewKubeCluster(t, name, extraStaticLabels) + kubeServer, err := types.NewKubernetesServerV3FromCluster(cluster, hostname, uuid.New().String()) + require.NoError(t, err) + return kubeServer +} + func formatTestLabels(l1, l2 map[string]string, verbose bool) string { labels := map[string]string{ "date": "Tue 11 Oct 2022 10:21:58 WEST", @@ -153,3 +329,14 @@ func formatTestLabels(l1, l2 map[string]string, verbose bool) string { } return common.FormatLabels(labels, verbose) } + +func makeTestLabels(extraStaticLabels map[string]string) map[string]string { + labels := make(map[string]string) + for k, v := range staticLabelsFixture { + labels[k] = v + } + for k, v := range extraStaticLabels { + labels[k] = v + } + return labels +} diff --git a/tool/tctl/common/db_command.go b/tool/tctl/common/db_command.go index e3eae83178113..f0f5799de7d98 100644 --- a/tool/tctl/common/db_command.go +++ b/tool/tctl/common/db_command.go @@ -84,24 +84,17 @@ func (c *DBCommand) ListDatabases(ctx context.Context, clt auth.ClientI) error { return trace.Wrap(err) } - var servers []types.DatabaseServer - resources, err := client.GetResourcesWithFilters(ctx, clt, proto.ListResourcesRequest{ + servers, err := client.GetAllResources[types.DatabaseServer](ctx, clt, &proto.ListResourcesRequest{ ResourceType: types.KindDatabaseServer, Labels: labels, PredicateExpression: c.predicateExpr, SearchKeywords: libclient.ParseSearchKeywords(c.searchKeywords, ','), }) - switch { - case err != nil: + if err != nil { if utils.IsPredicateError(err) { return trace.Wrap(utils.PredicateError{Err: err}) } return trace.Wrap(err) - default: - servers, err = types.ResourcesWithLabels(resources).AsDatabaseServers() - if err != nil { - return trace.Wrap(err) - } } coll := &databaseServerCollection{servers: servers} diff --git a/tool/tctl/common/helpers_test.go b/tool/tctl/common/helpers_test.go index e602b9a1d7e40..2380a94e63fdd 100644 --- a/tool/tctl/common/helpers_test.go +++ b/tool/tctl/common/helpers_test.go @@ -264,6 +264,9 @@ func makeAndRunTestAuthServer(t *testing.T, opts ...testServerOptionFunc) (auth } func waitForDatabases(t *testing.T, auth *service.TeleportProcess, dbs []servicecfg.Database) { + if len(dbs) == 0 { + return + } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() for { diff --git a/tool/tctl/common/kube_command.go b/tool/tctl/common/kube_command.go index 87b96e27365a2..d89face482bf8 100644 --- a/tool/tctl/common/kube_command.go +++ b/tool/tctl/common/kube_command.go @@ -25,8 +25,13 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" + libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/service/servicecfg" + "github.com/gravitational/teleport/lib/utils" ) // KubeCommand implements "tctl kube" group of commands. @@ -36,6 +41,10 @@ type KubeCommand struct { // format is the output format (text or yaml) format string + searchKeywords string + predicateExpr string + labels string + // verbose sets whether full table output should be shown for labels verbose bool @@ -49,8 +58,11 @@ func (c *KubeCommand) Initialize(app *kingpin.Application, config *servicecfg.Co kube := app.Command("kube", "Operate on registered Kubernetes clusters.") c.kubeList = kube.Command("ls", "List all Kubernetes clusters registered with the cluster.") + c.kubeList.Arg("labels", labelHelp).StringVar(&c.labels) c.kubeList.Flag("format", "Output format, 'text', 'json', or 'yaml'").Default(teleport.Text).StringVar(&c.format) c.kubeList.Flag("verbose", "Verbose table output, shows full label output").Short('v').BoolVar(&c.verbose) + c.kubeList.Flag("search", searchHelp).StringVar(&c.searchKeywords) + c.kubeList.Flag("query", queryHelp).StringVar(&c.predicateExpr) } // TryRun attempts to run subcommands like "kube ls". @@ -66,10 +78,22 @@ func (c *KubeCommand) TryRun(ctx context.Context, cmd string, client auth.Client // ListKube prints the list of kube clusters that have recently sent heartbeats // to the cluster. -func (c *KubeCommand) ListKube(ctx context.Context, client auth.ClientI) error { +func (c *KubeCommand) ListKube(ctx context.Context, clt auth.ClientI) error { + labels, err := libclient.ParseLabelSpec(c.labels) + if err != nil { + return trace.Wrap(err) + } - kubes, err := client.GetKubernetesServers(ctx) + kubes, err := client.GetAllResources[types.KubeServer](ctx, clt, &proto.ListResourcesRequest{ + ResourceType: types.KindKubeServer, + Labels: labels, + PredicateExpression: c.predicateExpr, + SearchKeywords: libclient.ParseSearchKeywords(c.searchKeywords, ','), + }) if err != nil { + if utils.IsPredicateError(err) { + return trace.Wrap(utils.PredicateError{Err: err}) + } return trace.Wrap(err) } diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go index 2f9853499bc35..13aa003c26069 100644 --- a/tool/tctl/common/resource_command.go +++ b/tool/tctl/common/resource_command.go @@ -24,6 +24,7 @@ import ( "math" "os" "sort" + "strings" "time" "github.com/alecthomas/kingpin/v2" @@ -1093,23 +1094,23 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err } fmt.Printf("lock %q has been deleted\n", name) case types.KindDatabaseServer: - dbServers, err := client.GetDatabaseServers(ctx, apidefaults.Namespace) + servers, err := client.GetDatabaseServers(ctx, apidefaults.Namespace) if err != nil { return trace.Wrap(err) } - deleted := false - for _, server := range dbServers { - if server.GetName() == rc.ref.Name { - if err := client.DeleteDatabaseServer(ctx, apidefaults.Namespace, server.GetHostID(), server.GetName()); err != nil { - return trace.Wrap(err) - } - deleted = true - } + resDesc := "database server" + servers = filterByNameOrPrefix(servers, rc.ref.Name) + name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc) + if err != nil { + return trace.Wrap(err) } - if !deleted { - return trace.NotFound("database server %q not found", rc.ref.Name) + for _, s := range servers { + err := client.DeleteDatabaseServer(ctx, apidefaults.Namespace, s.GetHostID(), name) + if err != nil { + return trace.Wrap(err) + } } - fmt.Printf("database server %q has been deleted\n", rc.ref.Name) + fmt.Printf("%s %q has been deleted\n", resDesc, name) case types.KindNetworkRestrictions: if err = resetNetworkRestrictions(ctx, client); err != nil { return trace.Wrap(err) @@ -1121,15 +1122,35 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err } fmt.Printf("application %q has been deleted\n", rc.ref.Name) case types.KindDatabase: - if err = client.DeleteDatabase(ctx, rc.ref.Name); err != nil { + databases, err := client.GetDatabases(ctx) + if err != nil { + return trace.Wrap(err) + } + resDesc := "database" + databases = filterByNameOrPrefix(databases, rc.ref.Name) + name, err := getOneResourceNameToDelete(databases, rc.ref, resDesc) + if err != nil { + return trace.Wrap(err) + } + if err := client.DeleteDatabase(ctx, name); err != nil { return trace.Wrap(err) } - fmt.Printf("database %q has been deleted\n", rc.ref.Name) + fmt.Printf("%s %q has been deleted\n", resDesc, name) case types.KindKubernetesCluster: - if err = client.DeleteKubernetesCluster(ctx, rc.ref.Name); err != nil { + clusters, err := client.GetKubernetesClusters(ctx) + if err != nil { return trace.Wrap(err) } - fmt.Printf("kubernetes cluster %q has been deleted\n", rc.ref.Name) + resDesc := "kubernetes cluster" + clusters = filterByNameOrPrefix(clusters, rc.ref.Name) + name, err := getOneResourceNameToDelete(clusters, rc.ref, resDesc) + if err != nil { + return trace.Wrap(err) + } + if err := client.DeleteKubernetesCluster(ctx, name); err != nil { + return trace.Wrap(err) + } + fmt.Printf("%s %q has been deleted\n", resDesc, name) case types.KindWindowsDesktopService: if err = client.DeleteWindowsDesktopService(ctx, rc.ref.Name); err != nil { return trace.Wrap(err) @@ -1182,23 +1203,23 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err } fmt.Printf("%s '%s/%s' has been deleted\n", types.KindCertAuthority, rc.ref.SubKind, rc.ref.Name) case types.KindKubeServer: - kubeServers, err := client.GetKubernetesServers(ctx) + servers, err := client.GetKubernetesServers(ctx) if err != nil { return trace.Wrap(err) } - deleted := false - for _, server := range kubeServers { - if server.GetName() == rc.ref.Name { - if err := client.DeleteKubernetesServer(ctx, server.GetHostID(), server.GetName()); err != nil { - return trace.Wrap(err) - } - deleted = true - } + resDesc := "kubernetes server" + servers = filterByNameOrPrefix(servers, rc.ref.Name) + name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc) + if err != nil { + return trace.Wrap(err) } - if !deleted { - return trace.NotFound("kubernetes server %q not found", rc.ref.Name) + for _, s := range servers { + err := client.DeleteKubernetesServer(ctx, s.GetHostID(), name) + if err != nil { + return trace.Wrap(err) + } } - fmt.Printf("kubernetes server %q has been deleted\n", rc.ref.Name) + fmt.Printf("%s %q has been deleted\n", resDesc, name) case types.KindUIConfig: err := client.DeleteUIConfig(ctx) if err != nil { @@ -1658,16 +1679,11 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client return &databaseServerCollection{servers: servers}, nil } - var out []types.DatabaseServer - for _, server := range servers { - if server.GetName() == rc.ref.Name { - out = append(out, server) - } - } - if len(out) == 0 { + servers = filterByNameOrPrefix(servers, rc.ref.Name) + if len(servers) == 0 { return nil, trace.NotFound("database server %q not found", rc.ref.Name) } - return &databaseServerCollection{servers: out}, nil + return &databaseServerCollection{servers: servers}, nil case types.KindKubeServer: servers, err := client.GetKubernetesServers(ctx) if err != nil { @@ -1676,17 +1692,14 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client if rc.ref.Name == "" { return &kubeServerCollection{servers: servers}, nil } - - var out []types.KubeServer - for _, server := range servers { - if server.GetName() == rc.ref.Name || server.GetHostname() == rc.ref.Name { - out = append(out, server) - } + altNameFn := func(r types.KubeServer) string { + return r.GetHostname() } - if len(out) == 0 { + servers = filterByNameOrPrefix(servers, rc.ref.Name, altNameFn) + if len(servers) == 0 { return nil, trace.NotFound("kubernetes server %q not found", rc.ref.Name) } - return &kubeServerCollection{servers: out}, nil + return &kubeServerCollection{servers: servers}, nil case types.KindAppServer: servers, err := client.GetApplicationServers(ctx, rc.namespace) @@ -1727,31 +1740,31 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client } return &appCollection{apps: []types.Application{app}}, nil case types.KindDatabase: + databases, err := client.GetDatabases(ctx) + if err != nil { + return nil, trace.Wrap(err) + } if rc.ref.Name == "" { - databases, err := client.GetDatabases(ctx) - if err != nil { - return nil, trace.Wrap(err) - } return &databaseCollection{databases: databases}, nil } - database, err := client.GetDatabase(ctx, rc.ref.Name) + databases = filterByNameOrPrefix(databases, rc.ref.Name) + if len(databases) == 0 { + return nil, trace.NotFound("database %q not found", rc.ref.Name) + } + return &databaseCollection{databases: databases}, nil + case types.KindKubernetesCluster: + clusters, err := client.GetKubernetesClusters(ctx) if err != nil { return nil, trace.Wrap(err) } - return &databaseCollection{databases: []types.Database{database}}, nil - case types.KindKubernetesCluster: if rc.ref.Name == "" { - clusters, err := client.GetKubernetesClusters(ctx) - if err != nil { - return nil, trace.Wrap(err) - } return &kubeClusterCollection{clusters: clusters}, nil } - cluster, err := client.GetKubernetesCluster(ctx, rc.ref.Name) - if err != nil { - return nil, trace.Wrap(err) + clusters = filterByNameOrPrefix(clusters, rc.ref.Name) + if len(clusters) == 0 { + return nil, trace.NotFound("kubernetes cluster %q not found", rc.ref.Name) } - return &kubeClusterCollection{clusters: []types.KubeCluster{cluster}}, nil + return &kubeClusterCollection{clusters: clusters}, nil case types.KindWindowsDesktopService: services, err := client.GetWindowsDesktopServices(ctx) if err != nil { @@ -2130,3 +2143,111 @@ func findDeviceByIDOrTag(ctx context.Context, remote devicepb.DeviceTrustService return nil, trace.BadParameter("found multiple devices for asset tag %q, please retry using the device ID instead", idOrTag) } + +// keepFn is a predicate function that returns true if a resource should be +// retained by filterResources. +type keepFn[T types.ResourceWithLabels] func(T) bool + +// filterResources takes a list of resources and returns a filtered list of +// resources for which the `keep` predicate function returns true. +func filterResources[T types.ResourceWithLabels](resources []T, keep keepFn[T]) []T { + out := make([]T, 0, len(resources)) + for _, r := range resources { + if keep(r) { + out = append(out, r) + } + } + return out +} + +// altNameFn is a func that returns an alternative name for a resource. +type altNameFn[T types.ResourceWithLabels] func(T) string + +// filterByNameOrPrefix filters resources by name or a prefix of the name. +// It prefers exact name filtering first - if none of the resource names match +// exactly (i.e. all of the resources are filtered out), then it retries and +// filters the resources by prefix of resource name instead. +// This is to avoid an annoying UX, for example: +// resources: [foo, foobar] +// $ tctl rm foo <- should select foo by exact name instead of matching both by +// prefix "foo". +func filterByNameOrPrefix[T types.ResourceWithLabels](resources []T, prefixOrName string, extra ...altNameFn[T]) []T { + // prefer exact names + out := filterByName(resources, prefixOrName, extra...) + if len(out) == 0 { + // fallback to looking for prefixes + out = filterByPrefix(resources, prefixOrName, extra...) + } + return out +} + +// filterByName filters resources by exact name match. +func filterByName[T types.ResourceWithLabels](resources []T, name string, altNameFns ...altNameFn[T]) []T { + return filterResources(resources, func(r T) bool { + if r.GetName() == name { + return true + } + for _, altName := range altNameFns { + if altName(r) == name { + return true + } + } + return false + }) +} + +// filterByPrefix filters resources by a prefix of the resource name. +func filterByPrefix[T types.ResourceWithLabels](resources []T, prefix string, altNameFns ...altNameFn[T]) []T { + return filterResources(resources, func(r T) bool { + if strings.HasPrefix(r.GetName(), prefix) { + return true + } + for _, altName := range altNameFns { + if strings.HasPrefix(altName(r), prefix) { + return true + } + } + return false + }) +} + +// getOneResourceNameToDelete checks a list of resources to ensure there is +// exactly one resource name among them, and returns that name or an error. +// Heartbeat resources can have the same name but different host ID, so this +// still allows a user to delete multiple heartbeats of the same name, for +// example `$ tctl rm db_server/someDB`. +func getOneResourceNameToDelete[T types.ResourceWithLabels](rs []T, ref services.Ref, resDesc string) (string, error) { + seen := make(map[string]struct{}) + for _, r := range rs { + seen[r.GetName()] = struct{}{} + } + switch len(seen) { + case 1: // need exactly one. + return rs[0].GetName(), nil + case 0: + return "", trace.NotFound("%v %q not found", resDesc, ref.Name) + default: + names := make([]string, 0, len(rs)) + for _, r := range rs { + names = append(names, r.GetName()) + } + msg := formatAmbiguousDeleteMessage(ref, resDesc, names) + return "", trace.BadParameter(msg) + } +} + +// formatAmbiguousDeleteMessage returns a formatted message when a user is +// attempting to delete multiple resources by an ambiguous prefix of the +// resource names. +func formatAmbiguousDeleteMessage(ref services.Ref, resDesc string, names []string) string { + slices.Sort(names) + // choose an actual resource for the example in the error. + exampleRef := ref + exampleRef.Name = names[0] + return fmt.Sprintf(`%s matches multiple %vs as a name prefix: +%v + +Use either a full resource name or an unambiguous prefix, for example: +$ tctl rm %s`, + ref.String(), resDesc, strings.Join(names, "\n"), exampleRef.String()) +} diff --git a/tool/tctl/common/resource_command_test.go b/tool/tctl/common/resource_command_test.go index af931355c12d8..090d3b5dce7ef 100644 --- a/tool/tctl/common/resource_command_test.go +++ b/tool/tctl/common/resource_command_test.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/lib/config" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/services" ) // TestDatabaseServerResource tests tctl db_server rm/get commands. @@ -76,6 +77,12 @@ func TestDatabaseServerResource(t *testing.T) { CACertFile: caCertFilePath, }, }, + { + Name: "db3", + Description: "Example MySQL", + Protocol: "mysql", + URI: "localhost:33308", + }, }, }, Proxy: config.Proxy{ @@ -93,7 +100,26 @@ func TestDatabaseServerResource(t *testing.T) { }, } - wantDB, err := types.NewDatabaseV3(types.Metadata{ + db1, err := types.NewDatabaseV3(types.Metadata{ + Name: "example", + Description: "Example MySQL", + Labels: map[string]string{types.OriginLabel: types.OriginConfigFile}, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMySQL, + URI: "localhost:33306", + CACert: fixtures.TLSCACertPEM, + AdminUser: &types.DatabaseAdminUser{ + Name: "", + }, + TLS: types.DatabaseTLS{ + Mode: types.DatabaseTLSMode_VERIFY_FULL, + ServerName: "db.example.com", + CACert: fixtures.TLSCACertPEM, + }, + }) + require.NoError(t, err) + + db2, err := types.NewDatabaseV3(types.Metadata{ Name: "example2", Description: "Example PostgreSQL", Labels: map[string]string{types.OriginLabel: types.OriginConfigFile}, @@ -112,6 +138,25 @@ func TestDatabaseServerResource(t *testing.T) { }) require.NoError(t, err) + db3, err := types.NewDatabaseV3(types.Metadata{ + Name: "db3", + Description: "Example MySQL", + Labels: map[string]string{types.OriginLabel: types.OriginConfigFile}, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMySQL, + URI: "localhost:33308", + CACert: fixtures.TLSCACertPEM, + AdminUser: &types.DatabaseAdminUser{ + Name: "", + }, + TLS: types.DatabaseTLS{ + Mode: types.DatabaseTLSMode_VERIFY_FULL, + ServerName: "db.example.com", + CACert: fixtures.TLSCACertPEM, + }, + }) + require.NoError(t, err) + _ = makeAndRunTestAuthServer(t, withFileConfig(fileConfig), withFileDescriptors(dynAddr.descriptors)) var out []*types.DatabaseServerV3 @@ -120,128 +165,53 @@ func TestDatabaseServerResource(t *testing.T) { buff, err := runResourceCommand(t, fileConfig, []string{"get", types.KindDatabaseServer, "--format=json"}) require.NoError(t, err) mustDecodeJSON(t, buff, &out) - require.Len(t, out, 2) - - wantServer := fmt.Sprintf("%v/%v", types.KindDatabaseServer, wantDB.GetName()) + require.Len(t, out, 3) // get specific database server + wantServer := fmt.Sprintf("%v/%v", types.KindDatabaseServer, db2.GetName()) buff, err = runResourceCommand(t, fileConfig, []string{"get", wantServer, "--format=json"}) require.NoError(t, err) mustDecodeJSON(t, buff, &out) require.Len(t, out, 1) gotDB := out[0].GetDatabase() - require.Empty(t, cmp.Diff([]types.Database{wantDB}, []types.Database{gotDB}, + require.Empty(t, cmp.Diff([]types.Database{db2}, []types.Database{gotDB}, cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace", "Expires"), )) - // remove database server - _, err = runResourceCommand(t, fileConfig, []string{"rm", wantServer}) - require.NoError(t, err) - - _, err = runResourceCommand(t, fileConfig, []string{"get", wantServer, "--format=json"}) - require.Error(t, err) - require.IsType(t, &trace.NotFoundError{}, err.(*trace.TraceErr).OrigError()) - - buff, err = runResourceCommand(t, fileConfig, []string{"get", "db", "--format=json"}) + // get database servers by prefix of name + wantServersPrefix := fmt.Sprintf("%v/%v", types.KindDatabaseServer, "exam") + buff, err = runResourceCommand(t, fileConfig, []string{"get", wantServersPrefix, "--format=json"}) require.NoError(t, err) mustDecodeJSON(t, buff, &out) - require.Len(t, out, 0) -} - -// TestDatabaseResource tests tctl commands that manage database resources. -func TestDatabaseResource(t *testing.T) { - dynAddr := newDynamicServiceAddr(t) - - fileConfig := &config.FileConfig{ - Global: config.Global{ - DataDir: t.TempDir(), - }, - Databases: config.Databases{ - Service: config.Service{ - EnabledFlag: "true", - }, - }, - Proxy: config.Proxy{ - Service: config.Service{ - EnabledFlag: "true", - }, - WebAddr: dynAddr.webAddr, - TunAddr: dynAddr.tunnelAddr, - }, - Auth: config.Auth{ - Service: config.Service{ - EnabledFlag: "true", - ListenAddress: dynAddr.authAddr, - }, - }, - } - - makeAndRunTestAuthServer(t, withFileConfig(fileConfig), withFileDescriptors(dynAddr.descriptors)) - - dbA, err := types.NewDatabaseV3(types.Metadata{ - Name: "db-a", - Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, - }, types.DatabaseSpecV3{ - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - }) - require.NoError(t, err) - - dbB, err := types.NewDatabaseV3(types.Metadata{ - Name: "db-b", - Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, - }, types.DatabaseSpecV3{ - Protocol: defaults.ProtocolMySQL, - URI: "localhost:3306", - TLS: types.DatabaseTLS{ - Mode: types.DatabaseTLSMode_VERIFY_CA, - }, - }) - require.NoError(t, err) - - var out []*types.DatabaseV3 - - // Initially there are no databases. - buf, err := runResourceCommand(t, fileConfig, []string{"get", types.KindDatabase, "--format=json"}) - require.NoError(t, err) - mustDecodeJSON(t, buf, &out) - require.Len(t, out, 0) - - // Create the databases. - dbYAMLPath := filepath.Join(t.TempDir(), "db.yaml") - require.NoError(t, os.WriteFile(dbYAMLPath, []byte(dbYAML), 0644)) - _, err = runResourceCommand(t, fileConfig, []string{"create", dbYAMLPath}) - require.NoError(t, err) - - // Fetch the databases, should have 2. - buf, err = runResourceCommand(t, fileConfig, []string{"get", types.KindDatabase, "--format=json"}) - require.NoError(t, err) - mustDecodeJSON(t, buf, &out) require.Len(t, out, 2) - require.Empty(t, cmp.Diff([]*types.DatabaseV3{dbA, dbB}, out, - cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), + gotDBs := types.DatabaseServers{out[0], out[1]}.ToDatabases() + require.Empty(t, cmp.Diff([]types.Database{db1, db2}, gotDBs, + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace", "Expires"), )) - // Fetch specific database. - buf, err = runResourceCommand(t, fileConfig, []string{"get", fmt.Sprintf("%v/db-b", types.KindDatabase), "--format=json"}) + // remove database servers by prefix is an error + _, err = runResourceCommand(t, fileConfig, []string{"rm", wantServersPrefix}) + require.ErrorContains(t, err, "db_server/exam matches multiple database servers") + + // remove database server by name + _, err = runResourceCommand(t, fileConfig, []string{"rm", wantServer}) require.NoError(t, err) - mustDecodeJSON(t, buf, &out) - require.Len(t, out, 1) - require.Empty(t, cmp.Diff([]*types.DatabaseV3{dbB}, out, - cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), - )) - // Remove a database. - _, err = runResourceCommand(t, fileConfig, []string{"rm", fmt.Sprintf("%v/db-a", types.KindDatabase)}) + _, err = runResourceCommand(t, fileConfig, []string{"get", wantServer, "--format=json"}) + require.Error(t, err) + require.IsType(t, &trace.NotFoundError{}, err.(*trace.TraceErr).OrigError()) + + // remove database server by prefix name. + _, err = runResourceCommand(t, fileConfig, []string{"rm", wantServersPrefix}) require.NoError(t, err) - // Fetch all databases again, should have 1. - buf, err = runResourceCommand(t, fileConfig, []string{"get", types.KindDatabase, "--format=json"}) + buff, err = runResourceCommand(t, fileConfig, []string{"get", "db_server", "--format=json"}) require.NoError(t, err) - mustDecodeJSON(t, buf, &out) + mustDecodeJSON(t, buff, &out) require.Len(t, out, 1) - require.Empty(t, cmp.Diff([]*types.DatabaseV3{dbB}, out, - cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), + gotDBs = types.DatabaseServers{out[0]}.ToDatabases() + require.Empty(t, cmp.Diff([]types.Database{db3}, gotDBs, + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace", "Expires"), )) } @@ -441,98 +411,6 @@ func TestIntegrationResource(t *testing.T) { }) } -// TestAppResource tests tctl commands that manage application resources. -func TestAppResource(t *testing.T) { - dynAddr := newDynamicServiceAddr(t) - - fileConfig := &config.FileConfig{ - Global: config.Global{ - DataDir: t.TempDir(), - Logger: config.Log{ - Severity: "debug", - }, - }, - Apps: config.Apps{ - Service: config.Service{ - EnabledFlag: "true", - }, - }, - Proxy: config.Proxy{ - Service: config.Service{ - EnabledFlag: "true", - }, - WebAddr: dynAddr.webAddr, - TunAddr: dynAddr.tunnelAddr, - }, - Auth: config.Auth{ - Service: config.Service{ - EnabledFlag: "true", - ListenAddress: dynAddr.authAddr, - }, - }, - } - - makeAndRunTestAuthServer(t, withFileConfig(fileConfig), withFileDescriptors(dynAddr.descriptors)) - - appA, err := types.NewAppV3(types.Metadata{ - Name: "appA", - Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, - }, types.AppSpecV3{ - URI: "localhost1", - }) - require.NoError(t, err) - - appB, err := types.NewAppV3(types.Metadata{ - Name: "appB", - Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, - }, types.AppSpecV3{ - URI: "localhost2", - }) - require.NoError(t, err) - - var out []*types.AppV3 - - // Initially there are no apps. - buf, err := runResourceCommand(t, fileConfig, []string{"get", types.KindApp, "--format=json"}) - require.NoError(t, err) - mustDecodeJSON(t, buf, &out) - require.Len(t, out, 0) - - // Create the apps. - appYAMLPath := filepath.Join(t.TempDir(), "app.yaml") - require.NoError(t, os.WriteFile(appYAMLPath, []byte(appYAML), 0644)) - _, err = runResourceCommand(t, fileConfig, []string{"create", appYAMLPath}) - require.NoError(t, err) - - // Fetch the apps, should have 2. - buf, err = runResourceCommand(t, fileConfig, []string{"get", types.KindApp, "--format=json"}) - require.NoError(t, err) - mustDecodeJSON(t, buf, &out) - require.Empty(t, cmp.Diff([]*types.AppV3{appA, appB}, out, - cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), - )) - - // Fetch specific app. - buf, err = runResourceCommand(t, fileConfig, []string{"get", fmt.Sprintf("%v/appB", types.KindApp), "--format=json"}) - require.NoError(t, err) - mustDecodeJSON(t, buf, &out) - require.Empty(t, cmp.Diff([]*types.AppV3{appB}, out, - cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), - )) - - // Remove an app. - _, err = runResourceCommand(t, fileConfig, []string{"rm", fmt.Sprintf("%v/appA", types.KindApp)}) - require.NoError(t, err) - - // Fetch all apps again, should have 1. - buf, err = runResourceCommand(t, fileConfig, []string{"get", types.KindApp, "--format=json"}) - require.NoError(t, err) - mustDecodeJSON(t, buf, &out) - require.Empty(t, cmp.Diff([]*types.AppV3{appB}, out, - cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), - )) -} - func TestCreateLock(t *testing.T) { dynAddr := newDynamicServiceAddr(t) fileConfig := &config.FileConfig{ @@ -648,34 +526,70 @@ const ( dbYAML = `kind: db version: v3 metadata: - name: db-a + name: foo spec: - protocol: "postgres" - uri: "localhost:5432" + protocol: "mysql" + uri: "localhost:3306" + tls: + mode: "verify-ca" --- kind: db version: v3 metadata: - name: db-b + name: foo-bar spec: - protocol: "mysql" - uri: "localhost:3306" + protocol: "postgres" + uri: "localhost:5433" tls: - mode: "verify-ca"` + mode: "verify-full" +--- +kind: db +version: v3 +metadata: + name: foo-bar-baz +spec: + protocol: "postgres" + uri: "localhost:5432"` appYAML = `kind: app version: v3 metadata: - name: appA + name: foo spec: uri: "localhost1" --- kind: app version: v3 metadata: - name: appB + name: foo-bar spec: - uri: "localhost2"` + uri: "localhost2" +--- +kind: app +version: v3 +metadata: + name: foo-bar-baz +spec: + uri: "localhost3"` + + kubeYAML = ` +kind: kube_cluster +version: v3 +metadata: + name: foo +spec: {} +--- +kind: kube_cluster +version: v3 +metadata: + name: foo-bar +spec: {} +--- +kind: kube_cluster +version: v3 +metadata: + name: foo-bar-baz +spec: {}` lockYAML = `kind: lock version: v2 @@ -925,6 +839,336 @@ func TestUpsertVerb(t *testing.T) { } } +type dynamicResourceTest[T types.ResourceWithLabels] struct { + kind string + resourceYAML string + fooResource T + fooBarResource T + fooBarBazResource T + runPrefixNameChecks bool +} + +func (test *dynamicResourceTest[T]) setup(t *testing.T) *config.FileConfig { + t.Helper() + requireResource := func(t *testing.T, r T, name string) { + t.Helper() + require.NotNil(t, r, "dynamicResourceTest requires a resource named %q", name) + require.Equal(t, r.GetName(), name, "dynamicResourceTest requires a resource named %q", name) + } + requireResource(t, test.fooResource, "foo") + requireResource(t, test.fooBarResource, "foo-bar") + requireResource(t, test.fooBarBazResource, "foo-bar-baz") + dynAddr := newDynamicServiceAddr(t) + fileConfig := &config.FileConfig{ + Global: config.Global{ + DataDir: t.TempDir(), + }, + Proxy: config.Proxy{ + Service: config.Service{ + EnabledFlag: "true", + }, + WebAddr: dynAddr.webAddr, + TunAddr: dynAddr.tunnelAddr, + }, + Auth: config.Auth{ + Service: config.Service{ + EnabledFlag: "true", + ListenAddress: dynAddr.authAddr, + }, + }, + } + _ = makeAndRunTestAuthServer(t, withFileConfig(fileConfig), withFileDescriptors(dynAddr.descriptors)) + return fileConfig +} + +func (test *dynamicResourceTest[T]) run(t *testing.T) { + t.Helper() + fileConfig := test.setup(t) + var out []T + + // Initially there are no resources. + buf, err := runResourceCommand(t, fileConfig, []string{"get", test.kind, "--format=json"}) + require.NoError(t, err) + mustDecodeJSON(t, buf, &out) + require.Len(t, out, 0) + + // Create the resources. + yamlPath := filepath.Join(t.TempDir(), "resources.yaml") + require.NoError(t, os.WriteFile(yamlPath, []byte(test.resourceYAML), 0644)) + _, err = runResourceCommand(t, fileConfig, []string{"create", yamlPath}) + require.NoError(t, err) + + // Fetch all resources. + buf, err = runResourceCommand(t, fileConfig, []string{"get", test.kind, "--format=json"}) + require.NoError(t, err) + mustDecodeJSON(t, buf, &out) + require.Len(t, out, 3) + require.Empty(t, cmp.Diff([]T{test.fooResource, test.fooBarResource, test.fooBarBazResource}, out, + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), + )) + + // Fetch specific resource. + buf, err = runResourceCommand(t, fileConfig, + []string{"get", fmt.Sprintf("%v/%v", test.kind, test.fooResource.GetName()), "--format=json"}) + require.NoError(t, err) + mustDecodeJSON(t, buf, &out) + require.Len(t, out, 1) + require.Empty(t, cmp.Diff([]T{test.fooResource}, out, + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), + )) + + // Remove a resource. + _, err = runResourceCommand(t, fileConfig, []string{"rm", fmt.Sprintf("%v/%v", test.kind, test.fooBarResource.GetName())}) + require.NoError(t, err) + + // Fetch all resources again. + buf, err = runResourceCommand(t, fileConfig, []string{"get", test.kind, "--format=json"}) + require.NoError(t, err) + mustDecodeJSON(t, buf, &out) + require.Len(t, out, 2) + require.Empty(t, cmp.Diff([]T{test.fooResource, test.fooBarBazResource}, out, + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), + )) + + if !test.runPrefixNameChecks { + return + } + + // Test prefix name behavior. + // Removing multiple resources ("foo" and "foo-bar-baz")by prefix name is an error. + _, err = runResourceCommand(t, fileConfig, []string{"rm", fmt.Sprintf("%v/%v", test.kind, "f")}) + require.ErrorContains(t, err, "matches multiple") + + // Remove "foo-bar-baz" resource by a prefix of its name. + _, err = runResourceCommand(t, fileConfig, []string{"rm", fmt.Sprintf("%v/%v", test.kind, "foo-bar-b")}) + require.NoError(t, err) + + // Fetch all resources again. + buf, err = runResourceCommand(t, fileConfig, []string{"get", test.kind, "--format=json"}) + require.NoError(t, err) + mustDecodeJSON(t, buf, &out) + require.Len(t, out, 1) + require.Empty(t, cmp.Diff([]T{test.fooResource}, out, + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Namespace"), + )) +} + +// TestDatabaseResource tests tctl commands that manage database resources. +func TestDatabaseResource(t *testing.T) { + t.Parallel() + dbFoo, err := types.NewDatabaseV3(types.Metadata{ + Name: "foo", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMySQL, + URI: "localhost:3306", + TLS: types.DatabaseTLS{ + Mode: types.DatabaseTLSMode_VERIFY_CA, + }, + }) + require.NoError(t, err) + dbFooBar, err := types.NewDatabaseV3(types.Metadata{ + Name: "foo-bar", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5433", + TLS: types.DatabaseTLS{ + Mode: types.DatabaseTLSMode_VERIFY_FULL, + }, + }) + require.NoError(t, err) + dbFooBarBaz, err := types.NewDatabaseV3(types.Metadata{ + Name: "foo-bar-baz", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + }) + require.NoError(t, err) + + require.NoError(t, err) + test := dynamicResourceTest[*types.DatabaseV3]{ + kind: types.KindDatabase, + resourceYAML: dbYAML, + fooResource: dbFoo, + fooBarResource: dbFooBar, + fooBarBazResource: dbFooBarBaz, + runPrefixNameChecks: true, + } + test.run(t) +} + +// TestKubeClusterResource tests tctl commands that manage dynamic kube cluster resources. +func TestKubeClusterResource(t *testing.T) { + t.Parallel() + kubeFoo, err := types.NewKubernetesClusterV3(types.Metadata{ + Name: "foo", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.KubernetesClusterSpecV3{}) + require.NoError(t, err) + kubeFooBar, err := types.NewKubernetesClusterV3(types.Metadata{ + Name: "foo-bar", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.KubernetesClusterSpecV3{}) + require.NoError(t, err) + kubeFooBarBaz, err := types.NewKubernetesClusterV3(types.Metadata{ + Name: "foo-bar-baz", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.KubernetesClusterSpecV3{}) + require.NoError(t, err) + test := dynamicResourceTest[*types.KubernetesClusterV3]{ + kind: types.KindKubernetesCluster, + resourceYAML: kubeYAML, + fooResource: kubeFoo, + fooBarResource: kubeFooBar, + fooBarBazResource: kubeFooBarBaz, + runPrefixNameChecks: true, + } + test.run(t) +} + +// TestAppResource tests tctl commands that manage application resources. +func TestAppResource(t *testing.T) { + t.Parallel() + appFoo, err := types.NewAppV3(types.Metadata{ + Name: "foo", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.AppSpecV3{ + URI: "localhost1", + }) + require.NoError(t, err) + appFooBar, err := types.NewAppV3(types.Metadata{ + Name: "foo-bar", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.AppSpecV3{ + URI: "localhost2", + }) + require.NoError(t, err) + appFooBarBaz, err := types.NewAppV3(types.Metadata{ + Name: "foo-bar-baz", + Labels: map[string]string{types.OriginLabel: types.OriginDynamic}, + }, types.AppSpecV3{ + URI: "localhost3", + }) + require.NoError(t, err) + test := dynamicResourceTest[*types.AppV3]{ + kind: types.KindApp, + resourceYAML: appYAML, + fooResource: appFoo, + fooBarResource: appFooBar, + fooBarBazResource: appFooBarBaz, + } + test.run(t) +} + +func TestGetOneResourceNameToDelete(t *testing.T) { + foo1 := mustCreateNewKubeServer(t, "foo", "host-foo", nil) + foo2 := mustCreateNewKubeServer(t, "foo", "host-foo", nil) + fooBar := mustCreateNewKubeServer(t, "foo-bar", "host-foo-bar", nil) + baz := mustCreateNewKubeServer(t, "baz", "host-baz", nil) + tests := []struct { + desc string + refName string + wantErrContains string + resources []types.KubeServer + wantName string + }{ + { + desc: "one resource is ok", + refName: "baz", + resources: []types.KubeServer{baz}, + wantName: "baz", + }, + { + desc: "multiple resources with same name is ok", + refName: "foo", + resources: []types.KubeServer{foo1, foo2}, + wantName: "foo", + }, + { + desc: "zero resources is an error", + refName: "xxx", + wantErrContains: `kubernetes server "xxx" not found`, + }, + { + desc: "multiple resources with different names is an error", + refName: "f", + resources: []types.KubeServer{foo1, foo2, fooBar}, + wantErrContains: "matches multiple", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ref := services.Ref{Kind: types.KindKubeServer, Name: test.refName} + resDesc := "kubernetes server" + name, err := getOneResourceNameToDelete(test.resources, ref, resDesc) + if test.wantErrContains != "" { + require.ErrorContains(t, err, test.wantErrContains) + return + } + require.Equal(t, test.wantName, name) + }) + } +} + +func TestFilterByNameOrPrefix(t *testing.T) { + foo1 := mustCreateNewKubeServer(t, "foo", "host-foo", nil) + foo2 := mustCreateNewKubeServer(t, "foo", "host-foo", nil) + fooBar := mustCreateNewKubeServer(t, "foo-bar", "host-foo-bar", nil) + baz := mustCreateNewKubeServer(t, "baz", "host-baz", nil) + resources := []types.KubeServer{ + foo1, foo2, fooBar, baz, + } + hostNameGetter := func(ks types.KubeServer) string { return ks.GetHostname() } + tests := []struct { + desc string + filter string + altNameGetters []altNameFn[types.KubeServer] + want []types.KubeServer + }{ + { + desc: "filters by exact name first", + filter: "foo", + want: []types.KubeServer{foo1, foo2}, + }, + { + desc: "filters by prefix name", + filter: "fo", + want: []types.KubeServer{foo1, foo2, fooBar}, + }, + { + desc: "checks alt names for exact matches first", + filter: "host-foo", + altNameGetters: []altNameFn[types.KubeServer]{hostNameGetter}, + want: []types.KubeServer{foo1, foo2}, + }, + { + desc: "checks alt names for prefix matches", + filter: "host-f", + altNameGetters: []altNameFn[types.KubeServer]{hostNameGetter}, + want: []types.KubeServer{foo1, foo2, fooBar}, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + got := filterByNameOrPrefix(resources, test.filter, test.altNameGetters...) + require.Empty(t, cmp.Diff(test.want, got)) + }) + } +} + +func TestFormatAmbiguousDeleteMessage(t *testing.T) { + ref := services.Ref{Kind: types.KindDatabase, Name: "x"} + resDesc := "database" + names := []string{"xbbb", "xaaa", "xccc", "xb"} + got := formatAmbiguousDeleteMessage(ref, resDesc, names) + require.Contains(t, got, "db/x matches multiple databases", "should have formated the ref used and pluralized the resource description") + wantSortedNames := strings.Join([]string{"xaaa", "xb", "xbbb", "xccc"}, "\n") + require.Contains(t, got, wantSortedNames, "should have sorted the matching names") + require.Contains(t, got, "$ tctl rm db/xaaa", "should have contained an example command") +} + // requireEqual creates an assertion function with a bound `expected` value // for use with table-driven tests func requireEqual(expected interface{}) require.ValueAssertionFunc { diff --git a/tool/tsh/access_request.go b/tool/tsh/access_request.go index a7930159d90a1..a313cc4a294b6 100644 --- a/tool/tsh/access_request.go +++ b/tool/tsh/access_request.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/kube/kubeconfig" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/tool/common" @@ -385,7 +386,7 @@ func onRequestSearch(cf *CLIConf) error { // If KubeCluster not provided try to read it from kubeconfig. if cf.KubernetesCluster == "" { - cf.KubernetesCluster = selectedKubeCluster(tc.SiteName) + cf.KubernetesCluster, _ = kubeconfig.SelectedKubeCluster(getKubeConfigPath(cf, ""), tc.SiteName) } if cf.ResourceKind == types.KindKubePod && cf.KubernetesCluster == "" { return trace.BadParameter("when searching for Pods, --kube-cluster cannot be empty") @@ -472,7 +473,7 @@ func onRequestSearch(cf *CLIConf) error { resourceIDs = append(resourceIDs, resourceID) row = []string{ - resource.GetName(), + common.FormatResourceName(resource, cf.Verbose), r.Spec.Namespace, common.FormatLabels(resource.GetAllLabels(), cf.Verbose), resourceID, @@ -494,7 +495,7 @@ func onRequestSearch(cf *CLIConf) error { hostName = r.GetHostname() } row = []string{ - resource.GetName(), + common.FormatResourceName(resource, cf.Verbose), hostName, common.FormatLabels(resource.GetAllLabels(), cf.Verbose), resourceID, diff --git a/tool/tsh/db.go b/tool/tsh/db.go index a857d6d6816a7..47fe1e7e81a12 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -251,12 +251,20 @@ func onDatabaseLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := newDatabaseInfo(cf, tc, nil) + profile, err := tc.ProfileStatus() + if err != nil { + return trace.Wrap(err) + } + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + dbInfo, err := getDatabaseInfo(cf, tc, routes) if err != nil { return trace.Wrap(err) } - database, err := dbInfo.GetDatabase(cf, tc) + database, err := dbInfo.GetDatabase(cf.Context, tc) if err != nil { return trace.Wrap(err) } @@ -352,7 +360,7 @@ func onDatabaseLogout(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - logout, _, err := filterActiveDatabases(cf.Context, tc, activeRoutes) + databases, err := getDatabasesForLogout(cf, tc, activeRoutes) if err != nil { return trace.Wrap(err) } @@ -361,12 +369,12 @@ func onDatabaseLogout(cf *CLIConf) error { log.Info("Note: an identity file is in use (`-i ...`); will only update database config files.") } - for _, db := range logout { + for _, db := range databases { if err := databaseLogout(tc, db, profile.IsVirtual); err != nil { return trace.Wrap(err) } } - msg, err := makeLogoutMessage(cf, logout, activeRoutes) + msg, err := makeLogoutMessage(cf, databases, activeRoutes) if err != nil { return trace.Wrap(err) } @@ -378,21 +386,16 @@ func onDatabaseLogout(cf *CLIConf) error { // result of "tsh db logout". func makeLogoutMessage(cf *CLIConf, logout, activeRoutes []tlsca.RouteToDatabase) (string, error) { switch len(logout) { - case 0: - selectors := resourceSelectors{ - kind: "database", - name: cf.DatabaseService, - labels: cf.Labels, - query: cf.PredicateExpression, - } - if selectors.IsEmpty() { - return "", trace.NotFound("Not logged into any databases") - } - return "", trace.NotFound("Not logged into %v", selectors) case 1: return fmt.Sprintf("Logged out of database %v", logout[0].ServiceName), nil case len(activeRoutes): return "Logged out of all databases", nil + case 0: + selectors := newDatabaseResourceSelectors(cf) + if selectors.IsEmpty() { + return "", trace.NotFound("Not logged into any databases") + } + return "", trace.NotFound("Not logged into %s", selectors) default: names := make([]string, 0, len(logout)) for _, route := range logout { @@ -429,7 +432,15 @@ func onDatabaseEnv(cf *CLIConf) error { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf, tc) + profile, err := tc.ProfileStatus() + if err != nil { + return trace.Wrap(err) + } + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + database, err := pickActiveDatabase(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -487,7 +498,11 @@ func onDatabaseConfig(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf, tc) + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + database, err := pickActiveDatabase(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -694,7 +709,7 @@ func prepareLocalProxyOptions(arg *localProxyConfig) ([]alpnproxy.LocalProxyConf // To set correct MySQL server version DB proxy needs additional protocol. if !arg.tunnel && arg.dbInfo.Protocol == defaults.ProtocolMySQL { - db, err := arg.dbInfo.GetDatabase(arg.cf, arg.tc) + db, err := arg.dbInfo.GetDatabase(arg.cf.Context, arg.tc) if err != nil { return nil, trace.Wrap(err) } @@ -750,7 +765,11 @@ func onDatabaseConnect(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := getDatabaseInfo(cf, tc) + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + dbInfo, err := getDatabaseInfo(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -808,66 +827,44 @@ func onDatabaseConnect(cf *CLIConf) error { // getDatabaseInfo fetches information about the database from tsh profile if DB // is active in profile and no labels or predicate query are given. // Otherwise, the ListDatabases endpoint is called. -func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) { - haveSelectors := tc.DatabaseService != "" || len(tc.Labels) > 0 || tc.PredicateExpression != "" - if !haveSelectors { - // if selectors are given, we might incur an extra ListDatabases API - // call here to match against an active database. - // So try to pick an active database only when we don't have - // selectors. - if route, err := pickActiveDatabase(cf, tc); err == nil { - return newDatabaseInfo(cf, tc, route) - } else if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - } - return newDatabaseInfo(cf, tc, nil) -} - -// newDatabaseInfo makes a new databaseInfo from the given route to the db. -// It checks the route and sets defaults as needed for protocol, db user, or db -// name. If the route is not given or the remote database is needed for setting -// a default, the database is retrieved by calling ListDatabases API and cached. -func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route *tlsca.RouteToDatabase) (*databaseInfo, error) { - dbInfo := &databaseInfo{} - if route != nil { - dbInfo.RouteToDatabase = *route - // the only way we're going to have all this info populated is from an - // active cert. - if dbInfo.ServiceName != "" && dbInfo.Protocol != "" && - dbInfo.Username != "" && dbInfo.Database != "" { - return dbInfo, nil +func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, routes []tlsca.RouteToDatabase) (*databaseInfo, error) { + if route, err := maybePickActiveDatabase(cf, routes); err == nil && route != nil { + info := &databaseInfo{RouteToDatabase: *route, isActive: true} + return info, info.checkAndSetDefaults(cf, tc) + } else if err != nil { + if trace.IsNotFound(err) { + return nil, trace.BadParameter("please specify a database service by name, --labels, or --query") } + return nil, trace.Wrap(err) } - db, err := dbInfo.GetDatabase(cf, tc) + + db, err := getDatabaseByNameOrDiscoveredName(cf, tc, routes) if err != nil { return nil, trace.Wrap(err) } - // now ensure the route name and protocol matches the db we fetched. - dbInfo.ServiceName = db.GetName() - dbInfo.Protocol = db.GetProtocol() - return dbInfo, dbInfo.checkAndSetPrincipalDefaults(cf, tc, db) -} -// checkAndSetPrincipalDefaults checks the db route (schema) name and username, -// and sets them to defaults if necessary. -func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { - profile, err := tc.ProfileStatus() - if err != nil { - return trace.Wrap(err) + info := &databaseInfo{ + database: db, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: db.GetName(), + Protocol: db.GetProtocol(), + }, + } + // check for an active route now that we have the full db name. + if route, ok := findActiveDatabase(db.GetName(), routes); ok { + info.RouteToDatabase = route + info.isActive = true + } + if err := info.checkAndSetDefaults(cf, tc); err != nil { + return nil, trace.Wrap(err) } + return info, nil +} - // if either user or db name isn't given as a cli flag, try to populate - // user/db name from an active db cert. - if cf.DatabaseUser == "" || cf.DatabaseName == "" { - routes, err := profile.DatabasesForCluster(tc.SiteName) - if err != nil { - return trace.Wrap(err) - } - if route, ok := findActiveDatabase(d.ServiceName, routes); ok { - d.Username = route.Username - d.Database = route.Database - } +// checkAndSetDefaults checks the db route, applies cli flags, and sets defaults. +func (d *databaseInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClient) error { + if d.ServiceName == "" { + return trace.BadParameter("missing database service name") } if cf.DatabaseUser != "" { d.Username = cf.DatabaseUser @@ -875,14 +872,32 @@ func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.Tele if cf.DatabaseName != "" { d.Database = cf.DatabaseName } + db, err := d.GetDatabase(cf.Context, tc) + if err != nil { + if d.isActive && trace.IsNotFound(err) && strings.Contains(err.Error(), d.ServiceName) { + hint := formatStaleDBCert(cf.SiteName, d.ServiceName) + return trace.Wrap(err, hint) + } + return trace.Wrap(err) + } + // ensure the route protocol matches the db. + d.Protocol = db.GetProtocol() + + needDBUser := d.Username == "" && role.RequireDatabaseUserMatcher(d.Protocol) + needDBName := d.Database == "" && role.RequireDatabaseNameMatcher(d.Protocol) + if !needDBUser && !needDBName { + return nil + } + // If database has admin user defined, we're most likely using automatic // user provisioning so default to Teleport username unless database // username was provided explicitly. - if d.Username == "" && db.GetAdminUser() != "" { + if needDBUser && db.GetAdminUser() != "" { log.Debugf("Defaulting to Teleport username %q as database username.", tc.Username) d.Username = tc.Username + needDBUser = false } - if d.Username == "" { + if needDBUser { switch d.Protocol { // When generating certificate for MongoDB access, database username must // be encoded into it. This is required to be able to tell which database @@ -922,66 +937,56 @@ type databaseInfo struct { // database corresponds to the db route and may be nil, so use GetDatabase // instead of accessing it directly. database types.Database + // isActive indicates an active database matched this db info. + isActive bool mu sync.Mutex } // GetDatabase returns the cached database or fetches it using the db route and // caches the result. -func (d *databaseInfo) GetDatabase(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { - if d.ServiceName == "" && cf.DatabaseService == "" && - len(tc.Labels) == 0 && tc.PredicateExpression == "" { - return nil, trace.BadParameter("specify a database service by name, --labels, or --query") - } +func (d *databaseInfo) GetDatabase(ctx context.Context, tc *client.TeleportClient) (types.Database, error) { d.mu.Lock() defer d.mu.Unlock() if d.database != nil { return d.database, nil } // holding mutex across the api call to avoid multiple redundant api calls. - var databases types.Databases - var err error - name := d.ServiceName - if name != "" { - databases, err = listDatabasesByName(cf.Context, tc, name) - } else { - name = cf.DatabaseService - // search by prefix if the db name comes from cli flag instead of cert. - databases, err = listDatabasesByPrefix(cf.Context, tc, name) - } + database, err := getDatabase(ctx, tc, d.ServiceName) if err != nil { return nil, trace.Wrap(err) } - db, err := chooseOneDatabase(cf, name, databases) - if err != nil { - return nil, trace.Wrap(err) - } - - d.database = db + d.database = database return d.database, nil } -// chooseOneDatabase is a helper func for GetDatabase that returns either the -// only database in a list of databases or returns a database that matches the -// nameOrPrefix exactly, otherwise an error. -func chooseOneDatabase(cf *CLIConf, nameOrPrefix string, databases types.Databases) (types.Database, error) { - if len(databases) == 1 { - return databases[0], nil - } - // Check if nameOrPrefix matches any database exactly and, if so, choose +// chooseOneDatabase is a helper func that returns either the only database in a +// list of databases or returns a database that matches the selector name +// or unambiguous discovered name exactly, otherwise an error. +func chooseOneDatabase(cf *CLIConf, databases types.Databases) (types.Database, error) { + selectors := newDatabaseResourceSelectors(cf) + // Check if the name matches any database exactly and, if so, choose // that database over any others. for _, db := range databases { - if db.GetName() == nameOrPrefix { + if db.GetName() == selectors.name { + log.Debugf("Selected database %q by exact name match", db.GetName()) return db, nil } } + // look for a single database with a matching discovered name label. + if dbs := findDatabasesByDiscoveredName(databases, selectors.name); len(dbs) > 0 { + names := make([]string, 0, len(dbs)) + for _, db := range dbs { + names = append(names, db.GetName()) + } + log.Debugf("Choosing amongst databases (%v) by discovered name", names) + databases = dbs + } + if len(databases) == 1 { + log.Debugf("Selected database %q", databases[0].GetName()) + return databases[0], nil + } // error - we need exactly one database. - selectors := resourceSelectors{ - kind: "database", - name: nameOrPrefix, - labels: cf.Labels, - query: cf.PredicateExpression, - } if len(databases) == 0 { return nil, trace.NotFound( "%v not found, use '%v' to see registered databases", selectors, @@ -991,57 +996,72 @@ func chooseOneDatabase(cf *CLIConf, nameOrPrefix string, databases types.Databas return nil, trace.BadParameter(errMsg) } -// listActiveDatabases lists databases that match active (logged in) databases. -func listActiveDatabases(ctx context.Context, tc *client.TeleportClient, routes []tlsca.RouteToDatabase) (types.Databases, error) { - names := make([]string, 0, len(routes)) - for _, r := range routes { - names = append(names, fmt.Sprintf("(name == %q)", r.ServiceName)) +// findDatabasesByDiscoveredName returns all databases that have a discovered +// name label that matches the given name. +func findDatabasesByDiscoveredName(databases types.Databases, name string) types.Databases { + var out types.Databases + for _, db := range databases { + discoveredName, ok := db.GetLabel(types.DiscoveredNameLabel) + if ok && discoveredName == name { + out = append(out, db) + } } - predicate := strings.Join(names, "||") - return listDatabasesWithPredicate(ctx, tc, predicate) + return out } -// listDatabasesByName lists database that match a given name. -func listDatabasesByName(ctx context.Context, tc *client.TeleportClient, name string) (types.Databases, error) { - predicate := fmt.Sprintf("name == %q", name) - return listDatabasesWithPredicate(ctx, tc, predicate) +// getDatabase gets a database using its full name. +func getDatabase(ctx context.Context, tc *client.TeleportClient, name string) (types.Database, error) { + matchName := makeNamePredicate(name) + databases, err := listDatabasesWithPredicate(ctx, tc, matchName) + if err != nil { + return nil, trace.Wrap(err) + } + if len(databases) == 0 { + return nil, trace.NotFound("database %q not found among registered databases in cluster %v", name, tc.SiteName) + } + return databases[0], nil } -// listDatabasesByPrefix lists databases that match a given name prefix. -func listDatabasesByPrefix(ctx context.Context, tc *client.TeleportClient, prefix string) (types.Databases, error) { - predicate := fmt.Sprintf(`hasPrefix(name, %q)`, prefix) - databases, err := listDatabasesWithPredicate(ctx, tc, predicate) - if err == nil || !utils.IsPredicateError(err) { - return databases, trace.Wrap(err) - } - // predicate error from using hasPrefix expression. - // fallback to listing without the hasPrefix predicate and filtering - // on client side for backwards compatibility. - databases, err = listDatabasesWithPredicate(ctx, tc, "") +// getDatabaseByNameOrDiscoveredName fetches a database that unambiguously +// matches a given name or a discovered name label. +func getDatabaseByNameOrDiscoveredName(cf *CLIConf, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) (types.Database, error) { + predicate := makeDiscoveredNameOrNamePredicate(cf.DatabaseService) + databases, err := listDatabasesWithPredicate(cf.Context, tc, predicate) if err != nil { return nil, trace.Wrap(err) } + if activeDBs := filterActiveDatabases(activeRoutes, databases); len(activeDBs) > 0 { + names := make([]string, 0, len(activeDBs)) + for _, db := range activeDBs { + names = append(names, db.GetName()) + } + log.Debugf("Choosing a database amongst active databases (%v)", names) + // preferentially choose from active databases if any of them match. + return chooseOneDatabase(cf, activeDBs) + } + return chooseOneDatabase(cf, databases) +} + +func filterActiveDatabases(routes []tlsca.RouteToDatabase, databases types.Databases) types.Databases { + databasesByName := databases.ToMap() var out types.Databases - for _, db := range databases { - if strings.HasPrefix(db.GetName(), prefix) { + for _, route := range routes { + if db, ok := databasesByName[route.ServiceName]; ok { out = append(out, db) } } - return out, nil + return out } // listDatabasesWithPredicate is a helper func for listing databases using // a given additional predicate expression. If the teleport client already // has a predicate expression, the predicates are combined with a logical AND. func listDatabasesWithPredicate(ctx context.Context, tc *client.TeleportClient, predicate string) (types.Databases, error) { - if predicate == "" { - predicate = tc.PredicateExpression - } else if tc.PredicateExpression != "" { - predicate = fmt.Sprintf("(%v) && (%v)", predicate, tc.PredicateExpression) - } var databases []types.Database err := client.RetryWithRelogin(ctx, tc, func() error { var err error + predicate := makePredicateConjunction(predicate, tc.PredicateExpression) + log.Debugf("Listing databases with predicate (%v) and labels %v", predicate, tc.Labels) databases, err = tc.ListDatabases(ctx, &proto.ListResourcesRequest{ Namespace: tc.Namespace, ResourceType: types.KindDatabaseServer, @@ -1053,6 +1073,57 @@ func listDatabasesWithPredicate(ctx context.Context, tc *client.TeleportClient, return databases, trace.Wrap(err) } +func makeDiscoveredNameOrNamePredicate(name string) string { + matchName := makeNamePredicate(name) + matchDiscoveredName := makeDiscoveredNamePredicate(name) + return makePredicateDisjunction(matchName, matchDiscoveredName) +} + +func makeDiscoveredNamePredicate(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + return fmt.Sprintf(`labels[%q] == %q`, types.DiscoveredNameLabel, name) +} + +func makeNamePredicate(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + return fmt.Sprintf(`name == %q`, name) +} + +// makePredicateConjunction combines two predicate expressions into one +// expression as a conjunction (logical AND) of the expressions. +func makePredicateConjunction(a, b string) string { + return combinePredicateExpressions(a, b, "&&") +} + +// makePredicateDisjunction combines two predicate expressions into one +// expression as a disjunction (logical OR) of the expressions. +func makePredicateDisjunction(a, b string) string { + return combinePredicateExpressions(a, b, "||") +} + +// combinePredicateExpressions combines two predicate expressions into one +// expression with the given operator. +func combinePredicateExpressions(a, b, op string) string { + a = strings.TrimSpace(a) + b = strings.TrimSpace(b) + switch { + case a == "": + return b + case b == "": + return a + case a == b: + return a + default: + return fmt.Sprintf("(%v) %v (%v)", a, op, b) + } +} + func needDatabaseRelogin(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDatabase, profile *client.ProfileStatus, requires *dbLocalProxyRequirement) (bool, error) { if (requires.localProxy && requires.tunnel) || isLocalProxyTunnelRequested(cf) { switch route.Protocol { @@ -1175,123 +1246,73 @@ func isMFADatabaseAccessRequired(ctx context.Context, tc *client.TeleportClient, // // If logged into multiple databases, returns an error unless one specified // explicitly via --db flag. -func pickActiveDatabase(cf *CLIConf, tc *client.TeleportClient) (*tlsca.RouteToDatabase, error) { - profile, err := tc.ProfileStatus() - if err != nil { +func pickActiveDatabase(cf *CLIConf, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) (*tlsca.RouteToDatabase, error) { + if route, err := maybePickActiveDatabase(cf, activeRoutes); err == nil && route != nil { + return route, nil + } else if err != nil { return nil, trace.Wrap(err) } - - routes, err := profile.DatabasesForCluster(tc.SiteName) - if err != nil { - return nil, trace.Wrap(err) - } - - if len(routes) == 0 { - return nil, trace.NotFound("please login using 'tsh db login' first") + // check if any active database can possibly match. + selectors := newDatabaseResourceSelectors(cf) + if routes := filterRoutesByPrefix(activeRoutes, selectors.name); len(routes) == 0 { + // no match is possible. + return nil, trace.NotFound(formatDBNotLoggedIn(cf.SiteName, selectors)) } - routes, databases, err := filterActiveDatabases(cf.Context, tc, routes) + db, err := getDatabaseByNameOrDiscoveredName(cf, tc, activeRoutes) if err != nil { return nil, trace.Wrap(err) } + if route, ok := findActiveDatabase(db.GetName(), activeRoutes); ok { + return &route, nil + } + return nil, trace.NotFound(formatDBNotLoggedIn(cf.SiteName, selectors)) +} - if len(routes) != 1 { - // error - we need exactly one route. - selectors := resourceSelectors{ - kind: "database", - name: cf.DatabaseService, - labels: cf.Labels, - query: cf.PredicateExpression, - } - if len(routes) == 0 { - return nil, trace.NotFound("not logged into %v", selectors) - } - if len(databases) == 0 { - // if not already given, try to fetch them so we can print full - // the full `tsh db ls -v` table of ambiguously matching active DBs. - databases, err = listActiveDatabases(cf.Context, tc, routes) - if err != nil { - return nil, trace.Wrap(err) +// maybePickActiveDatabase tries to pick a database automatically when selectors +// are not given, or by an exact name match of an active database when neither +// labels nor query are given. +// The route returned may be nil, indicating an active route could not be +// picked. +func maybePickActiveDatabase(cf *CLIConf, activeRoutes []tlsca.RouteToDatabase) (*tlsca.RouteToDatabase, error) { + selectors := newDatabaseResourceSelectors(cf) + if selectors.query == "" && selectors.labels == "" { + if selectors.name == "" { + switch len(activeRoutes) { + case 0: + return nil, trace.NotFound(formatDBNotLoggedIn(cf.SiteName, selectors)) + case 1: + log.Debugf("Auto-selecting the only active database %q", activeRoutes[0].ServiceName) + return &activeRoutes[0], nil + default: + return nil, trace.BadParameter(formatChooseActiveDB(activeRoutes)) } } - errMsg := formatAmbiguousDB(cf, selectors, databases) - return nil, trace.BadParameter(errMsg) - } - - route := &routes[0] - // If database user or name were provided on the CLI, - // override the default ones. - if cf.DatabaseUser != "" { - route.Username = cf.DatabaseUser - } - if cf.DatabaseName != "" { - route.Database = cf.DatabaseName - } - return route, nil -} - -// filterActiveDatabases takes a list of active database routes and returns a -// filtered list and, possibly, their corresponding types.Databases. -// Callers should therefore not assume that the types.Databases are populated. -// Filtering is done by matching on database name prefix, label, and query -// predicate selectors from the Teleport client. -// If an active database name matches exactly, all other active databases are -// filtered out - this is to avoid requiring additional selectors -// when a user gives an exact database name. -func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) ([]tlsca.RouteToDatabase, types.Databases, error) { - if len(activeRoutes) == 0 { - // nothing to filter - return nil, nil, nil - } - prefix := tc.DatabaseService - if len(tc.Labels) == 0 && tc.PredicateExpression == "" { - // when we have a name but don't have label or predicate query, look for - // a route that matches the name exactly to maybe avoid calling - // ListDatabases API below. - if route, ok := findActiveDatabase(prefix, activeRoutes); ok { - return []tlsca.RouteToDatabase{route}, nil, nil + if route, ok := findActiveDatabase(selectors.name, activeRoutes); ok { + log.Debugf("Selected active database %q by name", route.ServiceName) + return &route, nil } } + return nil, nil +} - // make a ListDatabases API call filtered by prefix name - databases, err := listDatabasesByPrefix(ctx, tc, prefix) - if err != nil { - return nil, nil, trace.Wrap(err) - } - databasesByName := databases.ToMap() - - // when a database matches the prefix fully, look for a - // corresponding active route. - if db, ok := databasesByName[prefix]; ok { - for _, route := range activeRoutes { - if route.ServiceName == db.GetName() { - return []tlsca.RouteToDatabase{route}, types.Databases{db}, nil - } - } - // no active route, but return the fetched databases if the caller is - // interested. - return nil, databases, nil +// getDatabasesForLogout selects databases for logout in "tsh db logout". +func getDatabasesForLogout(cf *CLIConf, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) ([]tlsca.RouteToDatabase, error) { + selectors := newDatabaseResourceSelectors(cf) + if selectors.IsEmpty() { + // if db name, labels, query was not given, logout of all databases. + return activeRoutes, nil } - - // otherwise, just filter routes to those that match the names of the - // databases. - var selectedRoutes []tlsca.RouteToDatabase - var activeDBs types.Databases - for _, route := range activeRoutes { - if db, ok := databasesByName[route.ServiceName]; ok { - selectedRoutes = append(selectedRoutes, route) - activeDBs = append(activeDBs, db) - } + route, err := pickActiveDatabase(cf, tc, activeRoutes) + if err != nil { + return nil, trace.Wrap(err) } - return selectedRoutes, activeDBs, nil + return []tlsca.RouteToDatabase{*route}, nil } // findActiveDatabase returns a database route and a bool indicating whether // the route was found. func findActiveDatabase(name string, activeRoutes []tlsca.RouteToDatabase) (tlsca.RouteToDatabase, bool) { - if name == "" && len(activeRoutes) == 1 { - return activeRoutes[0], true - } for _, r := range activeRoutes { if r.ServiceName == name { return r, true @@ -1300,11 +1321,58 @@ func findActiveDatabase(name string, activeRoutes []tlsca.RouteToDatabase) (tlsc return tlsca.RouteToDatabase{}, false } +func filterRoutesByPrefix(routes []tlsca.RouteToDatabase, prefix string) []tlsca.RouteToDatabase { + var out []tlsca.RouteToDatabase + for _, r := range routes { + if strings.HasPrefix(r.ServiceName, prefix) { + out = append(out, r) + } + } + return out +} + +func formatStaleDBCert(clusterFlag, name string) string { + return fmt.Sprintf("you are logged into a database that no longer exists in the cluster (remove it with '%v %v')", + formatDatabaseLogoutCommand(clusterFlag), name) +} + +func formatChooseActiveDB(routes []tlsca.RouteToDatabase) string { + var services []string + for _, r := range routes { + services = append(services, r.ServiceName) + } + return fmt.Sprintf("multiple databases are available (%v), please specify one by name, --labels, or --query", + strings.Join(services, ", ")) +} + +func formatDBNotLoggedIn(clusterFlag string, selectors resourceSelectors) string { + if selectors.IsEmpty() { + return fmt.Sprintf( + "please login using '%v' first (use '%v' to see registered databases)", + formatDatabaseLoginCommand(clusterFlag), + formatDatabaseListCommand(clusterFlag), + ) + } + return fmt.Sprintf("not logged into %s", selectors) +} + +func formatDatabaseLogoutCommand(clusterFlag string) string { + return formatTSHCommand("tsh db logout", clusterFlag) +} + +func formatDatabaseLoginCommand(clusterFlag string) string { + return formatTSHCommand("tsh db login", clusterFlag) +} + func formatDatabaseListCommand(clusterFlag string) string { + return formatTSHCommand("tsh db ls", clusterFlag) +} + +func formatTSHCommand(cmd, clusterFlag string) string { if clusterFlag == "" { - return "tsh db ls" + return cmd } - return fmt.Sprintf("tsh db ls --cluster=%v", clusterFlag) + return fmt.Sprintf("%v --cluster=%v", cmd, clusterFlag) } // formatDatabaseConnectCommand formats an appropriate database connection @@ -1525,7 +1593,8 @@ func formatAmbiguousDB(cf *CLIConf, selectors resourceSelectors, matchedDBs type showDatabasesAsText(&sb, cf.SiteName, matchedDBs, activeDBs, checker, verbose) listCommand := formatDatabaseListCommand(cf.SiteName) - return formatAmbiguityErrTemplate(cf, selectors, listCommand, sb.String()) + fullNameExample := matchedDBs[0].GetName() + return formatAmbiguityErrTemplate(cf, selectors, listCommand, sb.String(), fullNameExample) } // resourceSelectors is a helper struct for gathering up the selectors for a @@ -1560,14 +1629,24 @@ func (r resourceSelectors) IsEmpty() bool { return r.name == "" && r.labels == "" && r.query == "" } +func newDatabaseResourceSelectors(cf *CLIConf) resourceSelectors { + return resourceSelectors{ + kind: "database", + name: cf.DatabaseService, + labels: cf.Labels, + query: cf.PredicateExpression, + } +} + // formatAmbiguityErrTemplate is a helper func that formats an ambiguous // resource error message. -func formatAmbiguityErrTemplate(cf *CLIConf, selectors resourceSelectors, listCommand, matchTable string) string { +func formatAmbiguityErrTemplate(cf *CLIConf, selectors resourceSelectors, listCommand, matchTable, fullNameExample string) string { data := map[string]any{ "command": cf.CommandWithBinary(), "listCommand": strings.TrimSpace(listCommand), "kind": strings.TrimSpace(selectors.kind), "matchTable": strings.TrimSpace(matchTable), + "example": strings.TrimSpace(fullNameExample), } if !selectors.IsEmpty() { data["selectors"] = strings.TrimSpace(selectors.String()) @@ -1642,7 +1721,7 @@ multiple {{ .kind }}s are available: {{ .matchTable }} Hint: use '{{ .listCommand }} -v' or '{{ .listCommand }} --format=[json|yaml]' to list all {{ .kind }}s with full details. -Hint: try selecting the {{ .kind }} with a more specific name (ex: {{ .command }} full-{{ .kind }}-name). +Hint: try selecting the {{ .kind }} with a more specific name (ex: {{ .command }} {{ .example }}). Hint: try selecting the {{ .kind }} with additional --labels or --query predicate. `)) ) diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index c6f9109cfad85..9ae4903b9e94f 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -58,8 +58,7 @@ func TestTshDB(t *testing.T) { testenv.WithInsecureDevMode(t, true) t.Run("Login", testDatabaseLogin) t.Run("List", testListDatabase) - t.Run("FilterActiveDatabases", testFilterActiveDatabases) - t.Run("DatabaseInfo", testDatabaseInfo) + t.Run("DatabaseSelection", testDatabaseSelection) } // testDatabaseLogin tests "tsh db login" command and verifies "tsh db @@ -83,13 +82,6 @@ func testDatabaseLogin(t *testing.T) { cfg.Databases.Enabled = true cfg.Databases.Databases = []servicecfg.Database{ { - Name: "postgres", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - "env": "local", - }, - }, { Name: "postgres-rds-us-west-1-123456789012", Protocol: defaults.ProtocolPostgres, URI: "localhost:5432", @@ -105,22 +97,6 @@ func testDatabaseLogin(t *testing.T) { InstanceID: "postgres", }, }, - }, { - Name: "postgres-rds-us-west-2-123456789012", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - types.DiscoveredNameLabel: "postgres", - "region": "us-west-2", - "env": "prod", - }, - AWS: servicecfg.DatabaseAWS{ - AccountID: "123456789012", - Region: "us-west-2", - RDS: servicecfg.DatabaseAWSRDS{ - InstanceID: "postgres", - }, - }, }, { Name: "mysql", Protocol: defaults.ProtocolMySQL, @@ -214,17 +190,20 @@ func testDatabaseLogin(t *testing.T) { expectErrForEnvCmd: true, // "tsh db env" not supported for DynamoDB. }, { - name: "postgres", - databaseName: "postgres", - // the full db name is also a prefix of other dbs, but a full name - // match should take precedence over prefix matches. + name: "by full name", + databaseName: "postgres-rds-us-west-1-123456789012", + expectCertsLen: 1, + }, + { + name: "by discovered name", + databaseName: "postgres-rds-us-west-1-123456789012", dbSelectors: []string{"postgres"}, expectCertsLen: 1, }, { name: "by labels", - databaseName: "postgres", - dbSelectors: []string{"--labels", "env=local"}, + databaseName: "postgres-rds-us-west-1-123456789012", + dbSelectors: []string{"--labels", "region=us-west-1"}, expectCertsLen: 1, }, { @@ -233,12 +212,6 @@ func testDatabaseLogin(t *testing.T) { dbSelectors: []string{"--query", `labels.env=="prod" && labels.region == "us-west-1"`}, expectCertsLen: 1, }, - { - name: "by prefix name", - databaseName: "postgres-rds-us-west-2-123456789012", - dbSelectors: []string{"postgres-rds-us-west-2"}, - expectCertsLen: 1, - }, } // Note: keystore currently races when multiple tsh clients work in the @@ -529,6 +502,18 @@ func testListDatabase(t *testing.T) { require.Contains(t, captureStdout.String(), "leaf-postgres") } +func TestFormatDatabaseLoginCommand(t *testing.T) { + t.Parallel() + + t.Run("default", func(t *testing.T) { + require.Equal(t, "tsh db login", formatDatabaseLoginCommand("")) + }) + + t.Run("with cluster flag", func(t *testing.T) { + require.Equal(t, "tsh db login --cluster=leaf", formatDatabaseLoginCommand("leaf")) + }) +} + func TestFormatDatabaseListCommand(t *testing.T) { t.Parallel() @@ -793,357 +778,51 @@ func TestFormatDatabaseConnectArgs(t *testing.T) { } } -func testFilterActiveDatabases(t *testing.T) { +func TestResourceSelectors(t *testing.T) { t.Parallel() - // setup some databases and "active" routes to test filtering - - // databases that all have a name starting with with "foo" - fooDB1, fooRoute1 := makeDBConfigAndRoute("foo", map[string]string{"env": "dev", "svc": "fooer"}) - fooDB2, fooRoute2 := makeDBConfigAndRoute("foo-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) - fooDB3, fooRoute3 := makeDBConfigAndRoute("foo-westus-11111", map[string]string{"env": "prod", "region": "westus"}) - - // databases that all have a name starting with with "bar" - barDB1, barRoute1 := makeDBConfigAndRoute("bar", map[string]string{"env": "dev", "svc": "barrer"}) - barDB2, barRoute2 := makeDBConfigAndRoute("bar-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) - - // databases that all have a name starting with with "baz" - bazDB1, bazRoute1 := makeDBConfigAndRoute("baz", map[string]string{"env": "dev", "svc": "bazzer"}) - bazDB2, bazRoute2 := makeDBConfigAndRoute("baz2", map[string]string{"env": "prod", "svc": "bazzer"}) - routes := []tlsca.RouteToDatabase{ - fooRoute1, fooRoute2, fooRoute3, - barRoute1, barRoute2, - bazRoute1, bazRoute2, - } - s := newTestSuite(t, - withRootConfigFunc(func(cfg *servicecfg.Config) { - cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) - cfg.Databases.Enabled = true - cfg.Databases.Databases = []servicecfg.Database{ - fooDB1, fooDB2, fooDB3, - barDB1, barDB2, - bazDB1, bazDB2, - } - }), - ) - - // Log into Teleport cluster. - tmpHomePath, _ := mustLogin(t, s) + t.Run("formatting", testResourceSelectorsFormatting) + t.Run("IsEmpty", testResourceSelectorsIsEmpty) +} +func testResourceSelectorsIsEmpty(t *testing.T) { + t.Parallel() tests := []struct { - name, - dbNamePrefix, - labels, - query string - wantAPICall bool - overrideActiveRoutes []tlsca.RouteToDatabase - overrideAPIDatabasesCheckFn func(t *testing.T, databases types.Databases) - wantRoutes []tlsca.RouteToDatabase + desc string + selectors resourceSelectors + wantEmpty bool }{ { - name: "by exact name that is a prefix of others", - dbNamePrefix: fooRoute1.ServiceName, - wantAPICall: false, - wantRoutes: []tlsca.RouteToDatabase{fooRoute1}, - }, - { - name: "by exact name of inactive route that is a prefix of active routes", - dbNamePrefix: fooRoute1.ServiceName, - overrideActiveRoutes: []tlsca.RouteToDatabase{ - fooRoute2, fooRoute3, - barRoute1, barRoute2, - bazRoute1, bazRoute2, - }, - wantAPICall: true, - overrideAPIDatabasesCheckFn: func(t *testing.T, databases types.Databases) { - t.Helper() - require.NotNil(t, databases) - databasesByName := databases.ToMap() - require.Contains(t, databasesByName, fooRoute1.ServiceName) - require.Contains(t, databasesByName, fooRoute2.ServiceName) - require.Contains(t, databasesByName, fooRoute3.ServiceName) - }, - // the inactive route got filtered out, but active routes shouldn't - // have been matched by prefix either. - wantRoutes: nil, - }, - { - name: "by exact name that is not a prefix of others", - dbNamePrefix: fooRoute2.ServiceName, - wantAPICall: false, - wantRoutes: []tlsca.RouteToDatabase{fooRoute2}, + desc: "no fields set", + selectors: resourceSelectors{}, + wantEmpty: true, }, { - name: "by exact name that is a prefix of others with overlapping labels", - dbNamePrefix: bazRoute1.ServiceName, - labels: "svc=bazzer", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{bazRoute1}, + desc: "kind field set", + selectors: resourceSelectors{kind: "x"}, + wantEmpty: true, }, { - name: "by name prefix", - dbNamePrefix: "ba", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{barRoute1, barRoute2, bazRoute1, bazRoute2}, + desc: "name field set", + selectors: resourceSelectors{name: "x"}, }, { - name: "by labels", - labels: "env=dev", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute1, barRoute1, bazRoute1}, + desc: "labels field set", + selectors: resourceSelectors{labels: "x"}, }, { - name: "by query", - query: `labels.env == "dev"`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute1, barRoute1, bazRoute1}, - }, - { - name: "by name prefix and labels", - dbNamePrefix: "fo", - labels: "env=prod", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute2, fooRoute3}, - }, - { - name: "by name prefix and query", - dbNamePrefix: "fo", - query: `labels.region == "us-west-1"`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute2}, - }, - { - name: "by labels and query", - labels: "env=dev", - query: `hasPrefix(name, "baz")`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{bazRoute1}, - }, - { - name: "by name prefix and labels and query", - dbNamePrefix: "fo", - labels: "env=prod", - query: `labels.region == "westus"`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute3}, + desc: "query field set", + selectors: resourceSelectors{query: "x"}, }, } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - cf := &CLIConf{ - Context: ctx, - HomePath: tmpHomePath, - DatabaseService: tt.dbNamePrefix, - Labels: tt.labels, - PredicateExpression: tt.query, - } - tc, err := makeClient(cf) - require.NoError(t, err) - activeRoutes := routes - if tt.overrideActiveRoutes != nil { - activeRoutes = tt.overrideActiveRoutes - } - gotRoutes, dbs, err := filterActiveDatabases(ctx, tc, activeRoutes) - require.NoError(t, err) - require.Empty(t, cmp.Diff(tt.wantRoutes, gotRoutes)) - if tt.wantAPICall { - if tt.overrideAPIDatabasesCheckFn != nil { - tt.overrideAPIDatabasesCheckFn(t, dbs) - } else { - require.Equal(t, len(tt.wantRoutes), len(dbs), - "returned routes should have corresponding types.Databases") - for i := range tt.wantRoutes { - require.Equal(t, gotRoutes[i].ServiceName, dbs[i].GetName(), - "route %v does not match corresponding types.Database", i) - } - } - return - } - require.Zero(t, len(dbs), "unexpected API call to ListDatabases") + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + require.Equal(t, test.wantEmpty, test.selectors.IsEmpty()) }) } } -func testDatabaseInfo(t *testing.T) { +func testResourceSelectorsFormatting(t *testing.T) { t.Parallel() - alice, err := types.NewUser("alice@example.com") - require.NoError(t, err) - defaultDBUser := "admin" - defaultDBName := "default" - // add multiple allowed db names/users, to prevent default selection. - // these tests should use the db name/username from either cli flag or - // active cert only. - alice.SetDatabaseUsers([]string{defaultDBUser, "foo"}) - alice.SetDatabaseNames([]string{defaultDBName, "bar"}) - alice.SetRoles([]string{"access"}) - databases := []servicecfg.Database{ - { - Name: "postgres", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - "env": "local", - }, - }, { - Name: "postgres-2", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - "env": "local", - }, - }, { - Name: "mysql", - Protocol: defaults.ProtocolMySQL, - URI: "localhost:3306", - }, { - Name: "cassandra", - Protocol: defaults.ProtocolCassandra, - URI: "localhost:9042", - }, { - Name: "snowflake", - Protocol: defaults.ProtocolSnowflake, - URI: "localhost.snowflakecomputing.com", - }, { - Name: "mongo", - Protocol: defaults.ProtocolMongoDB, - URI: "localhost:27017", - }, { - Name: "mssql", - Protocol: defaults.ProtocolSQLServer, - URI: "localhost:1433", - }, { - Name: "dynamodb", - Protocol: defaults.ProtocolDynamoDB, - URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. - AWS: servicecfg.DatabaseAWS{ - AccountID: "123456789012", - ExternalID: "123123123", - Region: "us-west-1", - }, - }} - s := newTestSuite(t, - withRootConfigFunc(func(cfg *servicecfg.Config) { - cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, alice) - cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) - // separate MySQL port with TLS routing. - // set the public address to be sure even on v2+, tsh clients will see the separate port. - mySQLAddr := localListenerAddr() - cfg.Proxy.MySQLAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: mySQLAddr} - cfg.Proxy.MySQLPublicAddrs = []utils.NetAddr{{AddrNetwork: "tcp", Addr: mySQLAddr}} - cfg.Databases.Enabled = true - cfg.Databases.Databases = databases - }), - ) - s.user = alice - // Log into Teleport cluster. - tmpHomePath, _ := mustLogin(t, s) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - t.Run("newDatabaseInfo", func(t *testing.T) { - for _, db := range databases { - require.NotEmpty(t, db.Name) - require.NotEmpty(t, db.Protocol) - route := tlsca.RouteToDatabase{ - ServiceName: db.Name, - Protocol: db.Protocol, - Username: defaultDBUser, - Database: defaultDBName, - } - t.Run(route.ServiceName, func(t *testing.T) { - t.Run("with route", func(t *testing.T) { - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - tracer: tracing.NoopTracer(teleport.ComponentTSH), - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, &route) - require.NoError(t, err) - require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - if route.Protocol == defaults.ProtocolDynamoDB { - // v13 specific. We remove the dynamodb schema name from the route since it's not supported. - require.Equal(t, route.ServiceName, dbInfo.ServiceName) - require.Equal(t, route.Protocol, dbInfo.Protocol) - require.Equal(t, route.Username, dbInfo.Username) - } else { - require.Equal(t, route, dbInfo.RouteToDatabase) - } - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") - }) - t.Run("without route", func(t *testing.T) { - err = Run(ctx, []string{"db", "login", route.ServiceName, - "--db-user", route.Username, - "--db-name", route.Database, - }, setHomePath(tmpHomePath)) - require.NoError(t, err) - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - tracer: tracing.NoopTracer(teleport.ComponentTSH), - DatabaseService: route.ServiceName, - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, nil) - require.NoError(t, err) - require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - if route.Protocol == defaults.ProtocolDynamoDB { - // v13 specific. We remove the dynamodb schema name from the route since it's not supported. - require.Equal(t, route.ServiceName, dbInfo.ServiceName) - require.Equal(t, route.Protocol, dbInfo.Protocol) - require.Equal(t, route.Username, dbInfo.Username) - } else { - require.Equal(t, route, dbInfo.RouteToDatabase) - } - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "cached database should be the same") - }) - }) - } - }) - t.Run("getDatabaseInfo", func(t *testing.T) { - // login to "postgres-2" db. - err = Run(ctx, []string{"db", "login", "postgres-2"}, setHomePath(tmpHomePath)) - require.NoError(t, err) - cf := &CLIConf{ - Context: ctx, - HomePath: tmpHomePath, - // select the other db, "postgres", which was not logged into. - DatabaseService: "postgres", - // v13 specific: set the db name/username because it won't be - // set by default until v14+. - DatabaseUser: defaultDBUser, - DatabaseName: defaultDBName, - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := getDatabaseInfo(cf, tc) - require.NoError(t, err) - require.NotNil(t, dbInfo) - // verify that the active login route for "postgres-2" was not used - // instead of fetching info for the "postgres" db. - require.Equal(t, "postgres", dbInfo.ServiceName) - require.Equal(t, defaults.ProtocolPostgres, dbInfo.Protocol) - require.Equal(t, defaultDBUser, dbInfo.Username) - require.Equal(t, defaultDBName, dbInfo.Database) - require.NotNil(t, dbInfo.database) - }) -} - -func TestResourceSelectorsFormatting(t *testing.T) { tests := []struct { testName string selectors resourceSelectors @@ -1207,7 +886,12 @@ func makeDBConfigAndRoute(name string, staticLabels map[string]string) (servicec URI: "localhost:5432", StaticLabels: staticLabels, } - route := tlsca.RouteToDatabase{ServiceName: name} + route := tlsca.RouteToDatabase{ + ServiceName: name, + Protocol: defaults.ProtocolPostgres, + Username: "alice", + Database: "postgres", + } return db, route } @@ -1237,6 +921,22 @@ func TestChooseOneDatabase(t *testing.T) { URI: "uri", }) require.NoError(t, err) + db3, err := types.NewDatabaseV3(types.Metadata{ + Name: "my-db-with-some-suffix", + Labels: map[string]string{"foo": "bar", types.DiscoveredNameLabel: "my-db"}, + }, types.DatabaseSpecV3{ + Protocol: "protocol", + URI: "uri", + }) + require.NoError(t, err) + db4, err := types.NewDatabaseV3(types.Metadata{ + Name: "my-db-with-some-other-suffix", + Labels: map[string]string{"foo": "bar", types.DiscoveredNameLabel: "my-db"}, + }, types.DatabaseSpecV3{ + Protocol: "protocol", + URI: "uri", + }) + require.NoError(t, err) tests := []struct { desc string databases types.Databases @@ -1253,6 +953,11 @@ func TestChooseOneDatabase(t *testing.T) { databases: types.Databases{db0, db1, db2}, wantDB: db0, }, + { + desc: "multiple databases to choose from with unambiguous discovered name match", + databases: types.Databases{db1, db2, db3}, + wantDB: db3, + }, { desc: "zero databases to choose from is an error", wantErrContains: `database "my-db" with labels "foo=bar" with query (hasPrefix(name, "my-db")) not found, use 'tsh db ls --cluster=local-site'`, @@ -1262,6 +967,11 @@ func TestChooseOneDatabase(t *testing.T) { databases: types.Databases{db1, db2}, wantErrContains: `database "my-db" with labels "foo=bar" with query (hasPrefix(name, "my-db")) matches multiple databases`, }, + { + desc: "ambiguous discovered name databases is an error", + databases: types.Databases{db3, db4}, + wantErrContains: `database "my-db" with labels "foo=bar" with query (hasPrefix(name, "my-db")) matches multiple databases`, + }, } ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1270,11 +980,13 @@ func TestChooseOneDatabase(t *testing.T) { cf := &CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), + tracer: tracing.NoopTracer(teleport.ComponentTSH), + DatabaseService: "my-db", Labels: "foo=bar", PredicateExpression: `hasPrefix(name, "my-db")`, SiteName: "local-site", } - db, err := chooseOneDatabase(cf, "my-db", test.databases) + db, err := chooseOneDatabase(cf, test.databases) if test.wantErrContains != "" { require.ErrorContains(t, err, test.wantErrContains) return @@ -1285,3 +997,494 @@ func TestChooseOneDatabase(t *testing.T) { }) } } + +func TestMaybePickActiveDatabase(t *testing.T) { + t.Parallel() + x := tlsca.RouteToDatabase{ServiceName: "x"} + y := tlsca.RouteToDatabase{ServiceName: "y"} + z := tlsca.RouteToDatabase{ServiceName: "z"} + tests := []struct { + desc string + svcName, labels, query string + routes []tlsca.RouteToDatabase + wantRoute *tlsca.RouteToDatabase + wantErr string + }{ + { + desc: "does nothing if labels given", + routes: []tlsca.RouteToDatabase{x}, + svcName: "x", + labels: "env=dev", + }, + { + desc: "does nothing if query given", + svcName: "x", + routes: []tlsca.RouteToDatabase{x}, + query: `name == "x"`, + }, + { + desc: "picks an active route by name", + svcName: "y", + routes: []tlsca.RouteToDatabase{x, y, z}, + wantRoute: &y, + }, + { + desc: "does nothing if only unmatched name is given", + svcName: "y", + routes: []tlsca.RouteToDatabase{x, z}, + }, + { + desc: "picks the only active route without selectors", + routes: []tlsca.RouteToDatabase{x}, + wantRoute: &x, + }, + { + desc: "no routes and no selectors is an error", + routes: []tlsca.RouteToDatabase{}, + wantErr: "please login", + }, + { + desc: "many routes and no selectors is an error", + routes: []tlsca.RouteToDatabase{x, y, z}, + wantErr: "multiple databases", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + cf := &CLIConf{ + DatabaseService: test.svcName, + Labels: test.labels, + PredicateExpression: test.query, + } + route, err := maybePickActiveDatabase(cf, test.routes) + if test.wantErr != "" { + require.ErrorContains(t, err, test.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, test.wantRoute, route) + }) + } +} + +func TestFindActiveDatabase(t *testing.T) { + t.Parallel() + x := tlsca.RouteToDatabase{ServiceName: "x", Protocol: "postgres", Username: "alice", Database: "postgres"} + y := tlsca.RouteToDatabase{ServiceName: "y", Protocol: "postgres", Username: "alice", Database: "postgres"} + z := tlsca.RouteToDatabase{ServiceName: "z", Protocol: "postgres", Username: "alice", Database: "postgres"} + tests := []struct { + desc string + name string + routes []tlsca.RouteToDatabase + wantOK bool + wantRoute tlsca.RouteToDatabase + }{ + { + desc: "zero routes", + name: "x", + }, + { + desc: "no name with zero routes", + }, + { + desc: "no name with one route", + routes: []tlsca.RouteToDatabase{x}, + }, + { + desc: "no name with many routes", + routes: []tlsca.RouteToDatabase{x, y}, + }, + { + desc: "name in routes", + name: "x", + routes: []tlsca.RouteToDatabase{x, y}, + wantOK: true, + wantRoute: x, + }, + { + desc: "name not in routes", + name: "x", + routes: []tlsca.RouteToDatabase{y, z}, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + route, ok := findActiveDatabase(test.name, test.routes) + require.Equal(t, test.wantOK, ok) + require.Equal(t, test.wantRoute, route) + }) + } +} + +// testDatabaseSelection tests database selection by name, prefix name, labels, +// query, etc. +func testDatabaseSelection(t *testing.T) { + t.Parallel() + // setup some databases and "active" routes to test filtering + + // databases that all have a name starting with with "foo" + fooDB1, fooRoute1 := makeDBConfigAndRoute("foo", map[string]string{"env": "dev", "svc": "fooer"}) + fooRDSDB, fooRDSRoute := makeDBConfigAndRoute("foo-rds-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1", types.DiscoveredNameLabel: "foo-rds"}) + fooRDSCustomDB, fooRDSCustomRoute := makeDBConfigAndRoute("foo-rds-custom-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1", types.DiscoveredNameLabel: "foo-rds-custom"}) + // a route that isn't registered anymore, like when a user has logged into + // a db that isn't registered in the cluster anymore. + _, staleRoute := makeDBConfigAndRoute("stale", map[string]string{"env": "dev", "svc": "fooer"}) + + // databases that all have a name starting with with "bar" + barRDSDB1, barRDSRoute1 := makeDBConfigAndRoute("bar-rds-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1", types.DiscoveredNameLabel: "bar-rds"}) + barRDSDB2, barRDSRoute2 := makeDBConfigAndRoute("bar-rds-us-west-2-123456789012", map[string]string{"env": "prod", "region": "us-west-2", types.DiscoveredNameLabel: "bar-rds"}) + + activeRoutes := []tlsca.RouteToDatabase{ + fooRoute1, fooRDSRoute, fooRDSCustomRoute, staleRoute, + barRDSRoute1, barRDSRoute2, + } + + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetDatabaseUsers([]string{"alice", "bob"}) + alice.SetDatabaseNames([]string{"postgres", "other"}) + alice.SetRoles([]string{"access"}) + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, alice) + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + fooDB1, fooRDSDB, fooRDSCustomDB, + barRDSDB1, barRDSDB2, + } + }), + ) + s.user = alice + + // Log into Teleport cluster. + tmpHomePath, _ := mustLogin(t, s) + + t.Run("GetDatabasesForLogout", func(t *testing.T) { + t.Parallel() + tests := []struct { + name, + svcName, + labels, + query string + wantRoutes []tlsca.RouteToDatabase + wantErr string + }{ + { + name: "by exact name", + svcName: fooRDSRoute.ServiceName, + wantRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + }, + { + name: "by exact discovered name", + svcName: "foo-rds", + wantRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + }, + { + name: "by labels", + labels: "region=us-west-2", + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute2}, + }, + { + name: "by query", + query: `labels.region == "us-west-2"`, + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute2}, + }, + { + name: "by exact name of unregistered database", + svcName: staleRoute.ServiceName, + wantRoutes: []tlsca.RouteToDatabase{staleRoute}, + }, + { + name: "by exact discovered name that is ambiguous", + svcName: "bar-rds", + wantErr: "matches multiple", + }, + { + name: "by exact discovered name with labels", + svcName: "bar-rds", + labels: "region=us-west-1", + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + }, + { + name: "by exact discovered name with query", + svcName: "bar-rds", + query: `labels.region == "us-west-1"`, + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + }, + { + name: "all", + wantRoutes: activeRoutes, + }, + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: tt.svcName, + Labels: tt.labels, + PredicateExpression: tt.query, + } + tc, err := makeClient(cf) + require.NoError(t, err) + gotRoutes, err := getDatabasesForLogout(cf, tc, activeRoutes) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) + require.Empty(t, cmp.Diff(tt.wantRoutes, gotRoutes)) + }) + } + }) + + t.Run("GetDatabaseInfo", func(t *testing.T) { + t.Parallel() + tests := []struct { + desc string + svcName, labels, query string + dbUser, dbName string + activeRoutes []tlsca.RouteToDatabase + wantRoute tlsca.RouteToDatabase + wantActive bool + wantErr string + }{ + { + desc: "by exact name", + svcName: "foo", + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRoute1, + }, + { + desc: "by exact name of active db", + svcName: "foo", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: fooRoute1, + wantActive: true, + }, + { + desc: "by exact name of active db overriding user and schema", + svcName: "foo", + dbUser: "bob", + dbName: "other", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: tlsca.RouteToDatabase{ServiceName: "foo", Protocol: "postgres", Username: "bob", Database: "other"}, + wantActive: true, + }, + { + desc: "by exact name that is a prefix of an active db", + svcName: "foo", + dbUser: "alice", + dbName: "postgres", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + wantRoute: fooRoute1, + }, + { + desc: "by exact discovered name", + svcName: "foo-rds", + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRDSRoute, + }, + { + desc: "by labels", + labels: "env=dev,svc=fooer", + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRoute1, + }, + { + desc: "by labels and active route", + labels: "env=dev,svc=fooer", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: fooRoute1, + wantActive: true, + }, + { + desc: "by query", + query: `name=="foo" && labels.env=="dev" && labels.svc=="fooer"`, + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRoute1, + }, + { + desc: "by query and active route", + query: `name == "foo" && labels.env=="dev" && labels.svc=="fooer"`, + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: fooRoute1, + wantActive: true, + }, + { + desc: "by ambiguous exact discovered name", + svcName: "bar-rds", + wantErr: "matches multiple databases", + }, + { + desc: "resolves ambiguous exact discovered name by label", + svcName: "bar-rds", + labels: "region=us-west-1", + dbUser: "alice", + dbName: "postgres", + wantRoute: barRDSRoute1, + }, + { + desc: "resolves ambiguous exact discovered name by query", + svcName: "bar-rds", + query: `labels.region=="us-west-2"`, + dbUser: "alice", + dbName: "postgres", + wantRoute: barRDSRoute2, + }, + { + desc: "by name of db that does not exist", + svcName: "foo-rds-", + wantErr: `"foo-rds-" not found, use 'tsh db ls' to see registered databases`, + }, + { + desc: "by name of db that does not exist and is not active", + svcName: "foo-rds-", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + wantErr: `"foo-rds-" not found, use 'tsh db ls' to see registered databases`, + }, + { + desc: "by ambiguous labels", + labels: "region=us-west-1", + wantErr: "matches multiple databases", + }, + { + desc: "by ambiguous query", + query: `labels.region == "us-west-1"`, + wantErr: "matches multiple databases", + }, + { + desc: "by exact name of unregistered database", + svcName: staleRoute.ServiceName, + activeRoutes: []tlsca.RouteToDatabase{staleRoute}, + wantErr: `you are logged into a database that no longer exists in the cluster`, + }, + // cases without selectors should try choose to from active databases + { + desc: "no selectors with one active registered db", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + wantRoute: fooRDSRoute, + wantActive: true, + }, + { + desc: "no selectors with zero active registered db", + activeRoutes: []tlsca.RouteToDatabase{staleRoute}, + wantErr: `you are logged into a database that no longer exists in the cluster`, + }, + { + desc: "no selectors with multiple active registered db", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1, fooRDSRoute}, + wantErr: "multiple databases are available", + }, + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: test.svcName, + Labels: test.labels, + PredicateExpression: test.query, + DatabaseUser: test.dbUser, + DatabaseName: test.dbName, + } + tc, err := makeClient(cf) + require.NoError(t, err) + info, err := getDatabaseInfo(cf, tc, test.activeRoutes) + if test.wantErr != "" { + require.ErrorContains(t, err, test.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, test.wantRoute, info.RouteToDatabase) + db, err := info.GetDatabase(cf.Context, tc) + require.NoError(t, err) + require.Equal(t, info.ServiceName, db.GetName()) + require.Equal(t, info.Protocol, db.GetProtocol()) + require.Equal(t, db, info.database, "database should have been fetched and cached") + require.Equal(t, test.wantActive, info.isActive) + }) + } + }) + + t.Run("PickActiveDatabase", func(t *testing.T) { + t.Parallel() + tests := []struct { + desc string + activeRoutes []tlsca.RouteToDatabase + dbName string + wantRoute tlsca.RouteToDatabase + wantErr string + }{ + { + desc: "pick active db without selector", + activeRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + wantRoute: barRDSRoute1, + }, + { + desc: "pick active db with discovered name selector", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute, barRDSRoute1}, + dbName: "foo-rds", + wantRoute: fooRDSRoute, + }, + { + desc: "pick active db with exact name selector", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute, barRDSRoute1}, + dbName: fooRDSRoute.ServiceName, + wantRoute: fooRDSRoute, + }, + { + desc: "pick inactive db with selector", + dbName: "foo-rds", + activeRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + wantErr: `not logged into database "foo-rds"`, + }, + { + desc: "no active db", + activeRoutes: []tlsca.RouteToDatabase{}, + wantErr: "please login using 'tsh db login' first", + }, + { + desc: "multiple active db without selector", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute, barRDSRoute1}, + wantErr: "multiple databases are available", + }, + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: test.dbName, + } + tc, err := makeClient(cf) + require.NoError(t, err) + route, err := pickActiveDatabase(cf, tc, test.activeRoutes) + if test.wantErr != "" { + require.ErrorContains(t, err, test.wantErr) + return + } + require.NoError(t, err) + require.NotNil(t, route) + require.Equal(t, test.wantRoute, *route) + }) + } + }) +} diff --git a/tool/tsh/kube.go b/tool/tsh/kube.go index 121e8db93730b..06728a73c6951 100644 --- a/tool/tsh/kube.go +++ b/tool/tsh/kube.go @@ -53,15 +53,19 @@ import ( "k8s.io/kubectl/pkg/util/term" "github.com/gravitational/teleport" + apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keypaths" "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/kube/kubeconfig" kubeutils "github.com/gravitational/teleport/lib/kube/utils" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/tool/common" @@ -974,52 +978,74 @@ func (c *kubeLSCommand) run(cf *CLIConf) error { return trace.Wrap(err) } - selectedCluster := selectedKubeCluster(currentTeleportCluster) + // Ignore errors from fetching the current cluster, since it's not + // mandatory to have a cluster selected or even to have a kubeconfig file. + selectedCluster, _ := kubeconfig.SelectedKubeCluster(getKubeConfigPath(cf, ""), currentTeleportCluster) + err = c.showKubeClusters(cf.Stdout(), kubeClusters, selectedCluster) + return trace.Wrap(err) +} + +func (c *kubeLSCommand) showKubeClusters(w io.Writer, kubeClusters types.KubeClusters, selectedCluster string) error { format := strings.ToLower(c.format) switch format { case teleport.Text, "": - var ( - t asciitable.Table - columns = []string{"Kube Cluster Name", "Labels", "Selected"} - rows [][]string - ) - - for _, cluster := range kubeClusters { - var selectedMark string - if cluster.GetName() == selectedCluster { - selectedMark = "*" - } - rows = append(rows, []string{ - cluster.GetName(), - common.FormatLabels(cluster.GetAllLabels(), c.verbose), - selectedMark, - }) - } - - if c.quiet { - t = asciitable.MakeHeadlessTable(2) - for _, row := range rows { - t.AddRow(row[:2]) - } - } else if c.verbose { - t = asciitable.MakeTable(columns, rows...) - } else { - t = asciitable.MakeTableWithTruncatedColumn(columns, rows, "Labels") - } - fmt.Fprintln(cf.Stdout(), t.AsBuffer().String()) + out := formatKubeClustersAsText(kubeClusters, selectedCluster, c.quiet, c.verbose) + fmt.Fprintln(w, out) case teleport.JSON, teleport.YAML: + sort.Sort(kubeClusters) out, err := serializeKubeClusters(kubeClusters, selectedCluster, format) if err != nil { return trace.Wrap(err) } - fmt.Fprintln(cf.Stdout(), out) + fmt.Fprintln(w, out) default: - return trace.BadParameter("unsupported format %q", cf.Format) + return trace.BadParameter("unsupported format %q", c.format) } - return nil } +func getKubeClusterTextRow(kc types.KubeCluster, selectedCluster string, verbose bool) []string { + var selectedMark string + var row []string + if selectedCluster != "" && kc.GetName() == selectedCluster { + selectedMark = "*" + } + displayName := common.FormatResourceName(kc, verbose) + labels := common.FormatLabels(kc.GetAllLabels(), verbose) + row = append(row, displayName, labels, selectedMark) + return row +} + +func formatKubeClustersAsText(kubeClusters types.KubeClusters, selectedCluster string, quiet, verbose bool) string { + var ( + columns = []string{"Kube Cluster Name", "Labels", "Selected"} + t asciitable.Table + rows [][]string + ) + + for _, cluster := range kubeClusters { + r := getKubeClusterTextRow(cluster, selectedCluster, verbose) + rows = append(rows, r) + } + + switch { + case quiet: + // no column headers and only include the cluster name and labels. + t = asciitable.MakeHeadlessTable(2) + for _, row := range rows { + t.AddRow(row) + } + case verbose: + t = asciitable.MakeTable(columns, rows...) + default: + t = asciitable.MakeTableWithTruncatedColumn(columns, rows, "Labels") + } + + // stable sort by kube cluster name. + t.SortRowsBy([]int{0}, true) + return t.AsBuffer().String() +} + func serializeKubeClusters(kubeClusters []types.KubeCluster, selectedCluster, format string) (string, error) { type cluster struct { KubeClusterName string `json:"kube_cluster_name"` @@ -1077,27 +1103,13 @@ func (c *kubeLSCommand) runAllClusters(cf *CLIConf) error { return trace.Wrap(err) } - sort.Sort(listings) - format := strings.ToLower(c.format) switch format { case teleport.Text, "": - var t asciitable.Table - if cf.Quiet { - t = asciitable.MakeHeadlessTable(3) - } else { - t = asciitable.MakeTable([]string{"Proxy", "Cluster", "Kube Cluster Name", "Labels"}) - } - for _, listing := range listings { - t.AddRow([]string{ - listing.Proxy, - listing.Cluster, - listing.KubeCluster.GetName(), - common.FormatLabels(listing.KubeCluster.GetAllLabels(), c.verbose), - }) - } - fmt.Fprintln(cf.Stdout(), t.AsBuffer().String()) + out := formatKubeListingsAsText(listings, c.quiet, c.verbose) + fmt.Fprintln(cf.Stdout(), out) case teleport.JSON, teleport.YAML: + sort.Sort(listings) out, err := serializeKubeListings(listings, format) if err != nil { return trace.Wrap(err) @@ -1110,6 +1122,37 @@ func (c *kubeLSCommand) runAllClusters(cf *CLIConf) error { return nil } +func formatKubeListingsAsText(listings kubeListings, quiet, verbose bool) string { + var ( + columns = []string{"Proxy", "Cluster", "Kube Cluster Name", "Labels"} + t asciitable.Table + rows [][]string + ) + for _, listing := range listings { + r := append([]string{ + listing.Proxy, + listing.Cluster, + }, getKubeClusterTextRow(listing.KubeCluster, "", verbose)...) + rows = append(rows, r) + } + + switch { + case quiet: + // quiet, so no column headers. + t = asciitable.MakeHeadlessTable(4) + for _, row := range rows { + t.AddRow(row) + } + case verbose: + t = asciitable.MakeTable(columns, rows...) + default: + t = asciitable.MakeTableWithTruncatedColumn(columns, rows, "Labels") + } + // stable sort by proxy, then cluster, then kube cluster name. + t.SortRowsBy([]int{0, 1, 2}, true) + return t.AsBuffer().String() +} + func serializeKubeListings(kubeListings []kubeListing, format string) (string, error) { var out []byte var err error @@ -1121,24 +1164,19 @@ func serializeKubeListings(kubeListings []kubeListing, format string) (string, e return string(out), trace.Wrap(err) } -func selectedKubeCluster(currentTeleportCluster string) string { - kc, err := kubeconfig.Load("") - if err != nil { - log.WithError(err).Warning("Failed parsing existing kubeconfig") - return "" - } - return kubeconfig.KubeClusterFromContext(kc.CurrentContext, currentTeleportCluster) -} - type kubeLoginCommand struct { *kingpin.CmdClause - kubeCluster string - siteName string - impersonateUser string - impersonateGroups []string - namespace string - all bool - overrideContextName string + kubeCluster string + siteName string + impersonateUser string + impersonateGroups []string + namespace string + all bool + overrideContextName string + disableAccessRequest bool + requestReason string + labels string + predicateExpression string } func newKubeLoginCommand(parent *kingpin.CmdClause) *kubeLoginCommand { @@ -1147,92 +1185,235 @@ func newKubeLoginCommand(parent *kingpin.CmdClause) *kubeLoginCommand { } c.Flag("cluster", clusterHelp).Short('c').StringVar(&c.siteName) c.Arg("kube-cluster", "Name of the Kubernetes cluster to login to. Check 'tsh kube ls' for a list of available clusters.").StringVar(&c.kubeCluster) + c.Flag("labels", labelHelp).StringVar(&c.labels) + c.Flag("query", queryHelp).StringVar(&c.predicateExpression) c.Flag("as", "Configure custom Kubernetes user impersonation.").StringVar(&c.impersonateUser) c.Flag("as-groups", "Configure custom Kubernetes group impersonation.").StringsVar(&c.impersonateGroups) // TODO (tigrato): move this back to namespace once teleport drops the namespace flag. c.Flag("kube-namespace", "Configure the default Kubernetes namespace.").Short('n').StringVar(&c.namespace) - c.Flag("all", "Generate a kubeconfig with every cluster the user has access to.").BoolVar(&c.all) + c.Flag("all", "Generate a kubeconfig with every cluster the user has access to. Mutually exclusive with --labels or --query.").BoolVar(&c.all) c.Flag("set-context-name", "Define a custom context name. To use it with --all include \"{{.KubeName}}\""). // Use the default context name template if --set-context-name is not set. // This works as an hint to the user that the context name can be customized. Default(kubeconfig.ContextName("{{.ClusterName}}", "{{.KubeName}}")). StringVar(&c.overrideContextName) + c.Flag("request-reason", "Reason for requesting access").StringVar(&c.requestReason) + c.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&c.disableAccessRequest) return c } func (c *kubeLoginCommand) run(cf *CLIConf) error { - if c.kubeCluster == "" && !c.all { + switch { + case c.all && (c.labels != "" || c.predicateExpression != ""): + return trace.BadParameter("cannot use --labels or --query with --all") + case !c.all && c.getSelectors().IsEmpty(): return trace.BadParameter("kube-cluster name is required. Check 'tsh kube ls' for a list of available clusters.") } - // If --all and --set-context-name are set, ensure that the template is valid - // and can produce distinct context names for each cluster before proceeding. - if err := kubeconfig.CheckContextOverrideTemplate(c.overrideContextName); err != nil && c.all { - return trace.Wrap(err) + // If --all, --query, or --labels and --set-context-name are set, ensure + // that the template is valid and can produce distinct context names for + // each cluster before proceeding. + if c.all || c.labels != "" || c.predicateExpression != "" { + err := kubeconfig.CheckContextOverrideTemplate(c.overrideContextName) + if err != nil { + return trace.Wrap(err) + } } - // Set CLIConf.KubernetesCluster so that the kube cluster's context is automatically selected. - cf.KubernetesCluster = c.kubeCluster + // NOTE: in case relogin-retry logic is used, we want to avoid having + // cf.KubernetesCluster set because kube cluster selection by prefix name is + // not supported in that flow + // (it's equivalent to tsh login --kube-cluster=). + // We will set that flag later, after listing the kube clusters and choosing + // one by prefix/labels/query (if a cluster name/prefix was given). + cf.Labels = c.labels + cf.PredicateExpression = c.predicateExpression + cf.SiteName = c.siteName cf.kubernetesImpersonationConfig = impersonationConfig{ kubernetesUser: c.impersonateUser, kubernetesGroups: c.impersonateGroups, } cf.kubeNamespace = c.namespace + cf.disableAccessRequest = c.disableAccessRequest + cf.RequestReason = c.requestReason cf.ListAll = c.all - tc, err := makeClient(cf) if err != nil { return trace.Wrap(err) } - // Check that this kube cluster exists. - currentTeleportCluster, kubeClusters, err := fetchKubeClusters(cf.Context, tc) + + var kubeStatus *kubernetesStatus + err = retryWithAccessRequest(cf, tc, func() error { + // Check that this kube cluster exists. + var err error + kubeStatus, err = fetchKubeStatus(cf.Context, tc) + if err != nil { + return trace.Wrap(err) + } + err = c.checkClusterSelection(cf, tc, kubeStatus.kubeClusters) + if err != nil { + if trace.IsNotFound(err) { + // rewrap not found error as access denied, so we can retry + // fetching clusters with an access request. + return trace.AccessDenied(err.Error()) + } + return trace.Wrap(err) + } + return nil + }, c.accessRequestForKubeCluster, c.selectorsOrWildcard()) if err != nil { return trace.Wrap(err) } - clusterNames := kubeClustersToStrings(kubeClusters) - // If the user is trying to login to a specific cluster, check that it exists. - if c.kubeCluster != "" && !slices.Contains(clusterNames, c.kubeCluster) { - return trace.NotFound("kubernetes cluster %q not found, check 'tsh kube ls' for a list of known clusters", c.kubeCluster) - } // Update default kubeconfig file located at ~/.kube/config or the value of // KUBECONFIG env var even if the context exists. - if err := updateKubeConfig(cf, tc, "", c.overrideContextName); err != nil { + if err := updateKubeConfig(cf, tc, "", c.overrideContextName, kubeStatus); err != nil { return trace.Wrap(err) } // Generate a profile specific kubeconfig which can be used // by setting the kubeconfig environment variable (with `tsh env`) profileKubeconfigPath := keypaths.KubeConfigPath( - profile.FullProfilePath(cf.HomePath), tc.WebProxyHost(), tc.Username, currentTeleportCluster, c.kubeCluster, + profile.FullProfilePath(cf.HomePath), tc.WebProxyHost(), tc.Username, kubeStatus.teleportClusterName, c.kubeCluster, ) - if err := updateKubeConfig(cf, tc, profileKubeconfigPath, c.overrideContextName); err != nil { + if err := updateKubeConfig(cf, tc, profileKubeconfigPath, c.overrideContextName, kubeStatus); err != nil { return trace.Wrap(err) } - c.printUserMessage(cf, tc) + c.printUserMessage(cf, tc, kubeClustersToStrings(kubeStatus.kubeClusters)) + return nil +} + +func (c *kubeLoginCommand) selectorsOrWildcard() string { + selectors := c.getSelectors() + if !selectors.IsEmpty() { + return selectors.String() + } + if c.all { + return "*" + } + return "" +} + +// checkClusterSelection checks the kube clusters selected by user input. +func (c *kubeLoginCommand) checkClusterSelection(cf *CLIConf, tc *client.TeleportClient, clusters types.KubeClusters) error { + clusters = matchClustersByName(c.kubeCluster, clusters) + err := checkClusterSelection(cf, clusters, c.kubeCluster) + if err != nil { + return trace.Wrap(err) + } + if c.kubeCluster != "" && len(clusters) == 1 { + // Populate settings using the selected kube cluster, which contains + // the full cluster name. + c.kubeCluster = clusters[0].GetName() + // Set CLIConf.KubernetesCluster so that the kube cluster's context + // is automatically selected. + cf.KubernetesCluster = c.kubeCluster + tc.KubernetesCluster = c.kubeCluster + } return nil } -func (c *kubeLoginCommand) printUserMessage(cf *CLIConf, tc *client.TeleportClient) { +func checkClusterSelection(cf *CLIConf, clusters types.KubeClusters, name string) error { + switch { + // If the user is trying to login to a specific cluster, check that it + // exists and that a cluster matched the name/prefix unambiguously. + case name != "" && len(clusters) == 1: + return nil + // allow multiple selection without a name. + case name == "" && len(clusters) > 0: + return nil + } + + // anything else is an error. + selectors := resourceSelectors{ + kind: "kubernetes cluster", + name: name, + labels: cf.Labels, + query: cf.PredicateExpression, + } + if len(clusters) == 0 { + return trace.NotFound(formatKubeNotFound(cf.SiteName, selectors)) + } + errMsg := formatAmbiguousKubeCluster(cf, selectors, clusters) + return trace.BadParameter(errMsg) +} + +func (c *kubeLoginCommand) getSelectors() resourceSelectors { + return resourceSelectors{ + kind: "kubernetes cluster", + name: c.kubeCluster, + labels: c.labels, + query: c.predicateExpression, + } +} + +func matchClustersByName(nameOrPrefix string, clusters types.KubeClusters) types.KubeClusters { + if nameOrPrefix == "" { + return clusters + } + + // look for an exact full name match. + for _, kc := range clusters { + if kc.GetName() == nameOrPrefix { + return types.KubeClusters{kc} + } + } + + // or look for exact "discovered name" matches. + if clusters, ok := findKubeClustersByDiscoveredName(clusters, nameOrPrefix); ok { + return clusters + } + + // or just filter by prefix. + var prefixMatches types.KubeClusters + for _, kc := range clusters { + if strings.HasPrefix(kc.GetName(), nameOrPrefix) { + prefixMatches = append(prefixMatches, kc) + } + } + return prefixMatches +} + +func findKubeClustersByDiscoveredName(clusters types.KubeClusters, name string) (types.KubeClusters, bool) { + var out types.KubeClusters + for _, kc := range clusters { + discoveredName, ok := kc.GetLabel(types.DiscoveredNameLabel) + if ok && discoveredName == name { + out = append(out, kc) + } + } + return out, len(out) > 0 +} + +func (c *kubeLoginCommand) printUserMessage(cf *CLIConf, tc *client.TeleportClient, names []string) { if tc.Profile().RequireKubeLocalProxy() { - c.printLocalProxyUserMessage(cf) + c.printLocalProxyUserMessage(cf, names) return } - if c.kubeCluster != "" { + switch { + case c.kubeCluster != "": fmt.Fprintf(cf.Stdout(), "Logged into Kubernetes cluster %q. Try 'kubectl version' to test the connection.\n", c.kubeCluster) - } else { + case c.labels != "" || c.predicateExpression != "": + fmt.Fprintf(cf.Stdout(), `Logged into Kubernetes clusters: +%v + +Select a context and try 'kubectl version' to test the connection. +`, strings.Join(names, "\n")) + case c.all: fmt.Fprintf(cf.Stdout(), "Created kubeconfig with every Kubernetes cluster available. Select a context and try 'kubectl version' to test the connection.\n") } } -func (c *kubeLoginCommand) printLocalProxyUserMessage(cf *CLIConf) { +func (c *kubeLoginCommand) printLocalProxyUserMessage(cf *CLIConf, names []string) { switch { case c.kubeCluster != "": fmt.Fprintf(cf.Stdout(), `Logged into Kubernetes cluster %q.`, c.kubeCluster) - - default: + case c.labels != "" || c.predicateExpression != "": + fmt.Fprintf(cf.Stdout(), `Logged into Kubernetes clusters: +%v`, strings.Join(names, "\n")) + case c.all: fmt.Fprintf(cf.Stdout(), "Logged into all Kubernetes clusters available.") } @@ -1356,12 +1537,14 @@ func buildKubeConfigUpdate(cf *CLIConf, kubeStatus *kubernetesStatus, overrideCo clusterNames := kubeClustersToStrings(kubeStatus.kubeClusters) // Validate if cf.KubernetesCluster is part of the returned list of clusters - if cf.KubernetesCluster != "" && !slices.Contains(clusterNames, cf.KubernetesCluster) { - return nil, trace.NotFound("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'.", cf.KubernetesCluster) - } - // If ListAll is not enabled, update only cf.KubernetesCluster cluster. - if cf.KubernetesCluster != "" && !cf.ListAll { - clusterNames = []string{cf.KubernetesCluster} + if cf.KubernetesCluster != "" { + if !slices.Contains(clusterNames, cf.KubernetesCluster) { + return nil, trace.NotFound("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'.", cf.KubernetesCluster) + } + // If ListAll or labels/query is not enabled, update only cf.KubernetesCluster cluster. + if !cf.ListAll && cf.Labels == "" && cf.PredicateExpression == "" { + clusterNames = []string{cf.KubernetesCluster} + } } v.KubeClusters = clusterNames @@ -1389,7 +1572,7 @@ type impersonationConfig struct { // updateKubeConfig adds Teleport configuration to the users's kubeconfig based on the CLI // parameters and the kubernetes services in the current Teleport cluster. If no path for // the kubeconfig is given, it will use environment values or known defaults to get a path. -func updateKubeConfig(cf *CLIConf, tc *client.TeleportClient, path string, overrideContext string) error { +func updateKubeConfig(cf *CLIConf, tc *client.TeleportClient, path, overrideContext string, status *kubernetesStatus) error { // Fetch proxy's advertised ports to check for k8s support. if _, err := tc.Ping(cf.Context); err != nil { return trace.Wrap(err) @@ -1399,16 +1582,11 @@ func updateKubeConfig(cf *CLIConf, tc *client.TeleportClient, path string, overr return nil } - kubeStatus, err := fetchKubeStatus(cf.Context, tc) - if err != nil { - return trace.Wrap(err) - } - if cf.Proxy == "" { cf.Proxy = tc.WebProxyAddr } - values, err := buildKubeConfigUpdate(cf, kubeStatus, overrideContext) + values, err := buildKubeConfigUpdate(cf, status, overrideContext) if err != nil { return trace.Wrap(err) } @@ -1458,3 +1636,89 @@ func init() { clientauthv1beta1.AddToScheme(kubeScheme) clientauthentication.AddToScheme(kubeScheme) } + +// accessRequestForKubeCluster attempts to create a resource access request for the case +// where "tsh kube login" was attempted and access was denied +func (c *kubeLoginCommand) accessRequestForKubeCluster(ctx context.Context, cf *CLIConf, tc *client.TeleportClient) (types.AccessRequest, error) { + clt, err := tc.ConnectToCluster(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + defer clt.Close() + + kubes, err := apiclient.GetAllResources[types.KubeCluster](ctx, clt.AuthClient, &proto.ListResourcesRequest{ + Namespace: apidefaults.Namespace, + ResourceType: types.KindKubernetesCluster, + UseSearchAsRoles: true, + PredicateExpression: tc.PredicateExpression, + Labels: tc.Labels, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := c.checkClusterSelection(cf, tc, kubes); err != nil { + return nil, trace.Wrap(err) + } + + requestResourceIDs := make([]types.ResourceID, len(kubes)) + for i, kube := range kubes { + requestResourceIDs[i] = types.ResourceID{ + ClusterName: tc.SiteName, + Kind: types.KindKubernetesCluster, + Name: kube.GetName(), + } + } + + // Roles to request will be automatically determined on the backend. + req, err := services.NewAccessRequestWithResources(tc.Username, nil /* roles */, requestResourceIDs) + if err != nil { + return nil, trace.Wrap(err) + } + + // Set the DryRun flag and send the request to auth for full validation. If + // the user has no search_as_roles or is not allowed to connect to the Kube cluster + // we will get an error here. + req.SetDryRun(true) + req.SetRequestReason("Dry run, this request will not be created. If you see this, there is a bug.") + if err := tc.WithRootClusterClient(ctx, func(clt auth.ClientI) error { + return trace.Wrap(clt.CreateAccessRequest(ctx, req)) + }); err != nil { + return nil, trace.Wrap(err) + } + req.SetDryRun(false) + req.SetRequestReason("") + + return req, nil +} + +// formatAmbiguousKubeCluster is a helper func that formats an ambiguous kube +// cluster error message. +func formatAmbiguousKubeCluster(cf *CLIConf, selectors resourceSelectors, kubeClusters types.KubeClusters) string { + // dont mark the selected cluster + selectedCluster := "" + // verbose output to show full names and labels + quiet := false + verbose := true + table := formatKubeClustersAsText(kubeClusters, selectedCluster, quiet, verbose) + listCommand := formatKubeListCommand(cf.SiteName) + fullNameExample := kubeClusters[0].GetName() + return formatAmbiguityErrTemplate(cf, selectors, listCommand, table, fullNameExample) +} + +func formatKubeNotFound(clusterFlag string, selectors resourceSelectors) string { + listCmd := formatKubeListCommand(clusterFlag) + if selectors.IsEmpty() { + return fmt.Sprintf("no kubernetes clusters found, check '%v' for a list of known clusters", + listCmd) + } + return fmt.Sprintf("%v not found, check '%v' for a list of known clusters", + selectors, listCmd) +} + +func formatKubeListCommand(clusterFlag string) string { + if clusterFlag == "" { + return "tsh kube ls" + } + return fmt.Sprintf("tsh kube ls --cluster=%v", clusterFlag) +} diff --git a/tool/tsh/kube_proxy.go b/tool/tsh/kube_proxy.go index 72348c6cb09c6..f30667302e758 100644 --- a/tool/tsh/kube_proxy.go +++ b/tool/tsh/kube_proxy.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/auth/native" @@ -58,6 +59,9 @@ type proxyKubeCommand struct { namespace string port string format string + + labels string + predicateExpression string } func newProxyKubeCommand(parent *kingpin.CmdClause) *proxyKubeCommand { @@ -73,10 +77,15 @@ func newProxyKubeCommand(parent *kingpin.CmdClause) *proxyKubeCommand { c.Flag("kube-namespace", "Configure the default Kubernetes namespace.").Short('n').StringVar(&c.namespace) c.Flag("port", "Specifies the source port used by the proxy listener").Short('p').StringVar(&c.port) c.Flag("format", envVarFormatFlagDescription()).Short('f').Default(envVarDefaultFormat()).EnumVar(&c.format, envVarFormats...) + c.Flag("labels", labelHelp).StringVar(&c.labels) + c.Flag("query", queryHelp).StringVar(&c.predicateExpression) return c } func (c *proxyKubeCommand) run(cf *CLIConf) error { + cf.Labels = c.labels + cf.PredicateExpression = c.predicateExpression + cf.SiteName = c.siteName tc, err := makeClient(cf) if err != nil { return trace.Wrap(err) @@ -111,22 +120,36 @@ func (c *proxyKubeCommand) run(cf *CLIConf) error { } func (c *proxyKubeCommand) prepare(cf *CLIConf, tc *client.TeleportClient) (*clientcmdapi.Config, kubeconfig.LocalProxyClusters, error) { - defaultConfig, err := kubeconfig.Load("") + defaultConfig, err := kubeconfig.Load(getKubeConfigPath(cf, "")) if err != nil { return nil, nil, trace.Wrap(err) } // Use kube clusters from arg. - if len(c.kubeClusters) > 0 { - if c.siteName == "" { - c.siteName = tc.SiteName + if len(c.kubeClusters) > 0 || cf.Labels != "" || cf.PredicateExpression != "" { + _, kubeClusters, err := fetchKubeClusters(cf.Context, tc) + if err != nil { + return nil, nil, trace.Wrap(err) + } + switch len(c.kubeClusters) { + case 0: + // if no names are given, check just the labels/predicate selection. + if err := checkClusterSelection(cf, kubeClusters, ""); err != nil { + return nil, nil, trace.Wrap(err) + } + default: + // otherwise, check that each name matches exactly one kube cluster. + matchMap := matchClustersByNames(kubeClusters, c.kubeClusters...) + if err := checkMultipleClusterSelections(cf, matchMap); err != nil { + return nil, nil, trace.Wrap(err) + } + kubeClusters = combineMatchedClusters(matchMap) } - var clusters kubeconfig.LocalProxyClusters - for _, kubeCluster := range c.kubeClusters { + for _, kc := range kubeClusters { clusters = append(clusters, kubeconfig.LocalProxyCluster{ - TeleportCluster: c.siteName, - KubeCluster: kubeCluster, + TeleportCluster: tc.SiteName, + KubeCluster: kc.GetName(), Impersonate: c.impersonateUser, ImpersonateGroups: c.impersonateGroups, Namespace: c.namespace, @@ -524,6 +547,39 @@ func issueKubeCert(ctx context.Context, tc *client.TeleportClient, proxy *client return cert, nil } +// checkMultipleClusterSelections takes a map of name selectors to matched +// clusters and checks that each matching is valid. +func checkMultipleClusterSelections(cf *CLIConf, matchMap map[string]types.KubeClusters) error { + for name, clusters := range matchMap { + err := checkClusterSelection(cf, clusters, name) + if err != nil { + return trace.Wrap(err) + } + } + return nil +} + +// combineMatchedClusters combineMatchedClusters takes a map from name selector +// to matched clusters and combines all the matched clusters into a deduplicated +// slice. +func combineMatchedClusters(matchMap map[string]types.KubeClusters) types.KubeClusters { + var out types.KubeClusters + for _, clusters := range matchMap { + out = append(out, clusters...) + } + return types.DeduplicateKubeClusters(out) +} + +// matchClustersByNames maps each name to the clusters it matches by exact name +// or by prefix. +func matchClustersByNames(clusters types.KubeClusters, names ...string) map[string]types.KubeClusters { + matchesForNames := make(map[string]types.KubeClusters) + for _, name := range names { + matchesForNames[name] = matchClustersByName(name, clusters) + } + return matchesForNames +} + // proxyKubeTemplate is the message that gets printed to a user when a kube proxy is started. var proxyKubeTemplate = template.Must(template.New(""). Funcs(template.FuncMap{ diff --git a/tool/tsh/kube_proxy_test.go b/tool/tsh/kube_proxy_test.go index 3fe13d98b7e8e..423013f0f7df2 100644 --- a/tool/tsh/kube_proxy_test.go +++ b/tool/tsh/kube_proxy_test.go @@ -25,8 +25,10 @@ import ( "path" "strings" "testing" + "time" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -34,9 +36,14 @@ import ( clientcmdapi "k8s.io/client-go/tools/clientcmd/api" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keypaths" "github.com/gravitational/teleport/lib/kube/kubeconfig" + "github.com/gravitational/teleport/lib/service/servicecfg" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/tool/teleport/testenv" ) func (p *kubeTestPack) testProxyKube(t *testing.T) { @@ -86,6 +93,233 @@ func (p *kubeTestPack) testProxyKube(t *testing.T) { }) } +func TestProxyKubeComplexSelectors(t *testing.T) { + testenv.WithInsecureDevMode(t, true) + testenv.WithResyncInterval(t, 0) + kubeFoo := "foo" + kubeFooBar := "foo-bar" + kubeBaz := "baz-qux" + kubeBazEKS := "baz-eks-us-west-1-123456789012" + kubeFooLeaf := "foo" + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.SSH.Enabled = false + cfg.Kube.Enabled = true + cfg.Kube.ListenAddr = utils.MustParseAddr(localListenerAddr()) + cfg.Kube.KubeconfigPath = newKubeConfigFile(t, kubeFoo, kubeFooBar, kubeBaz) + cfg.Kube.StaticLabels = map[string]string{"env": "root"} + cfg.Kube.ResourceMatchers = []services.ResourceMatcher{{ + Labels: map[string]apiutils.Strings{"*": {"*"}}, + }} + }), + withLeafCluster(), + withLeafConfigFunc( + func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.SSH.Enabled = false + cfg.Kube.Enabled = true + cfg.Kube.ListenAddr = utils.MustParseAddr(localListenerAddr()) + cfg.Kube.KubeconfigPath = newKubeConfigFile(t, kubeFooLeaf) + cfg.Kube.StaticLabels = map[string]string{"env": "leaf"} + }, + ), + withValidationFunc(func(s *suite) bool { + rootClusters, err := s.root.GetAuthServer().GetKubernetesServers(ctx) + require.NoError(t, err) + leafClusters, err := s.leaf.GetAuthServer().GetKubernetesServers(ctx) + require.NoError(t, err) + return len(rootClusters) == 3 && len(leafClusters) == 1 + }), + ) + // setup a fake "discovered" kube cluster by adding a discovered name label + // to a dynamic kube cluster. + kc, err := types.NewKubernetesClusterV3( + types.Metadata{ + Name: kubeBazEKS, + Labels: map[string]string{ + types.DiscoveredNameLabel: "baz", + types.OriginLabel: types.OriginDynamic, + }, + }, + types.KubernetesClusterSpecV3{ + Kubeconfig: newKubeConfig(t, kubeBazEKS), + }, + ) + require.NoError(t, err) + err = s.root.GetAuthServer().CreateKubernetesCluster(ctx, kc) + require.NoError(t, err) + require.EventuallyWithT(t, func(c *assert.CollectT) { + servers, err := s.root.GetAuthServer().GetKubernetesServers(ctx) + assert.NoError(c, err) + for _, ks := range servers { + if ks.GetName() == kubeBazEKS { + return + } + } + assert.Fail(c, "kube server not found") + }, time.Second*10, time.Millisecond*500, "failed to find dynamically created kube cluster %v", kubeBazEKS) + + rootClusterName := s.root.Config.Auth.ClusterName.GetClusterName() + leafClusterName := s.leaf.Config.Auth.ClusterName.GetClusterName() + + tests := []struct { + desc string + makeValidateCmdFn func(*testing.T) func(*exec.Cmd) error + args []string + wantErr string + }{ + { + desc: "with full name", + makeValidateCmdFn: func(t *testing.T) func(*exec.Cmd) error { + return func(cmd *exec.Cmd) error { + config := kubeConfigFromCmdEnv(t, cmd) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeFoo) + return nil + } + }, + args: []string{kubeFoo, "--insecure"}, + }, + { + desc: "with discovered name", + makeValidateCmdFn: func(t *testing.T) func(*exec.Cmd) error { + return func(cmd *exec.Cmd) error { + config := kubeConfigFromCmdEnv(t, cmd) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeBazEKS) + return nil + } + }, + args: []string{"baz", "--insecure"}, + }, + { + desc: "with prefix name", + makeValidateCmdFn: func(t *testing.T) func(*exec.Cmd) error { + return func(cmd *exec.Cmd) error { + config := kubeConfigFromCmdEnv(t, cmd) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeFooBar) + return nil + } + }, + args: []string{"foo-b", "--insecure"}, + }, + { + desc: "with labels", + makeValidateCmdFn: func(t *testing.T) func(*exec.Cmd) error { + return func(cmd *exec.Cmd) error { + config := kubeConfigFromCmdEnv(t, cmd) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeFoo) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeFooBar) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeBaz) + return nil + } + }, + args: []string{"--labels", "env=root", "--insecure"}, + }, + { + desc: "with query", + makeValidateCmdFn: func(t *testing.T) func(*exec.Cmd) error { + return func(cmd *exec.Cmd) error { + config := kubeConfigFromCmdEnv(t, cmd) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeFoo) + return nil + } + }, + args: []string{"--query", `labels["env"]=="root"`, "--insecure"}, + }, + { + desc: "with labels, query, and prefix", + makeValidateCmdFn: func(t *testing.T) func(*exec.Cmd) error { + return func(cmd *exec.Cmd) error { + config := kubeConfigFromCmdEnv(t, cmd) + checkKubeLocalProxyConfig(t, s, config, rootClusterName, kubeFoo) + return nil + } + }, + args: []string{ + "--labels", "env=root", + "--query", `name == "foo"`, + "f", // prefix of "foo". + "--insecure", + }, + }, + { + desc: "in leaf cluster with prefix name", + makeValidateCmdFn: func(t *testing.T) func(*exec.Cmd) error { + return func(cmd *exec.Cmd) error { + config := kubeConfigFromCmdEnv(t, cmd) + checkKubeLocalProxyConfig(t, s, config, leafClusterName, kubeFooLeaf) + return nil + } + }, + args: []string{ + "--cluster", leafClusterName, + "--insecure", + "f", // prefix of "foo" kube cluster in leaf teleport cluster. + }, + }, + { + desc: "ambiguous name prefix is an error", + args: []string{ + "f", // prefix of foo, foo-bar in root cluster. + "--insecure", + }, + wantErr: `kubernetes cluster "f" matches multiple`, + }, + { + desc: "zero name matches is an error", + args: []string{ + "xxx", + "--insecure", + }, + wantErr: `kubernetes cluster "xxx" not found`, + }, + { + desc: "zero label matches is an error", + args: []string{ + "--labels", "env=nonexistent", + "--insecure", + }, + wantErr: `kubernetes cluster with labels "env=nonexistent" not found`, + }, + { + desc: "zero query matches is an error", + args: []string{ + "--query", `labels["env"]=="nonexistent"`, + "--insecure", + }, + wantErr: `kubernetes cluster with query (labels["env"]=="nonexistent") not found`, + }, + } + + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + // login for each parallel test to avoid races when multiple tsh + // clients work in the same profile dir. + tshHome, _ := mustLogin(t, s) + // Set kubeconfig to a non-exist file to avoid loading other things. + kubeConfigPath := path.Join(tshHome, "kube-config") + var cmdRunner func(*exec.Cmd) error + if test.makeValidateCmdFn != nil { + cmdRunner = test.makeValidateCmdFn(t) + } + err := Run(ctx, append([]string{"proxy", "kube", "--port", ports.Pop()}, test.args...), + setCmdRunner(cmdRunner), + setHomePath(tshHome), + setKubeConfigPath(kubeConfigPath), + ) + if test.wantErr != "" { + require.ErrorContains(t, err, test.wantErr) + return + } + require.NoError(t, err) + }) + } +} + func kubeConfigFromCmdEnv(t *testing.T, cmd *exec.Cmd) *clientcmdapi.Config { t.Helper() @@ -118,6 +352,9 @@ func sendRequestToKubeLocalProxy(t *testing.T, config *clientcmdapi.Config, tele contextName := kubeconfig.ContextName(teleportCluster, kubeCluster) + require.NotNil(t, config) + require.NotNil(t, config.Clusters) + require.Contains(t, config.Clusters, contextName) proxyURL, err := url.Parse(config.Clusters[contextName].ProxyURL) require.NoError(t, err) diff --git a/tool/tsh/kube_test.go b/tool/tsh/kube_test.go index 231107c0da5d5..d4f5007c64901 100644 --- a/tool/tsh/kube_test.go +++ b/tool/tsh/kube_test.go @@ -21,20 +21,31 @@ import ( "context" "fmt" "path/filepath" - "sort" + "reflect" "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + "golang.org/x/sync/errgroup" "k8s.io/client-go/tools/clientcmd" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" + "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keypaths" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/kube/kubeconfig" kubeserver "github.com/gravitational/teleport/lib/kube/proxy/testing/kube_server" + "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/tool/common" + "github.com/gravitational/teleport/tool/teleport/testenv" ) func TestKube(t *testing.T) { @@ -54,8 +65,6 @@ type kubeTestPack struct { rootKubeCluster1 string rootKubeCluster2 string leafKubeCluster string - serviceLabels map[string]string - formatedLabels string } func setupKubeTestPack(t *testing.T) *kubeTestPack { @@ -64,12 +73,18 @@ func setupKubeTestPack(t *testing.T) *kubeTestPack { ctx := context.Background() rootKubeCluster1 := "root-cluster" rootKubeCluster2 := "first-cluster" - leafKubeCluster := "leaf-cluster" - serviceLabels := map[string]string{ + // mock a discovered kube cluster name in the leaf Teleport cluster. + leafKubeCluster := "leaf-cluster-some-suffix-added-by-discovery-service" + rootLabels := map[string]string{ "label1": "val1", "ultra_long_label_for_teleport_kubernetes_service_list_kube_clusters_method": "ultra_long_label_value_for_teleport_kubernetes_service_list_kube_clusters_method", } - formatedLabels := formatServiceLabels(serviceLabels) + leafLabels := map[string]string{ + "label1": "val1", + "ultra_long_label_for_teleport_kubernetes_service_list_kube_clusters_method": "ultra_long_label_value_for_teleport_kubernetes_service_list_kube_clusters_method", + // mock a discovered kube cluster in the leaf Teleport cluster. + types.DiscoveredNameLabel: "leaf-cluster", + } s := newTestSuite(t, withRootConfigFunc(func(cfg *servicecfg.Config) { @@ -77,7 +92,7 @@ func setupKubeTestPack(t *testing.T) *kubeTestPack { cfg.Kube.Enabled = true cfg.Kube.ListenAddr = utils.MustParseAddr(localListenerAddr()) cfg.Kube.KubeconfigPath = newKubeConfigFile(t, rootKubeCluster1, rootKubeCluster2) - cfg.Kube.StaticLabels = serviceLabels + cfg.Kube.StaticLabels = rootLabels }), withLeafCluster(), withLeafConfigFunc( @@ -86,6 +101,7 @@ func setupKubeTestPack(t *testing.T) *kubeTestPack { cfg.Kube.Enabled = true cfg.Kube.ListenAddr = utils.MustParseAddr(localListenerAddr()) cfg.Kube.KubeconfigPath = newKubeConfigFile(t, leafKubeCluster) + cfg.Kube.StaticLabels = leafLabels }, ), withValidationFunc(func(s *suite) bool { @@ -105,12 +121,18 @@ func setupKubeTestPack(t *testing.T) *kubeTestPack { rootKubeCluster1: rootKubeCluster1, rootKubeCluster2: rootKubeCluster2, leafKubeCluster: leafKubeCluster, - serviceLabels: serviceLabels, - formatedLabels: formatedLabels, } } func (p *kubeTestPack) testListKube(t *testing.T) { + staticRootLabels := p.suite.root.Config.Kube.StaticLabels + formattedRootLabels := common.FormatLabels(staticRootLabels, false) + formattedRootLabelsVerbose := common.FormatLabels(staticRootLabels, true) + + staticLeafLabels := p.suite.leaf.Config.Kube.StaticLabels + formattedLeafLabels := common.FormatLabels(staticLeafLabels, false) + formattedLeafLabelsVerbose := common.FormatLabels(staticLeafLabels, true) + tests := []struct { name string args []string @@ -124,7 +146,10 @@ func (p *kubeTestPack) testListKube(t *testing.T) { // p.rootKubeCluster1 ("root-cluster") after sorting. table := asciitable.MakeTableWithTruncatedColumn( []string{"Kube Cluster Name", "Labels", "Selected"}, - [][]string{{p.rootKubeCluster2, p.formatedLabels, ""}, {p.rootKubeCluster1, p.formatedLabels, ""}}, + [][]string{ + {p.rootKubeCluster2, formattedRootLabels, ""}, + {p.rootKubeCluster1, formattedRootLabels, ""}, + }, "Labels") return table.AsBuffer().String() }, @@ -135,8 +160,8 @@ func (p *kubeTestPack) testListKube(t *testing.T) { wantTable: func() string { table := asciitable.MakeTable( []string{"Kube Cluster Name", "Labels", "Selected"}, - []string{p.rootKubeCluster2, p.formatedLabels, ""}, - []string{p.rootKubeCluster1, p.formatedLabels, ""}) + []string{p.rootKubeCluster2, formattedRootLabelsVerbose, ""}, + []string{p.rootKubeCluster1, formattedRootLabelsVerbose, ""}) return table.AsBuffer().String() }, }, @@ -145,8 +170,8 @@ func (p *kubeTestPack) testListKube(t *testing.T) { args: []string{"--quiet"}, wantTable: func() string { table := asciitable.MakeHeadlessTable(2) - table.AddRow([]string{p.rootKubeCluster2, p.formatedLabels, ""}) - table.AddRow([]string{p.rootKubeCluster1, p.formatedLabels, ""}) + table.AddRow([]string{p.rootKubeCluster2, formattedRootLabels, ""}) + table.AddRow([]string{p.rootKubeCluster1, formattedRootLabels, ""}) return table.AsBuffer().String() }, @@ -154,17 +179,47 @@ func (p *kubeTestPack) testListKube(t *testing.T) { { name: "list all clusters including leaf clusters", args: []string{"--all"}, + wantTable: func() string { + table := asciitable.MakeTableWithTruncatedColumn( + []string{"Proxy", "Cluster", "Kube Cluster Name", "Labels"}, + [][]string{ + // "leaf-cluster" should be displayed instead of the + // full leaf cluster name, since it is mocked as a + // discovered resource and the discovered resource name + // is displayed in non-verbose mode. + {p.root.Config.Proxy.WebAddr.String(), "leaf1", "leaf-cluster", formattedLeafLabels}, + {p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster2, formattedRootLabels}, + {p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster1, formattedRootLabels}, + }, + "Labels", + ) + return table.AsBuffer().String() + }, + }, + { + name: "list all clusters including leaf clusters with complete list of labels", + args: []string{"--all", "--verbose"}, wantTable: func() string { table := asciitable.MakeTable( []string{"Proxy", "Cluster", "Kube Cluster Name", "Labels"}, - - []string{p.root.Config.Proxy.WebAddr.String(), "leaf1", p.leafKubeCluster, ""}, - []string{p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster2, p.formatedLabels}, - []string{p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster1, p.formatedLabels}, + []string{p.root.Config.Proxy.WebAddr.String(), "leaf1", p.leafKubeCluster, formattedLeafLabelsVerbose}, + []string{p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster2, formattedRootLabelsVerbose}, + []string{p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster1, formattedRootLabelsVerbose}, ) return table.AsBuffer().String() }, }, + { + name: "list all clusters including leaf clusters in headless table", + args: []string{"--all", "--quiet"}, + wantTable: func() string { + table := asciitable.MakeHeadlessTable(4) + table.AddRow([]string{p.root.Config.Proxy.WebAddr.String(), "leaf1", "leaf-cluster", formattedLeafLabels}) + table.AddRow([]string{p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster2, formattedRootLabels}) + table.AddRow([]string{p.root.Config.Proxy.WebAddr.String(), "root", p.rootKubeCluster1, formattedRootLabels}) + return table.AsBuffer().String() + }, + }, } for _, tc := range tests { @@ -183,13 +238,315 @@ func (p *kubeTestPack) testListKube(t *testing.T) { tc.args..., ), setCopyStdout(captureStdout), + + // set a custom empty kube config for each test, as we do + // not want parallel (or even shuffled sequential) tests + // potentially racing on the same config + setKubeConfigPath(filepath.Join(t.TempDir(), "kubeconfig")), + ) + require.NoError(t, err) + got := strings.TrimSpace(captureStdout.String()) + want := strings.TrimSpace(tc.wantTable()) + diff := cmp.Diff(want, got) + require.Empty(t, diff) + }) + } +} + +func TestKubeLogin(t *testing.T) { + modules.SetTestModules(t, + &modules.TestModules{ + TestBuildType: modules.BuildEnterprise, + TestFeatures: modules.Features{ + Kubernetes: true, + }, + }, + ) + testenv.WithInsecureDevMode(t, true) + testenv.WithResyncInterval(t, 0) + t.Run("complex filters", testKubeLoginWithFilters) + t.Run("access request", testKubeLoginAccessRequest) +} + +func testKubeLoginWithFilters(t *testing.T) { + t.Parallel() + ctx := context.Background() + kubeFoo := "foo" + kubeFooBar := "foo-bar" + kubeBaz := "baz" + staticLabels := map[string]string{ + "env": "root", + } + allKubes := []string{kubeFoo, kubeFooBar, kubeBaz} + + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.Kube.Enabled = true + cfg.Kube.ListenAddr = utils.MustParseAddr(localListenerAddr()) + cfg.Kube.KubeconfigPath = newKubeConfigFile(t, allKubes...) + cfg.Kube.StaticLabels = staticLabels + }), + withValidationFunc(func(s *suite) bool { + rootClusters, err := s.root.GetAuthServer().GetKubernetesServers(ctx) + require.NoError(t, err) + return len(rootClusters) == 3 + }), + ) + + tests := []struct { + desc string + args []string + wantLoggedIn []string + wantSelected string + wantErrContains string + }{ + { + desc: "login with exact name and set current context", + args: []string{"foo"}, + wantLoggedIn: []string{"foo"}, + wantSelected: "foo", + }, + { + desc: "login with prefix name and set current context", + args: []string{"foo-b"}, + wantLoggedIn: []string{"foo-bar"}, + wantSelected: "foo-bar", + }, + { + desc: "login with all", + args: []string{"--all"}, + wantLoggedIn: []string{"foo", "foo-bar", "baz"}, + wantSelected: "", + }, + { + desc: "login with labels", + args: []string{"--labels", "env=root"}, + wantLoggedIn: []string{"foo", "foo-bar", "baz"}, + wantSelected: "", + }, + { + desc: "login with query", + args: []string{"--query", `name == "foo"`}, + wantLoggedIn: []string{"foo"}, + wantSelected: "", + }, + { + desc: "login to multiple with all and set current context by name", + args: []string{"foo", "--all"}, + wantLoggedIn: []string{"foo", "foo-bar", "baz"}, + wantSelected: "foo", + }, + { + desc: "login to multiple with labels and set current context by name", + args: []string{"foo", "--labels", "env=root"}, + wantLoggedIn: []string{"foo", "foo-bar", "baz"}, + wantSelected: "foo", + }, + { + desc: "login to multiple with query and set current context by prefix name", + args: []string{"foo-b", "--query", `name == "foo-bar" || name == "foo"`}, + wantLoggedIn: []string{"foo", "foo-bar"}, + wantSelected: "foo-bar", + }, + { + desc: "all with labels is an error", + args: []string{"xxx", "--all", "--labels", `env=root`}, + wantErrContains: "cannot use", + }, + { + desc: "all with query is an error", + args: []string{"xxx", "--all", "--query", `name == "foo-bar" || name == "foo"`}, + wantErrContains: "cannot use", + }, + { + desc: "missing required args is an error", + args: []string{}, + wantErrContains: "required", + }, + } + + tshHome, _ := mustLogin(t, s) + webProxyAddr, err := utils.ParseAddr(s.root.Config.Proxy.WebAddr.String()) + require.NoError(t, err) + // profile kube config path depends on web proxy host + webProxyHost := webProxyAddr.Host() + + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + // clone the login dir for each parallel test to avoid profile kube config file races. + tshHome := mustCloneTempDir(t, tshHome) + kubeConfigPath := filepath.Join(t.TempDir(), "kubeconfig") + err := Run( + context.Background(), + append([]string{ + "--insecure", + "kube", + "login", + }, + test.args..., + ), + setHomePath(tshHome), + // set a custom empty kube config for each test, as we do + // not want parallel (or even shuffled sequential) tests + // potentially racing on the same config + setKubeConfigPath(kubeConfigPath), ) + if test.wantErrContains != "" { + require.ErrorContains(t, err, test.wantErrContains) + return + } + require.NoError(t, err) + + config, err := kubeconfig.Load(kubeConfigPath) require.NoError(t, err) - require.Contains(t, captureStdout.String(), tc.wantTable()) + if test.wantSelected == "" { + require.Empty(t, config.CurrentContext) + } else { + require.Equal(t, kubeconfig.ContextName("root", test.wantSelected), config.CurrentContext) + } + for _, name := range allKubes { + contextName := kubeconfig.ContextName("root", name) + if !slices.Contains(test.wantLoggedIn, name) { + require.NotContains(t, config.AuthInfos, contextName, "unexpected kube cluster %v in config update", name) + return + } + require.Contains(t, config.AuthInfos, contextName, "kube cluster %v not in config update", name) + authInfo := config.AuthInfos[contextName] + require.NotNil(t, authInfo) + require.Contains(t, authInfo.Exec.Args, fmt.Sprintf("--kube-cluster=%v", name)) + } + + // ensure the profile config only contains one + profileKubeConfigPath := keypaths.KubeConfigPath( + profile.FullProfilePath(tshHome), + webProxyHost, + s.user.GetName(), + s.root.Config.Auth.ClusterName.GetClusterName(), + test.wantSelected, + ) + profileConfig, err := kubeconfig.Load(profileKubeConfigPath) + require.NoError(t, err) + for _, name := range allKubes { + contextName := kubeconfig.ContextName("root", name) + if name != test.wantSelected { + require.NotContains(t, profileConfig.AuthInfos, contextName, "unexpected kube cluster %v in profile config update", name) + return + } + require.Contains(t, profileConfig.AuthInfos, contextName, "kube cluster %v not in profile config update", name) + authInfo := profileConfig.AuthInfos[contextName] + require.NotNil(t, authInfo) + require.Contains(t, authInfo.Exec.Args, fmt.Sprintf("--kube-cluster=%v", name)) + } }) } } +func testKubeLoginAccessRequest(t *testing.T) { + t.Parallel() + const ( + roleName = "requester" + kubeCluster = "root-cluster" + ) + // Create a role that allows the user to request access to the cluster but + // not to access it directly. + role, err := types.NewRole( + roleName, + types.RoleSpecV6{ + Allow: types.RoleConditions{ + Request: &types.AccessRequestConditions{ + SearchAsRoles: []string{"access"}, + }, + }, + }, + ) + require.NoError(t, err) + + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + // reconfig the user to use the new role instead of the default ones + // User is the second bootstrap resource. + user, ok := cfg.Auth.BootstrapResources[1].(types.User) + require.True(t, ok) + user.SetRoles([]string{roleName}) + // Add the role to the list of bootstrap resources. + cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, role) + + // Enable kube and set the kubeconfig path. + cfg.Kube.Enabled = true + cfg.Kube.ListenAddr = utils.MustParseAddr(localListenerAddr()) + cfg.Kube.KubeconfigPath = newKubeConfigFile(t, kubeCluster) + }), + withValidationFunc(func(s *suite) bool { + // Check if the kube cluster was added. + rootClusters, err := s.root.GetAuthServer().GetKubernetesServers(context.Background()) + require.NoError(t, err) + return len(rootClusters) == 1 + }), + ) + // login as the user. + tshHome, kubeConfig := mustLogin(t, s) + + // Run the login command in a goroutine so we can check if the access + // request was created and approved. + // The goroutine will exit when the access request is approved. + wg := &errgroup.Group{} + wg.Go(func() error { + err := Run( + context.Background(), + []string{ + "--insecure", + "kube", + "login", + // use a prefix of the kube cluster name + "root-c", + "--request-reason", + "test", + }, + setHomePath(tshHome), + setKubeConfigPath(kubeConfig), + ) + return trace.Wrap(err) + }) + // Wait for the access request to be created and finally approve it. + var accessRequestID string + require.Eventually(t, func() bool { + accessRequests, err := s.root.GetAuthServer(). + GetAccessRequests( + context.Background(), + types.AccessRequestFilter{State: types.RequestState_PENDING}, + ) + if err != nil || len(accessRequests) != 1 { + return false + } + + equal := reflect.DeepEqual( + accessRequests[0].GetRequestedResourceIDs(), + []types.ResourceID{ + { + ClusterName: s.root.Config.Auth.ClusterName.GetClusterName(), + Kind: types.KindKubernetesCluster, + Name: kubeCluster, + }, + }, + ) + accessRequestID = accessRequests[0].GetName() + + return equal + }, 10*time.Second, 500*time.Millisecond) + // Approve the access request to release the login command lock. + err = s.root.GetAuthServer().SetAccessRequestState(context.Background(), types.AccessRequestUpdate{ + RequestID: accessRequestID, + State: types.RequestState_APPROVED, + }) + require.NoError(t, err) + // Wait for the login command to exit after the request is approved + require.NoError(t, wg.Wait()) +} + func newKubeConfigFile(t *testing.T, clusterNames ...string) string { tmpDir := t.TempDir() @@ -212,14 +569,23 @@ func newKubeConfigFile(t *testing.T, clusterNames ...string) string { return kubeConfigLocation } -func formatServiceLabels(labels map[string]string) string { - labelSlice := make([]string, 0, len(labels)) - for key, value := range labels { - labelSlice = append(labelSlice, fmt.Sprintf("%s=%s", key, value)) +func newKubeConfig(t *testing.T, name string) []byte { + kubeConf := clientcmdapi.NewConfig() + + kubeConf.Clusters[name] = &clientcmdapi.Cluster{ + Server: newKubeSelfSubjectServer(t), + InsecureSkipTLSVerify: true, } + kubeConf.AuthInfos[name] = &clientcmdapi.AuthInfo{} - sort.Strings(labelSlice) - return strings.Join(labelSlice, ",") + kubeConf.Contexts[name] = &clientcmdapi.Context{ + Cluster: name, + AuthInfo: name, + } + + buf, err := clientcmd.Write(*kubeConf) + require.NoError(t, err) + return buf } func newKubeSelfSubjectServer(t *testing.T) string { diff --git a/tool/tsh/kubectl.go b/tool/tsh/kubectl.go index 49618782827e7..adebbf80ef12f 100644 --- a/tool/tsh/kubectl.go +++ b/tool/tsh/kubectl.go @@ -379,7 +379,11 @@ func getKubeClusterName(args []string, teleportClusterName string) (string, erro kubeName, err := kubeconfig.SelectedKubeCluster(kubeconfigLocation, teleportClusterName) return kubeName, trace.Wrap(err) } - kubeName := kubeconfig.KubeClusterFromContext(selectedContext, teleportClusterName) + kc, err := kubeconfig.Load(kubeconfigLocation) + if err != nil { + return "", trace.Wrap(err) + } + kubeName := kubeconfig.KubeClusterFromContext(selectedContext, kc.Contexts[selectedContext], teleportClusterName) if kubeName == "" { return "", trace.BadParameter("selected context %q does not belong to Teleport cluster %q", selectedContext, teleportClusterName) } @@ -497,7 +501,6 @@ func shouldUseKubeLocalProxy(cf *CLIConf, kubectlArgs []string) (*clientcmdapi.C return nil, nil, false } return defaultConfig, kubeconfig.LocalProxyClusters{kubeCluster}, true - } func isKubectlConfigCommand(kubectlCommand *cobra.Command, args []string) bool { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index c5668837c729d..ee2d1416b13ec 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -365,7 +365,11 @@ func onProxyCommandDB(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := getDatabaseInfo(cf, tc) + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + dbInfo, err := getDatabaseInfo(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -492,7 +496,7 @@ func onProxyCommandDB(cf *CLIConf) error { func maybeAddDBUserPassword(cf *CLIConf, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { if dbInfo.Protocol == defaults.ProtocolCassandra { - db, err := dbInfo.GetDatabase(cf, tc) + db, err := dbInfo.GetDatabase(cf.Context, tc) if err != nil { return nil, trace.Wrap(err) } @@ -901,7 +905,7 @@ var dbProxyTpl = template.Must(template.New("").Parse(`Started DB proxy on {{.ad {{if .randomPort}}To avoid port randomization, you can choose the listening port using the --port flag. {{end}} ` + dbProxyConnectAd + ` -Use following credentials to connect to the {{.database}} proxy: +Use the following credentials to connect to the {{.database}} proxy: ca_file={{.ca}} cert_file={{.cert}} key_file={{.key}} diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 448fff8647fbd..be3e98b5a1d8b 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -179,7 +179,7 @@ type CLIConf struct { ProxyJump string // --local flag for ssh LocalExec bool - // SiteName specifies remote site go login to + // SiteName specifies remote site to login to. SiteName string // KubernetesCluster specifies the kubernetes cluster to login to. KubernetesCluster string @@ -609,7 +609,12 @@ func initLogger(cf *CLIConf) { } } -// Run executes TSH client. same as main() but easier to test +// Run executes TSH client. same as main() but easier to test. Note that this +// function modifies global state in `tsh` (e.g. the system logger), and WILL +// ALSO MODIFY EXTERNAL SHARED STATE in its default configuration (e.g. the +// $HOME/.tsh dir, $KUBECONFIG, etc). +// +// DO NOT RUN TESTS that call Run() in parallel (unless you taken precautions). func Run(ctx context.Context, args []string, opts ...cliOption) error { cf := CLIConf{ Context: ctx, @@ -2795,23 +2800,15 @@ func formatUsersForDB(database types.Database, accessChecker services.AccessChec return fmt.Sprintf("%v, except: %v", dbUsers.Allowed, dbUsers.Denied) } -func getDiscoveredName(r types.ResourceWithLabels) (string, bool) { - name, ok := r.GetAllLabels()[types.DiscoveredNameLabel] - return name, ok -} - func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, active []tlsca.RouteToDatabase, accessChecker services.AccessChecker, verbose bool) []string { name := database.GetName() - printName := name - if d, ok := getDiscoveredName(database); ok && !verbose && d != name { - printName = d - } + displayName := common.FormatResourceName(database, verbose) var connect string for _, a := range active { if a.ServiceName == name { - a.ServiceName = printName - // format the db name with the print name - printName = formatActiveDB(a) + a.ServiceName = displayName + // format the db name with the display name + displayName = formatActiveDB(a) // then revert it for connect string a.ServiceName = name switch a.Protocol { @@ -2833,7 +2830,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, labels := common.FormatLabels(database.GetAllLabels(), verbose) if verbose { row = append(row, - printName, + displayName, database.GetDescription(), database.GetProtocol(), database.GetType(), @@ -2844,7 +2841,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, ) } else { row = append(row, - printName, + displayName, database.GetDescription(), formatUsersForDB(database, accessChecker), labels, @@ -3029,7 +3026,10 @@ func serializeClusters(rootCluster clusterInfo, leafClusters []clusterInfo, form // accessRequestForSSH attempts to create a resource access request for the case // where "tsh ssh" was attempted and access was denied -func accessRequestForSSH(ctx context.Context, tc *client.TeleportClient) (types.AccessRequest, error) { +func accessRequestForSSH(ctx context.Context, _ *CLIConf, tc *client.TeleportClient) (types.AccessRequest, error) { + if tc.Host == "" { + return nil, trace.BadParameter("no host specified") + } clt, err := tc.ConnectToCluster(ctx) if err != nil { return nil, trace.Wrap(err) @@ -3091,19 +3091,25 @@ func accessRequestForSSH(ctx context.Context, tc *client.TeleportClient) (types. return req, nil } -func retryWithAccessRequest(cf *CLIConf, tc *client.TeleportClient, fn func() error) error { +func retryWithAccessRequest( + cf *CLIConf, + tc *client.TeleportClient, + fn func() error, + onAccessRequestCreator func(ctx context.Context, cf *CLIConf, tc *client.TeleportClient) (types.AccessRequest, error), + resource string, +) error { origErr := fn() - if cf.disableAccessRequest || !trace.IsAccessDenied(origErr) || tc.Host == "" { + if cf.disableAccessRequest || !trace.IsAccessDenied(origErr) { // Return if --disable-access-request was specified. // Return the original error if it's not AccessDenied. // Quit now if we don't have a hostname. return trace.Wrap(origErr) } - // Try to construct an access request for this node. - req, err := accessRequestForSSH(cf.Context, tc) + // Try to construct an access request for this resource. + req, err := onAccessRequestCreator(cf.Context, cf, tc) if err != nil { - // We can't request access to the node or we couldn't query the ID. Log + // We can't request access to the resource or we couldn't query the ID. Log // a short debug message in case this is unexpected, but return the // original AccessDenied error from the ssh attempt which is likely to // be far more relevant to the user. @@ -3114,7 +3120,7 @@ func retryWithAccessRequest(cf *CLIConf, tc *client.TeleportClient, fn func() er // Print and log the original AccessDenied error. fmt.Fprintln(os.Stderr, utils.UserMessageFromError(origErr)) - fmt.Fprintf(os.Stdout, "You do not currently have access to %s@%s, attempting to request access.\n\n", tc.HostLogin, tc.Host) + fmt.Fprintf(os.Stdout, "You do not currently have access to %q, attempting to request access.\n\n", resource) requestReason := cf.RequestReason if requestReason == "" { @@ -3205,7 +3211,10 @@ func onSSH(cf *CLIConf) error { return trace.Wrap(err) } return nil - }) + }, + accessRequestForSSH, + fmt.Sprintf("%s@%s", tc.HostLogin, tc.Host), + ) // Exit with the same exit status as the failed command. if tc.ExitStatus != 0 { var exitErr *common.ExitCodeError @@ -4176,6 +4185,7 @@ func makeProfileInfo(p *client.ProfileStatus, env map[string]string, isActive bo } } + selectedKubeCluster, _ := kubeconfig.SelectedKubeCluster("", p.Cluster) out := &profileInfo{ ProxyURL: p.ProxyURL.String(), Username: p.Username, @@ -4185,7 +4195,7 @@ func makeProfileInfo(p *client.ProfileStatus, env map[string]string, isActive bo Traits: p.Traits, Logins: logins, KubernetesEnabled: p.KubeEnabled, - KubernetesCluster: selectedKubeCluster(p.Cluster), + KubernetesCluster: selectedKubeCluster, KubernetesUsers: p.KubeUsers, KubernetesGroups: p.KubeGroups, Databases: p.DatabaseServices(), @@ -4652,7 +4662,7 @@ func onEnvironment(cf *CLIConf) error { fmt.Printf("unset %v\n", kubeClusterEnvVar) fmt.Printf("unset %v\n", teleport.EnvKubeConfig) case !cf.unsetEnvironment: - kubeName := selectedKubeCluster(profile.Cluster) + kubeName, _ := kubeconfig.SelectedKubeCluster("", profile.Cluster) fmt.Printf("export %v=%v\n", proxyEnvVar, profile.ProxyURL.Host) fmt.Printf("export %v=%v\n", clusterEnvVar, profile.Cluster) if kubeName != "" { @@ -4677,7 +4687,7 @@ func serializeEnvironment(profile *client.ProfileStatus, format string) (string, proxyEnvVar: profile.ProxyURL.Host, clusterEnvVar: profile.Cluster, } - kubeName := selectedKubeCluster(profile.Cluster) + kubeName, _ := kubeconfig.SelectedKubeCluster("", profile.Cluster) if kubeName != "" { env[kubeClusterEnvVar] = kubeName env[teleport.EnvKubeConfig] = profile.KubeConfigPath(kubeName) @@ -4854,7 +4864,15 @@ func updateKubeConfigOnLogin(cf *CLIConf, tc *client.TeleportClient, opts ...upd if len(cf.KubernetesCluster) == 0 { return nil } - err := updateKubeConfig(cf, tc, "" /* update the default kubeconfig */, "" /* do not override the context name */) + kubeStatus, err := fetchKubeStatus(cf.Context, tc) + if err != nil { + return trace.Wrap(err) + } + // update the default kubeconfig + kubeConfigPath := "" + // do not override the context name + overrideContextName := "" + err = updateKubeConfig(cf, tc, kubeConfigPath, overrideContextName, kubeStatus) return trace.Wrap(err) }