diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 6f8e805043fb7..d9ca4183ec729 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -59,6 +59,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -4200,9 +4201,44 @@ func (a *ServerWithRoles) ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error) if !a.hasBuiltinRole(types.RoleProxy) { return nil, trace.AccessDenied("this request can be only executed by a proxy") } + clusterName, err := a.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + proxyClusterName := a.context.Identity.GetIdentity().TeleportCluster + identityClusterName, err := extractOriginalClusterNameFromCSR(req) + if err != nil { + return nil, trace.Wrap(err) + } + if proxyClusterName != "" && + proxyClusterName != clusterName.GetClusterName() && + proxyClusterName != identityClusterName { + log.WithFields( + logrus.Fields{ + "proxy_cluster_name": proxyClusterName, + "identity_cluster_name": identityClusterName, + }, + ).Warn("KubeCSR request denied because the proxy and identity clusters didn't match") + return nil, trace.AccessDenied("can not sign certs for users via a different cluster proxy") + } return a.authServer.ProcessKubeCSR(req) } +func extractOriginalClusterNameFromCSR(req KubeCSR) (string, error) { + csr, err := tlsca.ParseCertificateRequestPEM(req.CSR) + if err != nil { + return "", trace.Wrap(err) + } + + // Extract identity from the CSR. Pass zero time for id.Expiry, it won't be + // used here. + id, err := tlsca.FromSubject(csr.Subject, time.Time{}) + if err != nil { + return "", trace.Wrap(err) + } + return id.TeleportCluster, nil +} + // GetDatabaseServers returns all registered database servers. func (a *ServerWithRoles) GetDatabaseServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.DatabaseServer, error) { if err := a.action(namespace, types.KindDatabaseServer, types.VerbList, types.VerbRead); err != nil { diff --git a/lib/auth/middleware_test.go b/lib/auth/middleware_test.go index e3fc68e307f60..7545fd24e0c38 100644 --- a/lib/auth/middleware_test.go +++ b/lib/auth/middleware_test.go @@ -194,7 +194,6 @@ func TestMiddlewareGetUser(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - m := &Middleware{ AccessPoint: s, } @@ -350,8 +349,10 @@ func TestWrapContextWithUser(t *testing.T) { } conn := &testConn{ - state: tls.ConnectionState{PeerCertificates: tt.peers, - HandshakeComplete: !tt.needsHandshake}, + state: tls.ConnectionState{ + PeerCertificates: tt.peers, + HandshakeComplete: !tt.needsHandshake, + }, remoteAddr: utils.MustParseAddr("127.0.0.1:4242"), }