diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 27ae5457ce918..82f7982645e38 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -26,9 +26,11 @@ import ( "math" "net" "net/http" + "os" "slices" "time" + "github.com/coreos/go-semver/semver" "github.com/gravitational/oxy/ratelimit" "github.com/gravitational/trace" om "github.com/grpc-ecosystem/go-grpc-middleware/providers/openmetrics/v2" @@ -42,6 +44,7 @@ import ( "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/grpc/interceptors" @@ -163,14 +166,20 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error) return nil, trace.Wrap(err) } + var oldestSupportedVersion *semver.Version + if os.Getenv("TELEPORT_UNSTABLE_REJECT_OLD_CLIENTS") == "yes" { + oldestSupportedVersion = &teleport.MinClientSemVersion + } + // authMiddleware authenticates request assuming TLS client authentication // adds authentication information to the context // and passes it to the API server authMiddleware := &Middleware{ - ClusterName: localClusterName.GetClusterName(), - AcceptedUsage: cfg.AcceptedUsage, - Limiter: limiter, - GRPCMetrics: grpcMetrics, + ClusterName: localClusterName.GetClusterName(), + AcceptedUsage: cfg.AcceptedUsage, + Limiter: limiter, + GRPCMetrics: grpcMetrics, + OldestSupportedVersion: oldestSupportedVersion, } apiServer, err := NewAPIServer(&cfg.APIConfig) @@ -366,6 +375,10 @@ type Middleware struct { // This is used by the proxy to forward the identity of the user who // connected to the proxy to the next hop. EnableCredentialsForwarding bool + // OldestSupportedVersion optionally allows the middleware to reject any connections + // originated from a client that is using an unsupported version. If not set, then no + // rejection occurs. + OldestSupportedVersion *semver.Version } // Wrap sets next handler in chain @@ -404,6 +417,40 @@ func getCustomRate(endpoint string) *ratelimit.RateSet { return nil } +// ValidateClientVersion inspects the client version for the connection and terminates +// the [IdentityInfo.Conn] if the client is unsupported. Requires the [Middleware.OldestSupportedVersion] +// to be configured before any connection rejection occurs. +func (a *Middleware) ValidateClientVersion(ctx context.Context, info IdentityInfo) error { + if a.OldestSupportedVersion == nil { + return nil + } + + clientVersionString, versionExists := metadata.ClientVersionFromContext(ctx) + if !versionExists { + return nil + } + + logger := log.WithFields(logrus.Fields{"identity": info.IdentityGetter.GetIdentity().Username, "version": clientVersionString}) + clientVersion, err := semver.NewVersion(clientVersionString) + if err != nil { + logger.WithError(err).Warn("Failed to determine client version") + if err := info.Conn.Close(); err != nil { + logger.WithError(err).Warn("Failed to close client connection") + } + return trace.AccessDenied("client version is unsupported") + } + + if clientVersion.LessThan(*a.OldestSupportedVersion) { + logger.Info("Terminating connection of client using unsupported version") + if err := info.Conn.Close(); err != nil { + logger.WithError(err).Warn("Failed to close client connection") + } + return trace.AccessDenied("client version is unsupported") + } + + return nil +} + // withAuthenticatedUser returns a new context with the ContextUser field set to // the caller's user identity as authenticated by their client TLS certificate. func (a *Middleware) withAuthenticatedUser(ctx context.Context) (context.Context, error) { @@ -423,6 +470,10 @@ func (a *Middleware) withAuthenticatedUser(ctx context.Context) (context.Context case IdentityInfo: connState = &info.TLSInfo.State identityGetter = info.IdentityGetter + + if err := a.ValidateClientVersion(ctx, info); err != nil { + return nil, trace.Wrap(err) + } // credentials.TLSInfo is provided if the grpc server is configured with // credentials.NewTLS. case credentials.TLSInfo: diff --git a/lib/auth/middleware_test.go b/lib/auth/middleware_test.go index ecc7630a5b9fc..f0f56fbae1f94 100644 --- a/lib/auth/middleware_test.go +++ b/lib/auth/middleware_test.go @@ -31,10 +31,14 @@ import ( "testing" "time" + "github.com/coreos/go-semver/semver" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/services" @@ -655,3 +659,77 @@ func (h *fakeHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { require.Empty(h.t, r.Header.Get(TeleportImpersonateUserHeader)) require.Empty(h.t, r.Header.Get(TeleportImpersonateIPHeader)) } + +type fakeConn struct { + net.Conn +} + +func (f fakeConn) Close() error { + return nil +} + +func TestValidateClientVersion(t *testing.T) { + cases := []struct { + name string + middleware Middleware + clientVersion string + errAssertion func(t *testing.T, err error) + }{ + { + name: "rejection disabled", + errAssertion: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "rejection enabled and client version not specified", + middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion}, + errAssertion: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "client rejected", + middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion}, + clientVersion: semver.Version{Major: teleport.SemVersion.Major - 2}.String(), + errAssertion: func(t *testing.T, err error) { + require.True(t, trace.IsAccessDenied(err), "got %T, expected access denied error", err) + }, + }, + { + name: "valid client v-1", + middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion}, + clientVersion: semver.Version{Major: teleport.SemVersion.Major - 1}.String(), + errAssertion: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "valid client v-0", + middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion}, + clientVersion: semver.Version{Major: teleport.SemVersion.Major}.String(), + errAssertion: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "invalid client version", + middleware: Middleware{OldestSupportedVersion: &teleport.MinClientSemVersion}, + clientVersion: "abc123", + errAssertion: func(t *testing.T, err error) { + require.True(t, trace.IsAccessDenied(err), "got %T, expected access denied error", err) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + if tt.clientVersion != "" { + ctx = metadata.NewIncomingContext(ctx, metadata.New(map[string]string{"version": tt.clientVersion})) + } + + tt.errAssertion(t, tt.middleware.ValidateClientVersion(ctx, IdentityInfo{Conn: fakeConn{}, IdentityGetter: TestBuiltin(types.RoleNode).I})) + }) + } +} diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 274ebfe26fb30..09556380ce9d7 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -47,6 +47,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/types" eventtypes "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" @@ -66,6 +67,66 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +func TestRejectedClients(t *testing.T) { + t.Setenv("TELEPORT_UNSTABLE_REJECT_OLD_CLIENTS", "yes") + + server, err := NewTestAuthServer(TestAuthServerConfig{ + Dir: t.TempDir(), + ClusterName: "cluster", + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + user, _, err := CreateUserAndRole(server.AuthServer, "user", []string{"role"}, nil) + require.NoError(t, err) + + tlsServer, err := server.NewTestTLSServer() + require.NoError(t, err) + defer tlsServer.Close() + + tlsConfig, err := tlsServer.ClientTLSConfig(TestUser(user.GetName())) + require.NoError(t, err) + + clt, err := NewClient(client.Config{ + DialInBackground: true, + Addrs: []string{tlsServer.Addr().String()}, + Credentials: []client.Credentials{ + client.LoadTLS(tlsConfig), + }, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + }) + require.NoError(t, err) + defer clt.Close() + + t.Run("reject old version", func(t *testing.T) { + version := teleport.MinClientSemVersion + version.Major-- + ctx := context.WithValue(context.Background(), metadata.DisableInterceptors{}, struct{}{}) + ctx = metadata.AddMetadataToContext(ctx, map[string]string{ + metadata.VersionKey: version.String(), + }) + resp, err := clt.Ping(ctx) + require.True(t, trace.IsConnectionProblem(err)) + require.Equal(t, proto.PingResponse{}, resp) + }) + + t.Run("allow valid versions", func(t *testing.T) { + version := teleport.MinClientSemVersion + version.Major-- + for i := 0; i < 5; i++ { + version.Major++ + + ctx := context.WithValue(context.Background(), metadata.DisableInterceptors{}, struct{}{}) + ctx = metadata.AddMetadataToContext(ctx, map[string]string{ + metadata.VersionKey: version.String(), + }) + resp, err := clt.Ping(ctx) + require.NoError(t, err) + require.NotNil(t, resp) + } + }) +} + // TestRemoteBuiltinRole tests remote builtin role // that gets mapped to remote proxy readonly role func TestRemoteBuiltinRole(t *testing.T) { diff --git a/lib/auth/transport_credentials.go b/lib/auth/transport_credentials.go index 880b4e83a480c..082153e9dd9f7 100644 --- a/lib/auth/transport_credentials.go +++ b/lib/auth/transport_credentials.go @@ -139,6 +139,8 @@ type IdentityInfo struct { // [TransportCredentialsConfig.Authorizer] provided to [NewTransportCredentials] // was nil. AuthContext *authz.Context + // Conn is the underlying [net.Conn] of the gRPC connection. + Conn net.Conn } // ServerHandshake does the authentication handshake for servers. It returns @@ -179,6 +181,7 @@ func (c *TransportCredentials) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ TLSInfo: tlsInfo, IdentityGetter: identityGetter, AuthContext: authCtx, + Conn: conn, }, nil }