diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index 7e1dff01fbd0e..47ecddcab5b49 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -22,20 +22,25 @@ import ( "crypto/tls" "crypto/x509/pkix" "encoding/json" + "fmt" + "io" "net" "net/http" "net/http/httptest" "net/url" "path/filepath" + "strings" "testing" "time" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jackc/pgconn" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/tools/clientcmd" @@ -46,7 +51,9 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/integration/helpers" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -55,6 +62,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/postgres" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) type Suite struct { @@ -692,3 +700,75 @@ func mustParseURL(t *testing.T, rawURL string) *url.URL { require.NoError(t, err) return u } + +// fakeSTSClient is a fake HTTP client used to fake STS responses when Auth +// server sends out pre-signed STS requests for IAM join verification. +type fakeSTSClient struct { + accountID string + arn string + credentials *credentials.Credentials +} + +func (f fakeSTSClient) Do(req *http.Request) (*http.Response, error) { + if err := awsutils.VerifyAWSSignature(req, f.credentials); err != nil { + return nil, trace.Wrap(err) + } + response := fmt.Sprintf(`{"GetCallerIdentityResponse": {"GetCallerIdentityResult": {"Account": "%s", "Arn": "%s" }}}`, f.accountID, f.arn) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(response)), + }, nil +} + +func mustCreateIAMJoinProvisionToken(t *testing.T, name, awsAccountID, allowedARN string) types.ProvisionToken { + t.Helper() + + provisionToken, err := types.NewProvisionTokenFromSpec( + name, + time.Now().Add(time.Hour), + types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Allow: []*types.TokenRule{ + { + AWSAccount: awsAccountID, + AWSARN: allowedARN, + }, + }, + JoinMethod: types.JoinMethodIAM, + }, + ) + require.NoError(t, err) + return provisionToken +} + +func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token string, credentials *credentials.Credentials) { + t.Helper() + + cred, err := credentials.Get() + require.NoError(t, err) + + t.Setenv("AWS_ACCESS_KEY_ID", cred.AccessKeyID) + t.Setenv("AWS_SECRET_ACCESS_KEY", cred.SecretAccessKey) + t.Setenv("AWS_SESSION_TOKEN", cred.SessionToken) + t.Setenv("AWS_REGION", "us-west-2") + + privateKey, err := ssh.ParseRawPrivateKey([]byte(fixtures.SSHCAPrivateKey)) + require.NoError(t, err) + pubTLS, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(privateKey) + require.NoError(t, err) + + node := uuid.NewString() + _, err = auth.Register(auth.RegisterParams{ + Token: token, + ID: auth.IdentityID{ + Role: types.RoleNode, + HostUUID: node, + NodeName: node, + }, + ProxyServer: proxyAddr, + JoinMethod: types.JoinMethodIAM, + PublicTLSKey: pubTLS, + PublicSSHKey: []byte(fixtures.SSHCAPublicKey), + }) + require.NoError(t, err, trace.DebugReport(err)) +} diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 32aa114a59d35..f03b8f8b94ebd 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -29,6 +29,7 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -46,6 +47,7 @@ import ( "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/integration/kube" "github.com/gravitational/teleport/lib" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/testauthority" libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" @@ -1505,3 +1507,39 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) { require.NoError(t, helpers.WaitForNodeCount(context.Background(), rc, rc.Secrets.SiteName, 1)) require.Greater(t, ph.Count(), 0) } + +// TestALPNSNIProxyGRPCInsecure tests ALPN protocol ProtocolProxyGRPCInsecure +// by registering a node with IAM join method. +func TestALPNSNIProxyGRPCInsecure(t *testing.T) { + lib.SetInsecureDevMode(true) + defer lib.SetInsecureDevMode(false) + + nodeAccount := "123456789012" + nodeRoleARN := "arn:aws:iam::123456789012:role/test" + nodeCredentials := credentials.NewStaticCredentials("FAKE_ID", "FAKE_KEY", "FAKE_TOKEN") + provisionToken := mustCreateIAMJoinProvisionToken(t, "iam-join-token", nodeAccount, nodeRoleARN) + + suite := newSuite(t, + withRootClusterConfig(rootClusterStandardConfig(t), func(config *servicecfg.Config) { + config.Auth.BootstrapResources = []types.Resource{provisionToken} + config.Auth.ServerOptions = []auth.ServerOption{auth.WithHTTPClientForAWSSTS(fakeSTSClient{ + accountID: nodeAccount, + arn: nodeRoleARN, + credentials: nodeCredentials, + })} + }), + withLeafClusterConfig(leafClusterStandardConfig(t)), + ) + + // Test register through Proxy. + mustRegisterUsingIAMMethod(t, suite.root.Config.Proxy.WebAddr, provisionToken.GetName(), nodeCredentials) + + // Test register through Proxy behind a L7 load balancer. + t.Run("ALPN conn upgrade", func(t *testing.T) { + albProxy := mustStartMockALBProxy(t, suite.root.Config.Proxy.WebAddr.Addr) + albAddr, err := utils.ParseAddr(albProxy.Addr().String()) + require.NoError(t, err) + + mustRegisterUsingIAMMethod(t, *albAddr, provisionToken.GetName(), nodeCredentials) + }) +} diff --git a/lib/auth/auth.go b/lib/auth/auth.go index c3fcee6f9ef8e..91406c1d4baea 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -593,6 +593,10 @@ type Server struct { loginHooksMu sync.RWMutex // loginHooks are a list of hooks that will be called on login. loginHooks []LoginHook + + // httpClientForAWSSTS overwrites the default HTTP client used for making + // STS requests. + httpClientForAWSSTS stsClient } // SetSAMLService registers svc as the SAMLService that provides the SAML @@ -5398,3 +5402,12 @@ func DefaultDNSNamesForRole(role types.SystemRole) []string { } return nil } + +// WithHTTPClientForAWSSTS is a ServerOption that overwrites default HTTP +// client used for STS requests. +func WithHTTPClientForAWSSTS(client stsClient) ServerOption { + return func(s *Server) error { + s.httpClientForAWSSTS = client + return nil + } +} diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 19af0f15273bb..0018454a99186 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -985,7 +985,7 @@ func NewServerIdentity(clt *Server, hostID string, role types.SystemRole) (*Iden &proto.HostCertsRequest{ HostID: hostID, NodeName: hostID, - Role: types.RoleAuth, + Role: role, PublicTLSKey: publicTLS, PublicSSHKey: pub, }) diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index cfa6f2fa2dbf7..b63ef32567283 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -196,21 +196,12 @@ type stsClient interface { Do(*http.Request) (*http.Response, error) } -type stsClientKey struct{} - -// stsClientFromContext allows the default http client to be overridden for tests -func stsClientFromContext(ctx context.Context) stsClient { - client, ok := ctx.Value(stsClientKey{}).(stsClient) - if ok { - return client - } - return http.DefaultClient -} - // executeSTSIdentityRequest sends the sts:GetCallerIdentity HTTP request to the // AWS API, parses the response, and returns the awsIdentity -func executeSTSIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) { - client := stsClientFromContext(ctx) +func executeSTSIdentityRequest(ctx context.Context, client stsClient, req *http.Request) (*awsIdentity, error) { + if client == nil { + client = http.DefaultClient + } // set the http request context so it can be canceled req = req.WithContext(ctx) @@ -312,7 +303,7 @@ func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *pro // send the signed request to the public AWS API and get the node identity // from the response - identity, err := executeSTSIdentityRequest(ctx, identityRequest) + identity, err := executeSTSIdentityRequest(ctx, a.httpClientForAWSSTS, identityRequest) if err != nil { return trace.Wrap(err) } diff --git a/lib/auth/join_iam_test.go b/lib/auth/join_iam_test.go index c67b40f67ea18..3a8ef165b9fda 100644 --- a/lib/auth/join_iam_test.go +++ b/lib/auth/join_iam_test.go @@ -490,6 +490,9 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { + // Set mock client. + a.httpClientForAWSSTS = tc.stsClient + // add token to auth server token, err := types.NewProvisionTokenFromSpec( tc.tokenName, @@ -503,7 +506,6 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) { requestContext := context.Background() requestContext = authz.ContextWithClientAddr(requestContext, &net.IPAddr{}) - requestContext = context.WithValue(requestContext, stsClientKey{}, tc.stsClient) _, err = a.RegisterUsingIAMMethod(requestContext, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { templateInput := defaultIdentityRequestTemplateInput(challenge) diff --git a/lib/auth/register.go b/lib/auth/register.go index 4dbba6bc270ec..f4ccdc2de5a9b 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -18,12 +18,14 @@ package auth import ( "context" + "crypto/tls" "crypto/x509" "os" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "golang.org/x/exp/slices" "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -32,6 +34,8 @@ import ( "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib" @@ -300,7 +304,7 @@ func registerThroughProxy(token string, params RegisterParams) (*proto.Certs, er switch params.JoinMethod { case types.JoinMethodIAM, types.JoinMethodAzure: // IAM and Azure join methods require gRPC client - conn, err := proxyJoinServiceConn(params) + conn, err := proxyJoinServiceConn(params, lib.IsInsecureDevMode()) if err != nil { return nil, trace.Wrap(err) } @@ -402,19 +406,44 @@ func registerThroughAuth(token string, params RegisterParams) (*proto.Certs, err // proxyJoinServiceConn attempts to connect to the join service running on the // proxy. The Proxy's TLS cert will be verified using the host's root CA pool // (PKI) unless the --insecure flag was passed. -func proxyJoinServiceConn(params RegisterParams) (*grpc.ClientConn, error) { +func proxyJoinServiceConn(params RegisterParams, insecure bool) (*grpc.ClientConn, error) { tlsConfig := utils.TLSConfig(params.CipherSuites) tlsConfig.Time = params.Clock.Now // set NextProtos for TLS routing, the actual protocol will be h2 tlsConfig.NextProtos = []string{string(common.ProtocolProxyGRPCInsecure), http2.NextProtoTLS} - if lib.IsInsecureDevMode() { + if insecure { tlsConfig.InsecureSkipVerify = true log.Warnf("Joining cluster without validating the identity of the Proxy Server.") } + // Check if proxy is behind a load balancer. If so, the connection upgrade + // will verify the load balancer's cert using system cert pool. This + // provides the same level of security as the client only verifies Proxy's + // web cert against system cert pool when connection upgrade is not + // required. + // + // With the ALPN connection upgrade, the tunneled TLS Routing request will + // skip verify as the Proxy server will present its host cert which is not + // fully verifiable at this point since the client does not have the host + // CAs yet before completing registration. + alpnConnUpgrade := client.IsALPNConnUpgradeRequired(getHostAddresses(params)[0], insecure) + if alpnConnUpgrade && !insecure { + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyConnection = verifyALPNUpgradedConn(params.Clock) + } + + dialer := client.NewDialer( + context.Background(), + apidefaults.DefaultIdleTimeout, + apidefaults.DefaultIOTimeout, + client.WithInsecureSkipVerify(insecure), + client.WithALPNConnUpgrade(alpnConnUpgrade), + ) + conn, err := grpc.Dial( getHostAddresses(params)[0], + grpc.WithContextDialer(client.GRPCContextDialer(dialer)), grpc.WithUnaryInterceptor(metadata.UnaryClientInterceptor), grpc.WithStreamInterceptor(metadata.StreamClientInterceptor), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), @@ -422,6 +451,25 @@ func proxyJoinServiceConn(params RegisterParams) (*grpc.ClientConn, error) { return conn, trace.Wrap(err) } +// verifyALPNUpgradedConn is a tls.Config.VerifyConnection callback function +// used by the tunneled TLS Routing request to verify the host cert of a Proxy +// behind a L7 load balancer. +// +// Since the client has not obtained the cluster CAs at this point, the +// presented cert cannot be fully verified yet. For now, this function only +// checks if "teleport.cluster.local" is present as one of the DNS names and +// verifies the cert is not expired. +func verifyALPNUpgradedConn(clock clockwork.Clock) func(tls.ConnectionState) error { + return func(server tls.ConnectionState) error { + for _, cert := range server.PeerCertificates { + if slices.Contains(cert.DNSNames, constants.APIDomain) && clock.Now().Before(cert.NotAfter) { + return nil + } + } + return trace.AccessDenied("server is not a Teleport proxy or server certificate is expired") + } +} + // insecureRegisterClient attempts to connects to the Auth Server using the // CA on disk. If no CA is found on disk, Teleport will not verify the Auth // Server it is connecting to. diff --git a/lib/auth/register_test.go b/lib/auth/register_test.go new file mode 100644 index 0000000000000..674d41eef7eca --- /dev/null +++ b/lib/auth/register_test.go @@ -0,0 +1,74 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "crypto/tls" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/utils" +) + +func TestVerifyALPNUpgradedConn(t *testing.T) { + t.Parallel() + + auth := newTestTLSServer(t) + proxy, err := NewServerIdentity(auth.Auth(), "test-proxy", types.RoleProxy) + require.NoError(t, err) + + tests := []struct { + name string + serverCert []byte + clock clockwork.Clock + checkError require.ErrorAssertionFunc + }{ + { + name: "proxy verified", + serverCert: proxy.TLSCertBytes, + clock: auth.Clock(), + checkError: require.NoError, + }, + { + name: "proxy expired", + serverCert: proxy.TLSCertBytes, + clock: clockwork.NewFakeClockAt(auth.Clock().Now().Add(defaults.CATTL + time.Hour)), + checkError: require.Error, + }, + { + name: "not proxy", + serverCert: []byte(fixtures.TLSCACertPEM), + clock: auth.Clock(), + checkError: require.Error, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serverCert, err := utils.ReadCertificates(test.serverCert) + require.NoError(t, err) + + test.checkError(t, verifyALPNUpgradedConn(test.clock)(tls.ConnectionState{ + PeerCertificates: serverCert, + })) + }) + } +} diff --git a/lib/service/service.go b/lib/service/service.go index 117847984e7fc..9b1358ff12040 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1623,7 +1623,7 @@ func (process *TeleportProcess) initAuthService() error { FIPS: cfg.FIPS, LoadAllCAs: cfg.Auth.LoadAllCAs, Clock: cfg.Clock, - }, func(as *auth.Server) error { + }, append(cfg.Auth.ServerOptions, func(as *auth.Server) error { if !process.Config.CachePolicy.Enabled { return nil } @@ -1641,7 +1641,7 @@ func (process *TeleportProcess) initAuthService() error { as.Cache = cache return nil - }) + })...) if err != nil { return trace.Wrap(err) } diff --git a/lib/service/servicecfg/auth.go b/lib/service/servicecfg/auth.go index 039d1927d723b..e73e0cfc2d5af 100644 --- a/lib/service/servicecfg/auth.go +++ b/lib/service/servicecfg/auth.go @@ -19,6 +19,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/keystore" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/limiter" @@ -101,6 +102,9 @@ type AuthConfig struct { // Clock is the clock instance auth uses. Typically you'd only want to set // this during testing. Clock clockwork.Clock + + // ServerOptions is a list of auth.Init options used in test. + ServerOptions []auth.ServerOption } // HostedPluginsConfig configures the hosted plugin runtime.