Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ import (
"github.com/gravitational/teleport/lib/githubactions"
"github.com/gravitational/teleport/lib/gitlab"
"github.com/gravitational/teleport/lib/inventory"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/kubernetestoken"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/loginrule"
Expand Down Expand Up @@ -3311,9 +3310,29 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
// If the certificate is targeting a trusted Teleport cluster, it is the
// responsibility of the cluster to ensure its existence.
if req.routeToCluster == clusterName && req.kubernetesCluster != "" {
if err := kubeutils.CheckKubeCluster(a.closeCtx, a, req.kubernetesCluster); err != nil {
found, _, err := a.UnifiedResourceCache.IterateUnifiedResources(a.closeCtx, func(rwl types.ResourceWithLabels) (bool, error) {
if rwl.GetKind() != types.KindKubeServer {
return false, nil
}

ks, ok := rwl.(types.KubeServer)
if !ok {
return false, nil
}

return ks.GetCluster().GetName() == req.kubernetesCluster, nil
}, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindKubeServer},
SortBy: types.SortBy{Field: services.SortByName},
Limit: 1,
})
if err != nil {
return nil, trace.Wrap(err)
}

if len(found) == 0 {
return nil, trace.BadParameter("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'", req.kubernetesCluster)
}
}

// See which database names and users this user is allowed to use.
Expand Down
107 changes: 105 additions & 2 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,25 @@ func newTestPack(
}
p.a.SetLockWatcher(lockWatcher)

// set cluster name
err = p.a.SetClusterName(p.clusterName)
urc, err := services.NewUnifiedResourceCache(ctx, services.UnifiedResourceCacheConfig{
Clock: p.a.clock,
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentAuth,
Client: p.a,
},
ResourceGetter: p.a,
})
if err != nil {
return p, trace.Wrap(err)
}

p.a.SetUnifiedResourcesCache(urc)

// set cluster name
if err := p.a.SetClusterName(p.clusterName); err != nil {
return p, trace.Wrap(err)
}

// set static tokens
staticTokens, err := types.NewStaticTokens(types.StaticTokensSpecV2{
StaticTokens: []types.ProvisionTokenV1{},
Expand Down Expand Up @@ -3004,6 +3017,96 @@ func TestGenerateUserCertWithHardwareKeySupport(t *testing.T) {
}
}

func TestGenerateKubernetesUserCert(t *testing.T) {
ctx := context.Background()
p, err := newTestPack(ctx, t.TempDir())
require.NoError(t, err)

user, _, err := CreateUserAndRole(p.a, "test-user", []string{}, nil)
require.NoError(t, err)

rc, err := types.NewRemoteCluster("leaf")
require.NoError(t, err)
_, err = p.a.CreateRemoteCluster(ctx, rc)
require.NoError(t, err)

kubeCluster, err := types.NewKubernetesClusterV3(types.Metadata{Name: "kube-cluster"}, types.KubernetesClusterSpecV3{})
require.NoError(t, err)
kubeServer, err := types.NewKubernetesServerV3FromCluster(kubeCluster, "foo", "1")
require.NoError(t, err)
_, err = p.a.UpsertKubernetesServer(ctx, kubeServer)
require.NoError(t, err)

// Wait for cache propagation of the kubernetes resources before proceeding with the tests.
require.EventuallyWithT(t, func(t *assert.CollectT) {
found, _, err := p.a.UnifiedResourceCache.IterateUnifiedResources(ctx, func(rwl types.ResourceWithLabels) (bool, error) {
if rwl.GetKind() != types.KindKubeServer {
return false, nil
}

ks, ok := rwl.(types.KubeServer)
if !ok {
return false, nil
}

return ks.GetCluster().GetName() == kubeCluster.GetName(), nil
}, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindKubeServer},
SortBy: types.SortBy{Field: services.SortByName},
Limit: 1,
})

assert.NoError(t, err)
assert.Len(t, found, 1)
}, 10*time.Second, 100*time.Millisecond)

accessInfo := services.AccessInfoFromUserState(user)
accessChecker, err := services.NewAccessChecker(accessInfo, p.clusterName.GetClusterName(), p.a)
require.NoError(t, err)

_, sshPubKey, _, tlsPubKey := newSSHAndTLSKeyPairs(t)

for _, tt := range []struct {
name string
teleportCluster string
kubernetesCluster string
assertErr require.ErrorAssertionFunc
}{
{
name: "leaf clusters not validated",
teleportCluster: "leaf",
kubernetesCluster: "foo",
assertErr: require.NoError,
},
{
name: "kubernetes cluster not registered",
teleportCluster: p.clusterName.GetClusterName(),
kubernetesCluster: "foo",
assertErr: require.Error,
},
{
name: "kubernetes cluster registered",
teleportCluster: p.clusterName.GetClusterName(),
kubernetesCluster: kubeCluster.GetName(),
assertErr: require.NoError,
},
} {
t.Run(tt.name, func(t *testing.T) {
certReq := certRequest{
user: user,
checker: accessChecker,
sshPublicKey: sshPubKey,
tlsPublicKey: tlsPubKey,
routeToCluster: tt.teleportCluster,
kubernetesCluster: tt.kubernetesCluster,
}

_, err = p.a.generateUserCert(ctx, certReq)
tt.assertErr(t, err)
})
}
}

func TestNewWebSession(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down
59 changes: 0 additions & 59 deletions lib/kube/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"context"
"encoding/hex"
"errors"
"slices"
"strings"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -148,48 +147,6 @@ func EncodeClusterName(clusterName string) string {
return "k" + hex.EncodeToString([]byte(clusterName))
}

// KubeServicesPresence fetches a list of registered kubernetes servers.
// It's a subset of services.Presence.
type KubeServicesPresence interface {
// GetKubernetesServers returns a list of registered kubernetes servers.
GetKubernetesServers(context.Context) ([]types.KubeServer, error)
}

// KubeClusterNames returns a sorted list of unique kubernetes cluster
// names registered in p.
//
// DELETE IN 11.0.0, replaced by ListKubeClustersWithFilters
func KubeClusterNames(ctx context.Context, p KubeServicesPresence) ([]string, error) {
kss, err := p.GetKubernetesServers(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return extractAndSortKubeClusterNames(kss), nil
}

func extractAndSortKubeClusterNames(kubeServers []types.KubeServer) []string {
kubeClusters := extractAndSortKubeClusters(kubeServers)
kubeClusterNames := make([]string, len(kubeClusters))
for i := range kubeClusters {
kubeClusterNames[i] = kubeClusters[i].GetName()
}

return kubeClusterNames
}

// KubeClusters returns a sorted list of unique kubernetes clusters
// registered in p.
//
// DELETE IN 11.0.0, replaced by ListKubeClustersWithFilters
func KubeClusters(ctx context.Context, p KubeServicesPresence) ([]types.KubeCluster, error) {
kubeServers, err := p.GetKubernetesServers(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

return extractAndSortKubeClusters(kubeServers), nil
}

// ListKubeClustersWithFilters returns a sorted list of unique kubernetes clusters
// registered in p.
func ListKubeClustersWithFilters(ctx context.Context, p client.GetResourcesClient, req proto.ListResourcesRequest) ([]types.KubeCluster, error) {
Expand Down Expand Up @@ -245,19 +202,3 @@ func GetKubeAgentVersion(ctx context.Context, pinger Pinger, clusterFeatures pro

return strings.TrimPrefix(agentVersion, "v"), nil
}

// CheckKubeCluster validates kubeClusterName is registered with this Teleport cluster.
func CheckKubeCluster(ctx context.Context, p KubeServicesPresence, kubeClusterName string) error {
if kubeClusterName == "" {
return trace.BadParameter("kube cluster name should not be empty.")
}
kubeClusterNames, err := KubeClusterNames(ctx, p)
if err != nil {
return trace.Wrap(err, "failed to get list of available Kubernetes clusters.")
}
if !slices.Contains(kubeClusterNames, kubeClusterName) {
return trace.BadParameter("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'", kubeClusterName)
}

return nil
}
82 changes: 0 additions & 82 deletions lib/kube/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,74 +26,9 @@ import (
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/automaticupgrades"
)

func TestCheckKubeCluster(t *testing.T) {
t.Parallel()
ctx := context.Background()

kubeServers := []types.KubeServer{
kubeServer(t, "k8s-1", "server1", "uuuid"),
kubeServer(t, "k8s-2", "server1", "uuuid"),
kubeServer(t, "k8s-3", "server1", "uuuid"),
kubeServer(t, "k8s-4", "server1", "uuuid"),
}

tests := []struct {
desc string
services []types.KubeServer
kubeCluster string
assertErr require.ErrorAssertionFunc
}{
{
desc: "valid cluster name",
services: kubeServers,
kubeCluster: "k8s-4",
assertErr: require.NoError,
},
{
desc: "invalid cluster name",
services: kubeServers,
kubeCluster: "k8s-5",
assertErr: require.Error,
},
{
desc: "no registered clusters",
services: []types.KubeServer{},
kubeCluster: "k8s-1",
assertErr: require.Error,
},
{
desc: "empty cluster provided",
services: kubeServers,
kubeCluster: "",
assertErr: require.Error,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckKubeCluster(ctx, mockKubeServicesPresence(tt.services), tt.kubeCluster)
tt.assertErr(t, err)
})
}
}

type mockKubeServicesPresence []types.KubeServer

func (p mockKubeServicesPresence) GetKubernetesServers(context.Context) ([]types.KubeServer, error) {
return p, nil
}

func kubeServer(t *testing.T, kubeCluster, hostname, hostID string) types.KubeServer {
cluster, err := types.NewKubernetesClusterV3(types.Metadata{Name: kubeCluster}, types.KubernetesClusterSpecV3{})
require.NoError(t, err)
server, err := types.NewKubernetesServerV3FromCluster(cluster, hostname, hostID)
require.NoError(t, err)
return server
}

func TestGetAgentVersion(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -162,20 +97,3 @@ type pinger struct {
func (p *pinger) Ping(ctx context.Context) (proto.PingResponse, error) {
return p.pingFn(ctx)
}

func TestExtractAndSortKubeClusterNames(t *testing.T) {
t.Parallel()

server1 := kubeServer(t, "watermelon", "server1", "uuuid")

server2 := kubeServer(t, "watermelon", "server1", "uuuid")

server3 := kubeServer(t, "banana", "server2", "uuuid2")

server4 := kubeServer(t, "apple", "server2", "uuuid2")

server5 := kubeServer(t, "pear", "server2", "uuuid2")

names := extractAndSortKubeClusterNames(types.KubeServers{server1, server2, server3, server4, server5})
require.Equal(t, []string{"apple", "banana", "pear", "watermelon"}, names)
}
20 changes: 2 additions & 18 deletions tool/tctl/common/auth_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import (
"github.com/gravitational/teleport/lib/client/identityfile"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -317,7 +316,6 @@ func (a *AuthCommand) GenerateKeys(ctx context.Context, clusterAPI authCommandCl
// certificateSigner is an interface for the methods used by GenerateAndSignKeys
// to sign certificates using the Auth Server.
type certificateSigner interface {
kubeutils.KubeServicesPresence
GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error)
GenerateUserCerts(ctx context.Context, req proto.UserCertsRequest) (*proto.Certs, error)
GenerateWindowsDesktopCert(context.Context, *proto.WindowsDesktopCertRequest) (*proto.WindowsDesktopCertResponse, error)
Expand Down Expand Up @@ -931,7 +929,7 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI certifica
}
keyRing.ClusterName = a.leafCluster

if err := a.checkKubeCluster(ctx, clusterAPI); err != nil {
if err := a.checkKubeCluster(); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -1092,7 +1090,7 @@ func (a *AuthCommand) checkLeafCluster(clusterAPI certificateSigner) error {
return trace.BadParameter("couldn't find leaf cluster named %q", a.leafCluster)
}

func (a *AuthCommand) checkKubeCluster(ctx context.Context, clusterAPI certificateSigner) error {
func (a *AuthCommand) checkKubeCluster() error {
if a.kubeCluster == "" {
return nil
}
Expand All @@ -1105,20 +1103,6 @@ func (a *AuthCommand) checkKubeCluster(ctx context.Context, clusterAPI certifica
return nil
}

localCluster, err := clusterAPI.GetClusterName()
if err != nil {
return trace.Wrap(err)
}
if localCluster.GetClusterName() != a.leafCluster {
// Skip validation on remote clusters, since we don't know their
// registered kube clusters.
return nil
}

if err := kubeutils.CheckKubeCluster(ctx, clusterAPI, a.kubeCluster); err != nil {
return trace.Wrap(err)
}

return nil
}

Expand Down
Loading