From 7a54cb39c7e2b37040771cf96a8ab5ebae49217c Mon Sep 17 00:00:00 2001
From: Andrew Lytvynov <andrew@gravitational.com>
Date: Fri, 18 Sep 2020 14:37:45 -0700
Subject: [PATCH] tsh: print kubernetes info in profile status

Print when k8s support is detected, and if so what users/groups are
used.
---
 lib/client/api.go | 44 +++++++++++++++++++++++++++++++++++---------
 tool/tsh/tsh.go   | 29 ++++++++++++++++++++---------
 2 files changed, 55 insertions(+), 18 deletions(-)

diff --git a/lib/client/api.go b/lib/client/api.go
index 70bce7f94daca..d09ebf0aeb052 100644
--- a/lib/client/api.go
+++ b/lib/client/api.go
@@ -53,6 +53,7 @@ import (
 	"github.com/gravitational/teleport/lib/session"
 	"github.com/gravitational/teleport/lib/shell"
 	"github.com/gravitational/teleport/lib/sshutils/scp"
+	"github.com/gravitational/teleport/lib/tlsca"
 	"github.com/gravitational/teleport/lib/utils"
 	"github.com/gravitational/teleport/lib/utils/agentconn"
 	"github.com/gravitational/teleport/lib/wrappers"
@@ -300,6 +301,16 @@ type ProfileStatus struct {
 	// Logins are the Linux accounts, also known as principals in OpenSSH terminology.
 	Logins []string
 
+	// KubeEnabled is true when this profile is configured to connect to a
+	// kubernetes cluster.
+	KubeEnabled bool
+
+	// KubeUsers are the kubernetes users used by this profile.
+	KubeUsers []string
+
+	// KubeGroups are the kubernetes groups used by this profile.
+	KubeGroups []string
+
 	// ValidUntil is the time at which this SSH certificate will expire.
 	ValidUntil time.Time
 
@@ -376,26 +387,26 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
 	if err != nil {
 		return nil, trace.Wrap(err)
 	}
-	keys, err := store.GetKey(profile.Name(), profile.Username)
+	key, err := store.GetKey(profile.Name(), profile.Username)
 	if err != nil {
 		return nil, trace.Wrap(err)
 	}
-	publicKey, _, _, _, err := ssh.ParseAuthorizedKey(keys.Cert)
+	publicKey, _, _, _, err := ssh.ParseAuthorizedKey(key.Cert)
 	if err != nil {
 		return nil, trace.Wrap(err)
 	}
-	cert, ok := publicKey.(*ssh.Certificate)
+	sshCert, ok := publicKey.(*ssh.Certificate)
 	if !ok {
 		return nil, trace.BadParameter("no certificate found")
 	}
 
 	// Extract from the certificate how much longer it will be valid for.
-	validUntil := time.Unix(int64(cert.ValidBefore), 0)
+	validUntil := time.Unix(int64(sshCert.ValidBefore), 0)
 
 	// Extract roles from certificate. Note, if the certificate is in old format,
 	// this will be empty.
 	var roles []string
-	rawRoles, ok := cert.Extensions[teleport.CertExtensionTeleportRoles]
+	rawRoles, ok := sshCert.Extensions[teleport.CertExtensionTeleportRoles]
 	if ok {
 		roles, err = services.UnmarshalCertRoles(rawRoles)
 		if err != nil {
@@ -407,7 +418,7 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
 	// Extract traits from the certificate. Note if the certificate is in the
 	// old format, this will be empty.
 	var traits wrappers.Traits
-	rawTraits, ok := cert.Extensions[teleport.CertExtensionTeleportTraits]
+	rawTraits, ok := sshCert.Extensions[teleport.CertExtensionTeleportTraits]
 	if ok {
 		err = wrappers.UnmarshalTraits([]byte(rawTraits), &traits)
 		if err != nil {
@@ -416,7 +427,7 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
 	}
 
 	var activeRequests services.RequestIDs
-	rawRequests, ok := cert.Extensions[teleport.CertExtensionTeleportActiveRequests]
+	rawRequests, ok := sshCert.Extensions[teleport.CertExtensionTeleportActiveRequests]
 	if ok {
 		if err := activeRequests.Unmarshal([]byte(rawRequests)); err != nil {
 			return nil, trace.Wrap(err)
@@ -426,7 +437,7 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
 	// Extract extensions from certificate. This lists the abilities of the
 	// certificate (like can the user request a PTY, port forwarding, etc.)
 	var extensions []string
-	for ext := range cert.Extensions {
+	for ext := range sshCert.Extensions {
 		if ext == teleport.CertExtensionTeleportRoles ||
 			ext == teleport.CertExtensionTeleportTraits ||
 			ext == teleport.CertExtensionTeleportRouteToCluster ||
@@ -448,19 +459,34 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
 		clusterName = profile.Name()
 	}
 
+	tlsCert, err := key.TLSCertificate()
+	if err != nil {
+		return nil, trace.Wrap(err)
+	}
+	tlsID, err := tlsca.FromSubject(tlsCert.Subject, time.Time{})
+	if err != nil {
+		return nil, trace.Wrap(err)
+	}
+
 	return &ProfileStatus{
 		ProxyURL: url.URL{
 			Scheme: "https",
 			Host:   profile.WebProxyAddr,
 		},
 		Username:       profile.Username,
-		Logins:         cert.ValidPrincipals,
+		Logins:         sshCert.ValidPrincipals,
 		ValidUntil:     validUntil,
 		Extensions:     extensions,
 		Roles:          roles,
 		Cluster:        clusterName,
 		Traits:         traits,
 		ActiveRequests: activeRequests,
+		// The TLS cert may have k8s users and groups even when proxy isn't
+		// configured to talk to a k8s cluster.
+		// RouteToCluster is only set when k8s support is enabled though.
+		KubeEnabled: tlsID.RouteToCluster != "" && (len(tlsID.KubernetesUsers) > 0 || len(tlsID.KubernetesGroups) > 0),
+		KubeUsers:   tlsID.KubernetesUsers,
+		KubeGroups:  tlsID.KubernetesGroups,
 	}, nil
 }
 
diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go
index 146cd98533d4e..78eab94e56f9a 100644
--- a/tool/tsh/tsh.go
+++ b/tool/tsh/tsh.go
@@ -1278,25 +1278,36 @@ func printStatus(debug bool, p *client.ProfileStatus, isActive bool) {
 		humanDuration = fmt.Sprintf("valid for %v", duration.Round(time.Minute))
 	}
 
-	fmt.Printf("%vProfile URL:  %v\n", prefix, p.ProxyURL.String())
-	fmt.Printf("  Logged in as: %v\n", p.Username)
+	fmt.Printf("%vProfile URL:       %v\n", prefix, p.ProxyURL.String())
+	fmt.Printf("  Logged in as:      %v\n", p.Username)
 	if p.Cluster != "" {
-		fmt.Printf("  Cluster:      %v\n", p.Cluster)
+		fmt.Printf("  Cluster:           %v\n", p.Cluster)
 	}
-	fmt.Printf("  Roles:        %v*\n", strings.Join(p.Roles, ", "))
+	fmt.Printf("  Roles:             %v*\n", strings.Join(p.Roles, ", "))
 	if debug {
 		for k, v := range p.Traits {
 			if count == 0 {
-				fmt.Printf("  Traits:       %v: %v\n", k, v)
+				fmt.Printf("  Traits:            %v: %v\n", k, v)
 			} else {
-				fmt.Printf("                %v: %v\n", k, v)
+				fmt.Printf("                     %v: %v\n", k, v)
 			}
 			count = count + 1
 		}
 	}
-	fmt.Printf("  Logins:       %v\n", strings.Join(p.Logins, ", "))
-	fmt.Printf("  Valid until:  %v [%v]\n", p.ValidUntil, humanDuration)
-	fmt.Printf("  Extensions:   %v\n", strings.Join(p.Extensions, ", "))
+	fmt.Printf("  Logins:            %v\n", strings.Join(p.Logins, ", "))
+	if p.KubeEnabled {
+		fmt.Printf("  Kubernetes:        enabled\n")
+	} else {
+		fmt.Printf("  Kubernetes:        disabled\n")
+	}
+	if p.KubeEnabled && len(p.KubeUsers) > 0 {
+		fmt.Printf("  Kubernetes users:  %v\n", strings.Join(p.KubeUsers, ", "))
+	}
+	if p.KubeEnabled && len(p.KubeGroups) > 0 {
+		fmt.Printf("  Kubernetes groups: %v\n", strings.Join(p.KubeGroups, ", "))
+	}
+	fmt.Printf("  Valid until:       %v [%v]\n", p.ValidUntil, humanDuration)
+	fmt.Printf("  Extensions:        %v\n", strings.Join(p.Extensions, ", "))
 
 	fmt.Printf("\n")
 }