From 094fa43d1f0bf660aff2600dd8c9663485c2ce3f Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Sun, 19 Mar 2023 16:33:27 -0400 Subject: [PATCH 01/27] ALPN connect test improvements --- api/defaults/defaults.go | 9 +++ api/profile/profile.go | 7 ++ lib/client/api.go | 52 +++++++++++---- lib/srv/alpnproxy/conn_upgrade.go | 74 ++++++++++++++++++--- lib/srv/alpnproxy/conn_upgrade_test.go | 59 +++++++++++++++- lib/srv/alpnproxy/local_proxy_config_opt.go | 18 +++++ tool/tsh/app.go | 3 +- tool/tsh/app_aws.go | 2 +- tool/tsh/app_azure.go | 2 +- tool/tsh/app_gcp.go | 2 +- tool/tsh/db.go | 6 +- tool/tsh/db_test.go | 23 +++++++ tool/tsh/proxy.go | 11 +-- 13 files changed, 232 insertions(+), 36 deletions(-) diff --git a/api/defaults/defaults.go b/api/defaults/defaults.go index 95f06e3602fd8..c19dd43f33501 100644 --- a/api/defaults/defaults.go +++ b/api/defaults/defaults.go @@ -145,4 +145,13 @@ const ( const ( // TunnelPublicAddrEnvar optionally specifies the alternative reverse tunnel address. TunnelPublicAddrEnvar = "TELEPORT_TUNNEL_PUBLIC_ADDR" + + // TLSRoutingConnUpgradeEnvVar overwrites the test result for deciding if + // ALPN connection upgrade is required. + // + // Sample values: + // true + // =yes,=no + // 0,=1 + TLSRoutingConnUpgradeEnvVar = "TELEPORT_TLS_ROUTING_CONN_UPGRADE" ) diff --git a/api/profile/profile.go b/api/profile/profile.go index 5746ac19bb61a..b06046e760d06 100644 --- a/api/profile/profile.go +++ b/api/profile/profile.go @@ -85,6 +85,13 @@ type Profile struct { // all proxy services are exposed on a single TLS listener (Proxy Web Listener). TLSRoutingEnabled bool `yaml:"tls_routing_enabled,omitempty"` + // TLSRoutingConnUpgradeRequired indicates that ALPN connection upgrades + // are required for making TLS routing requests. + // + // Note that this is applicable to the Proxy's Web port regardless of + // whether the Proxy is in single-port or multi-port configuration. + TLSRoutingConnUpgradeRequired bool `yaml:"tls_routing_conn_upgrade_required,omitempty"` + // AuthConnector (like "google", "passwordless"). // Equivalent to the --auth tsh flag. AuthConnector string `yaml:"auth_connector,omitempty"` diff --git a/lib/client/api.go b/lib/client/api.go index b45dd46e19811..7a07199ef3265 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -74,6 +74,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/shell" + "github.com/gravitational/teleport/lib/srv/alpnproxy" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/scp" @@ -374,6 +375,13 @@ type Config struct { // all proxy services are exposed on a single TLS listener (Proxy Web Listener). TLSRoutingEnabled bool + // TLSRoutingConnUpgradeRequired indicates that ALPN connection upgrades + // are required for making TLS routing requests. + // + // Note that this is applicable to the Proxy's Web port regardless of + // whether the Proxy is in single-port or multi-port configuration. + TLSRoutingConnUpgradeRequired bool + // Reason is a reason attached to started sessions meant to describe their intent. Reason string @@ -594,6 +602,7 @@ func (c *Config) LoadProfile(ps ProfileStore, proxyAddr string) error { c.MySQLProxyAddr = profile.MySQLProxyAddr c.MongoProxyAddr = profile.MongoProxyAddr c.TLSRoutingEnabled = profile.TLSRoutingEnabled + c.TLSRoutingConnUpgradeRequired = profile.TLSRoutingConnUpgradeRequired c.KeysDir = profile.Dir c.AuthConnector = profile.AuthConnector c.LoadAllCAs = profile.LoadAllCAs @@ -613,6 +622,10 @@ func (c *Config) LoadProfile(ps ProfileStore, proxyAddr string) error { log.Warnf("Unable to parse dynamic port forwarding in user profile: %v.", err) } + if required, ok := alpnproxy.OverwriteALPNConnUpgradeRequiredByEnv(c.WebProxyAddr); ok { + c.TLSRoutingConnUpgradeRequired = required + } + log.Infof("ALPN connection upgrade required for %q: %v.", c.WebProxyAddr, c.TLSRoutingConnUpgradeRequired) return nil } @@ -624,20 +637,21 @@ func (c *Config) SaveProfile(makeCurrent bool) error { } p := &profile.Profile{ - Username: c.Username, - WebProxyAddr: c.WebProxyAddr, - SSHProxyAddr: c.SSHProxyAddr, - KubeProxyAddr: c.KubeProxyAddr, - PostgresProxyAddr: c.PostgresProxyAddr, - MySQLProxyAddr: c.MySQLProxyAddr, - MongoProxyAddr: c.MongoProxyAddr, - ForwardedPorts: c.LocalForwardPorts.String(), - SiteName: c.SiteName, - TLSRoutingEnabled: c.TLSRoutingEnabled, - AuthConnector: c.AuthConnector, - MFAMode: c.AuthenticatorAttachment.String(), - LoadAllCAs: c.LoadAllCAs, - PrivateKeyPolicy: c.PrivateKeyPolicy, + Username: c.Username, + WebProxyAddr: c.WebProxyAddr, + SSHProxyAddr: c.SSHProxyAddr, + KubeProxyAddr: c.KubeProxyAddr, + PostgresProxyAddr: c.PostgresProxyAddr, + MySQLProxyAddr: c.MySQLProxyAddr, + MongoProxyAddr: c.MongoProxyAddr, + ForwardedPorts: c.LocalForwardPorts.String(), + SiteName: c.SiteName, + TLSRoutingEnabled: c.TLSRoutingEnabled, + TLSRoutingConnUpgradeRequired: c.TLSRoutingConnUpgradeRequired, + AuthConnector: c.AuthConnector, + MFAMode: c.AuthenticatorAttachment.String(), + LoadAllCAs: c.LoadAllCAs, + PrivateKeyPolicy: c.PrivateKeyPolicy, } if err := c.ClientStore.SaveProfile(p, makeCurrent); err != nil { @@ -838,6 +852,13 @@ func (c *Config) DatabaseProxyHostPort(db tlsca.RouteToDatabase) (string, int) { return c.WebProxyHostPort() } +// DoesDatabseUserWebProxyHostPort returns true if database is using web port. +func (c *Config) DoesDatabseUserWebProxyHostPort(db tlsca.RouteToDatabase) bool { + dbHost, dbPort := c.DatabaseProxyHostPort(db) + webHost, webPort := c.WebProxyHostPort() + return dbHost == webHost && dbPort == webPort +} + // GetKubeTLSServerName returns k8s server name used in KUBECONFIG to leverage TLS Routing. func GetKubeTLSServerName(k8host string) string { isIPFormat := net.ParseIP(k8host) != nil @@ -3057,6 +3078,9 @@ func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) { return nil, trace.Wrap(err) } + // Perform the ALPN test once at login. + tc.TLSRoutingConnUpgradeRequired = alpnproxy.IsALPNConnUpgradeRequired(tc.WebProxyAddr, tc.InsecureSkipVerify) + // Get the SSHLoginFunc that matches client and cluster settings. sshLoginFunc, err := tc.getSSHLoginFunc(pr) if err != nil { diff --git a/lib/srv/alpnproxy/conn_upgrade.go b/lib/srv/alpnproxy/conn_upgrade.go index e980ee50aba28..a3c8cb7718a74 100644 --- a/lib/srv/alpnproxy/conn_upgrade.go +++ b/lib/srv/alpnproxy/conn_upgrade.go @@ -20,9 +20,11 @@ import ( "bufio" "context" "crypto/tls" + "errors" "net" "net/http" "net/url" + "os" "strings" "github.com/gravitational/trace" @@ -31,6 +33,7 @@ import ( "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" ) @@ -46,6 +49,10 @@ import ( // Proxy Service to establish a tunnel for the originally planned traffic to // preserve the ALPN and SNI information. func IsALPNConnUpgradeRequired(addr string, insecure bool) bool { + if result, ok := OverwriteALPNConnUpgradeRequiredByEnv(addr); ok { + return result + } + netDialer := &net.Dialer{ Timeout: defaults.DefaultIOTimeout, } @@ -55,14 +62,14 @@ func IsALPNConnUpgradeRequired(addr string, insecure bool) bool { } testConn, err := tls.DialWithDialer(netDialer, "tcp", addr, tlsConfig) if err != nil { - // If dialing TLS fails for any reason, we assume connection upgrade is - // not required so it will fallback to original connection method. - // - // This includes handshake failures where both peers support ALPN but - // no common protocol is getting negotiated. We may have to revisit - // this situation or make it configurable if we have to get through a - // middleman with this behavior. For now, we are only interested in the - // case where the middleman does not support ALPN. + if isRemoteNoALPNError(err) { + logrus.Debugf("ALPN connection upgrade required for %q: %v. No ALPN protocol is negotiated by the server.", addr, true) + return true + } + + // If dialing TLS fails for any other reason, we assume connection + // upgrade is not required so it will fallback to original connection + // method. logrus.Infof("ALPN connection upgrade test failed for %q: %v.", addr, err) return false } @@ -75,6 +82,57 @@ func IsALPNConnUpgradeRequired(addr string, insecure bool) bool { return result } +func isRemoteNoALPNError(err error) bool { + if err = errors.Unwrap(err); err == nil { + return false + } + netOpError, ok := err.(*net.OpError) + if !ok { + return false + } + return netOpError.Op == "remote error" && strings.Contains(netOpError.Err.Error(), "tls: no application protocol") +} + +// OverwriteALPNConnUpgradeRequiredByEnv overwrites ALPN connection upgrade +// requirement by environment variable. +func OverwriteALPNConnUpgradeRequiredByEnv(addr string) (bool, bool) { + envValue := os.Getenv(defaults.TLSRoutingConnUpgradeEnvVar) + if envValue == "" { + return false, false + } + result := isALPNConnUpgradeRequiredByEnv(addr, envValue) + logrus.WithField(defaults.TLSRoutingConnUpgradeEnvVar, envValue).Debugf("ALPN connection upgrade required for %q: %v.", addr, result) + return result, true +} + +func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool { + tokens := strings.FieldsFunc(envValue, func(r rune) bool { + return r == ';' || r == ',' + }) + + var upgradeRequiredForAll bool + for _, token := range tokens { + switch { + case strings.ContainsRune(token, '='): + if _, boolText, ok := strings.Cut(token, addr+"="); ok { + upgradeRequiredForAddr, err := utils.ParseBool(boolText) + if err != nil { + logrus.Debugf("Failed to parse %v: %v", envValue, err) + } + return upgradeRequiredForAddr + } + + default: + if boolValue, err := utils.ParseBool(token); err != nil { + logrus.Debugf("Failed to parse %v: %v", envValue, err) + } else { + upgradeRequiredForAll = boolValue + } + } + } + return upgradeRequiredForAll +} + // alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then // tunnels the connection with this connection upgrade. type alpnConnUpgradeDialer struct { diff --git a/lib/srv/alpnproxy/conn_upgrade_test.go b/lib/srv/alpnproxy/conn_upgrade_test.go index 2326f72dd7ea0..c01d01eb4d131 100644 --- a/lib/srv/alpnproxy/conn_upgrade_test.go +++ b/lib/srv/alpnproxy/conn_upgrade_test.go @@ -44,21 +44,31 @@ func TestIsALPNConnUpgradeRequired(t *testing.T) { tests := []struct { name string serverProtos []string + insecure bool expectedResult bool }{ { - name: "upgrade required", + name: "upgrade required (handshake success)", serverProtos: nil, // Use nil for NextProtos to simulate no ALPN support. + insecure: true, expectedResult: true, }, { name: "upgrade not required (proto negotiated)", serverProtos: []string{string(common.ProtocolReverseTunnel)}, + insecure: true, expectedResult: false, }, { - name: "upgrade not required (handshake error)", + name: "upgrade required (handshake with no ALPN error)", serverProtos: []string{"unknown"}, + insecure: true, + expectedResult: true, + }, + { + name: "upgrade not required (other handshake error)", + serverProtos: []string{string(common.ProtocolReverseTunnel)}, + insecure: false, // to cause handshake error expectedResult: false, }, } @@ -66,7 +76,50 @@ func TestIsALPNConnUpgradeRequired(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { server := mustStartMockALPNServer(t, test.serverProtos) - require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(server.Addr().String(), true)) + require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(server.Addr().String(), test.insecure)) + }) + } +} + +func TestIsALPNConnUpgradeRequiredByEnv(t *testing.T) { + t.Parallel() + + addr := "example.teleport.com:443" + tests := []struct { + name string + envValue string + require require.BoolAssertionFunc + }{ + { + name: "upgraded required (for all addr)", + envValue: "yes", + require: require.True, + }, + { + name: "upgraded required (for target addr)", + envValue: "0;example.teleport.com:443=1", + require: require.True, + }, + { + name: "upgraded not required (for all addr)", + envValue: "false", + require: require.False, + }, + { + name: "upgraded not required (no addr match)", + envValue: "another.teleport.com:443=true", + require: require.False, + }, + { + name: "upgraded not required (for target addr)", + envValue: "another.teleport.com:443=true,example.teleport.com:443=false", + require: require.False, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.require(t, isALPNConnUpgradeRequiredByEnv(addr, test.envValue)) }) } } diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index a426870b344b0..9d0d8b1a79e2c 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -42,12 +42,30 @@ type GetClusterCACertPoolFunc func(ctx context.Context) (*x509.CertPool, error) func WithALPNConnUpgradeTest(ctx context.Context, getClusterCertPool GetClusterCACertPoolFunc) LocalProxyConfigOpt { return func(config *LocalProxyConfig) error { config.ALPNConnUpgradeRequired = IsALPNConnUpgradeRequired(config.RemoteProxyAddr, config.InsecureSkipVerify) + return trace.Wrap(WithClusterCAsIfConnUpgrade(ctx, getClusterCertPool)(config)) + } +} + +// WithClusterCAsIfConnUpgrade is a LocalProxyConfigOpt that fetches the +// cluster CAs when ALPN connection upgrades are required. +func WithClusterCAsIfConnUpgrade(ctx context.Context, getClusterCertPool GetClusterCACertPoolFunc) LocalProxyConfigOpt { + return func(config *LocalProxyConfig) error { if !config.ALPNConnUpgradeRequired { return nil } // If ALPN connection upgrade is required, explicitly use the cluster // CAs since the tunneled TLS routing connection serves the Host cert. + if config.RootCAs == nil && getClusterCertPool != nil { + return trace.Wrap(WithClusterCAs(ctx, getClusterCertPool)(config)) + } + return nil + } +} + +// WithClusterCAs is a LocalProxyConfigOpt that fetches the cluster CAs. +func WithClusterCAs(ctx context.Context, getClusterCertPool GetClusterCACertPoolFunc) LocalProxyConfigOpt { + return func(config *LocalProxyConfig) error { clusterCAs, err := getClusterCertPool(ctx) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/app.go b/tool/tsh/app.go index 4f118913ecd92..da7eeb0601127 100644 --- a/tool/tsh/app.go +++ b/tool/tsh/app.go @@ -37,7 +37,6 @@ import ( "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -194,7 +193,7 @@ func onAppLogin(cf *CLIConf) error { } func localProxyRequiredForApp(tc *client.TeleportClient) bool { - return alpnproxy.IsALPNConnUpgradeRequired(tc.WebProxyAddr, tc.InsecureSkipVerify) + return tc.TLSRoutingConnUpgradeRequired } // appLoginTpl is the message that gets printed to a user upon successful login diff --git a/tool/tsh/app_aws.go b/tool/tsh/app_aws.go index b36da4f476b8f..81b96d6127c1a 100644 --- a/tool/tsh/app_aws.go +++ b/tool/tsh/app_aws.go @@ -277,7 +277,7 @@ func (a *awsApp) startLocalALPNProxy(port string) error { a.localALPNProxy, err = alpnproxy.NewLocalProxy( makeBasicLocalProxyConfig(a.cf, tc, listener), alpnproxy.WithClientCerts(appCerts), - alpnproxy.WithALPNConnUpgradeTest(a.cf.Context, tc.RootClusterCACertPool), + alpnproxy.WithClusterCAsIfConnUpgrade(a.cf.Context, tc.RootClusterCACertPool), alpnproxy.WithHTTPMiddleware(&alpnproxy.AWSAccessMiddleware{ AWSCredentials: cred, }), diff --git a/tool/tsh/app_azure.go b/tool/tsh/app_azure.go index 08525817b5ecd..34b2e7cce57d0 100644 --- a/tool/tsh/app_azure.go +++ b/tool/tsh/app_azure.go @@ -241,7 +241,7 @@ func (a *azureApp) startLocalALPNProxy(port string) error { a.localALPNProxy, err = alpnproxy.NewLocalProxy( makeBasicLocalProxyConfig(a.cf, tc, listener), alpnproxy.WithClientCerts(appCerts), - alpnproxy.WithALPNConnUpgradeTest(a.cf.Context, tc.RootClusterCACertPool), + alpnproxy.WithClusterCAsIfConnUpgrade(a.cf.Context, tc.RootClusterCACertPool), alpnproxy.WithHTTPMiddleware(&alpnproxy.AzureMSIMiddleware{ Key: wsPK, Secret: a.msiSecret, diff --git a/tool/tsh/app_gcp.go b/tool/tsh/app_gcp.go index 7da274e8ce700..f716da692f98c 100644 --- a/tool/tsh/app_gcp.go +++ b/tool/tsh/app_gcp.go @@ -325,7 +325,7 @@ func (a *gcpApp) startLocalALPNProxy(port string) error { a.localALPNProxy, err = alpnproxy.NewLocalProxy( makeBasicLocalProxyConfig(a.cf, tc, listener), alpnproxy.WithClientCerts(appCerts), - alpnproxy.WithALPNConnUpgradeTest(a.cf.Context, tc.RootClusterCACertPool), + alpnproxy.WithClusterCAsIfConnUpgrade(a.cf.Context, tc.RootClusterCACertPool), alpnproxy.WithHTTPMiddleware(&alpnproxy.AuthorizationCheckerMiddleware{ Secret: a.secret, }), diff --git a/tool/tsh/db.go b/tool/tsh/db.go index 986e6de54e846..a790c98885f14 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -661,7 +661,7 @@ func prepareLocalProxyOptions(arg *localProxyConfig) ([]alpnproxy.LocalProxyConf opts := []alpnproxy.LocalProxyConfigOpt{ alpnproxy.WithDatabaseProtocol(arg.route.Protocol), - alpnproxy.WithALPNConnUpgradeTest(arg.cf.Context, arg.tc.RootClusterCACertPool), + alpnproxy.WithClusterCAsIfConnUpgrade(arg.cf.Context, arg.tc.RootClusterCACertPool), } if !arg.tunnel && arg.route.Protocol == defaults.ProtocolPostgres { @@ -1110,6 +1110,10 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route *tlsca.RouteToD out.addLocalProxyWithTunnel(formatKeyPolicyReason(tc.PrivateKeyPolicy)) } + if tc.TLSRoutingConnUpgradeRequired && tc.DoesDatabseUserWebProxyHostPort(*route) { + out.addLocalProxy("Teleport Proxy is behind a load balancer.") + } + switch route.Protocol { case defaults.ProtocolSnowflake, defaults.ProtocolDynamoDB, diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 22765efdb3afe..5175190e5a0f8 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -255,6 +255,7 @@ func TestLocalProxyRequirement(t *testing.T) { tests := map[string]struct { clusterAuthPref types.AuthPreference route *tlsca.RouteToDatabase + fakeSetup func(*client.TeleportClient) wantLocalProxy bool wantTunnel bool }{ @@ -277,6 +278,25 @@ func TestLocalProxyRequirement(t *testing.T) { wantLocalProxy: true, wantTunnel: true, }, + "local proxy not required for separate port": { + clusterAuthPref: defaultAuthPref, + fakeSetup: func(tc *client.TeleportClient) { + tc.TLSRoutingEnabled = false + tc.TLSRoutingConnUpgradeRequired = true + tc.PostgresProxyAddr = "separate.postgres.hostport:8888" + }, + wantLocalProxy: false, + wantTunnel: false, + }, + "local proxy required if behind lb": { + clusterAuthPref: defaultAuthPref, + fakeSetup: func(tc *client.TeleportClient) { + tc.TLSRoutingEnabled = true + tc.TLSRoutingConnUpgradeRequired = true + }, + wantLocalProxy: true, + wantTunnel: false, + }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { @@ -291,6 +311,9 @@ func TestLocalProxyRequirement(t *testing.T) { } tc, err := makeClient(cf, false) require.NoError(t, err) + if tt.fakeSetup != nil { + tt.fakeSetup(tc) + } route := &tlsca.RouteToDatabase{ ServiceName: "foo-db", Protocol: "postgres", diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 0960bdd32d85f..0e61bb9225298 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -574,7 +574,7 @@ func onProxyCommandApp(cf *CLIConf) error { makeBasicLocalProxyConfig(cf, tc, listener), alpnproxy.WithALPNProtocol(alpnProtocolForApp(app)), alpnproxy.WithClientCerts(appCerts), - alpnproxy.WithALPNConnUpgradeTest(cf.Context, tc.RootClusterCACertPool), + alpnproxy.WithClusterCAsIfConnUpgrade(cf.Context, tc.RootClusterCACertPool), ) if err != nil { if cerr := listener.Close(); cerr != nil { @@ -813,10 +813,11 @@ func isLocalProxyTunnelRequested(cf *CLIConf) bool { func makeBasicLocalProxyConfig(cf *CLIConf, tc *libclient.TeleportClient, listener net.Listener) alpnproxy.LocalProxyConfig { return alpnproxy.LocalProxyConfig{ - RemoteProxyAddr: tc.WebProxyAddr, - InsecureSkipVerify: cf.InsecureSkipVerify, - ParentContext: cf.Context, - Listener: listener, + RemoteProxyAddr: tc.WebProxyAddr, + InsecureSkipVerify: cf.InsecureSkipVerify, + ParentContext: cf.Context, + Listener: listener, + ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, } } From 7e848cf249f9abc27c8ba96f8426de1f4ab5af0f Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Tue, 21 Mar 2023 14:29:55 -0400 Subject: [PATCH 02/27] fix typos --- lib/client/api.go | 6 +++--- lib/srv/alpnproxy/conn_upgrade.go | 6 +++--- lib/srv/alpnproxy/local_proxy_config_opt.go | 5 +---- tool/tsh/db.go | 2 +- tool/tsh/db_test.go | 10 +++++----- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/lib/client/api.go b/lib/client/api.go index 7a07199ef3265..ea9c4912d6773 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -622,7 +622,7 @@ func (c *Config) LoadProfile(ps ProfileStore, proxyAddr string) error { log.Warnf("Unable to parse dynamic port forwarding in user profile: %v.", err) } - if required, ok := alpnproxy.OverwriteALPNConnUpgradeRequiredByEnv(c.WebProxyAddr); ok { + if required, ok := alpnproxy.OverwriteALPNConnUpgradeRequirementByEnv(c.WebProxyAddr); ok { c.TLSRoutingConnUpgradeRequired = required } log.Infof("ALPN connection upgrade required for %q: %v.", c.WebProxyAddr, c.TLSRoutingConnUpgradeRequired) @@ -852,8 +852,8 @@ func (c *Config) DatabaseProxyHostPort(db tlsca.RouteToDatabase) (string, int) { return c.WebProxyHostPort() } -// DoesDatabseUserWebProxyHostPort returns true if database is using web port. -func (c *Config) DoesDatabseUserWebProxyHostPort(db tlsca.RouteToDatabase) bool { +// DoesDatabseUseWebProxyHostPort returns true if database is using web port. +func (c *Config) DoesDatabseUseWebProxyHostPort(db tlsca.RouteToDatabase) bool { dbHost, dbPort := c.DatabaseProxyHostPort(db) webHost, webPort := c.WebProxyHostPort() return dbHost == webHost && dbPort == webPort diff --git a/lib/srv/alpnproxy/conn_upgrade.go b/lib/srv/alpnproxy/conn_upgrade.go index a3c8cb7718a74..84aad2d1e0857 100644 --- a/lib/srv/alpnproxy/conn_upgrade.go +++ b/lib/srv/alpnproxy/conn_upgrade.go @@ -49,7 +49,7 @@ import ( // Proxy Service to establish a tunnel for the originally planned traffic to // preserve the ALPN and SNI information. func IsALPNConnUpgradeRequired(addr string, insecure bool) bool { - if result, ok := OverwriteALPNConnUpgradeRequiredByEnv(addr); ok { + if result, ok := OverwriteALPNConnUpgradeRequirementByEnv(addr); ok { return result } @@ -93,9 +93,9 @@ func isRemoteNoALPNError(err error) bool { return netOpError.Op == "remote error" && strings.Contains(netOpError.Err.Error(), "tls: no application protocol") } -// OverwriteALPNConnUpgradeRequiredByEnv overwrites ALPN connection upgrade +// OverwriteALPNConnUpgradeRequirementByEnv overwrites ALPN connection upgrade // requirement by environment variable. -func OverwriteALPNConnUpgradeRequiredByEnv(addr string) (bool, bool) { +func OverwriteALPNConnUpgradeRequirementByEnv(addr string) (bool, bool) { envValue := os.Getenv(defaults.TLSRoutingConnUpgradeEnvVar) if envValue == "" { return false, false diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index 9d0d8b1a79e2c..c12837a769eba 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -56,10 +56,7 @@ func WithClusterCAsIfConnUpgrade(ctx context.Context, getClusterCertPool GetClus // If ALPN connection upgrade is required, explicitly use the cluster // CAs since the tunneled TLS routing connection serves the Host cert. - if config.RootCAs == nil && getClusterCertPool != nil { - return trace.Wrap(WithClusterCAs(ctx, getClusterCertPool)(config)) - } - return nil + return trace.Wrap(WithClusterCAs(ctx, getClusterCertPool)(config)) } } diff --git a/tool/tsh/db.go b/tool/tsh/db.go index a790c98885f14..d307c586dad2e 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -1110,7 +1110,7 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route *tlsca.RouteToD out.addLocalProxyWithTunnel(formatKeyPolicyReason(tc.PrivateKeyPolicy)) } - if tc.TLSRoutingConnUpgradeRequired && tc.DoesDatabseUserWebProxyHostPort(*route) { + if tc.TLSRoutingConnUpgradeRequired && tc.DoesDatabseUseWebProxyHostPort(*route) { out.addLocalProxy("Teleport Proxy is behind a load balancer.") } diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 5175190e5a0f8..55331ee623728 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -255,7 +255,7 @@ func TestLocalProxyRequirement(t *testing.T) { tests := map[string]struct { clusterAuthPref types.AuthPreference route *tlsca.RouteToDatabase - fakeSetup func(*client.TeleportClient) + setupTC func(*client.TeleportClient) wantLocalProxy bool wantTunnel bool }{ @@ -280,7 +280,7 @@ func TestLocalProxyRequirement(t *testing.T) { }, "local proxy not required for separate port": { clusterAuthPref: defaultAuthPref, - fakeSetup: func(tc *client.TeleportClient) { + setupTC: func(tc *client.TeleportClient) { tc.TLSRoutingEnabled = false tc.TLSRoutingConnUpgradeRequired = true tc.PostgresProxyAddr = "separate.postgres.hostport:8888" @@ -290,7 +290,7 @@ func TestLocalProxyRequirement(t *testing.T) { }, "local proxy required if behind lb": { clusterAuthPref: defaultAuthPref, - fakeSetup: func(tc *client.TeleportClient) { + setupTC: func(tc *client.TeleportClient) { tc.TLSRoutingEnabled = true tc.TLSRoutingConnUpgradeRequired = true }, @@ -311,8 +311,8 @@ func TestLocalProxyRequirement(t *testing.T) { } tc, err := makeClient(cf, false) require.NoError(t, err) - if tt.fakeSetup != nil { - tt.fakeSetup(tc) + if tt.setupTC != nil { + tt.setupTC(tc) } route := &tlsca.RouteToDatabase{ ServiceName: "foo-db", From 2c9b116128c5b896ea76332c5a0b3be6f9b53cb3 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Tue, 21 Mar 2023 14:48:19 -0400 Subject: [PATCH 03/27] remove extra period --- tool/tsh/db.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tool/tsh/db.go b/tool/tsh/db.go index d307c586dad2e..a3799075bc542 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -1111,7 +1111,7 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route *tlsca.RouteToD } if tc.TLSRoutingConnUpgradeRequired && tc.DoesDatabseUseWebProxyHostPort(*route) { - out.addLocalProxy("Teleport Proxy is behind a load balancer.") + out.addLocalProxy("Teleport Proxy is behind a load balancer") } switch route.Protocol { From 54019f384223102fd1680db419e9cb8d6015242a Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Tue, 28 Mar 2023 13:56:32 -0400 Subject: [PATCH 04/27] simplify error check --- lib/srv/alpnproxy/conn_upgrade.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lib/srv/alpnproxy/conn_upgrade.go b/lib/srv/alpnproxy/conn_upgrade.go index 84aad2d1e0857..47de915374063 100644 --- a/lib/srv/alpnproxy/conn_upgrade.go +++ b/lib/srv/alpnproxy/conn_upgrade.go @@ -83,14 +83,8 @@ func IsALPNConnUpgradeRequired(addr string, insecure bool) bool { } func isRemoteNoALPNError(err error) bool { - if err = errors.Unwrap(err); err == nil { - return false - } - netOpError, ok := err.(*net.OpError) - if !ok { - return false - } - return netOpError.Op == "remote error" && strings.Contains(netOpError.Err.Error(), "tls: no application protocol") + var opErr *net.OpError + return errors.As(err, &opErr) && opErr.Op == "remote error" && strings.Contains(opErr.Err.Error(), "tls: no application protocol") } // OverwriteALPNConnUpgradeRequirementByEnv overwrites ALPN connection upgrade From d98dd80d2e6d8e190ca4ef573a574875297b3b6f Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Tue, 28 Mar 2023 10:09:13 -0400 Subject: [PATCH 05/27] moving things over --- .../client}/alpnproxy/conn_upgrade.go | 17 +++-- .../client}/alpnproxy/conn_upgrade_test.go | 28 +++----- {lib/srv => api/client}/alpnproxy/dialer.go | 12 +--- api/constants/constants.go | 12 ++++ api/fixtures/fixtures.go | 66 +++++++++++++++++++ constants.go | 12 ---- lib/client/api.go | 2 +- lib/srv/alpnproxy/local_proxy.go | 9 +-- lib/srv/alpnproxy/local_proxy_config_opt.go | 3 +- lib/web/conn_upgrade.go | 8 +-- 10 files changed, 111 insertions(+), 58 deletions(-) rename {lib/srv => api/client}/alpnproxy/conn_upgrade.go (92%) rename {lib/srv => api/client}/alpnproxy/conn_upgrade_test.go (85%) rename {lib/srv => api/client}/alpnproxy/dialer.go (84%) create mode 100644 api/fixtures/fixtures.go diff --git a/lib/srv/alpnproxy/conn_upgrade.go b/api/client/alpnproxy/conn_upgrade.go similarity index 92% rename from lib/srv/alpnproxy/conn_upgrade.go rename to api/client/alpnproxy/conn_upgrade.go index 47de915374063..739123e2bdc67 100644 --- a/lib/srv/alpnproxy/conn_upgrade.go +++ b/api/client/alpnproxy/conn_upgrade.go @@ -30,11 +30,10 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" - apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/lib/srv/alpnproxy/common" ) // IsALPNConnUpgradeRequired returns true if a tunnel is required through a HTTP @@ -57,7 +56,7 @@ func IsALPNConnUpgradeRequired(addr string, insecure bool) bool { Timeout: defaults.DefaultIOTimeout, } tlsConfig := &tls.Config{ - NextProtos: []string{string(common.ProtocolReverseTunnel)}, + NextProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, InsecureSkipVerify: insecure, } testConn, err := tls.DialWithDialer(netDialer, "tcp", addr, tlsConfig) @@ -130,12 +129,12 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool { // alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then // tunnels the connection with this connection upgrade. type alpnConnUpgradeDialer struct { - dialer apiclient.ContextDialer + dialer client.ContextDialer tlsConfig *tls.Config } // newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer. -func newALPNConnUpgradeDialer(dialer apiclient.ContextDialer, tlsConfig *tls.Config) ContextDialer { +func newALPNConnUpgradeDialer(dialer client.ContextDialer, tlsConfig *tls.Config) client.ContextDialer { return &alpnConnUpgradeDialer{ dialer: dialer, tlsConfig: tlsConfig, @@ -172,7 +171,7 @@ func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr st err = upgradeConnThroughWebAPI(tlsConn, url.URL{ Host: addr, Scheme: "https", - Path: teleport.WebAPIConnUpgrade, + Path: constants.WebAPIConnUpgrade, }) if err != nil { defer tlsConn.Close() @@ -188,7 +187,7 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error { } // For now, only "alpn" is supported. - req.Header.Add(teleport.WebAPIConnUpgradeHeader, teleport.WebAPIConnUpgradeTypeALPN) + req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeALPN) // Send the request and check if upgrade is successful. if err = req.Write(conn); err != nil { @@ -204,7 +203,7 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error { if http.StatusNotFound == resp.StatusCode { return trace.NotImplemented( "connection upgrade call to %q failed with status code %v. Please upgrade the server and try again.", - teleport.WebAPIConnUpgrade, + constants.WebAPIConnUpgrade, resp.StatusCode, ) } diff --git a/lib/srv/alpnproxy/conn_upgrade_test.go b/api/client/alpnproxy/conn_upgrade_test.go similarity index 85% rename from lib/srv/alpnproxy/conn_upgrade_test.go rename to api/client/alpnproxy/conn_upgrade_test.go index c01d01eb4d131..d2f76b984a1c1 100644 --- a/lib/srv/alpnproxy/conn_upgrade_test.go +++ b/api/client/alpnproxy/conn_upgrade_test.go @@ -20,7 +20,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "errors" "net" "net/http" @@ -31,11 +30,9 @@ import ( "github.com/stretchr/testify/require" - "github.com/gravitational/teleport" - apiclient "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/srv/alpnproxy/common" - "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/fixtures" ) func TestIsALPNConnUpgradeRequired(t *testing.T) { @@ -55,7 +52,7 @@ func TestIsALPNConnUpgradeRequired(t *testing.T) { }, { name: "upgrade not required (proto negotiated)", - serverProtos: []string{string(common.ProtocolReverseTunnel)}, + serverProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, insecure: true, expectedResult: false, }, @@ -67,7 +64,7 @@ func TestIsALPNConnUpgradeRequired(t *testing.T) { }, { name: "upgrade not required (other handshake error)", - serverProtos: []string{string(common.ProtocolReverseTunnel)}, + serverProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, insecure: false, // to cause handshake error expectedResult: false, }, @@ -138,7 +135,7 @@ func TestALPNConnUpgradeDialer(t *testing.T) { pool.AddCert(server.Certificate()) tlsConfig := &tls.Config{RootCAs: pool} - preDialer := apiclient.NewDialer(ctx, 0, 5*time.Second, apiclient.WithTLSConfig(tlsConfig)) + preDialer := client.NewDialer(ctx, 0, 5*time.Second, client.WithTLSConfig(tlsConfig)) dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) conn, err := dialer.DialContext(ctx, "tcp", addr.Host) require.NoError(t, err) @@ -158,7 +155,7 @@ func TestALPNConnUpgradeDialer(t *testing.T) { require.NoError(t, err) tlsConfig := &tls.Config{InsecureSkipVerify: true} - preDialer := apiclient.NewDialer(ctx, 0, 5*time.Second, apiclient.WithTLSConfig(tlsConfig)) + preDialer := client.NewDialer(ctx, 0, 5*time.Second, client.WithTLSConfig(tlsConfig)) dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) _, err = dialer.DialContext(ctx, "tcp", addr.Host) require.Error(t, err) @@ -207,12 +204,7 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe listener.Close() }) - caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{ - CommonName: "localhost", - }, []string{"localhost"}, defaults.CATTL) - require.NoError(t, err) - - cert, err := tls.X509KeyPair(caCert, caKey) + cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) require.NoError(t, err) m := &mockALPNServer{ @@ -228,8 +220,8 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe // upgrade request and sends back some data inside the tunnel. func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, teleport.WebAPIConnUpgrade, r.URL.Path) - require.Equal(t, upgradeType, r.Header.Get(teleport.WebAPIConnUpgradeHeader)) + require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) + require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader)) hj, ok := w.(http.Hijacker) require.True(t, ok) diff --git a/lib/srv/alpnproxy/dialer.go b/api/client/alpnproxy/dialer.go similarity index 84% rename from lib/srv/alpnproxy/dialer.go rename to api/client/alpnproxy/dialer.go index f4b3e183bb97a..c73e22cae8590 100644 --- a/lib/srv/alpnproxy/dialer.go +++ b/api/client/alpnproxy/dialer.go @@ -24,15 +24,9 @@ import ( "github.com/gravitational/trace" - apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client" ) -// ContextDialer represents network dialer interface that uses context -type ContextDialer interface { - // DialContext is a function that dials the specified address - DialContext(ctx context.Context, network, addr string) (net.Conn, error) -} - // ALPNDialerConfig is the config for ALPNDialer. type ALPNDialerConfig struct { // KeepAlivePeriod defines period between keep alives. @@ -54,7 +48,7 @@ type ALPNDialer struct { } // NewALPNDialer creates a new ALPNDialer. -func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer { +func NewALPNDialer(cfg ALPNDialerConfig) client.ContextDialer { return &ALPNDialer{ cfg: cfg, } @@ -66,7 +60,7 @@ func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net. return nil, trace.BadParameter("missing TLS config") } - dialer := apiclient.NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, apiclient.WithTLSConfig(d.cfg.TLSConfig)) + dialer := client.NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, client.WithTLSConfig(d.cfg.TLSConfig)) if d.cfg.ALPNConnUpgradeRequired { dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ InsecureSkipVerify: d.cfg.TLSConfig.InsecureSkipVerify, diff --git a/api/constants/constants.go b/api/constants/constants.go index 50c409cd6aca1..47b3b6c513c29 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -402,3 +402,15 @@ const ( // TimeoutGetClusterAlerts is the timeout for grabbing cluster alerts from tctl and tsh TimeoutGetClusterAlerts = time.Millisecond * 500 ) + +const ( + // WebAPIConnUpgrade is the HTTP web API to make the connection upgrade + // call. + WebAPIConnUpgrade = "/webapi/connectionupgrade" + // WebAPIConnUpgradeHeader is the header used to indicate the requested + // connection upgrade types in the connection upgrade API. + WebAPIConnUpgradeHeader = "Upgrade" + // WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies + // the upgraded connection should be handled by the ALPN handler. + WebAPIConnUpgradeTypeALPN = "alpn" +) diff --git a/api/fixtures/fixtures.go b/api/fixtures/fixtures.go new file mode 100644 index 0000000000000..51cfe7fd30018 --- /dev/null +++ b/api/fixtures/fixtures.go @@ -0,0 +1,66 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fixtures + +const ( + TLSCACertPEM = `-----BEGIN CERTIFICATE----- +MIIDKjCCAhKgAwIBAgIQJtJDJZZBkg/afM8d2ZJCTjANBgkqhkiG9w0BAQsFADBA +MRUwEwYDVQQKEwxUZWxlcG9ydCBPU1MxJzAlBgNVBAMTHnRlbGVwb3J0LmxvY2Fs +aG9zdC5sb2NhbGRvbWFpbjAeFw0xNzA1MDkxOTQwMzZaFw0yNzA1MDcxOTQwMzZa +MEAxFTATBgNVBAoTDFRlbGVwb3J0IE9TUzEnMCUGA1UEAxMedGVsZXBvcnQubG9j +YWxob3N0LmxvY2FsZG9tYWluMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAuKFLaf2iII/xDR+m2Yj6PnUEa+qzqwxsdLUjnunFZaAXG+hZm4Ml80SCiBgI +gTHQlJyLIkTtuRoH5aeMyz1ERUCtii4ZsTqDrjjUybxP4r+4HVX6m34s6hwEr8Fi +fts9pMp4iS3tQguRc28gPdDo/T6VrJTVYUfUUsNDRtIrlB5O9igqqLnuaY9eqGi4 +PUx0G0wRYJpRywoj8G0IkpfQTiX+CAC7dt5ws7ZrnGqCNBLGi5bGsaMmptVbsSEp +1TenntF54V1iR49IV5JqDhm1S0HmkleoJzKdc+6sP/xNepz9PJzuF9d9NubTLWgB +sK28YItcmWHdHXD/ODxVaehRjwIDAQABoyAwHjAOBgNVHQ8BAf8EBAMCB4AwDAYD +VR0TAQH/BAIwADANBgkqhkiG9w0BAQsFAAOCAQEAAVU6sNBdj76saHwOxGSdnEqQ +o2tMuR3msSM4F6wFK2UkKepsD7CYIf/PzNSNUqA5JIEUVeMqGyiHuAbU4C655nT1 +IyJX1D/+r73sSp5jbIpQm2xoQGZnj6g/Kltw8OSOAw+DsMF/PLVqoWJp07u6ew/m +NxWsJKcZ5k+q4eMxci9mKRHHqsquWKXzQlURMNFI+mGaFwrKM4dmzaR0BEc+ilSx +QqUvQ74smsLK+zhNikmgjlGC5ob9g8XkhVAkJMAh2rb9onDNiRl68iAgczP88mXu +vN/o98dypzsPxXmw6tkDqIRPUAUbh465rlY5sKMmRgXi2rUfl/QV5nbozUo/HQ== +-----END CERTIFICATE-----` + TLSCAKeyPEM = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAuKFLaf2iII/xDR+m2Yj6PnUEa+qzqwxsdLUjnunFZaAXG+hZ +m4Ml80SCiBgIgTHQlJyLIkTtuRoH5aeMyz1ERUCtii4ZsTqDrjjUybxP4r+4HVX6 +m34s6hwEr8Fifts9pMp4iS3tQguRc28gPdDo/T6VrJTVYUfUUsNDRtIrlB5O9igq +qLnuaY9eqGi4PUx0G0wRYJpRywoj8G0IkpfQTiX+CAC7dt5ws7ZrnGqCNBLGi5bG +saMmptVbsSEp1TenntF54V1iR49IV5JqDhm1S0HmkleoJzKdc+6sP/xNepz9PJzu +F9d9NubTLWgBsK28YItcmWHdHXD/ODxVaehRjwIDAQABAoIBABy4orWrShRMsA/9 +k4QVpfAfXf+3tBlwxlJld1QaQ6XqgI3L2FyzyyyLxM6NBo2qhSsJKy+6j0yTOxVD +ukhHkJ5BUH3FbCPA2Yk5uAhl7ft1HZwaqvCTcUM99pCswbjAPFetU5DrfxQeHpNZ +fyd+ny/+E2SUhpkqhmIVlBqpSTQyOywbiEvZ6ZiFmncdHhXaCy3YZsylrKUGPzsJ +jfU2iOE167eTOIjPStsaoCPv9jLSyy2OvuNNudS+Y1qkFz8ZGvPp+HB+Iig+AlAE +7KMzNrIW7PlHTDgUly1cRCl3+84yE2mJ97+hHiEy//HIwVDUpI529i2hMYM/u4qz +Wso/2tkCgYEA2FdE4bmCrZiA9eS8qobwGLE1+MJME4YwfJkynZUHHX93xORPQ66e +WYpN7/xbMvBDa8LZZYVTNVtZ/SkEUaTb5NQW2zXKoIutk1PFBb8NbA0m8Ss/mOJA +d5nUYGr987O9fRh1yP9TksBshHB/5A8U2UG8MFFCNvJTZDPRkuSlMiUCgYEA2nnb +hAJrhY7PaF6jdfimGvvponkUiEbWLppg7/SjgPg+QgqIwuLybryXyOAp+TEnNzgU +ujAjhNtIiyB/B13TDxOgUgWUWPbPvUAWGEvwI9h+RLie1umGHd48G1NR76fwqSf1 +y7z3YRnq8vCdz8ywB3o5GO6SH6QkMJBIxfIMlKMCgYA55akOi7oYQT8KD4waSwCI +ayyZhU4cz4W8Yrd0CsUbtNhVvhAked/w8J2JA01Y5Yn1lfDeRX8OQYNkyAxa2Tbs +F4KCafPvYVIzonCQ6B9sclygoEVl4e8E0wtOPnP2O30TtG8ZOpOgK5UfIIhpfUvE +FN6LQ8PntpRwtZl5qW04bQKBgGnHhFxHG64fthZPdA9jY3E/NSCgRSuyOHN59aNY +rG1+RA6PsSXC4iRxlYAB4PCxNs6KjaaUNi5WSaprAnYbnFv5Ya802l20qmJ0C/6Z +jdydLo2xYd6mVHRTrICCd/J0OpZ8LYsGpDPUa6hSjeYVscj9CXYj1IYTYB5PTZzh +k+vHAoGBAJyA+RtBF5m64/TqhZFcesTtnpWaRhQ50xXnNVF3W1eKGPtdTDKOaENA +LJxgC1GdoEz2ilXW802H9QrdKf9GPqxwi2TVzfO6pzWkdZcmbItu+QCCFz+co+r8 ++ki49FmlfbR5YVPN+8X40aLQB4xDkCHwRwTkrigzWQhIOv8NAhDA +-----END RSA PRIVATE KEY-----` + // Backwards-compatibility alias for teleport.e + SigningCertPEM = TLSCACertPEM +) diff --git a/constants.go b/constants.go index 86640f05a9c9f..18d8b708a737d 100644 --- a/constants.go +++ b/constants.go @@ -799,15 +799,3 @@ const UserSingleUseCertTTL = time.Minute // StandardHTTPSPort is the default port used for the https URI scheme, // cf. RFC 7230 ยง 2.7.2. const StandardHTTPSPort = 443 - -const ( - // WebAPIConnUpgrade is the HTTP web API to make the connection upgrade - // call. - WebAPIConnUpgrade = "/webapi/connectionupgrade" - // WebAPIConnUpgradeHeader is the header used to indicate the requested - // connection upgrade types in the connection upgrade API. - WebAPIConnUpgradeHeader = "Upgrade" - // WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies - // the upgraded connection should be handled by the ALPN handler. - WebAPIConnUpgradeTypeALPN = "alpn" -) diff --git a/lib/client/api.go b/lib/client/api.go index ea9c4912d6773..2bd0ae1778ea2 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -47,6 +47,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/alpnproxy" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" @@ -74,7 +75,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/shell" - "github.com/gravitational/teleport/lib/srv/alpnproxy" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/scp" diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 56ce4c533b82c..a2f156c58516a 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -34,6 +34,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/exp/slices" + apiclient "github.com/gravitational/teleport/api/client/alpnproxy" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" commonApp "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" @@ -230,7 +231,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC return trace.Wrap(err) } - tlsConn, err := DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) + tlsConn, err := apiclient.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) if err != nil { return trace.Wrap(err) } @@ -255,8 +256,8 @@ func (l *LocalProxy) Close() error { return nil } -func (l *LocalProxy) getALPNDialerConfig(certs []tls.Certificate) ALPNDialerConfig { - return ALPNDialerConfig{ +func (l *LocalProxy) getALPNDialerConfig(certs []tls.Certificate) apiclient.ALPNDialerConfig { + return apiclient.ALPNDialerConfig{ ALPNConnUpgradeRequired: l.cfg.ALPNConnUpgradeRequired, TLSConfig: &tls.Config{ NextProtos: common.ProtocolsToString(l.cfg.Protocols), @@ -306,7 +307,7 @@ func (l *LocalProxy) StartHTTPAccessProxy(ctx context.Context) error { http.Error(w, http.StatusText(code), code) }, Transport: &http.Transport{ - DialTLSContext: NewALPNDialer(l.getALPNDialerConfig(l.getCerts())).DialContext, + DialTLSContext: apiclient.NewALPNDialer(l.getALPNDialerConfig(l.getCerts())).DialContext, }, } err := http.Serve(l.cfg.Listener, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index c12837a769eba..021ba4bf340e7 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/trace" + apiclient "github.com/gravitational/teleport/api/client/alpnproxy" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" ) @@ -41,7 +42,7 @@ type GetClusterCACertPoolFunc func(ctx context.Context) (*x509.CertPool, error) // already been set. func WithALPNConnUpgradeTest(ctx context.Context, getClusterCertPool GetClusterCACertPoolFunc) LocalProxyConfigOpt { return func(config *LocalProxyConfig) error { - config.ALPNConnUpgradeRequired = IsALPNConnUpgradeRequired(config.RemoteProxyAddr, config.InsecureSkipVerify) + config.ALPNConnUpgradeRequired = apiclient.IsALPNConnUpgradeRequired(config.RemoteProxyAddr, config.InsecureSkipVerify) return trace.Wrap(WithClusterCAsIfConnUpgrade(ctx, getClusterCertPool)(config)) } } diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index 9f1638a74a2d2..cdfe3dae0556a 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -25,17 +25,17 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/utils" ) // selectConnectionUpgrade selects the requested upgrade type and returns the // corresponding handler. func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHandler, error) { - upgrades := r.Header.Values(teleport.WebAPIConnUpgradeHeader) + upgrades := r.Header.Values(constants.WebAPIConnUpgradeHeader) for _, upgradeType := range upgrades { switch upgradeType { - case teleport.WebAPIConnUpgradeTypeALPN: + case constants.WebAPIConnUpgradeTypeALPN: return upgradeType, h.upgradeALPN, nil } } @@ -90,7 +90,7 @@ func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error { func writeUpgradeResponse(w io.Writer, upgradeType string) error { header := make(http.Header) - header.Add(teleport.WebAPIConnUpgradeHeader, upgradeType) + header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) response := &http.Response{ Status: http.StatusText(http.StatusSwitchingProtocols), StatusCode: http.StatusSwitchingProtocols, From f3399cb1d73838f9b9e0259ce42aa93c2cf92404 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Tue, 28 Mar 2023 16:01:14 -0400 Subject: [PATCH 06/27] tsh dials --- api/client/{alpnproxy/dialer.go => alpn.go} | 35 +++- .../conn_upgrade.go => alpn_conn_upgrade.go} | 7 +- ...rade_test.go => alpn_conn_upgrade_test.go} | 7 +- api/client/client.go | 14 +- api/client/contextdialer.go | 39 ++-- lib/client/api.go | 27 ++- lib/client/client.go | 1 + lib/srv/alpnproxy/local_proxy.go | 10 +- lib/srv/alpnproxy/local_proxy_config_opt.go | 4 +- lib/utils/proxy/proxy.go | 184 +++--------------- 10 files changed, 124 insertions(+), 204 deletions(-) rename api/client/{alpnproxy/dialer.go => alpn.go} (75%) rename api/client/{alpnproxy/conn_upgrade.go => alpn_conn_upgrade.go} (96%) rename api/client/{alpnproxy/conn_upgrade_test.go => alpn_conn_upgrade_test.go} (96%) diff --git a/api/client/alpnproxy/dialer.go b/api/client/alpn.go similarity index 75% rename from api/client/alpnproxy/dialer.go rename to api/client/alpn.go index c73e22cae8590..500410c53a361 100644 --- a/api/client/alpnproxy/dialer.go +++ b/api/client/alpn.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package alpnproxy +package client import ( "context" @@ -23,8 +23,6 @@ import ( "time" "github.com/gravitational/trace" - - "github.com/gravitational/teleport/api/client" ) // ALPNDialerConfig is the config for ALPNDialer. @@ -48,22 +46,41 @@ type ALPNDialer struct { } // NewALPNDialer creates a new ALPNDialer. -func NewALPNDialer(cfg ALPNDialerConfig) client.ContextDialer { +func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer { return &ALPNDialer{ cfg: cfg, } } -// DialContext implements ContextDialer. -func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { +func (d *ALPNDialer) getTLSConfig(addr string) (*tls.Config, error) { if d.cfg.TLSConfig == nil { return nil, trace.BadParameter("missing TLS config") } + if d.cfg.TLSConfig.ServerName != "" { + return d.cfg.TLSConfig, nil + } + + tlsConfig := d.cfg.TLSConfig.Clone() + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, trace.Wrap(err) + } + tlsConfig.ServerName = host + return tlsConfig, nil +} + +// DialContext implements ContextDialer. +func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + tlsConfig, err := d.getTLSConfig(addr) + if err != nil { + return nil, trace.Wrap(err) + } - dialer := client.NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, client.WithTLSConfig(d.cfg.TLSConfig)) + // TODO support system proxy. + dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout) if d.cfg.ALPNConnUpgradeRequired { dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ - InsecureSkipVerify: d.cfg.TLSConfig.InsecureSkipVerify, + InsecureSkipVerify: tlsConfig.InsecureSkipVerify, }) } @@ -72,7 +89,7 @@ func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net. return nil, trace.Wrap(err) } - tlsConn := tls.Client(conn, d.cfg.TLSConfig) + tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { defer tlsConn.Close() return nil, trace.Wrap(err) diff --git a/api/client/alpnproxy/conn_upgrade.go b/api/client/alpn_conn_upgrade.go similarity index 96% rename from api/client/alpnproxy/conn_upgrade.go rename to api/client/alpn_conn_upgrade.go index 739123e2bdc67..8ac0cd2bef397 100644 --- a/api/client/alpnproxy/conn_upgrade.go +++ b/api/client/alpn_conn_upgrade.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package alpnproxy +package client import ( "bufio" @@ -30,7 +30,6 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" - "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/utils" @@ -129,12 +128,12 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool { // alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then // tunnels the connection with this connection upgrade. type alpnConnUpgradeDialer struct { - dialer client.ContextDialer + dialer ContextDialer tlsConfig *tls.Config } // newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer. -func newALPNConnUpgradeDialer(dialer client.ContextDialer, tlsConfig *tls.Config) client.ContextDialer { +func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config) ContextDialer { return &alpnConnUpgradeDialer{ dialer: dialer, tlsConfig: tlsConfig, diff --git a/api/client/alpnproxy/conn_upgrade_test.go b/api/client/alpn_conn_upgrade_test.go similarity index 96% rename from api/client/alpnproxy/conn_upgrade_test.go rename to api/client/alpn_conn_upgrade_test.go index d2f76b984a1c1..4c02b3ebfebb6 100644 --- a/api/client/alpnproxy/conn_upgrade_test.go +++ b/api/client/alpn_conn_upgrade_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package alpnproxy +package client import ( "context" @@ -30,7 +30,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/fixtures" ) @@ -135,7 +134,7 @@ func TestALPNConnUpgradeDialer(t *testing.T) { pool.AddCert(server.Certificate()) tlsConfig := &tls.Config{RootCAs: pool} - preDialer := client.NewDialer(ctx, 0, 5*time.Second, client.WithTLSConfig(tlsConfig)) + preDialer := NewDialer(ctx, 0, 5*time.Second) dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) conn, err := dialer.DialContext(ctx, "tcp", addr.Host) require.NoError(t, err) @@ -155,7 +154,7 @@ func TestALPNConnUpgradeDialer(t *testing.T) { require.NoError(t, err) tlsConfig := &tls.Config{InsecureSkipVerify: true} - preDialer := client.NewDialer(ctx, 0, 5*time.Second, client.WithTLSConfig(tlsConfig)) + preDialer := NewDialer(ctx, 0, 5*time.Second) dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) _, err = dialer.DialContext(ctx, "tcp", addr.Host) require.Error(t, err) diff --git a/api/client/client.go b/api/client/client.go index 49993bb5902bb..99d6f8566ad94 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -329,6 +329,12 @@ type ( // authConnect connects to the Teleport Auth Server directly. func authConnect(ctx context.Context, params connectParams) (*Client, error) { dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithTLSConfig(params.tlsConfig)) + if params.cfg.IsALPNConnUpgradeRequired(params.addr, params.cfg.InsecureAddressDiscovery) { + dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ + InsecureSkipVerify: params.cfg.InsecureAddressDiscovery, + }) + } + clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as an auth server", params.addr) @@ -367,7 +373,7 @@ func tlsRoutingConnect(ctx context.Context, params connectParams) (*Client, erro if params.sshConfig == nil { return nil, trace.BadParameter("must provide ssh client config") } - dialer := newTLSRoutingTunnelDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, params.addr, params.cfg.InsecureAddressDiscovery) + dialer := newTLSRoutingTunnelDialer(*params.sshConfig, params) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v with TLS Routing dialer", params.addr) @@ -526,6 +532,9 @@ type Config struct { CircuitBreakerConfig breaker.Config // Context is the base context to use for dialing. If not provided context.Background is used Context context.Context + // IsALPNConnUpgradeRequired is a callback function to check whether + // connection upgrade is required for TLS routing. + IsALPNConnUpgradeRequired func(addr string, insecure bool) bool } // CheckAndSetDefaults checks and sets default config values. @@ -559,6 +568,9 @@ func (c *Config) CheckAndSetDefaults() error { if !c.DialInBackground { c.DialOpts = append(c.DialOpts, grpc.WithBlock()) } + if c.IsALPNConnUpgradeRequired == nil { + c.IsALPNConnUpgradeRequired = IsALPNConnUpgradeRequired + } return nil } diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index e8e5d544b2884..82e6ce497a485 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -136,9 +136,14 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur // newTLSRoutingTunnelDialer makes a reverse tunnel TLS Routing dialer to connect to an Auth server // through the SSH reverse tunnel on the proxy. -func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer { +func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer { return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { - resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure}) + insecure := params.cfg.InsecureAddressDiscovery + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: params.addr, + Insecure: insecure, + }) if err != nil { return nil, trace.Wrap(err) } @@ -152,30 +157,20 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeou return nil, trace.Wrap(err) } - dialer := &net.Dialer{ - Timeout: dialTimeout, - KeepAlive: keepAlivePeriod, - } - conn, err = dialer.DialContext(ctx, network, tunnelAddr) - if err != nil { - return nil, trace.Wrap(err) - } - - host, _, err := webclient.ParseHostPort(tunnelAddr) - if err != nil { - return nil, trace.Wrap(err) - } - - tlsConn := tls.Client(conn, &tls.Config{ - NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, - InsecureSkipVerify: insecure, - ServerName: host, + tlsConn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ + ALPNConnUpgradeRequired: params.cfg.IsALPNConnUpgradeRequired(tunnelAddr, insecure), + DialTimeout: params.cfg.DialTimeout, + KeepAlivePeriod: params.cfg.KeepAlivePeriod, + TLSConfig: &tls.Config{ + NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, + InsecureSkipVerify: insecure, + }, }) - if err := tlsConn.Handshake(); err != nil { + if err != nil { return nil, trace.Wrap(err) } - sconn, err := sshConnect(ctx, tlsConn, ssh, dialTimeout, tunnelAddr) + sconn, err := sshConnect(ctx, tlsConn, ssh, params.cfg.DialTimeout, tunnelAddr) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/client/api.go b/lib/client/api.go index 2bd0ae1778ea2..512d570ab1f06 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -47,7 +47,6 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/alpnproxy" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" @@ -622,7 +621,7 @@ func (c *Config) LoadProfile(ps ProfileStore, proxyAddr string) error { log.Warnf("Unable to parse dynamic port forwarding in user profile: %v.", err) } - if required, ok := alpnproxy.OverwriteALPNConnUpgradeRequirementByEnv(c.WebProxyAddr); ok { + if required, ok := client.OverwriteALPNConnUpgradeRequirementByEnv(c.WebProxyAddr); ok { c.TLSRoutingConnUpgradeRequired = required } log.Infof("ALPN connection upgrade required for %q: %v.", c.WebProxyAddr, c.TLSRoutingConnUpgradeRequired) @@ -2898,7 +2897,17 @@ func makeProxySSHClientWithTLSWrapper(ctx context.Context, tc *TeleportClient, s } tlsConfig.NextProtos = []string{string(alpncommon.ProtocolProxySSH)} - dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(tlsConfig)) + + alpnConfig := client.ALPNDialerConfig{ + TLSConfig: tlsConfig, + ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, + DialTimeout: sshConfig.Timeout, + } + if proxyAddr != tc.WebProxyAddr { + alpnConfig.ALPNConnUpgradeRequired = client.IsALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify) + } + + dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(client.NewALPNDialer(alpnConfig))) return dialer.Dial(ctx, "tcp", proxyAddr, sshConfig) } @@ -3079,7 +3088,7 @@ func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) { } // Perform the ALPN test once at login. - tc.TLSRoutingConnUpgradeRequired = alpnproxy.IsALPNConnUpgradeRequired(tc.WebProxyAddr, tc.InsecureSkipVerify) + tc.TLSRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(tc.WebProxyAddr, tc.InsecureSkipVerify) // Get the SSHLoginFunc that matches client and cluster settings. sshLoginFunc, err := tc.getSSHLoginFunc(pr) @@ -4449,6 +4458,7 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, + IsALPNConnUpgradeRequired: tc.isALPNConnUpgradeRequired, }) if err != nil { return nil, trace.Wrap(err) @@ -4456,6 +4466,15 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste return kubeproto.NewKubeServiceClient(clt.GetConnection()), nil } +func (tc *TeleportClient) isALPNConnUpgradeRequired(addr string, insecure bool) bool { + // Use cached value. + if addr == tc.WebProxyAddr { + return tc.TLSRoutingConnUpgradeRequired + } + // Do a test for other addresses. + return client.IsALPNConnUpgradeRequired(addr, insecure) +} + // RootClusterCACertPool returns a *x509.CertPool with the root cluster CA. func (tc *TeleportClient) RootClusterCACertPool(ctx context.Context) (*x509.CertPool, error) { _, span := tc.Tracer.Start( diff --git a/lib/client/client.go b/lib/client/client.go index ff2110e6fcc8d..eba9a81642a87 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1108,6 +1108,7 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co }, ALPNSNIAuthDialClusterName: clusterName, CircuitBreakerConfig: breaker.NoopBreakerConfig(), + IsALPNConnUpgradeRequired: proxy.teleportClient.isALPNConnUpgradeRequired, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index a2f156c58516a..1fd3e6f42a76f 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -34,7 +34,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/exp/slices" - apiclient "github.com/gravitational/teleport/api/client/alpnproxy" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" commonApp "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" @@ -231,7 +231,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC return trace.Wrap(err) } - tlsConn, err := apiclient.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) + tlsConn, err := client.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) if err != nil { return trace.Wrap(err) } @@ -256,8 +256,8 @@ func (l *LocalProxy) Close() error { return nil } -func (l *LocalProxy) getALPNDialerConfig(certs []tls.Certificate) apiclient.ALPNDialerConfig { - return apiclient.ALPNDialerConfig{ +func (l *LocalProxy) getALPNDialerConfig(certs []tls.Certificate) client.ALPNDialerConfig { + return client.ALPNDialerConfig{ ALPNConnUpgradeRequired: l.cfg.ALPNConnUpgradeRequired, TLSConfig: &tls.Config{ NextProtos: common.ProtocolsToString(l.cfg.Protocols), @@ -307,7 +307,7 @@ func (l *LocalProxy) StartHTTPAccessProxy(ctx context.Context) error { http.Error(w, http.StatusText(code), code) }, Transport: &http.Transport{ - DialTLSContext: apiclient.NewALPNDialer(l.getALPNDialerConfig(l.getCerts())).DialContext, + DialTLSContext: client.NewALPNDialer(l.getALPNDialerConfig(l.getCerts())).DialContext, }, } err := http.Serve(l.cfg.Listener, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index 021ba4bf340e7..736ef7a620f12 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -24,7 +24,7 @@ import ( "github.com/gravitational/trace" - apiclient "github.com/gravitational/teleport/api/client/alpnproxy" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" ) @@ -42,7 +42,7 @@ type GetClusterCACertPoolFunc func(ctx context.Context) (*x509.CertPool, error) // already been set. func WithALPNConnUpgradeTest(ctx context.Context, getClusterCertPool GetClusterCACertPoolFunc) LocalProxyConfigOpt { return func(config *LocalProxyConfig) error { - config.ALPNConnUpgradeRequired = apiclient.IsALPNConnUpgradeRequired(config.RemoteProxyAddr, config.InsecureSkipVerify) + config.ALPNConnUpgradeRequired = client.IsALPNConnUpgradeRequired(config.RemoteProxyAddr, config.InsecureSkipVerify) return trace.Wrap(WithClusterCAsIfConnUpgrade(ctx, getClusterCertPool)(config)) } } diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index 464665bf1c151..ed4d5a3d691d3 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -18,7 +18,6 @@ package proxy import ( "context" - "crypto/tls" "net" "net/url" "time" @@ -28,11 +27,11 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client" apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" apiutils "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/lib/utils" ) var log = logrus.WithFields(logrus.Fields{ @@ -61,24 +60,7 @@ func (d directDial) dialALPNWithDeadline(ctx context.Context, network string, ad ctx, span := tracing.DefaultProvider().Tracer("dialer").Start(ctx, "directDial/dialALPNWithDeadline") defer span.End() - dialer := &net.Dialer{ - Timeout: config.Timeout, - } - address, err := utils.ParseAddr(addr) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(address) - if err != nil { - return nil, trace.Wrap(err) - } - - tlsDialer := tls.Dialer{ - NetDialer: dialer, - Config: conf, - } - - tlsConn, err := tlsDialer.DialContext(ctx, network, addr) + tlsConn, err := d.alpnDialer.DialContext(ctx, network, addr) if err != nil { return nil, trace.Wrap(err) } @@ -95,28 +77,13 @@ type Dialer interface { } type directDial struct { - // insecure is whether to skip certificate validation. - insecure bool - // tlsRoutingEnabled indicates that proxy is running in TLSRouting mode. - tlsRoutingEnabled bool - // tlsConfig is the TLS config to use. - tlsConfig *tls.Config -} - -// getTLSConfig configures the dialers TLS config for a specified address. -func (d directDial) getTLSConfig(addr *utils.NetAddr) (*tls.Config, error) { - if d.tlsConfig == nil { - return nil, trace.BadParameter("TLS config was nil") - } - tlsConfig := d.tlsConfig.Clone() - tlsConfig.ServerName = addr.Host() - tlsConfig.InsecureSkipVerify = d.insecure - return tlsConfig, nil + // alpnDialer is the dialer used for TLS routing. + alpnDialer client.ContextDialer } // Dial calls ssh.Dial directly. func (d directDial) Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { - if d.tlsRoutingEnabled { + if d.alpnDialer != nil { client, err := d.dialALPNWithDeadline(ctx, network, addr, config) if err != nil { return nil, trace.Wrap(err) @@ -132,31 +99,13 @@ func (d directDial) Dial(ctx context.Context, network string, addr string, confi // DialTimeout acts like Dial but takes a timeout. func (d directDial) DialTimeout(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) { - dialer := &net.Dialer{ - Timeout: timeout, - } - - if d.tlsRoutingEnabled { - addr, err := utils.ParseAddr(address) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(addr) - if err != nil { - return nil, trace.Wrap(err) + dialer := d.alpnDialer + if dialer == nil { + dialer = &net.Dialer{ + Timeout: timeout, } - - tlsDialer := tls.Dialer{ - NetDialer: dialer, - Config: conf, - } - - tlsConn, err := tlsDialer.DialContext(ctx, "tcp", address) - if err != nil { - return nil, trace.Wrap(err) - } - return tlsConn, nil } + conn, err := dialer.DialContext(ctx, network, address) if err != nil { return nil, trace.Wrap(err) @@ -167,40 +116,8 @@ func (d directDial) DialTimeout(ctx context.Context, network, address string, ti type proxyDial struct { // proxyHost is the HTTPS proxy address. proxyURL *url.URL - // insecure is whether to skip certificate validation. - insecure bool - // tlsRoutingEnabled indicates that proxy is running in TLSRouting mode. - tlsRoutingEnabled bool - // tlsConfig is the TLS config to use. - tlsConfig *tls.Config -} - -// getTLSConfig configures the dialers TLS config for a specified address. -func (d proxyDial) getTLSConfig(addr *utils.NetAddr) (*tls.Config, error) { - if d.tlsConfig == nil { - return nil, trace.BadParameter("TLS config was nil") - } - tlsConfig := d.tlsConfig.Clone() - tlsConfig.ServerName = addr.Host() - tlsConfig.InsecureSkipVerify = d.insecure - return tlsConfig, nil -} - -// getTLSConfigForProxy configures the dialer's TLS config for the HTTPS proxy -// address. If the proxy is HTTP, a nil error and nil config are returned. -func (d proxyDial) getTLSConfigForProxy() (*tls.Config, error) { - if d.proxyURL.Scheme != "https" { - return nil, nil - } - netAddr, err := utils.ParseAddr(d.proxyURL.String()) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConfig, err := d.getTLSConfig(netAddr) - if err != nil { - return nil, trace.Wrap(err) - } - return tlsConfig, nil + // alpnDialer is the dialer used for TLS routing. + alpnDialer client.ContextDialer } // DialTimeout acts like Dial but takes a timeout. @@ -211,63 +128,38 @@ func (d proxyDial) DialTimeout(ctx context.Context, network, address string, tim defer cancel() ctx = timeoutCtx } - - tlsConfig, err := d.getTLSConfigForProxy() - if err != nil { - return nil, trace.Wrap(err) + if d.alpnDialer != nil { + tlsConn, err := d.alpnDialer.DialContext(ctx, network, address) + return tlsConn, trace.Wrap(err) } - conn, err := apiclient.DialProxy(ctx, d.proxyURL, address, apiclient.WithTLSConfig(tlsConfig)) + conn, err := apiclient.DialProxy(ctx, d.proxyURL, address) if err != nil { return nil, trace.Wrap(err) } - if d.tlsRoutingEnabled { - address, err := utils.ParseAddr(address) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(address) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConn := tls.Client(conn, conf) - if err = tlsConn.HandshakeContext(ctx); err != nil { - conn.Close() - return nil, trace.Wrap(err) - } - conn = tlsConn - } return conn, nil } // Dial first connects to a proxy, then uses the connection to establish a new // SSH connection. func (d proxyDial) Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { - tlsConfig, err := d.getTLSConfigForProxy() - if err != nil { - return nil, trace.Wrap(err) - } // Build a proxy connection first. - pconn, err := apiclient.DialProxy(ctx, d.proxyURL, addr, apiclient.WithTLSConfig(tlsConfig)) + var pconn net.Conn + var err error + if d.alpnDialer != nil { + pconn, err = d.alpnDialer.DialContext(ctx, network, addr) + } else { + pconn, err = apiclient.DialProxy(ctx, d.proxyURL, addr) + } if err != nil { return nil, trace.Wrap(err) } + if config.Timeout > 0 { if err := pconn.SetReadDeadline(time.Now().Add(config.Timeout)); err != nil { return nil, trace.Wrap(err) } } - if d.tlsRoutingEnabled { - address, err := utils.ParseAddr(addr) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(address) - if err != nil { - return nil, trace.Wrap(err) - } - pconn = tls.Client(pconn, conf) - } // Do the same as ssh.Dial but pass in proxy connection. c, chans, reqs, err := tracessh.NewClientConn(ctx, pconn, addr, config) @@ -285,27 +177,17 @@ func (d proxyDial) Dial(ctx context.Context, network string, addr string, config type dialerOptions struct { // insecureSkipTLSVerify is whether to skip certificate validation. insecureSkipTLSVerify bool - // tlsRoutingEnabled indicates that proxy is running in TLSRouting mode. - tlsRoutingEnabled bool - // tlsConfig is the TLS config to use for TLS routing. - tlsConfig *tls.Config + // alpnDialer is the dialer used for TLS routing. + alpnDialer client.ContextDialer } // DialerOptionFunc allows setting options as functional arguments to DialerFromEnvironment type DialerOptionFunc func(options *dialerOptions) // WithALPNDialer creates a dialer that allows to Teleport running in single-port mode. -func WithALPNDialer(tlsConfig *tls.Config) DialerOptionFunc { - return func(options *dialerOptions) { - options.tlsRoutingEnabled = true - options.tlsConfig = tlsConfig - } -} - -// WithInsecureSkipTLSVerify skips the certs verifications. -func WithInsecureSkipTLSVerify(insecure bool) DialerOptionFunc { +func WithALPNDialer(alpnDialer client.ContextDialer) DialerOptionFunc { return func(options *dialerOptions) { - options.insecureSkipTLSVerify = insecure + options.alpnDialer = alpnDialer } } @@ -327,17 +209,13 @@ func DialerFromEnvironment(addr string, opts ...DialerOptionFunc) Dialer { if proxyURL == nil { log.Debugf("No proxy set in environment, returning direct dialer.") return directDial{ - tlsConfig: options.tlsConfig, - tlsRoutingEnabled: options.tlsRoutingEnabled, - insecure: options.insecureSkipTLSVerify, + alpnDialer: options.alpnDialer, } } log.Debugf("Found proxy %q in environment, returning proxy dialer.", proxyURL) return proxyDial{ - proxyURL: proxyURL, - insecure: options.insecureSkipTLSVerify, - tlsRoutingEnabled: options.tlsRoutingEnabled, - tlsConfig: options.tlsConfig, + proxyURL: proxyURL, + alpnDialer: options.alpnDialer, } } From 9ace6ca989bd013252aff2b6d65ccc46cf93b51c Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Wed, 29 Mar 2023 08:39:14 -0400 Subject: [PATCH 07/27] reverse tunnel --- lib/auth/authclient/authclient.go | 1 + lib/client/api.go | 7 +-- lib/reversetunnel/agentpool.go | 81 +++++++++++++++++++++++++------ lib/reversetunnel/transport.go | 27 +++++++++-- lib/service/connect.go | 1 + lib/utils/proxy/proxy.go | 4 +- 6 files changed, 94 insertions(+), 27 deletions(-) diff --git a/lib/auth/authclient/authclient.go b/lib/auth/authclient/authclient.go index 02a4319d981cf..1379ac32f93f3 100644 --- a/lib/auth/authclient/authclient.go +++ b/lib/auth/authclient/authclient.go @@ -95,6 +95,7 @@ func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) { ClientConfig: cfg.SSH, Log: cfg.Log, InsecureSkipTLSVerify: cfg.TLS.InsecureSkipVerify, + ClusterCAs: cfg.TLS.RootCAs, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/api.go b/lib/client/api.go index 512d570ab1f06..335c61cf2356d 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2900,14 +2900,11 @@ func makeProxySSHClientWithTLSWrapper(ctx context.Context, tc *TeleportClient, s alpnConfig := client.ALPNDialerConfig{ TLSConfig: tlsConfig, - ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, + ALPNConnUpgradeRequired: tc.isALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify), DialTimeout: sshConfig.Timeout, } - if proxyAddr != tc.WebProxyAddr { - alpnConfig.ALPNConnUpgradeRequired = client.IsALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify) - } - dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(client.NewALPNDialer(alpnConfig))) + dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(alpnConfig)) return dialer.Dial(ctx, "tcp", proxyAddr, sshConfig) } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 637981361aa3d..49def2948e20f 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -31,6 +31,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -323,7 +324,7 @@ func (p *AgentPool) updateRuntimeConfig(ctx context.Context) error { return trace.Wrap(err) } - p.runtimeConfig.update(netConfig) + p.runtimeConfig.update(ctx, netConfig, p.Resolver) p.log.Debugf("Runtime config: tunnel_strategy: %v connection_count: %v", p.runtimeConfig.tunnelStrategyType, p.runtimeConfig.connectionCount) return nil @@ -478,20 +479,13 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease p.log.WithError(err).Debugf("Failed to update remote config.") } - options := []proxy.DialerOptionFunc{proxy.WithInsecureSkipTLSVerify(lib.IsInsecureDevMode())} + options := []proxy.DialerOptionFunc{} if p.runtimeConfig.useALPNRouting() { - tlsConfig := &tls.Config{ - NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, - } - - if p.runtimeConfig.useReverseTunnelV2() { - tlsConfig.NextProtos = []string{ - string(alpncommon.ProtocolReverseTunnelV2), - string(alpncommon.ProtocolReverseTunnel), - } + alpnDialerConfig, err := p.makeALPNDialerConfig() + if err != nil { + return nil, trace.Wrap(err) } - - options = append(options, proxy.WithALPNDialer(tlsConfig)) + options = append(options, proxy.WithALPNDialer(alpnDialerConfig)) } dialer := &agentDialer{ @@ -505,7 +499,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease agent, err := newAgent(agentConfig{ addr: *addr, - keepAlive: p.runtimeConfig.keepAliveInterval, + keepAlive: p.runtimeConfig.getKeepAliveInterval(), sshDialer: dialer, transporter: p, versionGetter: p, @@ -524,6 +518,35 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease return agent, nil } +func (p *AgentPool) makeALPNDialerConfig() (client.ALPNDialerConfig, error) { + tlsConfig := &tls.Config{ + NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, + InsecureSkipVerify: lib.IsInsecureDevMode(), + } + + if p.runtimeConfig.useReverseTunnelV2() { + tlsConfig.NextProtos = []string{ + string(alpncommon.ProtocolReverseTunnelV2), + string(alpncommon.ProtocolReverseTunnel), + } + } + + config := client.ALPNDialerConfig{ + TLSConfig: tlsConfig, + ALPNConnUpgradeRequired: p.runtimeConfig.useALPNConnUpgrade(), + KeepAlivePeriod: p.runtimeConfig.getKeepAliveInterval(), + } + + if config.ALPNConnUpgradeRequired { + rootCAs, _, err := auth.ClientCertPool(p.AccessPoint, p.Cluster, types.HostCA) + if err != nil { + return client.ALPNDialerConfig{}, trace.Wrap(err) + } + tlsConfig.RootCAs = rootCAs + } + return config, nil +} + // Wait blocks until the pool context is stopped. func (p *AgentPool) Wait() { if p == nil { @@ -587,6 +610,9 @@ type agentPoolRuntimeConfig struct { // isRemoteCluster forces the agent pool to connect to all proxies // regardless of the configured tunnel strategy. isRemoteCluster bool + // tlsRoutingConnUpgradeRequired indicates that ALPN connection upgrades + // are required for making TLS routing requests. + tlsRoutingConnUpgradeRequired bool // remoteTLSRoutingEnabled caches a remote clusters tls routing setting. This helps prevent // proxy endpoint stagnation where an even numbers of proxies are hidden behind a round robin @@ -653,6 +679,18 @@ func (c *agentPoolRuntimeConfig) useALPNRouting() bool { return c.proxyListenerMode == types.ProxyListenerMode_Multiplex } +func (c *agentPoolRuntimeConfig) useALPNConnUpgrade() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.tlsRoutingConnUpgradeRequired +} + +func (c *agentPoolRuntimeConfig) getKeepAliveInterval() time.Duration { + c.mu.RLock() + defer c.mu.RUnlock() + return c.keepAliveInterval +} + func (c *agentPoolRuntimeConfig) updateRemote(ctx context.Context, addr *utils.NetAddr) error { c.updateRemoteMu.Lock() defer c.updateRemoteMu.Unlock() @@ -705,13 +743,17 @@ func (c *agentPoolRuntimeConfig) updateRemote(ctx context.Context, addr *utils.N c.lastRemotePing = &now c.remoteTLSRoutingEnabled = tlsRoutingEnabled + if c.remoteTLSRoutingEnabled { + c.tlsRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(addr.Addr, lib.IsInsecureDevMode()) + } return nil } -func (c *agentPoolRuntimeConfig) update(netConfig types.ClusterNetworkingConfig) { +func (c *agentPoolRuntimeConfig) update(ctx context.Context, netConfig types.ClusterNetworkingConfig, resolver Resolver) { c.mu.Lock() defer c.mu.Unlock() + oldProxyListenerMode := c.proxyListenerMode c.keepAliveInterval = netConfig.GetKeepAliveInterval() c.proxyListenerMode = netConfig.GetProxyListenerMode() @@ -730,6 +772,15 @@ func (c *agentPoolRuntimeConfig) update(netConfig types.ClusterNetworkingConfig) if c.connectionCount <= 0 { c.connectionCount = defaultAgentConnectionCount } + + if c.proxyListenerMode == types.ProxyListenerMode_Multiplex && oldProxyListenerMode != c.proxyListenerMode { + addr, _, err := resolver(ctx) + if err == nil { + c.tlsRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(addr.Addr, lib.IsInsecureDevMode()) + } else { + logrus.WithError(err).Warnf("Faield to resolve addr.") + } + } } // Make sure ServerHandlerToListener implements both interfaces. diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index bb30531047dfe..5c2ed996e03b7 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -19,6 +19,7 @@ package reversetunnel import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" @@ -31,6 +32,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/sshutils" @@ -62,12 +64,17 @@ type TunnelAuthDialerConfig struct { Log logrus.FieldLogger // InsecureSkipTLSVerify is whether to skip certificate validation. InsecureSkipTLSVerify bool + // ClusterCAs contains cluster CAs. + ClusterCAs *x509.CertPool } func (c *TunnelAuthDialerConfig) CheckAndSetDefaults() error { if c.Resolver == nil { return trace.BadParameter("missing tunnel address resolver") } + if c.ClusterCAs == nil { + return trace.BadParameter("missing cluster CAs") + } return nil } @@ -80,9 +87,7 @@ type TunnelAuthDialer struct { // DialContext dials auth server via SSH tunnel func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { // Connect to the reverse tunnel server. - opts := []proxy.DialerOptionFunc{ - proxy.WithInsecureSkipTLSVerify(t.InsecureSkipTLSVerify), - } + opts := []proxy.DialerOptionFunc{} addr, mode, err := t.Resolver(ctx) if err != nil { @@ -91,8 +96,20 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co } if mode == types.ProxyListenerMode_Multiplex { - opts = append(opts, proxy.WithALPNDialer(&tls.Config{ - NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, + tlsConfig := &tls.Config{ + NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, + InsecureSkipVerify: t.InsecureSkipTLSVerify, + } + + alpnConnUpgradeRequired := client.IsALPNConnUpgradeRequired(addr.Addr, t.InsecureSkipTLSVerify) + if alpnConnUpgradeRequired { + tlsConfig.RootCAs = t.ClusterCAs + } + + opts = append(opts, proxy.WithALPNDialer(client.ALPNDialerConfig{ + TLSConfig: tlsConfig, + ALPNConnUpgradeRequired: alpnConnUpgradeRequired, + DialTimeout: t.ClientConfig.Timeout, })) } diff --git a/lib/service/connect.go b/lib/service/connect.go index 09be003fa9895..ea936754c3bcf 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -1166,6 +1166,7 @@ func (process *TeleportProcess) newClientThroughTunnel(authServers []utils.NetAd ClientConfig: sshConfig, Log: process.log, InsecureSkipTLSVerify: lib.IsInsecureDevMode(), + ClusterCAs: tlsConfig.RootCAs, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index ed4d5a3d691d3..734b62041d1b7 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -185,9 +185,9 @@ type dialerOptions struct { type DialerOptionFunc func(options *dialerOptions) // WithALPNDialer creates a dialer that allows to Teleport running in single-port mode. -func WithALPNDialer(alpnDialer client.ContextDialer) DialerOptionFunc { +func WithALPNDialer(alpnDialerConfig client.ALPNDialerConfig) DialerOptionFunc { return func(options *dialerOptions) { - options.alpnDialer = alpnDialer + options.alpnDialer = client.NewALPNDialer(alpnDialerConfig) } } From 1b4f4c5ee94ea2083aa325f0d0b4098a618ff0e1 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Wed, 29 Mar 2023 11:32:42 -0400 Subject: [PATCH 08/27] fix auth connect --- api/client/alpn.go | 4 +++- api/client/client.go | 33 ++++++++++++++++++++++++++++++++- api/client/contextdialer.go | 5 +++++ lib/client/client.go | 4 ++-- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/api/client/alpn.go b/api/client/alpn.go index 500410c53a361..3ce81bec55012 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -23,6 +23,8 @@ import ( "time" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client/webclient" ) // ALPNDialerConfig is the config for ALPNDialer. @@ -61,7 +63,7 @@ func (d *ALPNDialer) getTLSConfig(addr string) (*tls.Config, error) { } tlsConfig := d.cfg.TLSConfig.Clone() - host, _, err := net.SplitHostPort(addr) + host, _, err := webclient.ParseHostPort(addr) if err != nil { return nil, trace.Wrap(err) } diff --git a/api/client/client.go b/api/client/client.go index 99d6f8566ad94..b65af513f45fe 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client/okta" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" @@ -329,7 +330,8 @@ type ( // authConnect connects to the Teleport Auth Server directly. func authConnect(ctx context.Context, params connectParams) (*Client, error) { dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithTLSConfig(params.tlsConfig)) - if params.cfg.IsALPNConnUpgradeRequired(params.addr, params.cfg.InsecureAddressDiscovery) { + + if authConnectShouldALPNConnUpgrade(ctx, params) { dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ InsecureSkipVerify: params.cfg.InsecureAddressDiscovery, }) @@ -342,6 +344,29 @@ func authConnect(ctx context.Context, params connectParams) (*Client, error) { return clt, nil } +func authConnectShouldALPNConnUpgrade(ctx context.Context, params connectParams) bool { + if !authConnectShouldUseTLSRouting(ctx, params) { + return false + } + return params.cfg.IsALPNConnUpgradeRequired(params.addr, params.cfg.InsecureAddressDiscovery) +} + +func authConnectShouldUseTLSRouting(ctx context.Context, params connectParams) bool { + if params.cfg.WebProxyAddr != "" && params.cfg.WebProxyAddr == params.addr { + return true + } + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: params.addr, + Insecure: params.cfg.InsecureAddressDiscovery, + }) + if err != nil { + // HTTP ping call failed. This is likely an auth address. + return false + } + return resp.Proxy.TLSRoutingEnabled +} + // tunnelConnect connects to the Teleport Auth Server through the proxy's reverse tunnel. func tunnelConnect(ctx context.Context, params connectParams) (*Client, error) { if params.sshConfig == nil { @@ -501,6 +526,8 @@ func (c *Client) waitForConnectionReady(ctx context.Context) error { type Config struct { // Addrs is a list of teleport auth/proxy server addresses to dial. Addrs []string + // WebProxyAddr is the Teleport Proxy web address. + WebProxyAddr string // Credentials are a list of credentials to use when attempting // to connect to the server. Credentials []Credentials @@ -572,6 +599,10 @@ func (c *Config) CheckAndSetDefaults() error { c.IsALPNConnUpgradeRequired = IsALPNConnUpgradeRequired } + if c.WebProxyAddr != "" { + c.Addrs = utils.Deduplicate(append(c.Addrs, c.WebProxyAddr)) + } + return nil } diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 82e6ce497a485..667b1f6693f3e 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -157,6 +157,10 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte return nil, trace.Wrap(err) } + host, _, err := webclient.ParseHostPort(tunnelAddr) + if err != nil { + return nil, trace.Wrap(err) + } tlsConn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ ALPNConnUpgradeRequired: params.cfg.IsALPNConnUpgradeRequired(tunnelAddr, insecure), DialTimeout: params.cfg.DialTimeout, @@ -164,6 +168,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte TLSConfig: &tls.Config{ NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, InsecureSkipVerify: insecure, + ServerName: host, }, }) if err != nil { diff --git a/lib/client/client.go b/lib/client/client.go index eba9a81642a87..43b7c248f97d5 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1101,8 +1101,8 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co tlsConfig.InsecureSkipVerify = proxy.teleportClient.InsecureSkipVerify clt, err := auth.NewClient(client.Config{ - Context: ctx, - Addrs: []string{proxyAddr}, + Context: ctx, + WebProxyAddr: proxyAddr, Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, From c688918607afbab46189368066b0f7f32fbd8121 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Wed, 29 Mar 2023 16:14:40 -0400 Subject: [PATCH 09/27] move ping --- api/client/alpn.go | 22 +-- api/utils/pingconn/pingconn.go | 132 ++++++++++++++++++ .../utils/pingconn/pingconn_test.go | 13 +- lib/reversetunnel/agentpool.go | 6 + lib/reversetunnel/transport.go | 5 +- lib/srv/alpnproxy/common/protocols.go | 1 + lib/srv/alpnproxy/conn.go | 112 --------------- lib/srv/alpnproxy/local_proxy.go | 10 +- lib/srv/alpnproxy/proxy.go | 3 +- lib/srv/alpnproxy/proxy_test.go | 3 +- 10 files changed, 170 insertions(+), 137 deletions(-) create mode 100644 api/utils/pingconn/pingconn.go rename lib/srv/alpnproxy/conn_test.go => api/utils/pingconn/pingconn_test.go (96%) diff --git a/api/client/alpn.go b/api/client/alpn.go index 3ce81bec55012..24e96fcc89c53 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -20,11 +20,14 @@ import ( "context" "crypto/tls" "net" + "strings" "time" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/utils/pingconn" ) // ALPNDialerConfig is the config for ALPNDialer. @@ -96,19 +99,18 @@ func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net defer tlsConn.Close() return nil, trace.Wrap(err) } + + if strings.HasSuffix(tlsConn.ConnectionState().NegotiatedProtocol, "-ping") { + logrus.Debug("Using ping connection") + return pingconn.New(tlsConn), nil + } return tlsConn, nil } -// DialALPN a helper to dial using an ALPNDialer and returns a tls.Conn if +// TODO remove +// DialALPN a helper to dial using an ALPNDialer and returns a net.Conn if // successful. -func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (*tls.Conn, error) { +func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (net.Conn, error) { conn, err := NewALPNDialer(cfg).DialContext(ctx, "tcp", addr) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConn, ok := conn.(*tls.Conn) - if !ok { - return nil, trace.BadParameter("failed to convert to tls.Conn") - } - return tlsConn, nil + return conn, trace.Wrap(err) } diff --git a/api/utils/pingconn/pingconn.go b/api/utils/pingconn/pingconn.go new file mode 100644 index 0000000000000..1578a1110b90c --- /dev/null +++ b/api/utils/pingconn/pingconn.go @@ -0,0 +1,132 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pingconn + +import ( + "crypto/tls" + "encoding/binary" + "math" + "sync" + + "github.com/gravitational/trace" +) + +// New returns a ping connection wrapping the provided net.Conn. +func New(conn *tls.Conn) *PingConn { + return &PingConn{Conn: conn} +} + +// PingConn wraps a *tls.Conn and add ping capabilities to it, including the +// `WritePing` function and `Read` (which excludes ping packets). +// +// When using this connection, the packets written will contain an initial data: +// the packet size. When reading, this information is taken into account, but it +// is not returned to the caller. +// +// Ping messages have a packet size of zero and are produced only when +// `WritePing` is called. On `Read`, any Ping packet is discarded. +type PingConn struct { + //net.Conn + *tls.Conn + + muRead sync.Mutex + muWrite sync.Mutex + + // currentSize size of bytes of the current packet. + currentSize uint32 +} + +// Read reads content from the underlaying connection, discarding any ping +// messages it finds. +func (c *PingConn) Read(p []byte) (int, error) { + c.muRead.Lock() + defer c.muRead.Unlock() + + err := c.discardPingReads() + if err != nil { + return 0, err + } + + // Check if the current size is larger than the provided buffer. + readSize := c.currentSize + if c.currentSize > uint32(len(p)) { + readSize = uint32(len(p)) + } + + n, err := c.Conn.Read(p[:readSize]) + c.currentSize -= uint32(n) + + return n, err +} + +// WritePing writes the ping packet to the connection. +func (c *PingConn) WritePing() error { + c.muWrite.Lock() + defer c.muWrite.Unlock() + + return binary.Write(c.Conn, binary.BigEndian, uint32(0)) +} + +// discardPingReads reads from the wrapped net.Conn until it encounters a +// non-ping packet. +func (c *PingConn) discardPingReads() error { + for c.currentSize == 0 { + err := binary.Read(c.Conn, binary.BigEndian, &c.currentSize) + if err != nil { + return err + } + } + + return nil +} + +// Write writes provided content to the underlying connection with proper +// protocol fields. +func (c *PingConn) Write(p []byte) (int, error) { + c.muWrite.Lock() + defer c.muWrite.Unlock() + + // Avoid overflow when casting data length. It is only present to avoid + // panicking if the size cannot be cast. Callers should handle packet length + // limits, such as protocol implementations and audits. + if uint64(len(p)) > math.MaxUint32 { + return 0, trace.BadParameter("invalid content size, max size permitted is %d", uint64(math.MaxUint32)) + } + + size := uint32(len(p)) + if size == 0 { + return 0, nil + } + + // Write packet size. + if err := binary.Write(c.Conn, binary.BigEndian, size); err != nil { + return 0, trace.Wrap(err) + } + + // Iterate until everything is written. + var written int + for written < len(p) { + n, err := c.Conn.Write(p) + written += n + + if err != nil { + return written, trace.Wrap(err) + } + } + + return written, nil +} diff --git a/lib/srv/alpnproxy/conn_test.go b/api/utils/pingconn/pingconn_test.go similarity index 96% rename from lib/srv/alpnproxy/conn_test.go rename to api/utils/pingconn/pingconn_test.go index d91175d628a22..69a6a32cc5c47 100644 --- a/lib/srv/alpnproxy/conn_test.go +++ b/api/utils/pingconn/pingconn_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package alpnproxy +package pingconn import ( "bytes" @@ -26,6 +26,8 @@ import ( "time" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/fixtures" ) func TestPingConnection(t *testing.T) { @@ -273,7 +275,7 @@ func makePingConn(t *testing.T) (*PingConn, *PingConn) { writer, reader := net.Pipe() tlsWriter, tlsReader := makeTLSConn(t, writer, reader) - return NewPingConn(tlsWriter), NewPingConn(tlsReader) + return New(tlsWriter), New(tlsReader) } // makeBufferedPingConn creates connections to have asynchronous writes. @@ -321,7 +323,7 @@ func makeBufferedPingConn(t *testing.T) (*PingConn, *PingConn) { } tlsConnA, tlsConnB := makeTLSConn(t, connSlice[0], connSlice[1]) - return NewPingConn(tlsConnA), NewPingConn(tlsConnB) + return New(tlsConnA), New(tlsConnB) } // makeTLSConn take two connections (client and server) and wrap them into TLS @@ -334,10 +336,13 @@ func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) + require.NoError(t, err) + // Server go func() { tlsConn := tls.Server(server, &tls.Config{ - Certificates: []tls.Certificate{mustGenCertSignedWithCA(t, mustGenSelfSignedCert(t))}, + Certificates: []tls.Certificate{cert}, }) tlsConnChan <- struct { *tls.Conn diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 49def2948e20f..a16545272fad1 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" + libdefaults "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnel/track" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" @@ -688,6 +689,11 @@ func (c *agentPoolRuntimeConfig) useALPNConnUpgrade() bool { func (c *agentPoolRuntimeConfig) getKeepAliveInterval() time.Duration { c.mu.RLock() defer c.mu.RUnlock() + + // When behind a load balancer, use a shorter ping. + if c.tlsRoutingConnUpgradeRequired { + return utils.MinTTL(libdefaults.ProxyPingInterval, c.keepAliveInterval) + } return c.keepAliveInterval } diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 5c2ed996e03b7..89ccb78198c1c 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -97,7 +97,10 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co if mode == types.ProxyListenerMode_Multiplex { tlsConfig := &tls.Config{ - NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, + NextProtos: []string{ + string(alpncommon.ProtocolWithPing(alpncommon.ProtocolReverseTunnel)), + string(alpncommon.ProtocolReverseTunnel), + }, InsecureSkipVerify: t.InsecureSkipTLSVerify, } diff --git a/lib/srv/alpnproxy/common/protocols.go b/lib/srv/alpnproxy/common/protocols.go index 97224a3b9e4e8..73d61aaecba3c 100644 --- a/lib/srv/alpnproxy/common/protocols.go +++ b/lib/srv/alpnproxy/common/protocols.go @@ -201,6 +201,7 @@ var DatabaseProtocols = []Protocol{ var ProtocolsWithPingSupport = append( DatabaseProtocols, ProtocolTCP, + ProtocolReverseTunnel, ) // WithPingProtocols adds Ping protocols to the list for each protocol that diff --git a/lib/srv/alpnproxy/conn.go b/lib/srv/alpnproxy/conn.go index ef566dbb9fd81..c8bed17e08112 100644 --- a/lib/srv/alpnproxy/conn.go +++ b/lib/srv/alpnproxy/conn.go @@ -17,16 +17,10 @@ limitations under the License. package alpnproxy import ( - "crypto/tls" - "encoding/binary" "io" - "math" "net" - "sync" "time" - "github.com/gravitational/trace" - "github.com/gravitational/teleport/lib/utils" ) @@ -103,109 +97,3 @@ func (conn readOnlyConn) RemoteAddr() net.Addr { return &utils.Net func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } - -// NewPingConn returns a ping connection wrapping the provided net.Conn. -func NewPingConn(conn *tls.Conn) *PingConn { - return &PingConn{Conn: conn} -} - -// PingConn wraps a *tls.Conn and add ping capabilities to it, including the -// `WritePing` function and `Read` (which excludes ping packets). -// -// When using this connection, the packets written will contain an initial data: -// the packet size. When reading, this information is taken into account, but it -// is not returned to the caller. -// -// Ping messages have a packet size of zero and are produced only when -// `WritePing` is called. On `Read`, any Ping packet is discarded. -type PingConn struct { - //net.Conn - *tls.Conn - - muRead sync.Mutex - muWrite sync.Mutex - - // currentSize size of bytes of the current packet. - currentSize uint32 -} - -// Read reads content from the underlaying connection, discarding any ping -// messages it finds. -func (c *PingConn) Read(p []byte) (int, error) { - c.muRead.Lock() - defer c.muRead.Unlock() - - err := c.discardPingReads() - if err != nil { - return 0, err - } - - // Check if the current size is larger than the provided buffer. - readSize := c.currentSize - if c.currentSize > uint32(len(p)) { - readSize = uint32(len(p)) - } - - n, err := c.Conn.Read(p[:readSize]) - c.currentSize -= uint32(n) - - return n, err -} - -// WritePing writes the ping packet to the connection. -func (c *PingConn) WritePing() error { - c.muWrite.Lock() - defer c.muWrite.Unlock() - - return binary.Write(c.Conn, binary.BigEndian, uint32(0)) -} - -// discardPingReads reads from the wrapped net.Conn until it encounters a -// non-ping packet. -func (c *PingConn) discardPingReads() error { - for c.currentSize == 0 { - err := binary.Read(c.Conn, binary.BigEndian, &c.currentSize) - if err != nil { - return err - } - } - - return nil -} - -// Write writes provided content to the underlying connection with proper -// protocol fields. -func (c *PingConn) Write(p []byte) (int, error) { - c.muWrite.Lock() - defer c.muWrite.Unlock() - - // Avoid overflow when casting data length. It is only present to avoid - // panicking if the size cannot be cast. Callers should handle packet length - // limits, such as protocol implementations and audits. - if uint64(len(p)) > math.MaxUint32 { - return 0, trace.BadParameter("invalid content size, max size permitted is %d", uint64(math.MaxUint32)) - } - - size := uint32(len(p)) - if size == 0 { - return 0, nil - } - - // Write packet size. - if err := binary.Write(c.Conn, binary.BigEndian, size); err != nil { - return 0, trace.Wrap(err) - } - - // Iterate until everything is written. - var written int - for written < len(p) { - n, err := c.Conn.Write(p) - written += n - - if err != nil { - return written, trace.Wrap(err) - } - } - - return written, nil -} diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 1fd3e6f42a76f..0966f2bd7eef3 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -231,17 +231,11 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC return trace.Wrap(err) } - tlsConn, err := client.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) + upstreamConn, err := client.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) if err != nil { return trace.Wrap(err) } - defer tlsConn.Close() - - var upstreamConn net.Conn = tlsConn - if common.IsPingProtocol(common.Protocol(tlsConn.ConnectionState().NegotiatedProtocol)) { - l.cfg.Log.Debug("Using ping connection") - upstreamConn = NewPingConn(tlsConn) - } + defer upstreamConn.Close() return trace.Wrap(utils.ProxyConn(ctx, downstreamConn, upstreamConn)) } diff --git a/lib/srv/alpnproxy/proxy.go b/lib/srv/alpnproxy/proxy.go index 4ac5613fb538d..2c0595cf7f8cf 100644 --- a/lib/srv/alpnproxy/proxy.go +++ b/lib/srv/alpnproxy/proxy.go @@ -33,6 +33,7 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" @@ -410,7 +411,7 @@ func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOver // handlePingConnection starts the server ping routine and returns `pingConn`. func (p *Proxy) handlePingConnection(ctx context.Context, conn *tls.Conn) net.Conn { - pingConn := NewPingConn(conn) + pingConn := pingconn.New(conn) // Start ping routine. It will continuously send pings in a defined // interval. diff --git a/lib/srv/alpnproxy/proxy_test.go b/lib/srv/alpnproxy/proxy_test.go index 6c81f62696de1..7e7e9451c8132 100644 --- a/lib/srv/alpnproxy/proxy_test.go +++ b/lib/srv/alpnproxy/proxy_test.go @@ -32,6 +32,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/srv/db/dbutils" "github.com/gravitational/teleport/lib/tlsca" @@ -191,7 +192,7 @@ func TestProxyTLSDatabaseHandler(t *testing.T) { }) require.NoError(t, err) - conn := NewPingConn(baseConn) + conn := pingconn.New(baseConn) tlsConn := tls.Client(conn, &tls.Config{ Certificates: []tls.Certificate{ clientCert, From 1928f65dc3e1ed0e985cb068bf36e1bda843bb34 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 30 Mar 2023 08:59:45 -0400 Subject: [PATCH 10/27] add ssh support --- tool/tsh/proxy.go | 54 +++++++++++++++++++++-------------------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 0e61bb9225298..0d4634be65c09 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -233,50 +233,44 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy remoteProxyAddr := net.JoinHostPort(sp.proxyHost, sp.proxyPort) httpsProxy := apiutils.GetProxyURL(remoteProxyAddr) - pool, err := tc.LocalAgent().ClientCertPool(sp.clusterName) - if err != nil { - return nil, trace.Wrap(err) + if sp.tlsRouting { + conn, err := dialSSHThroughALPNSNIProxy(ctx, tc, sp) + return conn, trace.Wrap(err) } // If HTTPS_PROXY is configured, we need to open a TCP connection via // the specified HTTPS Proxy, otherwise, we can just open a plain TCP // connection. - var tcpConn net.Conn if httpsProxy != nil { httpProxyTLSConfig := &tls.Config{ - RootCAs: pool, InsecureSkipVerify: tc.InsecureSkipVerify, - ServerName: httpsProxy.Hostname(), - } - tcpConn, err = client.DialProxy(ctx, httpsProxy, remoteProxyAddr, client.WithTLSConfig(httpProxyTLSConfig)) - if err != nil { - return nil, trace.Wrap(err) - } - } else { - tcpConn, err = (&net.Dialer{}).DialContext(ctx, "tcp", remoteProxyAddr) - if err != nil { - return nil, trace.Wrap(err) } + tcpConn, err := client.DialProxy(ctx, httpsProxy, remoteProxyAddr, client.WithTLSConfig(httpProxyTLSConfig)) + return tcpConn, trace.Wrap(err) } + tcpConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteProxyAddr) + return tcpConn, err +} - // If TLS routing is not enabled, just return the TCP connection - if !sp.tlsRouting { - return tcpConn, nil - } - - // Otherwise, we need to upgrade the TCP connection to a TLS connection. - tlsConfig := &tls.Config{ - RootCAs: pool, - NextProtos: []string{string(alpncommon.ProtocolProxySSH)}, - InsecureSkipVerify: tc.InsecureSkipVerify, - ServerName: sp.proxyHost, - } - tlsConn := tls.Client(tcpConn, tlsConfig) - if err := tlsConn.HandshakeContext(ctx); err != nil { - tlsConn.Close() +func dialSSHThroughALPNSNIProxy(ctx context.Context, tc *libclient.TeleportClient) (net.Conn, error) { + pool, err := tc.LocalAgent().ClientCertPool(sp.clusterName) + if err != nil { return nil, trace.Wrap(err) } + remoteProxyAddr := net.JoinHostPort(sp.proxyHost, sp.proxyPort) + conn, err := client.DialALPN(ctx, remoteProxyAddr, client.ALPNDialerConfig{ + TLSConfig: &tls.Config{ + RootCAs: pool, + NextProtos: []string{ + string(alpncommon.ProtocolWithPing(alpncommon.ProtocolProxySSH)), + string(alpncommon.ProtocolProxySSH), + }, + InsecureSkipVerify: tc.InsecureSkipVerify, + ServerName: sp.proxyHost, + }, + ALPNConnUpgradeRequired: client.IsALPNConnUpgradeRequired(remoteProxyAddr, tc.InsecureSkipVerify), + }) return tlsConn, nil } From b6d6c8a29013f75a0bb92b453cebca62cb870897 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 30 Mar 2023 14:16:07 -0400 Subject: [PATCH 11/27] add HTTP client support --- api/client/alpn.go | 23 +++++----- api/client/client.go | 38 +++++++--------- api/client/contextdialer.go | 47 ++++++++++++++++++-- api/client/proxy.go | 10 ++++- api/constants/constants.go | 2 + lib/auth/clt.go | 11 ++++- lib/client/api.go | 12 +++-- lib/client/client.go | 2 +- lib/srv/alpnproxy/common/protocols.go | 1 + lib/srv/alpnproxy/proxy.go | 5 +++ tool/tsh/proxy.go | 63 +++++++++++---------------- 11 files changed, 133 insertions(+), 81 deletions(-) diff --git a/api/client/alpn.go b/api/client/alpn.go index 24e96fcc89c53..0414fb54be113 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -27,6 +27,7 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/utils/pingconn" ) @@ -45,7 +46,8 @@ type ALPNDialerConfig struct { // ALPNDialer is a ContextDialer that dials a connection to the Proxy Service // with ALPN and SNI configured in the provided TLSConfig. An ALPN connection // upgrade is also performed at the initial connection, if an upgrade is -// required. +// required. If the negotiated protocol is a ping protocol, it will return the +// de-multiplexed connection without the ping. type ALPNDialer struct { cfg ALPNDialerConfig } @@ -81,14 +83,15 @@ func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net return nil, trace.Wrap(err) } - // TODO support system proxy. - dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout) - if d.cfg.ALPNConnUpgradeRequired { - dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ - InsecureSkipVerify: tlsConfig.InsecureSkipVerify, - }) + tlsConfigForDialer := &tls.Config{ + InsecureSkipVerify: tlsConfig.InsecureSkipVerify, } + dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, + WithTLSConfig(tlsConfigForDialer), + WithALPNConnUpgrade(d.cfg.ALPNConnUpgradeRequired), + ) + conn, err := dialer.DialContext(ctx, network, addr) if err != nil { return nil, trace.Wrap(err) @@ -100,16 +103,14 @@ func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net return nil, trace.Wrap(err) } - if strings.HasSuffix(tlsConn.ConnectionState().NegotiatedProtocol, "-ping") { + if strings.HasSuffix(tlsConn.ConnectionState().NegotiatedProtocol, constants.ALPNSNIProtocolPingSuffix) { logrus.Debug("Using ping connection") return pingconn.New(tlsConn), nil } return tlsConn, nil } -// TODO remove -// DialALPN a helper to dial using an ALPNDialer and returns a net.Conn if -// successful. +// DialALPN a helper to dial using an ALPNDialer. func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (net.Conn, error) { conn, err := NewALPNDialer(cfg).DialContext(ctx, "tcp", addr) return conn, trace.Wrap(err) diff --git a/api/client/client.go b/api/client/client.go index b65af513f45fe..43aaaf6dfb8b4 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -329,13 +329,12 @@ type ( // authConnect connects to the Teleport Auth Server directly. func authConnect(ctx context.Context, params connectParams) (*Client, error) { - dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithTLSConfig(params.tlsConfig)) - - if authConnectShouldALPNConnUpgrade(ctx, params) { - dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ - InsecureSkipVerify: params.cfg.InsecureAddressDiscovery, - }) - } + dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, + WithTLSConfig(&tls.Config{ + InsecureSkipVerify: params.tlsConfig.InsecureSkipVerify, + }), + WithALPNConnUpgrade(IsWebProxyAndConnUpgradeRequired(ctx, params.addr, ¶ms.cfg)), + ) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { @@ -344,27 +343,20 @@ func authConnect(ctx context.Context, params connectParams) (*Client, error) { return clt, nil } -func authConnectShouldALPNConnUpgrade(ctx context.Context, params connectParams) bool { - if !authConnectShouldUseTLSRouting(ctx, params) { - return false - } - return params.cfg.IsALPNConnUpgradeRequired(params.addr, params.cfg.InsecureAddressDiscovery) +// TODO +func IsWebProxyAndConnUpgradeRequired(ctx context.Context, targetAddr string, cfg *Config) bool { + return isWebProxy(ctx, targetAddr, cfg) && cfg.IsALPNConnUpgradeRequired(targetAddr, cfg.InsecureAddressDiscovery) } - -func authConnectShouldUseTLSRouting(ctx context.Context, params connectParams) bool { - if params.cfg.WebProxyAddr != "" && params.cfg.WebProxyAddr == params.addr { +func isWebProxy(ctx context.Context, targetAddr string, cfg *Config) bool { + if cfg.WebProxyAddr != "" && cfg.WebProxyAddr == targetAddr { return true } - resp, err := webclient.Find(&webclient.Config{ + _, err := webclient.Find(&webclient.Config{ Context: ctx, - ProxyAddr: params.addr, - Insecure: params.cfg.InsecureAddressDiscovery, + ProxyAddr: targetAddr, + Insecure: cfg.InsecureAddressDiscovery, }) - if err != nil { - // HTTP ping call failed. This is likely an auth address. - return false - } - return resp.Proxy.TLSRoutingEnabled + return err == nil } // tunnelConnect connects to the Teleport Auth Server through the proxy's reverse tunnel. diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 667b1f6693f3e..1714b8674e352 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "net" + "net/url" "time" "github.com/gravitational/trace" @@ -56,6 +57,12 @@ func newDirectDialer(keepAlivePeriod, dialTimeout time.Duration) *net.Dialer { } } +func newProxyURLDialer(proxyURL *url.URL, dialer *net.Dialer, opts ...DialProxyOption) ContextDialer { + return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + return DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...) + }) +} + // tracedDialer ensures that the provided ContextDialerFunc is given a context // which contains tracing information. In the event that a grpc dial occurs without // a grpc.WithBlock dialing option, the context provided to the dial function will @@ -79,15 +86,45 @@ func tracedDialer(ctx context.Context, fn ContextDialerFunc) ContextDialerFunc { // NewDialer makes a new dialer that connects to an Auth server either directly or via an HTTP proxy, depending // on the environment. func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, opts ...DialProxyOption) ContextDialer { + var cfg dialProxyConfig + for _, opt := range opts { + opt(&cfg) + } + return tracedDialer(ctx, func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := newDirectDialer(keepAlivePeriod, dialTimeout) + netDialer := newDirectDialer(keepAlivePeriod, dialTimeout) + + // Base direct dialer. + var dialer ContextDialer = netDialer + + // Wrap with proxy URL dialer if proxy URL is detected if proxyURL := utils.GetProxyURL(addr); proxyURL != nil { - return DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...) + dialer = newProxyURLDialer(proxyURL, netDialer, opts...) + } + + // Wrap with alpnConnUpgradeDialer if upgrade is required. + if cfg.alpnConnUpgradeRequired { + dialer = newALPNConnUpgradeDialer(dialer, cfg.tlsConfig) } + + // Dial. return dialer.DialContext(ctx, network, addr) }) } +func isAddrWebProxy(ctx context.Context, targetAddr, knownWebProxyAddr string, insecure bool) bool { + if knownWebProxyAddr != "" && targetAddr == knownWebProxyAddr { + return true + } + + _, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: targetAddr, + Insecure: insecure, + }) + return err == nil +} + // NewProxyDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy. // The dialer will ping the web client to discover the tunnel proxy address on each dial. func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool, opts ...DialProxyOption) ContextDialer { @@ -161,12 +198,16 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte if err != nil { return nil, trace.Wrap(err) } + tlsConn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ ALPNConnUpgradeRequired: params.cfg.IsALPNConnUpgradeRequired(tunnelAddr, insecure), DialTimeout: params.cfg.DialTimeout, KeepAlivePeriod: params.cfg.KeepAlivePeriod, TLSConfig: &tls.Config{ - NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, + NextProtos: []string{ + constants.ALPNSNIProtocolReverseTunnel + constants.ALPNSNIProtocolPingSuffix, + constants.ALPNSNIProtocolReverseTunnel, + }, InsecureSkipVerify: insecure, ServerName: host, }, diff --git a/api/client/proxy.go b/api/client/proxy.go index 2c222f0b0dee6..c6bd125667f86 100644 --- a/api/client/proxy.go +++ b/api/client/proxy.go @@ -30,7 +30,8 @@ import ( ) type dialProxyConfig struct { - tlsConfig *tls.Config + tlsConfig *tls.Config + alpnConnUpgradeRequired bool } // DialProxyOption allows setting options as functional arguments to DialProxy. @@ -44,6 +45,13 @@ func WithTLSConfig(tlsConfig *tls.Config) DialProxyOption { } } +// TODO +func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialProxyOption { + return func(cfg *dialProxyConfig) { + cfg.alpnConnUpgradeRequired = alpnConnUpgradeRequired + } +} + // DialProxy creates a connection to a server via an HTTP or SOCKS5 Proxy. func DialProxy(ctx context.Context, proxyURL *url.URL, addr string, opts ...DialProxyOption) (net.Conn, error) { return DialProxyWithDialer(ctx, proxyURL, addr, &net.Dialer{}, opts...) diff --git a/api/constants/constants.go b/api/constants/constants.go index 47b3b6c513c29..2f06e321c0fa5 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -319,6 +319,8 @@ const ( ALPNSNIAuthProtocol = "teleport-auth@" // ALPNSNIProtocolReverseTunnel is TLS ALPN protocol value used to indicate Proxy reversetunnel protocol. ALPNSNIProtocolReverseTunnel = "teleport-reversetunnel" + // ALPNSNIProtocolPingSuffix is TLS ALPN suffix used to wrap connections with Ping. + ALPNSNIProtocolPingSuffix = "-ping" ) const ( diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 529408218027e..fd18675798b8a 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -18,6 +18,7 @@ package auth import ( "context" + "crypto/tls" "net" "net/url" "time" @@ -99,9 +100,17 @@ func NewClient(cfg client.Config, params ...roundtrip.ClientParam) (*Client, err if len(cfg.Addrs) == 0 { return nil, trace.BadParameter("no addresses to dial") } - contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithTLSConfig(httpTLS)) httpDialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { for _, addr := range cfg.Addrs { + contextDialerOpts := []client.DialProxyOption{ + client.WithTLSConfig(&tls.Config{ + InsecureSkipVerify: httpTLS.InsecureSkipVerify, + }), + client.WithALPNConnUpgrade(client.IsWebProxyAndConnUpgradeRequired(ctx, addr, &cfg)), + } + + contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, contextDialerOpts...) + conn, err = contextDialer.DialContext(ctx, network, addr) if err == nil { return conn, nil diff --git a/lib/client/api.go b/lib/client/api.go index 22bc63d73e0cc..a49e5025c4d64 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2899,11 +2899,14 @@ func makeProxySSHClientWithTLSWrapper(ctx context.Context, tc *TeleportClient, s return nil, trace.Wrap(err) } - tlsConfig.NextProtos = []string{string(alpncommon.ProtocolProxySSH)} + tlsConfig.NextProtos = []string{ + string(alpncommon.ProtocolWithPing(alpncommon.ProtocolProxySSH)), + string(alpncommon.ProtocolProxySSH), + } alpnConfig := client.ALPNDialerConfig{ TLSConfig: tlsConfig, - ALPNConnUpgradeRequired: tc.isALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify), + ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify), DialTimeout: sshConfig.Timeout, } @@ -4458,7 +4461,7 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, - IsALPNConnUpgradeRequired: tc.isALPNConnUpgradeRequired, + IsALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired, }) if err != nil { return nil, trace.Wrap(err) @@ -4466,7 +4469,8 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste return kubeproto.NewKubeServiceClient(clt.GetConnection()), nil } -func (tc *TeleportClient) isALPNConnUpgradeRequired(addr string, insecure bool) bool { +// IsALPNConnUpgradeRequired returns true if connection upgrade is required for provided addr. +func (tc *TeleportClient) IsALPNConnUpgradeRequired(addr string, insecure bool) bool { // Use cached value. if addr == tc.WebProxyAddr { return tc.TLSRoutingConnUpgradeRequired diff --git a/lib/client/client.go b/lib/client/client.go index 43b7c248f97d5..811b7711c95c2 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1108,7 +1108,7 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co }, ALPNSNIAuthDialClusterName: clusterName, CircuitBreakerConfig: breaker.NoopBreakerConfig(), - IsALPNConnUpgradeRequired: proxy.teleportClient.isALPNConnUpgradeRequired, + IsALPNConnUpgradeRequired: proxy.teleportClient.IsALPNConnUpgradeRequired, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/alpnproxy/common/protocols.go b/lib/srv/alpnproxy/common/protocols.go index 73d61aaecba3c..c69c07e146705 100644 --- a/lib/srv/alpnproxy/common/protocols.go +++ b/lib/srv/alpnproxy/common/protocols.go @@ -202,6 +202,7 @@ var ProtocolsWithPingSupport = append( DatabaseProtocols, ProtocolTCP, ProtocolReverseTunnel, + ProtocolProxySSH, ) // WithPingProtocols adds Ping protocols to the list for each protocol that diff --git a/lib/srv/alpnproxy/proxy.go b/lib/srv/alpnproxy/proxy.go index 2c0595cf7f8cf..73df93271966b 100644 --- a/lib/srv/alpnproxy/proxy.go +++ b/lib/srv/alpnproxy/proxy.go @@ -208,6 +208,11 @@ func (h *HandlerDecs) CheckAndSetDefaults() error { if h.ForwardTLS && h.TLSConfig != nil { return trace.BadParameter("the ForwardTLS flag and TLSConfig can't be used at the same time") } + + if h.TLSConfig != nil && len(h.TLSConfig.NextProtos) == 0 { + h.TLSConfig = h.TLSConfig.Clone() + h.TLSConfig.NextProtos = common.ProtocolsToString(common.SupportedProtocols) + } return nil } diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 0d4634be65c09..82c98ae52a51a 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -38,9 +38,9 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" - apiutils "github.com/gravitational/teleport/api/utils" libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/db/dbcmd" "github.com/gravitational/teleport/lib/defaults" @@ -231,47 +231,36 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy // if sp.tlsRouting is true, remoteProxyAddr is the ALPN listener port. // if it is false, then remoteProxyAddr is the SSH proxy port. remoteProxyAddr := net.JoinHostPort(sp.proxyHost, sp.proxyPort) - httpsProxy := apiutils.GetProxyURL(remoteProxyAddr) - if sp.tlsRouting { - conn, err := dialSSHThroughALPNSNIProxy(ctx, tc, sp) - return conn, trace.Wrap(err) - } - - // If HTTPS_PROXY is configured, we need to open a TCP connection via - // the specified HTTPS Proxy, otherwise, we can just open a plain TCP - // connection. - if httpsProxy != nil { - httpProxyTLSConfig := &tls.Config{ - InsecureSkipVerify: tc.InsecureSkipVerify, + var dialer client.ContextDialer + switch { + case sp.tlsRouting: + pool, err := tc.LocalAgent().ClientCertPool(sp.clusterName) + if err != nil { + return nil, trace.Wrap(err) } - tcpConn, err := client.DialProxy(ctx, httpsProxy, remoteProxyAddr, client.WithTLSConfig(httpProxyTLSConfig)) - return tcpConn, trace.Wrap(err) - } - tcpConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteProxyAddr) - return tcpConn, err -} - -func dialSSHThroughALPNSNIProxy(ctx context.Context, tc *libclient.TeleportClient) (net.Conn, error) { - pool, err := tc.LocalAgent().ClientCertPool(sp.clusterName) - if err != nil { - return nil, trace.Wrap(err) - } - remoteProxyAddr := net.JoinHostPort(sp.proxyHost, sp.proxyPort) - conn, err := client.DialALPN(ctx, remoteProxyAddr, client.ALPNDialerConfig{ - TLSConfig: &tls.Config{ - RootCAs: pool, - NextProtos: []string{ - string(alpncommon.ProtocolWithPing(alpncommon.ProtocolProxySSH)), - string(alpncommon.ProtocolProxySSH), + dialer = client.NewALPNDialer(client.ALPNDialerConfig{ + TLSConfig: &tls.Config{ + RootCAs: pool, + NextProtos: []string{ + string(alpncommon.ProtocolWithPing(alpncommon.ProtocolProxySSH)), + string(alpncommon.ProtocolProxySSH), + }, + InsecureSkipVerify: tc.InsecureSkipVerify, + ServerName: sp.proxyHost, }, + ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired(remoteProxyAddr, tc.InsecureSkipVerify), + }) + + default: + dialer = client.NewDialer(ctx, apidefaults.DefaultIOTimeout, apidefaults.DefaultIdleTimeout, client.WithTLSConfig(&tls.Config{ InsecureSkipVerify: tc.InsecureSkipVerify, - ServerName: sp.proxyHost, - }, - ALPNConnUpgradeRequired: client.IsALPNConnUpgradeRequired(remoteProxyAddr, tc.InsecureSkipVerify), - }) - return tlsConn, nil + })) + } + + conn, err := dialer.DialContext(ctx, "tcp", remoteProxyAddr) + return conn, trace.Wrap(err) } func proxySubsystemName(userHost, cluster string) string { From 1d5760aaf69b6e4655beaa0e03a7c7f1e7787271 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 30 Mar 2023 14:57:02 -0400 Subject: [PATCH 12/27] Move ALPN dialer, ALPN conn upgrade, Ping conn to api --- .../alpnproxy/dialer.go => api/client/alpn.go | 12 +- .../client/alpn_conn_upgrade.go | 18 ++- .../client/alpn_conn_upgrade_test.go | 29 ++-- api/constants/constants.go | 12 ++ api/fixtures/fixtures.go | 64 +++++++++ api/utils/pingconn/pingconn.go | 132 ++++++++++++++++++ .../utils/pingconn/pingconn_test.go | 13 +- constants.go | 12 -- lib/client/api.go | 5 +- lib/fixtures/fixtures.go | 50 +------ lib/srv/alpnproxy/conn.go | 112 --------------- lib/srv/alpnproxy/local_proxy.go | 12 +- lib/srv/alpnproxy/local_proxy_config_opt.go | 3 +- lib/srv/alpnproxy/proxy.go | 3 +- lib/srv/alpnproxy/proxy_test.go | 3 +- lib/web/conn_upgrade.go | 8 +- 16 files changed, 260 insertions(+), 228 deletions(-) rename lib/srv/alpnproxy/dialer.go => api/client/alpn.go (85%) rename lib/srv/alpnproxy/conn_upgrade.go => api/client/alpn_conn_upgrade.go (92%) rename lib/srv/alpnproxy/conn_upgrade_test.go => api/client/alpn_conn_upgrade_test.go (85%) create mode 100644 api/fixtures/fixtures.go create mode 100644 api/utils/pingconn/pingconn.go rename lib/srv/alpnproxy/conn_test.go => api/utils/pingconn/pingconn_test.go (96%) diff --git a/lib/srv/alpnproxy/dialer.go b/api/client/alpn.go similarity index 85% rename from lib/srv/alpnproxy/dialer.go rename to api/client/alpn.go index f4b3e183bb97a..15ac2451194cc 100644 --- a/lib/srv/alpnproxy/dialer.go +++ b/api/client/alpn.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package alpnproxy +package client import ( "context" @@ -23,16 +23,8 @@ import ( "time" "github.com/gravitational/trace" - - apiclient "github.com/gravitational/teleport/api/client" ) -// ContextDialer represents network dialer interface that uses context -type ContextDialer interface { - // DialContext is a function that dials the specified address - DialContext(ctx context.Context, network, addr string) (net.Conn, error) -} - // ALPNDialerConfig is the config for ALPNDialer. type ALPNDialerConfig struct { // KeepAlivePeriod defines period between keep alives. @@ -66,7 +58,7 @@ func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net. return nil, trace.BadParameter("missing TLS config") } - dialer := apiclient.NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, apiclient.WithTLSConfig(d.cfg.TLSConfig)) + dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, WithTLSConfig(d.cfg.TLSConfig)) if d.cfg.ALPNConnUpgradeRequired { dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ InsecureSkipVerify: d.cfg.TLSConfig.InsecureSkipVerify, diff --git a/lib/srv/alpnproxy/conn_upgrade.go b/api/client/alpn_conn_upgrade.go similarity index 92% rename from lib/srv/alpnproxy/conn_upgrade.go rename to api/client/alpn_conn_upgrade.go index 1425ae70ee274..c82f91dd7ff63 100644 --- a/lib/srv/alpnproxy/conn_upgrade.go +++ b/api/client/alpn_conn_upgrade.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package alpnproxy +package client import ( "bufio" @@ -30,11 +30,9 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" - apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/lib/srv/alpnproxy/common" ) // IsALPNConnUpgradeRequired returns true if a tunnel is required through a HTTP @@ -57,7 +55,7 @@ func IsALPNConnUpgradeRequired(addr string, insecure bool) bool { Timeout: defaults.DefaultIOTimeout, } tlsConfig := &tls.Config{ - NextProtos: []string{string(common.ProtocolReverseTunnel)}, + NextProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, InsecureSkipVerify: insecure, } testConn, err := tls.DialWithDialer(netDialer, "tcp", addr, tlsConfig) @@ -145,12 +143,12 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool { // alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then // tunnels the connection with this connection upgrade. type alpnConnUpgradeDialer struct { - dialer apiclient.ContextDialer + dialer ContextDialer tlsConfig *tls.Config } // newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer. -func newALPNConnUpgradeDialer(dialer apiclient.ContextDialer, tlsConfig *tls.Config) ContextDialer { +func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config) ContextDialer { return &alpnConnUpgradeDialer{ dialer: dialer, tlsConfig: tlsConfig, @@ -187,7 +185,7 @@ func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr st err = upgradeConnThroughWebAPI(tlsConn, url.URL{ Host: addr, Scheme: "https", - Path: teleport.WebAPIConnUpgrade, + Path: constants.WebAPIConnUpgrade, }) if err != nil { defer tlsConn.Close() @@ -203,7 +201,7 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error { } // For now, only "alpn" is supported. - req.Header.Add(teleport.WebAPIConnUpgradeHeader, teleport.WebAPIConnUpgradeTypeALPN) + req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeALPN) // Send the request and check if upgrade is successful. if err = req.Write(conn); err != nil { @@ -219,7 +217,7 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error { if http.StatusNotFound == resp.StatusCode { return trace.NotImplemented( "connection upgrade call to %q failed with status code %v. Please upgrade the server and try again.", - teleport.WebAPIConnUpgrade, + constants.WebAPIConnUpgrade, resp.StatusCode, ) } diff --git a/lib/srv/alpnproxy/conn_upgrade_test.go b/api/client/alpn_conn_upgrade_test.go similarity index 85% rename from lib/srv/alpnproxy/conn_upgrade_test.go rename to api/client/alpn_conn_upgrade_test.go index c01d01eb4d131..4c02b3ebfebb6 100644 --- a/lib/srv/alpnproxy/conn_upgrade_test.go +++ b/api/client/alpn_conn_upgrade_test.go @@ -14,13 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -package alpnproxy +package client import ( "context" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "errors" "net" "net/http" @@ -31,11 +30,8 @@ import ( "github.com/stretchr/testify/require" - "github.com/gravitational/teleport" - apiclient "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/srv/alpnproxy/common" - "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/fixtures" ) func TestIsALPNConnUpgradeRequired(t *testing.T) { @@ -55,7 +51,7 @@ func TestIsALPNConnUpgradeRequired(t *testing.T) { }, { name: "upgrade not required (proto negotiated)", - serverProtos: []string{string(common.ProtocolReverseTunnel)}, + serverProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, insecure: true, expectedResult: false, }, @@ -67,7 +63,7 @@ func TestIsALPNConnUpgradeRequired(t *testing.T) { }, { name: "upgrade not required (other handshake error)", - serverProtos: []string{string(common.ProtocolReverseTunnel)}, + serverProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)}, insecure: false, // to cause handshake error expectedResult: false, }, @@ -138,7 +134,7 @@ func TestALPNConnUpgradeDialer(t *testing.T) { pool.AddCert(server.Certificate()) tlsConfig := &tls.Config{RootCAs: pool} - preDialer := apiclient.NewDialer(ctx, 0, 5*time.Second, apiclient.WithTLSConfig(tlsConfig)) + preDialer := NewDialer(ctx, 0, 5*time.Second) dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) conn, err := dialer.DialContext(ctx, "tcp", addr.Host) require.NoError(t, err) @@ -158,7 +154,7 @@ func TestALPNConnUpgradeDialer(t *testing.T) { require.NoError(t, err) tlsConfig := &tls.Config{InsecureSkipVerify: true} - preDialer := apiclient.NewDialer(ctx, 0, 5*time.Second, apiclient.WithTLSConfig(tlsConfig)) + preDialer := NewDialer(ctx, 0, 5*time.Second) dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) _, err = dialer.DialContext(ctx, "tcp", addr.Host) require.Error(t, err) @@ -207,12 +203,7 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe listener.Close() }) - caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{ - CommonName: "localhost", - }, []string{"localhost"}, defaults.CATTL) - require.NoError(t, err) - - cert, err := tls.X509KeyPair(caCert, caKey) + cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) require.NoError(t, err) m := &mockALPNServer{ @@ -228,8 +219,8 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe // upgrade request and sends back some data inside the tunnel. func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, teleport.WebAPIConnUpgrade, r.URL.Path) - require.Equal(t, upgradeType, r.Header.Get(teleport.WebAPIConnUpgradeHeader)) + require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) + require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader)) hj, ok := w.(http.Hijacker) require.True(t, ok) diff --git a/api/constants/constants.go b/api/constants/constants.go index 50c409cd6aca1..47b3b6c513c29 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -402,3 +402,15 @@ const ( // TimeoutGetClusterAlerts is the timeout for grabbing cluster alerts from tctl and tsh TimeoutGetClusterAlerts = time.Millisecond * 500 ) + +const ( + // WebAPIConnUpgrade is the HTTP web API to make the connection upgrade + // call. + WebAPIConnUpgrade = "/webapi/connectionupgrade" + // WebAPIConnUpgradeHeader is the header used to indicate the requested + // connection upgrade types in the connection upgrade API. + WebAPIConnUpgradeHeader = "Upgrade" + // WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies + // the upgraded connection should be handled by the ALPN handler. + WebAPIConnUpgradeTypeALPN = "alpn" +) diff --git a/api/fixtures/fixtures.go b/api/fixtures/fixtures.go new file mode 100644 index 0000000000000..573e4d3c9f651 --- /dev/null +++ b/api/fixtures/fixtures.go @@ -0,0 +1,64 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fixtures + +const ( + TLSCACertPEM = `-----BEGIN CERTIFICATE----- +MIIDKjCCAhKgAwIBAgIQJtJDJZZBkg/afM8d2ZJCTjANBgkqhkiG9w0BAQsFADBA +MRUwEwYDVQQKEwxUZWxlcG9ydCBPU1MxJzAlBgNVBAMTHnRlbGVwb3J0LmxvY2Fs +aG9zdC5sb2NhbGRvbWFpbjAeFw0xNzA1MDkxOTQwMzZaFw0yNzA1MDcxOTQwMzZa +MEAxFTATBgNVBAoTDFRlbGVwb3J0IE9TUzEnMCUGA1UEAxMedGVsZXBvcnQubG9j +YWxob3N0LmxvY2FsZG9tYWluMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAuKFLaf2iII/xDR+m2Yj6PnUEa+qzqwxsdLUjnunFZaAXG+hZm4Ml80SCiBgI +gTHQlJyLIkTtuRoH5aeMyz1ERUCtii4ZsTqDrjjUybxP4r+4HVX6m34s6hwEr8Fi +fts9pMp4iS3tQguRc28gPdDo/T6VrJTVYUfUUsNDRtIrlB5O9igqqLnuaY9eqGi4 +PUx0G0wRYJpRywoj8G0IkpfQTiX+CAC7dt5ws7ZrnGqCNBLGi5bGsaMmptVbsSEp +1TenntF54V1iR49IV5JqDhm1S0HmkleoJzKdc+6sP/xNepz9PJzuF9d9NubTLWgB +sK28YItcmWHdHXD/ODxVaehRjwIDAQABoyAwHjAOBgNVHQ8BAf8EBAMCB4AwDAYD +VR0TAQH/BAIwADANBgkqhkiG9w0BAQsFAAOCAQEAAVU6sNBdj76saHwOxGSdnEqQ +o2tMuR3msSM4F6wFK2UkKepsD7CYIf/PzNSNUqA5JIEUVeMqGyiHuAbU4C655nT1 +IyJX1D/+r73sSp5jbIpQm2xoQGZnj6g/Kltw8OSOAw+DsMF/PLVqoWJp07u6ew/m +NxWsJKcZ5k+q4eMxci9mKRHHqsquWKXzQlURMNFI+mGaFwrKM4dmzaR0BEc+ilSx +QqUvQ74smsLK+zhNikmgjlGC5ob9g8XkhVAkJMAh2rb9onDNiRl68iAgczP88mXu +vN/o98dypzsPxXmw6tkDqIRPUAUbh465rlY5sKMmRgXi2rUfl/QV5nbozUo/HQ== +-----END CERTIFICATE-----` + TLSCAKeyPEM = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAuKFLaf2iII/xDR+m2Yj6PnUEa+qzqwxsdLUjnunFZaAXG+hZ +m4Ml80SCiBgIgTHQlJyLIkTtuRoH5aeMyz1ERUCtii4ZsTqDrjjUybxP4r+4HVX6 +m34s6hwEr8Fifts9pMp4iS3tQguRc28gPdDo/T6VrJTVYUfUUsNDRtIrlB5O9igq +qLnuaY9eqGi4PUx0G0wRYJpRywoj8G0IkpfQTiX+CAC7dt5ws7ZrnGqCNBLGi5bG +saMmptVbsSEp1TenntF54V1iR49IV5JqDhm1S0HmkleoJzKdc+6sP/xNepz9PJzu +F9d9NubTLWgBsK28YItcmWHdHXD/ODxVaehRjwIDAQABAoIBABy4orWrShRMsA/9 +k4QVpfAfXf+3tBlwxlJld1QaQ6XqgI3L2FyzyyyLxM6NBo2qhSsJKy+6j0yTOxVD +ukhHkJ5BUH3FbCPA2Yk5uAhl7ft1HZwaqvCTcUM99pCswbjAPFetU5DrfxQeHpNZ +fyd+ny/+E2SUhpkqhmIVlBqpSTQyOywbiEvZ6ZiFmncdHhXaCy3YZsylrKUGPzsJ +jfU2iOE167eTOIjPStsaoCPv9jLSyy2OvuNNudS+Y1qkFz8ZGvPp+HB+Iig+AlAE +7KMzNrIW7PlHTDgUly1cRCl3+84yE2mJ97+hHiEy//HIwVDUpI529i2hMYM/u4qz +Wso/2tkCgYEA2FdE4bmCrZiA9eS8qobwGLE1+MJME4YwfJkynZUHHX93xORPQ66e +WYpN7/xbMvBDa8LZZYVTNVtZ/SkEUaTb5NQW2zXKoIutk1PFBb8NbA0m8Ss/mOJA +d5nUYGr987O9fRh1yP9TksBshHB/5A8U2UG8MFFCNvJTZDPRkuSlMiUCgYEA2nnb +hAJrhY7PaF6jdfimGvvponkUiEbWLppg7/SjgPg+QgqIwuLybryXyOAp+TEnNzgU +ujAjhNtIiyB/B13TDxOgUgWUWPbPvUAWGEvwI9h+RLie1umGHd48G1NR76fwqSf1 +y7z3YRnq8vCdz8ywB3o5GO6SH6QkMJBIxfIMlKMCgYA55akOi7oYQT8KD4waSwCI +ayyZhU4cz4W8Yrd0CsUbtNhVvhAked/w8J2JA01Y5Yn1lfDeRX8OQYNkyAxa2Tbs +F4KCafPvYVIzonCQ6B9sclygoEVl4e8E0wtOPnP2O30TtG8ZOpOgK5UfIIhpfUvE +FN6LQ8PntpRwtZl5qW04bQKBgGnHhFxHG64fthZPdA9jY3E/NSCgRSuyOHN59aNY +rG1+RA6PsSXC4iRxlYAB4PCxNs6KjaaUNi5WSaprAnYbnFv5Ya802l20qmJ0C/6Z +jdydLo2xYd6mVHRTrICCd/J0OpZ8LYsGpDPUa6hSjeYVscj9CXYj1IYTYB5PTZzh +k+vHAoGBAJyA+RtBF5m64/TqhZFcesTtnpWaRhQ50xXnNVF3W1eKGPtdTDKOaENA +LJxgC1GdoEz2ilXW802H9QrdKf9GPqxwi2TVzfO6pzWkdZcmbItu+QCCFz+co+r8 ++ki49FmlfbR5YVPN+8X40aLQB4xDkCHwRwTkrigzWQhIOv8NAhDA +-----END RSA PRIVATE KEY-----` +) diff --git a/api/utils/pingconn/pingconn.go b/api/utils/pingconn/pingconn.go new file mode 100644 index 0000000000000..1578a1110b90c --- /dev/null +++ b/api/utils/pingconn/pingconn.go @@ -0,0 +1,132 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pingconn + +import ( + "crypto/tls" + "encoding/binary" + "math" + "sync" + + "github.com/gravitational/trace" +) + +// New returns a ping connection wrapping the provided net.Conn. +func New(conn *tls.Conn) *PingConn { + return &PingConn{Conn: conn} +} + +// PingConn wraps a *tls.Conn and add ping capabilities to it, including the +// `WritePing` function and `Read` (which excludes ping packets). +// +// When using this connection, the packets written will contain an initial data: +// the packet size. When reading, this information is taken into account, but it +// is not returned to the caller. +// +// Ping messages have a packet size of zero and are produced only when +// `WritePing` is called. On `Read`, any Ping packet is discarded. +type PingConn struct { + //net.Conn + *tls.Conn + + muRead sync.Mutex + muWrite sync.Mutex + + // currentSize size of bytes of the current packet. + currentSize uint32 +} + +// Read reads content from the underlaying connection, discarding any ping +// messages it finds. +func (c *PingConn) Read(p []byte) (int, error) { + c.muRead.Lock() + defer c.muRead.Unlock() + + err := c.discardPingReads() + if err != nil { + return 0, err + } + + // Check if the current size is larger than the provided buffer. + readSize := c.currentSize + if c.currentSize > uint32(len(p)) { + readSize = uint32(len(p)) + } + + n, err := c.Conn.Read(p[:readSize]) + c.currentSize -= uint32(n) + + return n, err +} + +// WritePing writes the ping packet to the connection. +func (c *PingConn) WritePing() error { + c.muWrite.Lock() + defer c.muWrite.Unlock() + + return binary.Write(c.Conn, binary.BigEndian, uint32(0)) +} + +// discardPingReads reads from the wrapped net.Conn until it encounters a +// non-ping packet. +func (c *PingConn) discardPingReads() error { + for c.currentSize == 0 { + err := binary.Read(c.Conn, binary.BigEndian, &c.currentSize) + if err != nil { + return err + } + } + + return nil +} + +// Write writes provided content to the underlying connection with proper +// protocol fields. +func (c *PingConn) Write(p []byte) (int, error) { + c.muWrite.Lock() + defer c.muWrite.Unlock() + + // Avoid overflow when casting data length. It is only present to avoid + // panicking if the size cannot be cast. Callers should handle packet length + // limits, such as protocol implementations and audits. + if uint64(len(p)) > math.MaxUint32 { + return 0, trace.BadParameter("invalid content size, max size permitted is %d", uint64(math.MaxUint32)) + } + + size := uint32(len(p)) + if size == 0 { + return 0, nil + } + + // Write packet size. + if err := binary.Write(c.Conn, binary.BigEndian, size); err != nil { + return 0, trace.Wrap(err) + } + + // Iterate until everything is written. + var written int + for written < len(p) { + n, err := c.Conn.Write(p) + written += n + + if err != nil { + return written, trace.Wrap(err) + } + } + + return written, nil +} diff --git a/lib/srv/alpnproxy/conn_test.go b/api/utils/pingconn/pingconn_test.go similarity index 96% rename from lib/srv/alpnproxy/conn_test.go rename to api/utils/pingconn/pingconn_test.go index d91175d628a22..69a6a32cc5c47 100644 --- a/lib/srv/alpnproxy/conn_test.go +++ b/api/utils/pingconn/pingconn_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package alpnproxy +package pingconn import ( "bytes" @@ -26,6 +26,8 @@ import ( "time" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/fixtures" ) func TestPingConnection(t *testing.T) { @@ -273,7 +275,7 @@ func makePingConn(t *testing.T) (*PingConn, *PingConn) { writer, reader := net.Pipe() tlsWriter, tlsReader := makeTLSConn(t, writer, reader) - return NewPingConn(tlsWriter), NewPingConn(tlsReader) + return New(tlsWriter), New(tlsReader) } // makeBufferedPingConn creates connections to have asynchronous writes. @@ -321,7 +323,7 @@ func makeBufferedPingConn(t *testing.T) (*PingConn, *PingConn) { } tlsConnA, tlsConnB := makeTLSConn(t, connSlice[0], connSlice[1]) - return NewPingConn(tlsConnA), NewPingConn(tlsConnB) + return New(tlsConnA), New(tlsConnB) } // makeTLSConn take two connections (client and server) and wrap them into TLS @@ -334,10 +336,13 @@ func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) + require.NoError(t, err) + // Server go func() { tlsConn := tls.Server(server, &tls.Config{ - Certificates: []tls.Certificate{mustGenCertSignedWithCA(t, mustGenSelfSignedCert(t))}, + Certificates: []tls.Certificate{cert}, }) tlsConnChan <- struct { *tls.Conn diff --git a/constants.go b/constants.go index 86640f05a9c9f..18d8b708a737d 100644 --- a/constants.go +++ b/constants.go @@ -799,15 +799,3 @@ const UserSingleUseCertTTL = time.Minute // StandardHTTPSPort is the default port used for the https URI scheme, // cf. RFC 7230 ยง 2.7.2. const StandardHTTPSPort = 443 - -const ( - // WebAPIConnUpgrade is the HTTP web API to make the connection upgrade - // call. - WebAPIConnUpgrade = "/webapi/connectionupgrade" - // WebAPIConnUpgradeHeader is the header used to indicate the requested - // connection upgrade types in the connection upgrade API. - WebAPIConnUpgradeHeader = "Upgrade" - // WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies - // the upgraded connection should be handled by the ALPN handler. - WebAPIConnUpgradeTypeALPN = "alpn" -) diff --git a/lib/client/api.go b/lib/client/api.go index a4af7ff464b1e..34cd6b542433c 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -74,7 +74,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/shell" - "github.com/gravitational/teleport/lib/srv/alpnproxy" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/scp" @@ -622,7 +621,7 @@ func (c *Config) LoadProfile(ps ProfileStore, proxyAddr string) error { log.Warnf("Unable to parse dynamic port forwarding in user profile: %v.", err) } - if required, ok := alpnproxy.OverwriteALPNConnUpgradeRequirementByEnv(c.WebProxyAddr); ok { + if required, ok := client.OverwriteALPNConnUpgradeRequirementByEnv(c.WebProxyAddr); ok { c.TLSRoutingConnUpgradeRequired = required } log.Infof("ALPN connection upgrade required for %q: %v.", c.WebProxyAddr, c.TLSRoutingConnUpgradeRequired) @@ -3083,7 +3082,7 @@ func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) { } // Perform the ALPN test once at login. - tc.TLSRoutingConnUpgradeRequired = alpnproxy.IsALPNConnUpgradeRequired(tc.WebProxyAddr, tc.InsecureSkipVerify) + tc.TLSRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(tc.WebProxyAddr, tc.InsecureSkipVerify) // Get the SSHLoginFunc that matches client and cluster settings. sshLoginFunc, err := tc.getSSHLoginFunc(pr) diff --git a/lib/fixtures/fixtures.go b/lib/fixtures/fixtures.go index 9d6ce39dc1a4b..cf86dd07d1959 100644 --- a/lib/fixtures/fixtures.go +++ b/lib/fixtures/fixtures.go @@ -19,6 +19,8 @@ import ( "testing" "github.com/gravitational/trace" + + apifixtures "github.com/gravitational/teleport/api/fixtures" ) // AssertNotFound expects not found error @@ -149,52 +151,8 @@ spec: pHM7WKwFyW1dvEDax3BGj9/cbKvpvcwRurn:oasis:names:tc:SAML:1.1:nameid-format:emailAddressurn:oasis:names:tc:SAML:1.1:nameid-format:unspecified` const ( - TLSCACertPEM = `-----BEGIN CERTIFICATE----- -MIIDKjCCAhKgAwIBAgIQJtJDJZZBkg/afM8d2ZJCTjANBgkqhkiG9w0BAQsFADBA -MRUwEwYDVQQKEwxUZWxlcG9ydCBPU1MxJzAlBgNVBAMTHnRlbGVwb3J0LmxvY2Fs -aG9zdC5sb2NhbGRvbWFpbjAeFw0xNzA1MDkxOTQwMzZaFw0yNzA1MDcxOTQwMzZa -MEAxFTATBgNVBAoTDFRlbGVwb3J0IE9TUzEnMCUGA1UEAxMedGVsZXBvcnQubG9j -YWxob3N0LmxvY2FsZG9tYWluMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC -AQEAuKFLaf2iII/xDR+m2Yj6PnUEa+qzqwxsdLUjnunFZaAXG+hZm4Ml80SCiBgI -gTHQlJyLIkTtuRoH5aeMyz1ERUCtii4ZsTqDrjjUybxP4r+4HVX6m34s6hwEr8Fi -fts9pMp4iS3tQguRc28gPdDo/T6VrJTVYUfUUsNDRtIrlB5O9igqqLnuaY9eqGi4 -PUx0G0wRYJpRywoj8G0IkpfQTiX+CAC7dt5ws7ZrnGqCNBLGi5bGsaMmptVbsSEp -1TenntF54V1iR49IV5JqDhm1S0HmkleoJzKdc+6sP/xNepz9PJzuF9d9NubTLWgB -sK28YItcmWHdHXD/ODxVaehRjwIDAQABoyAwHjAOBgNVHQ8BAf8EBAMCB4AwDAYD -VR0TAQH/BAIwADANBgkqhkiG9w0BAQsFAAOCAQEAAVU6sNBdj76saHwOxGSdnEqQ -o2tMuR3msSM4F6wFK2UkKepsD7CYIf/PzNSNUqA5JIEUVeMqGyiHuAbU4C655nT1 -IyJX1D/+r73sSp5jbIpQm2xoQGZnj6g/Kltw8OSOAw+DsMF/PLVqoWJp07u6ew/m -NxWsJKcZ5k+q4eMxci9mKRHHqsquWKXzQlURMNFI+mGaFwrKM4dmzaR0BEc+ilSx -QqUvQ74smsLK+zhNikmgjlGC5ob9g8XkhVAkJMAh2rb9onDNiRl68iAgczP88mXu -vN/o98dypzsPxXmw6tkDqIRPUAUbh465rlY5sKMmRgXi2rUfl/QV5nbozUo/HQ== ------END CERTIFICATE-----` - TLSCAKeyPEM = `-----BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAuKFLaf2iII/xDR+m2Yj6PnUEa+qzqwxsdLUjnunFZaAXG+hZ -m4Ml80SCiBgIgTHQlJyLIkTtuRoH5aeMyz1ERUCtii4ZsTqDrjjUybxP4r+4HVX6 -m34s6hwEr8Fifts9pMp4iS3tQguRc28gPdDo/T6VrJTVYUfUUsNDRtIrlB5O9igq -qLnuaY9eqGi4PUx0G0wRYJpRywoj8G0IkpfQTiX+CAC7dt5ws7ZrnGqCNBLGi5bG -saMmptVbsSEp1TenntF54V1iR49IV5JqDhm1S0HmkleoJzKdc+6sP/xNepz9PJzu -F9d9NubTLWgBsK28YItcmWHdHXD/ODxVaehRjwIDAQABAoIBABy4orWrShRMsA/9 -k4QVpfAfXf+3tBlwxlJld1QaQ6XqgI3L2FyzyyyLxM6NBo2qhSsJKy+6j0yTOxVD -ukhHkJ5BUH3FbCPA2Yk5uAhl7ft1HZwaqvCTcUM99pCswbjAPFetU5DrfxQeHpNZ -fyd+ny/+E2SUhpkqhmIVlBqpSTQyOywbiEvZ6ZiFmncdHhXaCy3YZsylrKUGPzsJ -jfU2iOE167eTOIjPStsaoCPv9jLSyy2OvuNNudS+Y1qkFz8ZGvPp+HB+Iig+AlAE -7KMzNrIW7PlHTDgUly1cRCl3+84yE2mJ97+hHiEy//HIwVDUpI529i2hMYM/u4qz -Wso/2tkCgYEA2FdE4bmCrZiA9eS8qobwGLE1+MJME4YwfJkynZUHHX93xORPQ66e -WYpN7/xbMvBDa8LZZYVTNVtZ/SkEUaTb5NQW2zXKoIutk1PFBb8NbA0m8Ss/mOJA -d5nUYGr987O9fRh1yP9TksBshHB/5A8U2UG8MFFCNvJTZDPRkuSlMiUCgYEA2nnb -hAJrhY7PaF6jdfimGvvponkUiEbWLppg7/SjgPg+QgqIwuLybryXyOAp+TEnNzgU -ujAjhNtIiyB/B13TDxOgUgWUWPbPvUAWGEvwI9h+RLie1umGHd48G1NR76fwqSf1 -y7z3YRnq8vCdz8ywB3o5GO6SH6QkMJBIxfIMlKMCgYA55akOi7oYQT8KD4waSwCI -ayyZhU4cz4W8Yrd0CsUbtNhVvhAked/w8J2JA01Y5Yn1lfDeRX8OQYNkyAxa2Tbs -F4KCafPvYVIzonCQ6B9sclygoEVl4e8E0wtOPnP2O30TtG8ZOpOgK5UfIIhpfUvE -FN6LQ8PntpRwtZl5qW04bQKBgGnHhFxHG64fthZPdA9jY3E/NSCgRSuyOHN59aNY -rG1+RA6PsSXC4iRxlYAB4PCxNs6KjaaUNi5WSaprAnYbnFv5Ya802l20qmJ0C/6Z -jdydLo2xYd6mVHRTrICCd/J0OpZ8LYsGpDPUa6hSjeYVscj9CXYj1IYTYB5PTZzh -k+vHAoGBAJyA+RtBF5m64/TqhZFcesTtnpWaRhQ50xXnNVF3W1eKGPtdTDKOaENA -LJxgC1GdoEz2ilXW802H9QrdKf9GPqxwi2TVzfO6pzWkdZcmbItu+QCCFz+co+r8 -+ki49FmlfbR5YVPN+8X40aLQB4xDkCHwRwTkrigzWQhIOv8NAhDA ------END RSA PRIVATE KEY-----` + TLSCACertPEM = apifixtures.TLSCACertPEM + TLSCAKeyPEM = apifixtures.TLSCAKeyPEM // Backwards-compatibility alias for teleport.e SigningCertPEM = TLSCACertPEM ) diff --git a/lib/srv/alpnproxy/conn.go b/lib/srv/alpnproxy/conn.go index ef566dbb9fd81..c8bed17e08112 100644 --- a/lib/srv/alpnproxy/conn.go +++ b/lib/srv/alpnproxy/conn.go @@ -17,16 +17,10 @@ limitations under the License. package alpnproxy import ( - "crypto/tls" - "encoding/binary" "io" - "math" "net" - "sync" "time" - "github.com/gravitational/trace" - "github.com/gravitational/teleport/lib/utils" ) @@ -103,109 +97,3 @@ func (conn readOnlyConn) RemoteAddr() net.Addr { return &utils.Net func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } - -// NewPingConn returns a ping connection wrapping the provided net.Conn. -func NewPingConn(conn *tls.Conn) *PingConn { - return &PingConn{Conn: conn} -} - -// PingConn wraps a *tls.Conn and add ping capabilities to it, including the -// `WritePing` function and `Read` (which excludes ping packets). -// -// When using this connection, the packets written will contain an initial data: -// the packet size. When reading, this information is taken into account, but it -// is not returned to the caller. -// -// Ping messages have a packet size of zero and are produced only when -// `WritePing` is called. On `Read`, any Ping packet is discarded. -type PingConn struct { - //net.Conn - *tls.Conn - - muRead sync.Mutex - muWrite sync.Mutex - - // currentSize size of bytes of the current packet. - currentSize uint32 -} - -// Read reads content from the underlaying connection, discarding any ping -// messages it finds. -func (c *PingConn) Read(p []byte) (int, error) { - c.muRead.Lock() - defer c.muRead.Unlock() - - err := c.discardPingReads() - if err != nil { - return 0, err - } - - // Check if the current size is larger than the provided buffer. - readSize := c.currentSize - if c.currentSize > uint32(len(p)) { - readSize = uint32(len(p)) - } - - n, err := c.Conn.Read(p[:readSize]) - c.currentSize -= uint32(n) - - return n, err -} - -// WritePing writes the ping packet to the connection. -func (c *PingConn) WritePing() error { - c.muWrite.Lock() - defer c.muWrite.Unlock() - - return binary.Write(c.Conn, binary.BigEndian, uint32(0)) -} - -// discardPingReads reads from the wrapped net.Conn until it encounters a -// non-ping packet. -func (c *PingConn) discardPingReads() error { - for c.currentSize == 0 { - err := binary.Read(c.Conn, binary.BigEndian, &c.currentSize) - if err != nil { - return err - } - } - - return nil -} - -// Write writes provided content to the underlying connection with proper -// protocol fields. -func (c *PingConn) Write(p []byte) (int, error) { - c.muWrite.Lock() - defer c.muWrite.Unlock() - - // Avoid overflow when casting data length. It is only present to avoid - // panicking if the size cannot be cast. Callers should handle packet length - // limits, such as protocol implementations and audits. - if uint64(len(p)) > math.MaxUint32 { - return 0, trace.BadParameter("invalid content size, max size permitted is %d", uint64(math.MaxUint32)) - } - - size := uint32(len(p)) - if size == 0 { - return 0, nil - } - - // Write packet size. - if err := binary.Write(c.Conn, binary.BigEndian, size); err != nil { - return 0, trace.Wrap(err) - } - - // Iterate until everything is written. - var written int - for written < len(p) { - n, err := c.Conn.Write(p) - written += n - - if err != nil { - return written, trace.Wrap(err) - } - } - - return written, nil -} diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 103a45af5e60d..d914002faf7d9 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -34,6 +34,8 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/exp/slices" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" commonApp "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" @@ -218,7 +220,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC return trace.Wrap(err) } - tlsConn, err := DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) + tlsConn, err := client.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) if err != nil { return trace.Wrap(err) } @@ -227,7 +229,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC var upstreamConn net.Conn = tlsConn if common.IsPingProtocol(common.Protocol(tlsConn.ConnectionState().NegotiatedProtocol)) { l.cfg.Log.Debug("Using ping connection") - upstreamConn = NewPingConn(tlsConn) + upstreamConn = pingconn.New(tlsConn) } return trace.Wrap(utils.ProxyConn(ctx, downstreamConn, upstreamConn)) @@ -243,8 +245,8 @@ func (l *LocalProxy) Close() error { return nil } -func (l *LocalProxy) getALPNDialerConfig(certs []tls.Certificate) ALPNDialerConfig { - return ALPNDialerConfig{ +func (l *LocalProxy) getALPNDialerConfig(certs []tls.Certificate) client.ALPNDialerConfig { + return client.ALPNDialerConfig{ ALPNConnUpgradeRequired: l.cfg.ALPNConnUpgradeRequired, TLSConfig: &tls.Config{ NextProtos: common.ProtocolsToString(l.cfg.Protocols), @@ -285,7 +287,7 @@ func (l *LocalProxy) makeHTTPReverseProxy(certs []tls.Certificate) *httputil.Rev http.Error(w, http.StatusText(code), code) }, Transport: &http.Transport{ - DialTLSContext: NewALPNDialer(l.getALPNDialerConfig(certs)).DialContext, + DialTLSContext: client.NewALPNDialer(l.getALPNDialerConfig(certs)).DialContext, }, } } diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index 20b5e61e23dde..f2af0b5ed81ae 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" ) @@ -41,7 +42,7 @@ type GetClusterCACertPoolFunc func(ctx context.Context) (*x509.CertPool, error) // already been set. func WithALPNConnUpgradeTest(ctx context.Context, getClusterCertPool GetClusterCACertPoolFunc) LocalProxyConfigOpt { return func(config *LocalProxyConfig) error { - config.ALPNConnUpgradeRequired = IsALPNConnUpgradeRequired(config.RemoteProxyAddr, config.InsecureSkipVerify) + config.ALPNConnUpgradeRequired = client.IsALPNConnUpgradeRequired(config.RemoteProxyAddr, config.InsecureSkipVerify) return trace.Wrap(WithClusterCAsIfConnUpgrade(ctx, getClusterCertPool)(config)) } } diff --git a/lib/srv/alpnproxy/proxy.go b/lib/srv/alpnproxy/proxy.go index 4ac5613fb538d..2c0595cf7f8cf 100644 --- a/lib/srv/alpnproxy/proxy.go +++ b/lib/srv/alpnproxy/proxy.go @@ -33,6 +33,7 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" @@ -410,7 +411,7 @@ func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOver // handlePingConnection starts the server ping routine and returns `pingConn`. func (p *Proxy) handlePingConnection(ctx context.Context, conn *tls.Conn) net.Conn { - pingConn := NewPingConn(conn) + pingConn := pingconn.New(conn) // Start ping routine. It will continuously send pings in a defined // interval. diff --git a/lib/srv/alpnproxy/proxy_test.go b/lib/srv/alpnproxy/proxy_test.go index 6c81f62696de1..7e7e9451c8132 100644 --- a/lib/srv/alpnproxy/proxy_test.go +++ b/lib/srv/alpnproxy/proxy_test.go @@ -32,6 +32,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/srv/db/dbutils" "github.com/gravitational/teleport/lib/tlsca" @@ -191,7 +192,7 @@ func TestProxyTLSDatabaseHandler(t *testing.T) { }) require.NoError(t, err) - conn := NewPingConn(baseConn) + conn := pingconn.New(baseConn) tlsConn := tls.Client(conn, &tls.Config{ Certificates: []tls.Certificate{ clientCert, diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index 9f1638a74a2d2..cdfe3dae0556a 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -25,17 +25,17 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/utils" ) // selectConnectionUpgrade selects the requested upgrade type and returns the // corresponding handler. func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHandler, error) { - upgrades := r.Header.Values(teleport.WebAPIConnUpgradeHeader) + upgrades := r.Header.Values(constants.WebAPIConnUpgradeHeader) for _, upgradeType := range upgrades { switch upgradeType { - case teleport.WebAPIConnUpgradeTypeALPN: + case constants.WebAPIConnUpgradeTypeALPN: return upgradeType, h.upgradeALPN, nil } } @@ -90,7 +90,7 @@ func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error { func writeUpgradeResponse(w io.Writer, upgradeType string) error { header := make(http.Header) - header.Add(teleport.WebAPIConnUpgradeHeader, upgradeType) + header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) response := &http.Response{ Status: http.StatusText(http.StatusSwitchingProtocols), StatusCode: http.StatusSwitchingProtocols, From a5de0d9e2a747299beab7f06b2b6d1c6f56fd813 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 30 Mar 2023 16:01:45 -0400 Subject: [PATCH 13/27] beatify --- api/client/alpn.go | 62 +++++++++++++---- api/client/client.go | 20 +++--- api/client/contextdialer.go | 37 ++++------ api/client/proxy.go | 11 ++- api/utils/slices.go | 4 +- lib/auth/clt.go | 12 +--- lib/client/api.go | 12 +--- lib/reversetunnel/agentpool.go | 80 +++++++--------------- lib/reversetunnel/transport.go | 25 +++---- lib/service/service_test.go | 4 ++ lib/srv/alpnproxy/common/protocols.go | 6 ++ lib/srv/alpnproxy/common/protocols_test.go | 24 ++++++- lib/utils/proxy/proxy.go | 45 +++++++----- tool/tsh/proxy.go | 11 +-- 14 files changed, 185 insertions(+), 168 deletions(-) diff --git a/api/client/alpn.go b/api/client/alpn.go index 0414fb54be113..7e848d3a0a7b5 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -19,6 +19,7 @@ package client import ( "context" "crypto/tls" + "crypto/x509" "net" "strings" "time" @@ -31,6 +32,17 @@ import ( "github.com/gravitational/teleport/api/utils/pingconn" ) +// GetClusterCAsFunc is a function to fetch cluster CAs. +type GetClusterCAsFunc func(ctx context.Context) (*x509.CertPool, error) + +// ClusterCAsFromCertPool returns a GetClusterCAsFunc with provided static cert +// pool. +func ClusterCAsFromCertPool(cas *x509.CertPool) GetClusterCAsFunc { + return func(_ context.Context) (*x509.CertPool, error) { + return cas, nil + } +} + // ALPNDialerConfig is the config for ALPNDialer. type ALPNDialerConfig struct { // KeepAlivePeriod defines period between keep alives. @@ -41,13 +53,17 @@ type ALPNDialerConfig struct { TLSConfig *tls.Config // ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required. ALPNConnUpgradeRequired bool + // GetClusterCAs is an optional callback function to fetch cluster + // CAs when connection upgrade is required. If not provided, it's assumed + // the proper CAs are already present in TLSConfig. + GetClusterCAs GetClusterCAsFunc } // ALPNDialer is a ContextDialer that dials a connection to the Proxy Service // with ALPN and SNI configured in the provided TLSConfig. An ALPN connection // upgrade is also performed at the initial connection, if an upgrade is -// required. If the negotiated protocol is a ping protocol, it will return the -// de-multiplexed connection without the ping. +// required. If the negotiated protocol is a Ping protocol, it will return the +// de-multiplexed connection without the Ping. type ALPNDialer struct { cfg ALPNDialerConfig } @@ -59,36 +75,54 @@ func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer { } } -func (d *ALPNDialer) getTLSConfig(addr string) (*tls.Config, error) { +func (d *ALPNDialer) shouldUpdateTLSConfig() bool { + switch { + case d.cfg.TLSConfig.ServerName == "": + return true + case d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil: + return true + default: + return false + } +} + +func (d *ALPNDialer) getTLSConfig(ctx context.Context, addr string) (*tls.Config, error) { if d.cfg.TLSConfig == nil { return nil, trace.BadParameter("missing TLS config") } - if d.cfg.TLSConfig.ServerName != "" { + + if !d.shouldUpdateTLSConfig() { return d.cfg.TLSConfig, nil } tlsConfig := d.cfg.TLSConfig.Clone() - host, _, err := webclient.ParseHostPort(addr) - if err != nil { - return nil, trace.Wrap(err) + if d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil { + rootCAs, err := d.cfg.GetClusterCAs(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + tlsConfig.RootCAs = rootCAs + } + + if d.cfg.TLSConfig.ServerName == "" { + host, _, err := webclient.ParseHostPort(addr) + if err != nil { + return nil, trace.Wrap(err) + } + tlsConfig.ServerName = host } - tlsConfig.ServerName = host return tlsConfig, nil } // DialContext implements ContextDialer. func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - tlsConfig, err := d.getTLSConfig(addr) + tlsConfig, err := d.getTLSConfig(ctx, addr) if err != nil { return nil, trace.Wrap(err) } - tlsConfigForDialer := &tls.Config{ - InsecureSkipVerify: tlsConfig.InsecureSkipVerify, - } - dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, - WithTLSConfig(tlsConfigForDialer), + WithInsecureSkipVerify(d.cfg.TLSConfig.InsecureSkipVerify), WithALPNConnUpgrade(d.cfg.ALPNConnUpgradeRequired), ) diff --git a/api/client/client.go b/api/client/client.go index 43aaaf6dfb8b4..bbfabdc89efba 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -330,9 +330,7 @@ type ( // authConnect connects to the Teleport Auth Server directly. func authConnect(ctx context.Context, params connectParams) (*Client, error) { dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, - WithTLSConfig(&tls.Config{ - InsecureSkipVerify: params.tlsConfig.InsecureSkipVerify, - }), + WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery), WithALPNConnUpgrade(IsWebProxyAndConnUpgradeRequired(ctx, params.addr, ¶ms.cfg)), ) @@ -343,10 +341,16 @@ func authConnect(ctx context.Context, params connectParams) (*Client, error) { return clt, nil } -// TODO +// IsWebProxyAndConnUpgradeRequired returns if targetAddr is a Teleport web +// proxy address and ALPN connection upgrade is required for it. If no cluster +// name is provided for ALPN, assume dialer is not trying to connect Auth +// through Proxy using TLS routing. func IsWebProxyAndConnUpgradeRequired(ctx context.Context, targetAddr string, cfg *Config) bool { - return isWebProxy(ctx, targetAddr, cfg) && cfg.IsALPNConnUpgradeRequired(targetAddr, cfg.InsecureAddressDiscovery) + return cfg.ALPNSNIAuthDialClusterName != "" && + isWebProxy(ctx, targetAddr, cfg) && + cfg.IsALPNConnUpgradeRequired(targetAddr, cfg.InsecureAddressDiscovery) } + func isWebProxy(ctx context.Context, targetAddr string, cfg *Config) bool { if cfg.WebProxyAddr != "" && cfg.WebProxyAddr == targetAddr { return true @@ -518,7 +522,9 @@ func (c *Client) waitForConnectionReady(ctx context.Context) error { type Config struct { // Addrs is a list of teleport auth/proxy server addresses to dial. Addrs []string - // WebProxyAddr is the Teleport Proxy web address. + // WebProxyAddr is the Teleport Proxy web address. If not provided, extra + // webapi pings may be required to find out if Addrs are web proxy + // addresses. WebProxyAddr string // Credentials are a list of credentials to use when attempting // to connect to the server. @@ -590,11 +596,9 @@ func (c *Config) CheckAndSetDefaults() error { if c.IsALPNConnUpgradeRequired == nil { c.IsALPNConnUpgradeRequired = IsALPNConnUpgradeRequired } - if c.WebProxyAddr != "" { c.Addrs = utils.Deduplicate(append(c.Addrs, c.WebProxyAddr)) } - return nil } diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 1714b8674e352..9099fa3ee6487 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -19,6 +19,7 @@ package client import ( "context" "crypto/tls" + "crypto/x509" "net" "net/url" "time" @@ -112,19 +113,6 @@ func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, }) } -func isAddrWebProxy(ctx context.Context, targetAddr, knownWebProxyAddr string, insecure bool) bool { - if knownWebProxyAddr != "" && targetAddr == knownWebProxyAddr { - return true - } - - _, err := webclient.Find(&webclient.Config{ - Context: ctx, - ProxyAddr: targetAddr, - Insecure: insecure, - }) - return err == nil -} - // NewProxyDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy. // The dialer will ping the web client to discover the tunnel proxy address on each dial. func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool, opts ...DialProxyOption) ContextDialer { @@ -176,11 +164,7 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer { return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { insecure := params.cfg.InsecureAddressDiscovery - resp, err := webclient.Find(&webclient.Config{ - Context: ctx, - ProxyAddr: params.addr, - Insecure: insecure, - }) + resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: params.addr, Insecure: insecure}) if err != nil { return nil, trace.Wrap(err) } @@ -198,11 +182,9 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte if err != nil { return nil, trace.Wrap(err) } - tlsConn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ - ALPNConnUpgradeRequired: params.cfg.IsALPNConnUpgradeRequired(tunnelAddr, insecure), - DialTimeout: params.cfg.DialTimeout, - KeepAlivePeriod: params.cfg.KeepAlivePeriod, + DialTimeout: params.cfg.DialTimeout, + KeepAlivePeriod: params.cfg.KeepAlivePeriod, TLSConfig: &tls.Config{ NextProtos: []string{ constants.ALPNSNIProtocolReverseTunnel + constants.ALPNSNIProtocolPingSuffix, @@ -211,6 +193,17 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte InsecureSkipVerify: insecure, ServerName: host, }, + ALPNConnUpgradeRequired: params.cfg.IsALPNConnUpgradeRequired(tunnelAddr, insecure), + GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { + if len(params.cfg.Credentials) == 0 { + return nil, trace.BadParameter("no credentials") + } + tlsConfig, err := params.cfg.Credentials[0].TLSConfig() + if err != nil { + return nil, trace.Wrap(err) + } + return tlsConfig.RootCAs, nil + }, }) if err != nil { return nil, trace.Wrap(err) diff --git a/api/client/proxy.go b/api/client/proxy.go index c6bd125667f86..b546301760279 100644 --- a/api/client/proxy.go +++ b/api/client/proxy.go @@ -45,7 +45,16 @@ func WithTLSConfig(tlsConfig *tls.Config) DialProxyOption { } } -// TODO +// WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy. +func WithInsecureSkipVerify(insecure bool) DialProxyOption { + return func(cfg *dialProxyConfig) { + cfg.tlsConfig = &tls.Config{ + InsecureSkipVerify: insecure, + } + } +} + +// WithALPNConnUpgrade specifies if ALPN connection upgrade is required. func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialProxyOption { return func(cfg *dialProxyConfig) { cfg.alpnConnUpgradeRequired = alpnConnUpgradeRequired diff --git a/api/utils/slices.go b/api/utils/slices.go index ef0bdc08cf2b4..61dc124200718 100644 --- a/api/utils/slices.go +++ b/api/utils/slices.go @@ -47,7 +47,7 @@ func JoinStrings[T ~string](elems []T, sep string) T { // Deduplicate deduplicates list of comparable values. func Deduplicate[T comparable](in []T) []T { - if len(in) == 0 { + if len(in) <= 1 { return in } out := make([]T, 0, len(in)) @@ -63,7 +63,7 @@ func Deduplicate[T comparable](in []T) []T { // DeduplicateAny deduplicates list of any values with compare function. func DeduplicateAny[T any](in []T, compare func(T, T) bool) []T { - if len(in) == 0 { + if len(in) <= 1 { return in } out := make([]T, 0, len(in)) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index fd18675798b8a..e98dcc2bbe06b 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -18,7 +18,6 @@ package auth import ( "context" - "crypto/tls" "net" "net/url" "time" @@ -102,15 +101,10 @@ func NewClient(cfg client.Config, params ...roundtrip.ClientParam) (*Client, err } httpDialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { for _, addr := range cfg.Addrs { - contextDialerOpts := []client.DialProxyOption{ - client.WithTLSConfig(&tls.Config{ - InsecureSkipVerify: httpTLS.InsecureSkipVerify, - }), + contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, + client.WithInsecureSkipVerify(httpTLS.InsecureSkipVerify), client.WithALPNConnUpgrade(client.IsWebProxyAndConnUpgradeRequired(ctx, addr, &cfg)), - } - - contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, contextDialerOpts...) - + ) conn, err = contextDialer.DialContext(ctx, network, addr) if err == nil { return conn, nil diff --git a/lib/client/api.go b/lib/client/api.go index f39b549766d13..813ec81ff0618 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2900,18 +2900,12 @@ func makeProxySSHClientWithTLSWrapper(ctx context.Context, tc *TeleportClient, s return nil, trace.Wrap(err) } - tlsConfig.NextProtos = []string{ - string(alpncommon.ProtocolWithPing(alpncommon.ProtocolProxySSH)), - string(alpncommon.ProtocolProxySSH), - } - - alpnConfig := client.ALPNDialerConfig{ + tlsConfig.NextProtos = alpncommon.NextProtosWithPing(alpncommon.ProtocolProxySSH) + dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(client.ALPNDialerConfig{ TLSConfig: tlsConfig, ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify), DialTimeout: sshConfig.Timeout, - } - - dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(alpnConfig)) + })) return dialer.Dial(ctx, "tcp", proxyAddr, sshConfig) } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index a16545272fad1..ea89ee182ad95 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -19,6 +19,7 @@ package reversetunnel import ( "context" "crypto/tls" + "crypto/x509" "errors" "io" "net" @@ -39,7 +40,6 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" - libdefaults "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnel/track" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" @@ -480,13 +480,9 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease p.log.WithError(err).Debugf("Failed to update remote config.") } - options := []proxy.DialerOptionFunc{} + options := []proxy.DialerOptionFunc{proxy.WithInsecureSkipTLSVerify(lib.IsInsecureDevMode())} if p.runtimeConfig.useALPNRouting() { - alpnDialerConfig, err := p.makeALPNDialerConfig() - if err != nil { - return nil, trace.Wrap(err) - } - options = append(options, proxy.WithALPNDialer(alpnDialerConfig)) + options = append(options, proxy.WithALPNDialer(p.runtimeConfig.alpnDialerConfig(p.getClusterCAs))) } dialer := &agentDialer{ @@ -500,7 +496,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease agent, err := newAgent(agentConfig{ addr: *addr, - keepAlive: p.runtimeConfig.getKeepAliveInterval(), + keepAlive: p.runtimeConfig.keepAliveInterval, sshDialer: dialer, transporter: p, versionGetter: p, @@ -519,33 +515,9 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease return agent, nil } -func (p *AgentPool) makeALPNDialerConfig() (client.ALPNDialerConfig, error) { - tlsConfig := &tls.Config{ - NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, - InsecureSkipVerify: lib.IsInsecureDevMode(), - } - - if p.runtimeConfig.useReverseTunnelV2() { - tlsConfig.NextProtos = []string{ - string(alpncommon.ProtocolReverseTunnelV2), - string(alpncommon.ProtocolReverseTunnel), - } - } - - config := client.ALPNDialerConfig{ - TLSConfig: tlsConfig, - ALPNConnUpgradeRequired: p.runtimeConfig.useALPNConnUpgrade(), - KeepAlivePeriod: p.runtimeConfig.getKeepAliveInterval(), - } - - if config.ALPNConnUpgradeRequired { - rootCAs, _, err := auth.ClientCertPool(p.AccessPoint, p.Cluster, types.HostCA) - if err != nil { - return client.ALPNDialerConfig{}, trace.Wrap(err) - } - tlsConfig.RootCAs = rootCAs - } - return config, nil +func (p *AgentPool) getClusterCAs(_ context.Context) (*x509.CertPool, error) { + clusterCAs, _, err := auth.ClientCertPool(p.AccessPoint, p.Cluster, types.HostCA) + return clusterCAs, trace.Wrap(err) } // Wait blocks until the pool context is stopped. @@ -659,14 +631,25 @@ func (c *agentPoolRuntimeConfig) restrictConnectionCount() bool { return c.tunnelStrategyType == types.ProxyPeering } -// useReverseTunnelV2 returns true if reverse tunnel should be used. -func (c *agentPoolRuntimeConfig) useReverseTunnelV2() bool { +// alpnDialerConfig creates a config for ALPN dialer. +func (c *agentPoolRuntimeConfig) alpnDialerConfig(getClusterCAs client.GetClusterCAsFunc) client.ALPNDialerConfig { c.mu.RLock() defer c.mu.RUnlock() - if c.isRemoteCluster { - return false + + nextProtos := alpncommon.NextProtosWithPing(alpncommon.ProtocolReverseTunnel) + if !c.isRemoteCluster && c.tunnelStrategyType == types.ProxyPeering { + nextProtos = append([]string{string(alpncommon.ProtocolReverseTunnelV2)}, nextProtos...) + } + + return client.ALPNDialerConfig{ + TLSConfig: &tls.Config{ + NextProtos: nextProtos, + InsecureSkipVerify: lib.IsInsecureDevMode(), + }, + KeepAlivePeriod: c.keepAliveInterval, + ALPNConnUpgradeRequired: c.tlsRoutingConnUpgradeRequired, + GetClusterCAs: getClusterCAs, } - return c.tunnelStrategyType == types.ProxyPeering } // useALPNRouting returns true agents should connect using alpn routing. @@ -680,23 +663,6 @@ func (c *agentPoolRuntimeConfig) useALPNRouting() bool { return c.proxyListenerMode == types.ProxyListenerMode_Multiplex } -func (c *agentPoolRuntimeConfig) useALPNConnUpgrade() bool { - c.mu.RLock() - defer c.mu.RUnlock() - return c.tlsRoutingConnUpgradeRequired -} - -func (c *agentPoolRuntimeConfig) getKeepAliveInterval() time.Duration { - c.mu.RLock() - defer c.mu.RUnlock() - - // When behind a load balancer, use a shorter ping. - if c.tlsRoutingConnUpgradeRequired { - return utils.MinTTL(libdefaults.ProxyPingInterval, c.keepAliveInterval) - } - return c.keepAliveInterval -} - func (c *agentPoolRuntimeConfig) updateRemote(ctx context.Context, addr *utils.NetAddr) error { c.updateRemoteMu.Lock() defer c.updateRemoteMu.Unlock() diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 89ccb78198c1c..5f5db7dbb28fd 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -87,7 +87,9 @@ type TunnelAuthDialer struct { // DialContext dials auth server via SSH tunnel func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { // Connect to the reverse tunnel server. - opts := []proxy.DialerOptionFunc{} + opts := []proxy.DialerOptionFunc{ + proxy.WithInsecureSkipTLSVerify(t.InsecureSkipTLSVerify), + } addr, mode, err := t.Resolver(ctx) if err != nil { @@ -96,23 +98,14 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co } if mode == types.ProxyListenerMode_Multiplex { - tlsConfig := &tls.Config{ - NextProtos: []string{ - string(alpncommon.ProtocolWithPing(alpncommon.ProtocolReverseTunnel)), - string(alpncommon.ProtocolReverseTunnel), - }, - InsecureSkipVerify: t.InsecureSkipTLSVerify, - } - - alpnConnUpgradeRequired := client.IsALPNConnUpgradeRequired(addr.Addr, t.InsecureSkipTLSVerify) - if alpnConnUpgradeRequired { - tlsConfig.RootCAs = t.ClusterCAs - } - opts = append(opts, proxy.WithALPNDialer(client.ALPNDialerConfig{ - TLSConfig: tlsConfig, - ALPNConnUpgradeRequired: alpnConnUpgradeRequired, + TLSConfig: &tls.Config{ + NextProtos: alpncommon.NextProtosWithPing(alpncommon.ProtocolReverseTunnel), + InsecureSkipVerify: t.InsecureSkipTLSVerify, + }, DialTimeout: t.ClientConfig.Timeout, + ALPNConnUpgradeRequired: client.IsALPNConnUpgradeRequired(addr.Addr, t.InsecureSkipTLSVerify), + GetClusterCAs: client.ClusterCAsFromCertPool(t.ClusterCAs), })) } diff --git a/lib/service/service_test.go b/lib/service/service_test.go index e093055a0dd81..70b444d511320 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -485,6 +485,8 @@ func TestSetupProxyTLSConfig(t *testing.T) { "http/1.1", "h2", "acme-tls/1", + "teleport-proxy-ssh-ping", + "teleport-reversetunnel-ping", "teleport-tcp-ping", "teleport-postgres-ping", "teleport-mysql-ping", @@ -517,6 +519,8 @@ func TestSetupProxyTLSConfig(t *testing.T) { name: "ACME disabled", acmeEnabled: false, wantNextProtos: []string{ + "teleport-proxy-ssh-ping", + "teleport-reversetunnel-ping", "teleport-tcp-ping", "teleport-postgres-ping", "teleport-mysql-ping", diff --git a/lib/srv/alpnproxy/common/protocols.go b/lib/srv/alpnproxy/common/protocols.go index c69c07e146705..2f2f93291a85f 100644 --- a/lib/srv/alpnproxy/common/protocols.go +++ b/lib/srv/alpnproxy/common/protocols.go @@ -136,6 +136,12 @@ func ProtocolsToString(protocols []Protocol) []string { return out } +// NextProtosWithPing adds Ping protocols to provided list of ALPN protocols +// then converts them to a list of strings for tls.Config.NextProtos. +func NextProtosWithPing(protocols ...Protocol) []string { + return ProtocolsToString(WithPingProtocols(protocols)) +} + // ToALPNProtocol maps provided database protocol to ALPN protocol. func ToALPNProtocol(dbProtocol string) (Protocol, error) { switch dbProtocol { diff --git a/lib/srv/alpnproxy/common/protocols_test.go b/lib/srv/alpnproxy/common/protocols_test.go index 2268008eaa3e4..5873b05fdf520 100644 --- a/lib/srv/alpnproxy/common/protocols_test.go +++ b/lib/srv/alpnproxy/common/protocols_test.go @@ -27,16 +27,16 @@ func TestWithPingProtocols(t *testing.T) { []Protocol{ "teleport-tcp-ping", "teleport-redis-ping", - "teleport-reversetunnel", + "teleport-auth@", "teleport-tcp", "teleport-redis", "h2", }, WithPingProtocols([]Protocol{ - ProtocolReverseTunnel, + ProtocolAuth, ProtocolTCP, ProtocolRedisDB, - ProtocolReverseTunnel, + ProtocolAuth, ProtocolHTTP2, }), ) @@ -48,3 +48,21 @@ func TestIsDBTLSProtocol(t *testing.T) { require.False(t, IsDBTLSProtocol("teleport-tcp")) require.False(t, IsDBTLSProtocol("")) } + +func BenchmarkNextProtosWithPing(b *testing.B) { + b.Run("one with ping support", func(b *testing.B) { + for n := 0; n < b.N; n++ { + NextProtosWithPing(ProtocolReverseTunnel) + } + }) + b.Run("one without ping support", func(b *testing.B) { + for n := 0; n < b.N; n++ { + NextProtosWithPing(ProtocolHTTP) + } + }) + b.Run("five", func(b *testing.B) { + for n := 0; n < b.N; n++ { + NextProtosWithPing(ProtocolAuth, ProtocolTCP, ProtocolRedisDB, ProtocolAuth, ProtocolHTTP2) + } + }) +} diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index 734b62041d1b7..441b78394188e 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -27,7 +27,6 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/client" apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" @@ -78,7 +77,7 @@ type Dialer interface { type directDial struct { // alpnDialer is the dialer used for TLS routing. - alpnDialer client.ContextDialer + alpnDialer apiclient.ContextDialer } // Dial calls ssh.Dial directly. @@ -99,11 +98,13 @@ func (d directDial) Dial(ctx context.Context, network string, addr string, confi // DialTimeout acts like Dial but takes a timeout. func (d directDial) DialTimeout(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) { - dialer := d.alpnDialer - if dialer == nil { - dialer = &net.Dialer{ - Timeout: timeout, - } + if d.alpnDialer != nil { + conn, err := d.alpnDialer.DialContext(ctx, network, address) + return conn, trace.Wrap(err) + } + + dialer := &net.Dialer{ + Timeout: timeout, } conn, err := dialer.DialContext(ctx, network, address) @@ -116,8 +117,10 @@ func (d directDial) DialTimeout(ctx context.Context, network, address string, ti type proxyDial struct { // proxyHost is the HTTPS proxy address. proxyURL *url.URL + // insecure is whether to skip certificate validation. + insecure bool // alpnDialer is the dialer used for TLS routing. - alpnDialer client.ContextDialer + alpnDialer apiclient.ContextDialer } // DialTimeout acts like Dial but takes a timeout. @@ -128,12 +131,14 @@ func (d proxyDial) DialTimeout(ctx context.Context, network, address string, tim defer cancel() ctx = timeoutCtx } + + // ALPN dialer handles proxy URL internally. if d.alpnDialer != nil { tlsConn, err := d.alpnDialer.DialContext(ctx, network, address) return tlsConn, trace.Wrap(err) } - conn, err := apiclient.DialProxy(ctx, d.proxyURL, address) + conn, err := apiclient.DialProxy(ctx, d.proxyURL, address, apiclient.WithInsecureSkipVerify(d.insecure)) if err != nil { return nil, trace.Wrap(err) } @@ -144,13 +149,7 @@ func (d proxyDial) DialTimeout(ctx context.Context, network, address string, tim // SSH connection. func (d proxyDial) Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { // Build a proxy connection first. - var pconn net.Conn - var err error - if d.alpnDialer != nil { - pconn, err = d.alpnDialer.DialContext(ctx, network, addr) - } else { - pconn, err = apiclient.DialProxy(ctx, d.proxyURL, addr) - } + pconn, err := d.DialTimeout(ctx, network, addr, config.Timeout) if err != nil { return nil, trace.Wrap(err) } @@ -178,16 +177,23 @@ type dialerOptions struct { // insecureSkipTLSVerify is whether to skip certificate validation. insecureSkipTLSVerify bool // alpnDialer is the dialer used for TLS routing. - alpnDialer client.ContextDialer + alpnDialer apiclient.ContextDialer } // DialerOptionFunc allows setting options as functional arguments to DialerFromEnvironment type DialerOptionFunc func(options *dialerOptions) // WithALPNDialer creates a dialer that allows to Teleport running in single-port mode. -func WithALPNDialer(alpnDialerConfig client.ALPNDialerConfig) DialerOptionFunc { +func WithALPNDialer(alpnDialerConfig apiclient.ALPNDialerConfig) DialerOptionFunc { + return func(options *dialerOptions) { + options.alpnDialer = apiclient.NewALPNDialer(alpnDialerConfig) + } +} + +// WithInsecureSkipTLSVerify skips the certs verifications. +func WithInsecureSkipTLSVerify(insecure bool) DialerOptionFunc { return func(options *dialerOptions) { - options.alpnDialer = client.NewALPNDialer(alpnDialerConfig) + options.insecureSkipTLSVerify = insecure } } @@ -215,6 +221,7 @@ func DialerFromEnvironment(addr string, opts ...DialerOptionFunc) Dialer { log.Debugf("Found proxy %q in environment, returning proxy dialer.", proxyURL) return proxyDial{ proxyURL: proxyURL, + insecure: options.insecureSkipTLSVerify, alpnDialer: options.alpnDialer, } } diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 82c98ae52a51a..a1ca2cd4312c2 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -242,11 +242,8 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy dialer = client.NewALPNDialer(client.ALPNDialerConfig{ TLSConfig: &tls.Config{ - RootCAs: pool, - NextProtos: []string{ - string(alpncommon.ProtocolWithPing(alpncommon.ProtocolProxySSH)), - string(alpncommon.ProtocolProxySSH), - }, + RootCAs: pool, + NextProtos: alpncommon.NextProtosWithPing(alpncommon.ProtocolProxySSH), InsecureSkipVerify: tc.InsecureSkipVerify, ServerName: sp.proxyHost, }, @@ -254,9 +251,7 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy }) default: - dialer = client.NewDialer(ctx, apidefaults.DefaultIOTimeout, apidefaults.DefaultIdleTimeout, client.WithTLSConfig(&tls.Config{ - InsecureSkipVerify: tc.InsecureSkipVerify, - })) + dialer = client.NewDialer(ctx, apidefaults.DefaultIOTimeout, apidefaults.DefaultIdleTimeout, client.WithInsecureSkipVerify(tc.InsecureSkipVerify)) } conn, err := dialer.DialContext(ctx, "tcp", remoteProxyAddr) From 29f493aee24bfb6a2d7aa835392dfa2fea3b9741 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Sat, 1 Apr 2023 16:53:38 -0400 Subject: [PATCH 14/27] add test --- api/client/alpn_test.go | 115 +++++++++++++++++++++ api/client/client.go | 6 +- api/client/contextdialer.go | 3 - integration/helpers/instance.go | 53 +++++----- integration/proxy/proxy_helpers.go | 47 +++++++-- integration/proxy/proxy_test.go | 104 +++++++++++++++++-- lib/reversetunnel/agentpool.go | 1 + lib/service/service.go | 4 +- lib/srv/alpnproxy/common/protocols.go | 7 ++ lib/srv/alpnproxy/common/protocols_test.go | 11 ++ lib/srv/alpnproxy/proxy.go | 4 - 11 files changed, 299 insertions(+), 56 deletions(-) create mode 100644 api/client/alpn_test.go diff --git a/api/client/alpn_test.go b/api/client/alpn_test.go new file mode 100644 index 0000000000000..92c8ddaedc1b7 --- /dev/null +++ b/api/client/alpn_test.go @@ -0,0 +1,115 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestALPNDialer_getTLSConfig(t *testing.T) { + t.Parallel() + cas := x509.NewCertPool() + + tests := []struct { + name string + input ALPNDialerConfig + wantTLSConfig *tls.Config + wantError bool + }{ + { + name: "missing tls config", + input: ALPNDialerConfig{}, + wantError: true, + }, + { + name: "no update", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{ + ServerName: "example.com", + }, + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + }, + }, + { + name: "no update when upgrade required", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{ + ServerName: "example.com", + RootCAs: cas, + }, + ALPNConnUpgradeRequired: true, + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + RootCAs: cas, + }, + }, + { + name: "name updated", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{}, + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + }, + }, + { + name: "get cas failed", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{}, + ALPNConnUpgradeRequired: true, + GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { + return nil, trace.AccessDenied("fail it") + }, + }, + wantError: true, + }, + { + name: "cas updated", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{}, + ALPNConnUpgradeRequired: true, + GetClusterCAs: ClusterCAsFromCertPool(cas), + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + RootCAs: cas, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dialer := NewALPNDialer(test.input).(*ALPNDialer) + tlsConfig, err := dialer.getTLSConfig(context.Background(), "example.com:443") + if test.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.wantTLSConfig, tlsConfig) + } + }) + } +} diff --git a/api/client/client.go b/api/client/client.go index bbfabdc89efba..79f1901767c1d 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -327,7 +327,7 @@ type ( } ) -// authConnect connects to the Teleport Auth Server directly. +// authConnect connects to the Teleport Auth Server directly or through Proxy. func authConnect(ctx context.Context, params connectParams) (*Client, error) { dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery), @@ -352,8 +352,8 @@ func IsWebProxyAndConnUpgradeRequired(ctx context.Context, targetAddr string, cf } func isWebProxy(ctx context.Context, targetAddr string, cfg *Config) bool { - if cfg.WebProxyAddr != "" && cfg.WebProxyAddr == targetAddr { - return true + if cfg.WebProxyAddr != "" { + return cfg.WebProxyAddr == targetAddr } _, err := webclient.Find(&webclient.Config{ Context: ctx, diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 9099fa3ee6487..0265b71be2a7b 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -195,9 +195,6 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte }, ALPNConnUpgradeRequired: params.cfg.IsALPNConnUpgradeRequired(tunnelAddr, insecure), GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { - if len(params.cfg.Credentials) == 0 { - return nil, trace.BadParameter("no credentials") - } tlsConfig, err := params.cfg.Credentials[0].TLSConfig() if err != nil { return nil, trace.Wrap(err) diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index 7ab15f6d1128f..2d9a36361a8cb 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -1255,6 +1255,8 @@ type ClientConfig struct { Stderr io.Writer // Stdout overrides standard output for the session Stdout io.Writer + // ALBAddr is the address to a local server that simulates a layer 7 load balancer. + ALBAddr string } // NewClientWithCreds creates client with credentials @@ -1280,12 +1282,16 @@ func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.Te var webProxyAddr string var sshProxyAddr string - if cfg.Proxy == nil { - webProxyAddr = i.Web - sshProxyAddr = i.SSHProxy - } else { + switch { + case cfg.Proxy != nil: webProxyAddr = cfg.Proxy.WebAddr sshProxyAddr = cfg.Proxy.SSHAddr + case cfg.ALBAddr != "": + webProxyAddr = cfg.ALBAddr + sshProxyAddr = cfg.ALBAddr + default: + webProxyAddr = i.Web + sshProxyAddr = i.SSHProxy } fwdAgentMode := client.ForwardAgentNo @@ -1294,25 +1300,26 @@ func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.Te } cconf := &client.Config{ - Username: cfg.Login, - Host: cfg.Host, - HostPort: cfg.Port, - HostLogin: cfg.Login, - InsecureSkipVerify: true, - KeysDir: keyDir, - SiteName: cfg.Cluster, - ForwardAgent: fwdAgentMode, - Labels: cfg.Labels, - WebProxyAddr: webProxyAddr, - SSHProxyAddr: sshProxyAddr, - InteractiveCommand: cfg.Interactive, - TLSRoutingEnabled: i.IsSinglePortSetup, - Tracer: tracing.NoopProvider().Tracer("test"), - EnableEscapeSequences: cfg.EnableEscapeSequences, - Stderr: cfg.Stderr, - Stdin: cfg.Stdin, - Stdout: cfg.Stdout, - NonInteractive: true, + Username: cfg.Login, + Host: cfg.Host, + HostPort: cfg.Port, + HostLogin: cfg.Login, + InsecureSkipVerify: true, + KeysDir: keyDir, + SiteName: cfg.Cluster, + ForwardAgent: fwdAgentMode, + Labels: cfg.Labels, + WebProxyAddr: webProxyAddr, + SSHProxyAddr: sshProxyAddr, + InteractiveCommand: cfg.Interactive, + TLSRoutingEnabled: i.IsSinglePortSetup, + TLSRoutingConnUpgradeRequired: cfg.ALBAddr != "", + Tracer: tracing.NoopProvider().Tracer("test"), + EnableEscapeSequences: cfg.EnableEscapeSequences, + Stderr: cfg.Stderr, + Stdin: cfg.Stdin, + Stdout: cfg.Stdout, + NonInteractive: true, } // JumpHost turns on jump host mode diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index dbe210ba3e086..e6c753ec6d71e 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -33,6 +33,7 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jackc/pgconn" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" @@ -398,6 +399,25 @@ func withTrustedCluster() proxySuiteOptionsFunc { } } +func withTrustedClusterBehindALB() proxySuiteOptionsFunc { + return func(options *suiteOptions) { + originalSetup := options.updateRoleMappingFunc + + options.updateRoleMappingFunc = func(t *testing.T, suite *Suite) { + t.Helper() + + if originalSetup != nil { + originalSetup(t, suite) + } + require.NotNil(t, options.trustedCluster) + + albProxy := mustStartMockALBProxy(t, suite.root.Config.Proxy.WebAddr.Addr) + options.trustedCluster.SetProxyAddress(albProxy.Addr().String()) + options.trustedCluster.SetReverseTunnelAddress(albProxy.Addr().String()) + } + } +} + func mustRunPostgresQuery(t *testing.T, client *pgconn.PgConn) { result, err := client.Exec(context.Background(), "select 1").ReadAll() require.NoError(t, err) @@ -474,7 +494,7 @@ func mustCreateKubeConfigFile(t *testing.T, config clientcmdapi.Config) string { func mustCreateListener(t *testing.T) net.Listener { t.Helper() - listener, err := net.Listen("tcp", ":0") + listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { @@ -594,7 +614,7 @@ type mockAWSALBProxy struct { cert tls.Certificate } -func (m *mockAWSALBProxy) serve(ctx context.Context, t *testing.T) { +func (m *mockAWSALBProxy) serve(ctx context.Context) { for { select { case <-ctx.Done(): @@ -604,26 +624,33 @@ func (m *mockAWSALBProxy) serve(ctx context.Context, t *testing.T) { conn, err := m.Accept() if err != nil { - if utils.IsOKNetworkError(err) { - return - } - require.NoError(t, err) + logrus.WithError(err).Debugf("Failed to accept conn.") return } go func() { + defer conn.Close() + // Handshake with incoming client and drops ALPN. downstreamConn := tls.Server(conn, &tls.Config{ Certificates: []tls.Certificate{m.cert}, }) - require.NoError(t, downstreamConn.HandshakeContext(ctx)) + + // api.Client may try different connection methods. Just close the + // connection when something goes wrong. + if err := downstreamConn.HandshakeContext(ctx); err != nil { + logrus.WithError(err).Debugf("Failed to handshake.") + return + } // Make a connection to the proxy server with ALPN protos. upstreamConn, err := tls.Dial("tcp", m.proxyAddr, &tls.Config{ InsecureSkipVerify: true, }) - require.NoError(t, err) - + if err != nil { + logrus.WithError(err).Debugf("Failed to dial upstream.") + return + } utils.ProxyConn(ctx, downstreamConn, upstreamConn) }() } @@ -640,7 +667,7 @@ func mustStartMockALBProxy(t *testing.T, proxyAddr string) *mockAWSALBProxy { Listener: mustCreateListener(t), cert: mustCreateSelfSignedCert(t), } - go m.serve(ctx, t) + go m.serve(ctx) return m } diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index b4c82311c1174..b6ed57715f100 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -21,6 +21,7 @@ import ( "context" "crypto/x509" "crypto/x509/pkix" + "fmt" "net" "net/http" "net/http/httptest" @@ -69,6 +70,7 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { secondClusterPortSetup helpers.InstanceListenerSetupFunc disableALPNListenerOnRoot bool disableALPNListenerOnLeaf bool + testALPNConnUpgrade bool }{ { name: "StandardAndOnePortSetupMasterALPNDisabled", @@ -85,17 +87,20 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { name: "TwoClusterOnePortSetup", mainClusterPortSetup: helpers.SingleProxyPortSetup, secondClusterPortSetup: helpers.SingleProxyPortSetup, + testALPNConnUpgrade: true, }, { name: "OnePortAndStandardListenerSetupLeafALPNDisabled", mainClusterPortSetup: helpers.SingleProxyPortSetup, secondClusterPortSetup: helpers.StandardListenerSetup, disableALPNListenerOnLeaf: true, + testALPNConnUpgrade: true, }, { name: "OnePortAndStandardListenerSetup", mainClusterPortSetup: helpers.SingleProxyPortSetup, secondClusterPortSetup: helpers.StandardListenerSetup, + testALPNConnUpgrade: true, }, } @@ -132,6 +137,31 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { Host: helpers.Loopback, Port: helpers.Port(t, suite.leaf.SSH), }) + + if tc.testALPNConnUpgrade { + t.Run("ALPN conn upgrade", func(t *testing.T) { + // Make a mock ALB which points to the Teleport Proxy Service. + albProxy := mustStartMockALBProxy(t, suite.root.Config.Proxy.WebAddr.Addr) + + // Run command in root. + suite.mustConnectToClusterAndRunSSHCommand(t, helpers.ClientConfig{ + Login: username, + Cluster: suite.root.Secrets.SiteName, + Host: helpers.Loopback, + Port: helpers.Port(t, suite.root.SSH), + ALBAddr: albProxy.Addr().String(), + }) + + // Run command in leaf. + suite.mustConnectToClusterAndRunSSHCommand(t, helpers.ClientConfig{ + Login: username, + Cluster: suite.leaf.Secrets.SiteName, + Host: helpers.Loopback, + Port: helpers.Port(t, suite.leaf.SSH), + ALBAddr: albProxy.Addr().String(), + }) + }) + } }) } } @@ -144,6 +174,7 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { secondClusterListenerSetup helpers.InstanceListenerSetupFunc disableALPNListenerOnRoot bool disableALPNListenerOnLeaf bool + extraSuiteOptions []proxySuiteOptionsFunc }{ { name: "StandardAndOnePortSetupMasterALPNDisabled", @@ -172,6 +203,12 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { mainClusterListenerSetup: helpers.SingleProxyPortSetup, secondClusterListenerSetup: helpers.StandardListenerSetup, }, + { + name: "TrustedClusterBehindALB", + mainClusterListenerSetup: helpers.SingleProxyPortSetup, + secondClusterListenerSetup: helpers.SingleProxyPortSetup, + extraSuiteOptions: []proxySuiteOptionsFunc{withTrustedClusterBehindALB()}, + }, } for _, tc := range testCase { t.Run(tc.name, func(t *testing.T) { @@ -180,7 +217,7 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { username := helpers.MustGetCurrentUser(t).Username - suite := newSuite(t, + opts := []proxySuiteOptionsFunc{ withRootClusterConfig(rootClusterStandardConfig(t)), withLeafClusterConfig(leafClusterStandardConfig(t)), withRootClusterListeners(tc.mainClusterListenerSetup), @@ -189,7 +226,8 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { withLeafClusterRoles(newRole(t, "auxdevs", username)), withRootAndLeafTrustedClusterReset(), withTrustedCluster(), - ) + } + suite := newSuite(t, append(opts, tc.extraSuiteOptions...)...) nodeHostname := "clusterauxnode" suite.addNodeToLeafCluster(t, "clusterauxnode") @@ -1172,16 +1210,60 @@ func TestALPNProxyAuthClientConnectWithUserIdentity(t *testing.T) { identity := client.LoadIdentityFile(identityFilePath) require.NoError(t, err) - tc, err := client.New(context.Background(), client.Config{ - Addrs: []string{rc.Web}, - Credentials: []client.Credentials{identity}, - InsecureAddressDiscovery: true, - }) - require.NoError(t, err) + // Make a mock ALB which points to the Teleport Proxy Service. Then + // client can point to this ALB instead. + albProxy := mustStartMockALBProxy(t, rc.Web) - resp, err := tc.Ping(context.Background()) - require.NoError(t, err) - require.Equal(t, rc.Secrets.SiteName, resp.ClusterName) + connectTypes := []struct { + name string + clientConfig client.Config + }{ + { + name: "sync", + clientConfig: client.Config{ + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + }, + }, + { + name: "background", + clientConfig: client.Config{ + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + DialInBackground: true, + ALPNSNIAuthDialClusterName: cfg.ClusterName, + }, + }, + } + + targets := []struct { + name string + addr string + }{ + { + name: "web", + addr: rc.Web, + }, + { + name: "alb", + addr: albProxy.Addr().String(), + }, + } + + for _, target := range targets { + for _, connectType := range connectTypes { + t.Run(fmt.Sprintf("%v/%v", target.name, connectType.name), func(t *testing.T) { + connectType.clientConfig.Addrs = []string{target.addr} + + tc, err := client.New(context.Background(), connectType.clientConfig) + require.NoError(t, err) + + resp, err := tc.Ping(context.Background()) + require.NoError(t, err) + require.Equal(t, rc.Secrets.SiteName, resp.ClusterName) + }) + } + } } // TestALPNProxyDialProxySSHWithoutInsecureMode tests dialing to the localhost with teleport-proxy-ssh diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index ea89ee182ad95..bcf9254765c7d 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -717,6 +717,7 @@ func (c *agentPoolRuntimeConfig) updateRemote(ctx context.Context, addr *utils.N c.remoteTLSRoutingEnabled = tlsRoutingEnabled if c.remoteTLSRoutingEnabled { c.tlsRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(addr.Addr, lib.IsInsecureDevMode()) + logrus.Debugf("ALPN upgrade required for remote %v: %v", addr.Addr, c.tlsRoutingConnUpgradeRequired) } return nil } diff --git a/lib/service/service.go b/lib/service/service.go index 3c44e3fb129bb..b9b6ca48cd387 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4390,7 +4390,7 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers func setupTLSConfigALPNProtocols(tlsConfig *tls.Config) { // Go 1.17 introduced strict ALPN https://golang.org/doc/go1.17#ALPN If a client protocol is not recognized // the TLS handshake will fail. - tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, alpncommon.ProtocolsToString(alpncommon.SupportedProtocols)...)) + alpncommon.AddNextProtos(tlsConfig, alpncommon.SupportedProtocols...) } func setupTLSConfigClientCAsForCluster(tlsConfig *tls.Config, accessPoint auth.ReadProxyAccessPoint, clusterName string) { @@ -4491,7 +4491,7 @@ func setupALPNRouter(listeners *proxyListeners, serverTLSConfig *tls.Config, cfg router.Add(alpnproxy.HandlerDecs{ MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolProxySSH), Handler: sshProxyListener.HandleConnection, - TLSConfig: serverTLSConfig, + TLSConfig: alpncommon.AddNextProtos(serverTLSConfig.Clone(), alpncommon.ProtocolProxySSH), }) listeners.ssh = sshProxyListener diff --git a/lib/srv/alpnproxy/common/protocols.go b/lib/srv/alpnproxy/common/protocols.go index 2f2f93291a85f..4003fb59c3da9 100644 --- a/lib/srv/alpnproxy/common/protocols.go +++ b/lib/srv/alpnproxy/common/protocols.go @@ -17,6 +17,7 @@ limitations under the License. package common import ( + "crypto/tls" "strings" "github.com/gravitational/trace" @@ -142,6 +143,12 @@ func NextProtosWithPing(protocols ...Protocol) []string { return ProtocolsToString(WithPingProtocols(protocols)) } +// AddNextProtos adds ALPN protocols to the provided tls.Config. +func AddNextProtos(config *tls.Config, protocols ...Protocol) *tls.Config { + config.NextProtos = utils.Deduplicate(append(config.NextProtos, NextProtosWithPing(protocols...)...)) + return config +} + // ToALPNProtocol maps provided database protocol to ALPN protocol. func ToALPNProtocol(dbProtocol string) (Protocol, error) { switch dbProtocol { diff --git a/lib/srv/alpnproxy/common/protocols_test.go b/lib/srv/alpnproxy/common/protocols_test.go index 5873b05fdf520..614db9f14399a 100644 --- a/lib/srv/alpnproxy/common/protocols_test.go +++ b/lib/srv/alpnproxy/common/protocols_test.go @@ -17,6 +17,7 @@ limitations under the License. package common import ( + "crypto/tls" "testing" "github.com/stretchr/testify/require" @@ -49,6 +50,16 @@ func TestIsDBTLSProtocol(t *testing.T) { require.False(t, IsDBTLSProtocol("")) } +func TestAddNextProtos(t *testing.T) { + input := &tls.Config{ + NextProtos: []string{"proto1", "proto2"}, + } + want := &tls.Config{ + NextProtos: []string{"proto1", "proto2", "teleport-proxy-ssh-ping", "teleport-proxy-ssh"}, + } + require.Equal(t, want, AddNextProtos(input, ProtocolProxySSH)) +} + func BenchmarkNextProtosWithPing(b *testing.B) { b.Run("one with ping support", func(b *testing.B) { for n := 0; n < b.N; n++ { diff --git a/lib/srv/alpnproxy/proxy.go b/lib/srv/alpnproxy/proxy.go index 73df93271966b..1f81184d4f148 100644 --- a/lib/srv/alpnproxy/proxy.go +++ b/lib/srv/alpnproxy/proxy.go @@ -209,10 +209,6 @@ func (h *HandlerDecs) CheckAndSetDefaults() error { return trace.BadParameter("the ForwardTLS flag and TLSConfig can't be used at the same time") } - if h.TLSConfig != nil && len(h.TLSConfig.NextProtos) == 0 { - h.TLSConfig = h.TLSConfig.Clone() - h.TLSConfig.NextProtos = common.ProtocolsToString(common.SupportedProtocols) - } return nil } From 7a8a30952a5dd46a7604e33b12b475fcc5d518c0 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Sun, 2 Apr 2023 14:27:57 -0400 Subject: [PATCH 15/27] beautify round 2 --- api/client/alpn.go | 43 +++++++++++++++++------------- api/client/client.go | 12 ++++----- api/client/contextdialer.go | 40 +++++++++++++++++++++------ api/client/proxy.go | 23 ++-------------- integration/proxy/proxy_helpers.go | 3 +++ integration/proxy/proxy_test.go | 4 +-- lib/client/api.go | 8 +++--- 7 files changed, 75 insertions(+), 58 deletions(-) diff --git a/api/client/alpn.go b/api/client/alpn.go index 7e848d3a0a7b5..696996de866e6 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -76,40 +76,36 @@ func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer { } func (d *ALPNDialer) shouldUpdateTLSConfig() bool { - switch { - case d.cfg.TLSConfig.ServerName == "": - return true - case d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil: - return true - default: - return false - } + return d.shouldUpdateServerName() || d.shouldGetClusterCAs() +} +func (d *ALPNDialer) shouldUpdateServerName() bool { + return d.cfg.TLSConfig.ServerName == "" +} +func (d *ALPNDialer) shouldGetClusterCAs() bool { + return d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil } func (d *ALPNDialer) getTLSConfig(ctx context.Context, addr string) (*tls.Config, error) { if d.cfg.TLSConfig == nil { return nil, trace.BadParameter("missing TLS config") } - if !d.shouldUpdateTLSConfig() { return d.cfg.TLSConfig, nil } + var err error tlsConfig := d.cfg.TLSConfig.Clone() - if d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil { - rootCAs, err := d.cfg.GetClusterCAs(ctx) + if d.shouldGetClusterCAs() { + tlsConfig.RootCAs, err = d.cfg.GetClusterCAs(ctx) if err != nil { return nil, trace.Wrap(err) } - tlsConfig.RootCAs = rootCAs } - - if d.cfg.TLSConfig.ServerName == "" { - host, _, err := webclient.ParseHostPort(addr) + if d.shouldUpdateServerName() { + tlsConfig.ServerName, _, err = webclient.ParseHostPort(addr) if err != nil { return nil, trace.Wrap(err) } - tlsConfig.ServerName = host } return tlsConfig, nil } @@ -137,8 +133,8 @@ func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net return nil, trace.Wrap(err) } - if strings.HasSuffix(tlsConn.ConnectionState().NegotiatedProtocol, constants.ALPNSNIProtocolPingSuffix) { - logrus.Debug("Using ping connection") + if IsALPNPingProtocol(tlsConn.ConnectionState().NegotiatedProtocol) { + logrus.Debugf("Using ping connection for protocol %v.", tlsConn.ConnectionState().NegotiatedProtocol) return pingconn.New(tlsConn), nil } return tlsConn, nil @@ -149,3 +145,14 @@ func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (net.Conn, conn, err := NewALPNDialer(cfg).DialContext(ctx, "tcp", addr) return conn, trace.Wrap(err) } + +// ALPNSNIProtocolPingSuffix receives an ALPN protocol and returns it with the +// Ping protocol suffix. +func ALPNProtocolWithPing(protocol string) string { + return protocol + constants.ALPNSNIProtocolPingSuffix +} + +// IsALPNPingProtocol checks if the provided protocol is suffixed with Ping. +func IsALPNPingProtocol(protocol string) bool { + return strings.HasSuffix(protocol, constants.ALPNSNIProtocolPingSuffix) +} diff --git a/api/client/client.go b/api/client/client.go index 79f1901767c1d..26141514c375c 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -341,10 +341,10 @@ func authConnect(ctx context.Context, params connectParams) (*Client, error) { return clt, nil } -// IsWebProxyAndConnUpgradeRequired returns if targetAddr is a Teleport web -// proxy address and ALPN connection upgrade is required for it. If no cluster -// name is provided for ALPN, assume dialer is not trying to connect Auth -// through Proxy using TLS routing. +// IsWebProxyAndConnUpgradeRequired returns true if targetAddr is a Teleport +// web proxy address and ALPN connection upgrade is required for it. If no +// cluster name is provided for ALPN, assume dialer is not trying to connect +// Auth through Proxy using TLS Routing. func IsWebProxyAndConnUpgradeRequired(ctx context.Context, targetAddr string, cfg *Config) bool { return cfg.ALPNSNIAuthDialClusterName != "" && isWebProxy(ctx, targetAddr, cfg) && @@ -368,7 +368,7 @@ func tunnelConnect(ctx context.Context, params connectParams) (*Client, error) { if params.sshConfig == nil { return nil, trace.BadParameter("must provide ssh client config") } - dialer := newTunnelDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithTLSConfig(params.tlsConfig)) + dialer := newTunnelDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery)) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as a reverse tunnel proxy", params.addr) @@ -381,7 +381,7 @@ func proxyConnect(ctx context.Context, params connectParams) (*Client, error) { if params.sshConfig == nil { return nil, trace.BadParameter("must provide ssh client config") } - dialer := NewProxyDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, params.addr, params.cfg.InsecureAddressDiscovery, WithTLSConfig(params.tlsConfig)) + dialer := NewProxyDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, params.addr, params.cfg.InsecureAddressDiscovery, WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery)) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as a web proxy", params.addr) diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 0265b71be2a7b..7d6caa8d172e2 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -36,6 +36,30 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" ) +type dialConfig struct { + tlsConfig *tls.Config + alpnConnUpgradeRequired bool +} + +// WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy. +func WithInsecureSkipVerify(insecure bool) DialOption { + return func(cfg *dialProxyConfig) { + cfg.tlsConfig = &tls.Config{ + InsecureSkipVerify: insecure, + } + } +} + +// WithALPNConnUpgrade specifies if ALPN connection upgrade is required. +func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialOption { + return func(cfg *dialProxyConfig) { + cfg.alpnConnUpgradeRequired = alpnConnUpgradeRequired + } +} + +// DialOption allows setting options as functional arguments to api.NewDialer. +type DialOption func(cfg *dialConfig) + // ContextDialer represents network dialer interface that uses context type ContextDialer interface { // DialContext is a function that dials the specified address @@ -86,8 +110,8 @@ func tracedDialer(ctx context.Context, fn ContextDialerFunc) ContextDialerFunc { // NewDialer makes a new dialer that connects to an Auth server either directly or via an HTTP proxy, depending // on the environment. -func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, opts ...DialProxyOption) ContextDialer { - var cfg dialProxyConfig +func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, opts ...DialOption) ContextDialer { + var cfg dialConfig for _, opt := range opts { opt(&cfg) } @@ -98,12 +122,12 @@ func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, // Base direct dialer. var dialer ContextDialer = netDialer - // Wrap with proxy URL dialer if proxy URL is detected + // Wrap with proxy URL dialer if proxy URL is detected. if proxyURL := utils.GetProxyURL(addr); proxyURL != nil { dialer = newProxyURLDialer(proxyURL, netDialer, opts...) } - // Wrap with alpnConnUpgradeDialer if upgrade is required. + // Wrap with alpnConnUpgradeDialer if upgrade is required for TLS Routing. if cfg.alpnConnUpgradeRequired { dialer = newALPNConnUpgradeDialer(dialer, cfg.tlsConfig) } @@ -162,7 +186,7 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur // newTLSRoutingTunnelDialer makes a reverse tunnel TLS Routing dialer to connect to an Auth server // through the SSH reverse tunnel on the proxy. func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer { - return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { insecure := params.cfg.InsecureAddressDiscovery resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: params.addr, Insecure: insecure}) if err != nil { @@ -182,12 +206,12 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte if err != nil { return nil, trace.Wrap(err) } - tlsConn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ + conn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ DialTimeout: params.cfg.DialTimeout, KeepAlivePeriod: params.cfg.KeepAlivePeriod, TLSConfig: &tls.Config{ NextProtos: []string{ - constants.ALPNSNIProtocolReverseTunnel + constants.ALPNSNIProtocolPingSuffix, + ALPNProtocolWithPing(constants.ALPNSNIProtocolReverseTunnel), constants.ALPNSNIProtocolReverseTunnel, }, InsecureSkipVerify: insecure, @@ -206,7 +230,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte return nil, trace.Wrap(err) } - sconn, err := sshConnect(ctx, tlsConn, ssh, params.cfg.DialTimeout, tunnelAddr) + sconn, err := sshConnect(ctx, conn, ssh, params.cfg.DialTimeout, tunnelAddr) if err != nil { return nil, trace.Wrap(err) } diff --git a/api/client/proxy.go b/api/client/proxy.go index b546301760279..500217fd9941e 100644 --- a/api/client/proxy.go +++ b/api/client/proxy.go @@ -29,13 +29,10 @@ import ( "golang.org/x/net/proxy" ) -type dialProxyConfig struct { - tlsConfig *tls.Config - alpnConnUpgradeRequired bool -} +type dialProxyConfig = dialConfig // DialProxyOption allows setting options as functional arguments to DialProxy. -type DialProxyOption func(cfg *dialProxyConfig) +type DialProxyOption = DialOption // WithTLSConfig provides the dialer with the TLS config to use when using an // HTTPS proxy. @@ -45,22 +42,6 @@ func WithTLSConfig(tlsConfig *tls.Config) DialProxyOption { } } -// WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy. -func WithInsecureSkipVerify(insecure bool) DialProxyOption { - return func(cfg *dialProxyConfig) { - cfg.tlsConfig = &tls.Config{ - InsecureSkipVerify: insecure, - } - } -} - -// WithALPNConnUpgrade specifies if ALPN connection upgrade is required. -func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialProxyOption { - return func(cfg *dialProxyConfig) { - cfg.alpnConnUpgradeRequired = alpnConnUpgradeRequired - } -} - // DialProxy creates a connection to a server via an HTTP or SOCKS5 Proxy. func DialProxy(ctx context.Context, proxyURL *url.URL, addr string, opts ...DialProxyOption) (net.Conn, error) { return DialProxyWithDialer(ctx, proxyURL, addr, &net.Dialer{}, opts...) diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index e6c753ec6d71e..7e1dff01fbd0e 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -399,6 +399,9 @@ func withTrustedCluster() proxySuiteOptionsFunc { } } +// withTrustedClusterBehindALB creates a local server that simulates a layer 7 +// LB and puts it infront of the root cluster when the leaf connects through +// the reverse tunnel. func withTrustedClusterBehindALB() proxySuiteOptionsFunc { return func(options *suiteOptions) { originalSetup := options.updateRoleMappingFunc diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index b6ed57715f100..0c3fb02c5abd6 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -143,7 +143,7 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { // Make a mock ALB which points to the Teleport Proxy Service. albProxy := mustStartMockALBProxy(t, suite.root.Config.Proxy.WebAddr.Addr) - // Run command in root. + // Run command in root through ALB address. suite.mustConnectToClusterAndRunSSHCommand(t, helpers.ClientConfig{ Login: username, Cluster: suite.root.Secrets.SiteName, @@ -152,7 +152,7 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { ALBAddr: albProxy.Addr().String(), }) - // Run command in leaf. + // Run command in leaf through ALB address. suite.mustConnectToClusterAndRunSSHCommand(t, helpers.ClientConfig{ Login: username, Cluster: suite.leaf.Secrets.SiteName, diff --git a/lib/client/api.go b/lib/client/api.go index 813ec81ff0618..8848e57718dbb 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -4451,12 +4451,13 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste tlsConfig.NextProtos = []string{string(alpncommon.ProtocolProxyGRPCSecure), http2.NextProtoTLS} clt, err := client.New(ctx, client.Config{ - Addrs: []string{tc.Config.WebProxyAddr}, + WebProxyAddr: tc.Config.WebProxyAddr, DialInBackground: false, Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, - IsALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired, + ALPNSNIAuthDialClusterName: clusterName, + IsALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired, }) if err != nil { return nil, trace.Wrap(err) @@ -4464,7 +4465,8 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste return kubeproto.NewKubeServiceClient(clt.GetConnection()), nil } -// IsALPNConnUpgradeRequired returns true if connection upgrade is required for provided addr. +// IsALPNConnUpgradeRequired returns true if connection upgrade is required for +// provided addr. func (tc *TeleportClient) IsALPNConnUpgradeRequired(addr string, insecure bool) bool { // Use cached value. if addr == tc.WebProxyAddr { From 511b9d068a1a2779a15185b99193241cb657c85a Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Sun, 2 Apr 2023 15:08:39 -0400 Subject: [PATCH 16/27] fix timeout --- tool/tsh/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index a1ca2cd4312c2..07679f1ec7921 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -251,7 +251,7 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy }) default: - dialer = client.NewDialer(ctx, apidefaults.DefaultIOTimeout, apidefaults.DefaultIdleTimeout, client.WithInsecureSkipVerify(tc.InsecureSkipVerify)) + dialer = client.NewDialer(ctx, apidefaults.DefaultIdleTimeout, apidefaults.DefaultIOTimeout, client.WithInsecureSkipVerify(tc.InsecureSkipVerify)) } conn, err := dialer.DialContext(ctx, "tcp", remoteProxyAddr) From ff3c8699a830eafe364180f6ba80c711533c4e32 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 6 Apr 2023 12:22:15 -0400 Subject: [PATCH 17/27] Implement alpn-ping upgrade for reversetunnel and ssh --- api/client/alpn.go | 33 ++++++-- api/client/alpn_conn_upgrade.go | 34 ++++++--- api/client/alpn_conn_upgrade_test.go | 104 +++++++++++++++++--------- api/client/client.go | 21 +----- api/client/contextdialer.go | 24 ++++-- api/client/proxy/client.go | 28 ++++++- api/constants/constants.go | 10 +++ lib/client/api.go | 8 +- lib/reversetunnel/agentpool.go | 6 +- lib/reversetunnel/transport.go | 2 +- lib/service/service.go | 2 +- lib/service/service_test.go | 4 - lib/srv/alpnproxy/common/protocols.go | 2 - lib/web/conn_upgrade.go | 41 ++++++++++ lib/web/conn_upgrade_test.go | 59 ++++++++++----- tool/tsh/proxy.go | 2 +- 16 files changed, 267 insertions(+), 113 deletions(-) diff --git a/api/client/alpn.go b/api/client/alpn.go index 83e9f1099d52e..91eb8d0e8e794 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -81,6 +81,7 @@ func (d *ALPNDialer) shouldUpdateTLSConfig() bool { func (d *ALPNDialer) shouldUpdateServerName() bool { return d.cfg.TLSConfig.ServerName == "" } + func (d *ALPNDialer) shouldGetClusterCAs() bool { return d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil } @@ -95,12 +96,20 @@ func (d *ALPNDialer) getTLSConfig(ctx context.Context, addr string) (*tls.Config var err error tlsConfig := d.cfg.TLSConfig.Clone() + + // When Teleport Proxy is behind a L7 load balancer, the load balancer + // usually terminates TLS with public certs, and the Proxy is usually in + // private subnets with self-signed web certs. During the connection + // upgrade flow for TLS Routing, instead of serving these self-signed web + // certs, the TLS Routing handler at the Proxy server will present the + // Cluster CAs so clients here can still verify the server. if d.shouldGetClusterCAs() { tlsConfig.RootCAs, err = d.cfg.GetClusterCAs(ctx) if err != nil { return nil, trace.Wrap(err) } } + if d.shouldUpdateServerName() { tlsConfig.ServerName, _, err = webclient.ParseHostPort(addr) if err != nil { @@ -120,6 +129,7 @@ func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, WithInsecureSkipVerify(d.cfg.TLSConfig.InsecureSkipVerify), WithALPNConnUpgrade(d.cfg.ALPNConnUpgradeRequired), + WithALPNConnUpgradePing(shouldALPNConnUpgradeWithPing(tlsConfig)), ) conn, err := dialer.DialContext(ctx, network, addr) @@ -146,13 +156,24 @@ func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (net.Conn, return conn, trace.Wrap(err) } -// ALPNSNIProtocolPingSuffix receives an ALPN protocol and returns it with the -// Ping protocol suffix. -func ALPNProtocolWithPing(protocol string) string { - return protocol + constants.ALPNSNIProtocolPingSuffix -} - // IsALPNPingProtocol checks if the provided protocol is suffixed with Ping. func IsALPNPingProtocol(protocol string) bool { return strings.HasSuffix(protocol, constants.ALPNSNIProtocolPingSuffix) } + +// shouldALPNConnUpgradeWithPing returns true if Ping wrapper is required +// during connection upgrade. +func shouldALPNConnUpgradeWithPing(config *tls.Config) bool { + for _, proto := range config.NextProtos { + switch proto { + // Server usually sends SSH keepalives or HTTP2 pings every five + // minutes for reverse tunnel and SSH connections. Load balancers + // usually have a shorter idle timeout. Thus wrapping the connection + // with Ping protocol at the connection upgrade layer to keepalive. + case constants.ALPNSNIProtocolReverseTunnel, + constants.ALPNSNIProtocolSSH: + return true + } + } + return false +} diff --git a/api/client/alpn_conn_upgrade.go b/api/client/alpn_conn_upgrade.go index c82f91dd7ff63..9bf33ce4df224 100644 --- a/api/client/alpn_conn_upgrade.go +++ b/api/client/alpn_conn_upgrade.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/pingconn" ) // IsALPNConnUpgradeRequired returns true if a tunnel is required through a HTTP @@ -145,13 +146,15 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool { type alpnConnUpgradeDialer struct { dialer ContextDialer tlsConfig *tls.Config + withPing bool } // newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer. -func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config) ContextDialer { +func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withPing bool) ContextDialer { return &alpnConnUpgradeDialer{ dialer: dialer, tlsConfig: tlsConfig, + withPing: withPing, } } @@ -182,26 +185,34 @@ func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr st tlsConn := tls.Client(conn, cfg) - err = upgradeConnThroughWebAPI(tlsConn, url.URL{ + upgradeURL := url.URL{ Host: addr, Scheme: "https", Path: constants.WebAPIConnUpgrade, - }) - if err != nil { - defer tlsConn.Close() - return nil, trace.Wrap(err) } - return tlsConn, nil + + switch { + case d.withPing: + if err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, constants.WebAPIConnUpgradeTypeALPNPing); err != nil { + return nil, trace.NewAggregate(err, tlsConn.Close()) + } + return pingconn.New(tlsConn), nil + + default: + if err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, constants.WebAPIConnUpgradeTypeALPN); err != nil { + return nil, trace.NewAggregate(err, tlsConn.Close()) + } + return tlsConn, nil + } } -func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error { +func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) error { req, err := http.NewRequest(http.MethodGet, api.String(), nil) if err != nil { return trace.Wrap(err) } - // For now, only "alpn" is supported. - req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeALPN) + req.Header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) // Send the request and check if upgrade is successful. if err = req.Write(conn); err != nil { @@ -216,8 +227,9 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error { if http.StatusSwitchingProtocols != resp.StatusCode { if http.StatusNotFound == resp.StatusCode { return trace.NotImplemented( - "connection upgrade call to %q failed with status code %v. Please upgrade the server and try again.", + "connection upgrade call to %q with upgrade type %v failed with status code %v. Please upgrade the server and try again.", constants.WebAPIConnUpgrade, + upgradeType, resp.StatusCode, ) } diff --git a/api/client/alpn_conn_upgrade_test.go b/api/client/alpn_conn_upgrade_test.go index 4c02b3ebfebb6..a52a412134a8a 100644 --- a/api/client/alpn_conn_upgrade_test.go +++ b/api/client/alpn_conn_upgrade_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/fixtures" + "github.com/gravitational/teleport/api/utils/pingconn" ) func TestIsALPNConnUpgradeRequired(t *testing.T) { @@ -123,42 +124,58 @@ func TestIsALPNConnUpgradeRequiredByEnv(t *testing.T) { func TestALPNConnUpgradeDialer(t *testing.T) { t.Parallel() - t.Run("connection upgraded", func(t *testing.T) { - ctx := context.Background() - - server := httptest.NewTLSServer(mockConnUpgradeHandler(t, "alpn", []byte("hello"))) - t.Cleanup(server.Close) - addr, err := url.Parse(server.URL) - require.NoError(t, err) - pool := x509.NewCertPool() - pool.AddCert(server.Certificate()) - - tlsConfig := &tls.Config{RootCAs: pool} - preDialer := NewDialer(ctx, 0, 5*time.Second) - dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) - conn, err := dialer.DialContext(ctx, "tcp", addr.Host) - require.NoError(t, err) - - data := make([]byte, 100) - n, err := conn.Read(data) - require.NoError(t, err) - require.Equal(t, string(data[:n]), "hello") - }) - - t.Run("connection upgrade API not found", func(t *testing.T) { - ctx := context.Background() - - server := httptest.NewTLSServer(http.NotFoundHandler()) - t.Cleanup(server.Close) - addr, err := url.Parse(server.URL) - require.NoError(t, err) + tests := []struct { + name string + serverHandler http.Handler + withPing bool + wantError bool + }{ + { + name: "connection upgrade", + serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), + }, + { + name: "connection upgrade with ping", + serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), + withPing: true, + }, + { + name: "connection upgrade API not found", + serverHandler: http.NotFoundHandler(), + wantError: true, + }, + } - tlsConfig := &tls.Config{InsecureSkipVerify: true} - preDialer := NewDialer(ctx, 0, 5*time.Second) - dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) - _, err = dialer.DialContext(ctx, "tcp", addr.Host) - require.Error(t, err) - }) + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server := httptest.NewTLSServer(test.serverHandler) + t.Cleanup(server.Close) + addr, err := url.Parse(server.URL) + require.NoError(t, err) + pool := x509.NewCertPool() + pool.AddCert(server.Certificate()) + + tlsConfig := &tls.Config{RootCAs: pool} + preDialer := newDirectDialer(0, 5*time.Second) + dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig, test.withPing) + conn, err := dialer.DialContext(ctx, "tcp", addr.Host) + if test.wantError { + require.Error(t, err) + return + } + require.NoError(t, err) + defer conn.Close() + + data := make([]byte, 100) + n, err := conn.Read(data) + require.NoError(t, err) + require.Equal(t, string(data[:n]), "hello") + }) + } } type mockALPNServer struct { @@ -218,6 +235,8 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe // mockConnUpgradeHandler mocks the server side implementation to handle an // upgrade request and sends back some data inside the tunnel. func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { + t.Helper() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader)) @@ -238,7 +257,18 @@ func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http require.NoError(t, response.Write(conn)) // Upgraded. - _, err = conn.Write(write) - require.NoError(t, err) + switch upgradeType { + case constants.WebAPIConnUpgradeTypeALPNPing: + // Wrap conn with Ping and write some pings. + pingConn := pingconn.New(conn) + pingConn.WritePing() + _, err = pingConn.Write(write) + require.NoError(t, err) + pingConn.WritePing() + + default: + _, err = conn.Write(write) + require.NoError(t, err) + } }) } diff --git a/api/client/client.go b/api/client/client.go index 2333ad7dcab50..fd8b374d12f90 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -44,7 +44,6 @@ import ( "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client/okta" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" @@ -345,22 +344,10 @@ func authConnect(ctx context.Context, params connectParams) (*Client, error) { // Auth through Proxy using TLS Routing. func IsWebProxyAndConnUpgradeRequired(ctx context.Context, targetAddr string, cfg *Config) bool { return cfg.ALPNSNIAuthDialClusterName != "" && - isWebProxy(ctx, targetAddr, cfg) && + cfg.WebProxyAddr == targetAddr && cfg.IsALPNConnUpgradeRequired(targetAddr, cfg.InsecureAddressDiscovery) } -func isWebProxy(ctx context.Context, targetAddr string, cfg *Config) bool { - if cfg.WebProxyAddr != "" { - return cfg.WebProxyAddr == targetAddr - } - _, err := webclient.Find(&webclient.Config{ - Context: ctx, - ProxyAddr: targetAddr, - Insecure: cfg.InsecureAddressDiscovery, - }) - return err == nil -} - // tunnelConnect connects to the Teleport Auth Server through the proxy's reverse tunnel. func tunnelConnect(ctx context.Context, params connectParams) (*Client, error) { if params.sshConfig == nil { @@ -520,9 +507,7 @@ func (c *Client) waitForConnectionReady(ctx context.Context) error { type Config struct { // Addrs is a list of teleport auth/proxy server addresses to dial. Addrs []string - // WebProxyAddr is the Teleport Proxy web address. If not provided, extra - // webapi pings may be required to find out if Addrs are web proxy - // addresses. + // WebProxyAddr is the Teleport Proxy web address. WebProxyAddr string // Credentials are a list of credentials to use when attempting // to connect to the server. @@ -556,7 +541,7 @@ type Config struct { // Context is the base context to use for dialing. If not provided context.Background is used Context context.Context // IsALPNConnUpgradeRequired is a callback function to check whether - // connection upgrade is required for TLS routing. + // connection upgrade is required for TLS Routing. IsALPNConnUpgradeRequired func(addr string, insecure bool) bool } diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 7d6caa8d172e2..7c3b976a42f27 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -39,6 +39,7 @@ import ( type dialConfig struct { tlsConfig *tls.Config alpnConnUpgradeRequired bool + alpnConnUpgradeWithPing bool } // WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy. @@ -57,6 +58,13 @@ func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialOption { } } +// WithALPNConnUpgradePing specifies if Ping is required during ALPN connection upgrade. +func WithALPNConnUpgradePing(alpnConnUpgradeWithPing bool) DialOption { + return func(cfg *dialProxyConfig) { + cfg.alpnConnUpgradeWithPing = alpnConnUpgradeWithPing + } +} + // DialOption allows setting options as functional arguments to api.NewDialer. type DialOption func(cfg *dialConfig) @@ -129,7 +137,7 @@ func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, // Wrap with alpnConnUpgradeDialer if upgrade is required for TLS Routing. if cfg.alpnConnUpgradeRequired { - dialer = newALPNConnUpgradeDialer(dialer, cfg.tlsConfig) + dialer = newALPNConnUpgradeDialer(dialer, cfg.tlsConfig, cfg.alpnConnUpgradeWithPing) } // Dial. @@ -161,6 +169,15 @@ func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dura }) } +// GRPCContextDialer converts a ContextDialer to a function used for +// grpc.WithContextDialer. +func GRPCContextDialer(dialer ContextDialer) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, "tcp", addr) + return conn, trace.Wrap(err) + } +} + // newTunnelDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy. func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, opts ...DialProxyOption) ContextDialer { dialer := newDirectDialer(keepAlivePeriod, dialTimeout) @@ -210,10 +227,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte DialTimeout: params.cfg.DialTimeout, KeepAlivePeriod: params.cfg.KeepAlivePeriod, TLSConfig: &tls.Config{ - NextProtos: []string{ - ALPNProtocolWithPing(constants.ALPNSNIProtocolReverseTunnel), - constants.ALPNSNIProtocolReverseTunnel, - }, + NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, InsecureSkipVerify: insecure, ServerName: host, }, diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index b0299c8fa97bf..b20194c0916ae 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -25,6 +25,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -82,6 +83,11 @@ type ClientConfig struct { DialTimeout time.Duration // DialOpts define options for dialing the client connection. DialOpts []grpc.DialOption + // IsALPNConnUpgradeRequired is a callback function to check whether + // connection upgrade is required for TLS Routing. + IsALPNConnUpgradeRequired func(addr string, insecure bool) bool + // InsecureSkipVerify is an option to skip HTTPS cert check + InsecureSkipVerify bool // The below items are intended to be used by tests to connect without mTLS. // The gRPC transport credentials to use when establishing the connection to proxy. @@ -108,6 +114,9 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.DialTimeout <= 0 { c.DialTimeout = defaults.DefaultIOTimeout } + if c.IsALPNConnUpgradeRequired == nil { + c.IsALPNConnUpgradeRequired = client.IsALPNConnUpgradeRequired + } if c.TLSConfig != nil { c.clientCreds = func() client.Credentials { @@ -209,6 +218,7 @@ func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { if err == nil { return clt, nil } + logrus.Debugf("==== falling back") } clt, sshErr := newSSHClient(ctx, &cfg) @@ -298,6 +308,7 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error dialCtx, cfg.ProxySSHAddress, append(cfg.DialOpts, + grpc.WithContextDialer(newDialerForGRPCClient(ctx, cfg)), grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}), grpc.WithChainUnaryInterceptor( append(cfg.UnaryInterceptors, @@ -336,6 +347,20 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error }, nil } +func newDialerForGRPCClient(ctx context.Context, cfg *ClientConfig) func(context.Context, string) (net.Conn, error) { + dialOpts := []client.DialOption{ + client.WithInsecureSkipVerify(cfg.InsecureSkipVerify), + } + + if cfg.TLSRoutingEnabled { + dialOpts = append(dialOpts, + client.WithALPNConnUpgrade(cfg.IsALPNConnUpgradeRequired(cfg.ProxyWebAddress, cfg.InsecureSkipVerify)), + client.WithALPNConnUpgradePing(true), + ) + } + return client.GRPCContextDialer(client.NewDialer(ctx, defaults.DefaultIdleTimeout, cfg.DialTimeout, dialOpts...)) +} + // teleportAuthority is the extension set by the server // which contains the name of the cluster it is in. const teleportAuthority = "x-teleport-authority" @@ -441,10 +466,11 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config case c.cfg.TLSRoutingEnabled: return client.Config{ Context: ctx, - Addrs: []string{c.cfg.ProxyWebAddress}, + WebProxyAddr: c.cfg.ProxyWebAddress, Credentials: []client.Credentials{c.cfg.clientCreds()}, ALPNSNIAuthDialClusterName: cluster, CircuitBreakerConfig: breaker.NoopBreakerConfig(), + IsALPNConnUpgradeRequired: c.cfg.IsALPNConnUpgradeRequired, } case c.sshClient != nil: return client.Config{ diff --git a/api/constants/constants.go b/api/constants/constants.go index 2f06e321c0fa5..25f13d5e5dba7 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -319,6 +319,8 @@ const ( ALPNSNIAuthProtocol = "teleport-auth@" // ALPNSNIProtocolReverseTunnel is TLS ALPN protocol value used to indicate Proxy reversetunnel protocol. ALPNSNIProtocolReverseTunnel = "teleport-reversetunnel" + // ALPNSNIProtocolSSH is the TLS ALPN protocol value used to indicate Proxy SSH protocol. + ALPNSNIProtocolSSH = "teleport-proxy-ssh" // ALPNSNIProtocolPingSuffix is TLS ALPN suffix used to wrap connections with Ping. ALPNSNIProtocolPingSuffix = "-ping" ) @@ -415,4 +417,12 @@ const ( // WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies // the upgraded connection should be handled by the ALPN handler. WebAPIConnUpgradeTypeALPN = "alpn" + // WebAPIConnUpgradeTypeALPNPing is a connection upgrade type that + // specifies the upgraded connection should be handled by the ALPN handler + // wrapped with the Ping protocol. + // + // This should be used when the tunnelled TLS Routing protocol cannot keep + // long-lived connections alive as L7 LB usually ignores TCP keepalives and + // has very short idle timeouts. + WebAPIConnUpgradeTypeALPNPing = "alpn-ping" ) diff --git a/lib/client/api.go b/lib/client/api.go index 24868b1d41b2e..8bf4c277d817f 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2789,7 +2789,9 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient, clt, err := makeProxySSHClient(ctx, tc, config) return clt, trace.Wrap(err) }), - SSHConfig: cfg.ClientConfig, + SSHConfig: cfg.ClientConfig, + IsALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired, + InsecureSkipVerify: tc.InsecureSkipVerify, }) if err != nil { return nil, trace.Wrap(err) @@ -4587,13 +4589,11 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste tlsConfig.NextProtos = []string{string(alpncommon.ProtocolProxyGRPCSecure), http2.NextProtoTLS} clt, err := client.New(ctx, client.Config{ - WebProxyAddr: tc.Config.WebProxyAddr, + Addrs: []string{tc.Config.WebProxyAddr}, DialInBackground: false, Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, - ALPNSNIAuthDialClusterName: clusterName, - IsALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index bcf9254765c7d..f2f56ff3965cb 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -636,14 +636,14 @@ func (c *agentPoolRuntimeConfig) alpnDialerConfig(getClusterCAs client.GetCluste c.mu.RLock() defer c.mu.RUnlock() - nextProtos := alpncommon.NextProtosWithPing(alpncommon.ProtocolReverseTunnel) + protocols := []alpncommon.Protocol{alpncommon.ProtocolReverseTunnel} if !c.isRemoteCluster && c.tunnelStrategyType == types.ProxyPeering { - nextProtos = append([]string{string(alpncommon.ProtocolReverseTunnelV2)}, nextProtos...) + protocols = []alpncommon.Protocol{alpncommon.ProtocolReverseTunnelV2, alpncommon.ProtocolReverseTunnel} } return client.ALPNDialerConfig{ TLSConfig: &tls.Config{ - NextProtos: nextProtos, + NextProtos: alpncommon.ProtocolsToString(protocols), InsecureSkipVerify: lib.IsInsecureDevMode(), }, KeepAlivePeriod: c.keepAliveInterval, diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 5f5db7dbb28fd..a4ea3927c171b 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -100,7 +100,7 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co if mode == types.ProxyListenerMode_Multiplex { opts = append(opts, proxy.WithALPNDialer(client.ALPNDialerConfig{ TLSConfig: &tls.Config{ - NextProtos: alpncommon.NextProtosWithPing(alpncommon.ProtocolReverseTunnel), + NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, InsecureSkipVerify: t.InsecureSkipTLSVerify, }, DialTimeout: t.ClientConfig.Timeout, diff --git a/lib/service/service.go b/lib/service/service.go index 6f9e5d8f02036..389cee3b11dcc 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4559,7 +4559,7 @@ func setupALPNRouter(listeners *proxyListeners, serverTLSConfig *tls.Config, cfg router.Add(alpnproxy.HandlerDecs{ MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolProxySSH), Handler: sshProxyListener.HandleConnection, - TLSConfig: alpncommon.AddNextProtos(serverTLSConfig.Clone(), alpncommon.ProtocolProxySSH), + TLSConfig: serverTLSConfig.Clone(), }) listeners.ssh = sshProxyListener diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 5fccc80f6a6de..e3ba427252ef3 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -486,8 +486,6 @@ func TestSetupProxyTLSConfig(t *testing.T) { "http/1.1", "h2", "acme-tls/1", - "teleport-proxy-ssh-ping", - "teleport-reversetunnel-ping", "teleport-tcp-ping", "teleport-postgres-ping", "teleport-mysql-ping", @@ -522,8 +520,6 @@ func TestSetupProxyTLSConfig(t *testing.T) { name: "ACME disabled", acmeEnabled: false, wantNextProtos: []string{ - "teleport-proxy-ssh-ping", - "teleport-reversetunnel-ping", "teleport-tcp-ping", "teleport-postgres-ping", "teleport-mysql-ping", diff --git a/lib/srv/alpnproxy/common/protocols.go b/lib/srv/alpnproxy/common/protocols.go index 55cb7ea1ef8f3..bbf358653843e 100644 --- a/lib/srv/alpnproxy/common/protocols.go +++ b/lib/srv/alpnproxy/common/protocols.go @@ -221,8 +221,6 @@ var DatabaseProtocols = []Protocol{ var ProtocolsWithPingSupport = append( DatabaseProtocols, ProtocolTCP, - ProtocolReverseTunnel, - ProtocolProxySSH, ) // WithPingProtocols adds Ping protocols to the list for each protocol that diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index cdfe3dae0556a..9ecf98948dfb3 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -21,11 +21,14 @@ import ( "io" "net" "net/http" + "time" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" ) @@ -35,6 +38,8 @@ func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHa upgrades := r.Header.Values(constants.WebAPIConnUpgradeHeader) for _, upgradeType := range upgrades { switch upgradeType { + case constants.WebAPIConnUpgradeTypeALPNPing: + return upgradeType, h.upgradeALPNWithPing, nil case constants.WebAPIConnUpgradeTypeALPN: return upgradeType, h.upgradeALPN, nil } @@ -88,6 +93,42 @@ func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error { return h.cfg.ALPNHandler(ctx, waitConn) } +func (h *Handler) upgradeALPNWithPing(ctx context.Context, conn net.Conn) error { + h.log.Errorf("=== with ping") + if h.cfg.ALPNHandler == nil { + return trace.BadParameter("missing ALPNHandler") + } + + pingConn := pingconn.New(conn) + + // Cancel ping background goroutine when connection is closed. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go h.startPing(ctx, pingConn) + + return h.upgradeALPN(ctx, pingConn) +} + +func (h *Handler) startPing(ctx context.Context, pingConn *pingconn.PingConn) { + ticker := time.NewTicker(defaults.ProxyPingInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + h.log.Errorf("=== stopped") + return + case <-ticker.C: + err := pingConn.WritePing() + if err != nil { + if !utils.IsOKNetworkError(err) { + h.log.WithError(err).Warn("Failed to write ping message") + } + return + } + } + } +} + func writeUpgradeResponse(w io.Writer, upgradeType string) error { header := make(http.Header) header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) diff --git a/lib/web/conn_upgrade_test.go b/lib/web/conn_upgrade_test.go index c7234e56af516..3705d5e23138c 100644 --- a/lib/web/conn_upgrade_test.go +++ b/lib/web/conn_upgrade_test.go @@ -29,6 +29,9 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" ) func TestWriteUpgradeResponse(t *testing.T) { @@ -86,35 +89,53 @@ func TestHandlerConnectionUpgrade(t *testing.T) { defer serverConn.Close() defer clientConn.Close() - r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil) - require.NoError(t, err) - r.Header.Add("Upgrade", "alpn") - - go func() { - // serverConn will be hijacked. - w := newResponseWriterHijacker(nil, serverConn) - _, err := h.connectionUpgrade(w, r, nil) - require.NoError(t, err) - }() + sendConnUpgradeRequest(t, h, constants.WebAPIConnUpgradeTypeALPN, serverConn, clientConn) - // Verify clientConn receives http.StatusSwitchingProtocols. - clientConnReader := bufio.NewReader(clientConn) - resp, err := http.ReadResponse(clientConnReader, r) + // Verify clientConn receives data sent by Config.ALPNHandler. + receive, err := bufio.NewReader(clientConn).ReadString(byte('@')) require.NoError(t, err) + require.Equal(t, expectedPayload, receive) + }) - // Always drain/close the body. - io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() + t.Run("upgraded to ALPN with Ping", func(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() - require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + sendConnUpgradeRequest(t, h, constants.WebAPIConnUpgradeTypeALPNPing, serverConn, clientConn) - // Verify clientConn receives data sent by Config.ALPNHandler. - receive, err := clientConnReader.ReadString(byte('@')) + // Verify ping-wrapped clientConn receives data sent by Config.ALPNHandler. + receive, err := bufio.NewReader(pingconn.New(clientConn)).ReadString(byte('@')) require.NoError(t, err) require.Equal(t, expectedPayload, receive) }) } +func sendConnUpgradeRequest(t *testing.T, h *Handler, upgradeType string, serverConn, clientConn net.Conn) { + t.Helper() + + r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil) + require.NoError(t, err) + r.Header.Add("Upgrade", upgradeType) + + go func() { + // serverConn will be hijacked. + w := newResponseWriterHijacker(nil, serverConn) + _, err := h.connectionUpgrade(w, r, nil) + require.NoError(t, err) + }() + + // Verify clientConn receives http.StatusSwitchingProtocols. + resp, err := http.ReadResponse(bufio.NewReader(clientConn), r) + require.NoError(t, err) + + // Always drain/close the body. + io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) +} + // responseWriterHijacker is a mock http.ResponseWriter that also serves a // net.Conn for http.Hijacker. type responseWriterHijacker struct { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index ae5e4e92eb4b7..0eff8383d1154 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -246,7 +246,7 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy dialer = client.NewALPNDialer(client.ALPNDialerConfig{ TLSConfig: &tls.Config{ RootCAs: pool, - NextProtos: alpncommon.NextProtosWithPing(alpncommon.ProtocolProxySSH), + NextProtos: []string{string(alpncommon.ProtocolProxySSH)}, InsecureSkipVerify: tc.InsecureSkipVerify, ServerName: sp.proxyHost, }, From 922fc8bfb4046f9617bed2ed007e57a53a7c40ca Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 6 Apr 2023 15:04:11 -0400 Subject: [PATCH 18/27] clean up --- api/client/proxy/client.go | 2 -- api/constants/constants.go | 2 +- integration/proxy/proxy_test.go | 2 +- lib/client/api.go | 2 +- lib/service/service.go | 4 +-- lib/srv/alpnproxy/common/protocols.go | 13 -------- lib/srv/alpnproxy/common/protocols_test.go | 35 ++-------------------- lib/srv/alpnproxy/proxy.go | 1 - lib/web/conn_upgrade.go | 3 +- 9 files changed, 9 insertions(+), 55 deletions(-) diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index b20194c0916ae..08c0d9cc179f4 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -25,7 +25,6 @@ import ( "time" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -218,7 +217,6 @@ func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { if err == nil { return clt, nil } - logrus.Debugf("==== falling back") } clt, sshErr := newSSHClient(ctx, &cfg) diff --git a/api/constants/constants.go b/api/constants/constants.go index 25f13d5e5dba7..1d58d56ba69ad 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -421,7 +421,7 @@ const ( // specifies the upgraded connection should be handled by the ALPN handler // wrapped with the Ping protocol. // - // This should be used when the tunnelled TLS Routing protocol cannot keep + // This should be used when the tunneled TLS Routing protocol cannot keep // long-lived connections alive as L7 LB usually ignores TCP keepalives and // has very short idle timeouts. WebAPIConnUpgradeTypeALPNPing = "alpn-ping" diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 0c3fb02c5abd6..c4fc82925aa5e 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -1253,7 +1253,7 @@ func TestALPNProxyAuthClientConnectWithUserIdentity(t *testing.T) { for _, target := range targets { for _, connectType := range connectTypes { t.Run(fmt.Sprintf("%v/%v", target.name, connectType.name), func(t *testing.T) { - connectType.clientConfig.Addrs = []string{target.addr} + connectType.clientConfig.WebProxyAddr = target.addr tc, err := client.New(context.Background(), connectType.clientConfig) require.NoError(t, err) diff --git a/lib/client/api.go b/lib/client/api.go index 8bf4c277d817f..7a6fd351d61f8 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -3038,7 +3038,7 @@ func makeProxySSHClientWithTLSWrapper(ctx context.Context, tc *TeleportClient, s return nil, trace.Wrap(err) } - tlsConfig.NextProtos = alpncommon.NextProtosWithPing(alpncommon.ProtocolProxySSH) + tlsConfig.NextProtos = []string{string(alpncommon.ProtocolProxySSH)} dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(client.ALPNDialerConfig{ TLSConfig: tlsConfig, ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify), diff --git a/lib/service/service.go b/lib/service/service.go index 389cee3b11dcc..5a0c1a6e2a760 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4458,7 +4458,7 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers func setupTLSConfigALPNProtocols(tlsConfig *tls.Config) { // Go 1.17 introduced strict ALPN https://golang.org/doc/go1.17#ALPN If a client protocol is not recognized // the TLS handshake will fail. - alpncommon.AddNextProtos(tlsConfig, alpncommon.SupportedProtocols...) + tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, alpncommon.ProtocolsToString(alpncommon.SupportedProtocols)...)) } func setupTLSConfigClientCAsForCluster(tlsConfig *tls.Config, accessPoint auth.ReadProxyAccessPoint, clusterName string) { @@ -4559,7 +4559,7 @@ func setupALPNRouter(listeners *proxyListeners, serverTLSConfig *tls.Config, cfg router.Add(alpnproxy.HandlerDecs{ MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolProxySSH), Handler: sshProxyListener.HandleConnection, - TLSConfig: serverTLSConfig.Clone(), + TLSConfig: serverTLSConfig, }) listeners.ssh = sshProxyListener diff --git a/lib/srv/alpnproxy/common/protocols.go b/lib/srv/alpnproxy/common/protocols.go index bbf358653843e..40d3c91ad35b5 100644 --- a/lib/srv/alpnproxy/common/protocols.go +++ b/lib/srv/alpnproxy/common/protocols.go @@ -17,7 +17,6 @@ limitations under the License. package common import ( - "crypto/tls" "strings" "github.com/gravitational/trace" @@ -140,18 +139,6 @@ func ProtocolsToString(protocols []Protocol) []string { return out } -// NextProtosWithPing adds Ping protocols to provided list of ALPN protocols -// then converts them to a list of strings for tls.Config.NextProtos. -func NextProtosWithPing(protocols ...Protocol) []string { - return ProtocolsToString(WithPingProtocols(protocols)) -} - -// AddNextProtos adds ALPN protocols to the provided tls.Config. -func AddNextProtos(config *tls.Config, protocols ...Protocol) *tls.Config { - config.NextProtos = utils.Deduplicate(append(config.NextProtos, NextProtosWithPing(protocols...)...)) - return config -} - // ToALPNProtocol maps provided database protocol to ALPN protocol. func ToALPNProtocol(dbProtocol string) (Protocol, error) { switch dbProtocol { diff --git a/lib/srv/alpnproxy/common/protocols_test.go b/lib/srv/alpnproxy/common/protocols_test.go index 614db9f14399a..2268008eaa3e4 100644 --- a/lib/srv/alpnproxy/common/protocols_test.go +++ b/lib/srv/alpnproxy/common/protocols_test.go @@ -17,7 +17,6 @@ limitations under the License. package common import ( - "crypto/tls" "testing" "github.com/stretchr/testify/require" @@ -28,16 +27,16 @@ func TestWithPingProtocols(t *testing.T) { []Protocol{ "teleport-tcp-ping", "teleport-redis-ping", - "teleport-auth@", + "teleport-reversetunnel", "teleport-tcp", "teleport-redis", "h2", }, WithPingProtocols([]Protocol{ - ProtocolAuth, + ProtocolReverseTunnel, ProtocolTCP, ProtocolRedisDB, - ProtocolAuth, + ProtocolReverseTunnel, ProtocolHTTP2, }), ) @@ -49,31 +48,3 @@ func TestIsDBTLSProtocol(t *testing.T) { require.False(t, IsDBTLSProtocol("teleport-tcp")) require.False(t, IsDBTLSProtocol("")) } - -func TestAddNextProtos(t *testing.T) { - input := &tls.Config{ - NextProtos: []string{"proto1", "proto2"}, - } - want := &tls.Config{ - NextProtos: []string{"proto1", "proto2", "teleport-proxy-ssh-ping", "teleport-proxy-ssh"}, - } - require.Equal(t, want, AddNextProtos(input, ProtocolProxySSH)) -} - -func BenchmarkNextProtosWithPing(b *testing.B) { - b.Run("one with ping support", func(b *testing.B) { - for n := 0; n < b.N; n++ { - NextProtosWithPing(ProtocolReverseTunnel) - } - }) - b.Run("one without ping support", func(b *testing.B) { - for n := 0; n < b.N; n++ { - NextProtosWithPing(ProtocolHTTP) - } - }) - b.Run("five", func(b *testing.B) { - for n := 0; n < b.N; n++ { - NextProtosWithPing(ProtocolAuth, ProtocolTCP, ProtocolRedisDB, ProtocolAuth, ProtocolHTTP2) - } - }) -} diff --git a/lib/srv/alpnproxy/proxy.go b/lib/srv/alpnproxy/proxy.go index 3ce9efdd8f578..6cf22111841f9 100644 --- a/lib/srv/alpnproxy/proxy.go +++ b/lib/srv/alpnproxy/proxy.go @@ -208,7 +208,6 @@ func (h *HandlerDecs) CheckAndSetDefaults() error { if h.ForwardTLS && h.TLSConfig != nil { return trace.BadParameter("the ForwardTLS flag and TLSConfig can't be used at the same time") } - return nil } diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index 9ecf98948dfb3..8324aabdd5f2b 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -94,7 +94,6 @@ func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error { } func (h *Handler) upgradeALPNWithPing(ctx context.Context, conn net.Conn) error { - h.log.Errorf("=== with ping") if h.cfg.ALPNHandler == nil { return trace.BadParameter("missing ALPNHandler") } @@ -105,6 +104,7 @@ func (h *Handler) upgradeALPNWithPing(ctx context.Context, conn net.Conn) error ctx, cancel := context.WithCancel(ctx) defer cancel() go h.startPing(ctx, pingConn) + time.Sleep(time.Millisecond * 100) return h.upgradeALPN(ctx, pingConn) } @@ -115,7 +115,6 @@ func (h *Handler) startPing(ctx context.Context, pingConn *pingconn.PingConn) { for { select { case <-ctx.Done(): - h.log.Errorf("=== stopped") return case <-ticker.C: err := pingConn.WritePing() From 5c576c0fba0a375c97e3b3450d34ca1e61c87468 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Fri, 7 Apr 2023 12:44:07 -0400 Subject: [PATCH 19/27] fix proxy test --- api/client/proxy/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 08c0d9cc179f4..262dee0260a17 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -305,7 +305,7 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error conn, err := grpc.DialContext( dialCtx, cfg.ProxySSHAddress, - append(cfg.DialOpts, + append([]grpc.DialOption{ grpc.WithContextDialer(newDialerForGRPCClient(ctx, cfg)), grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}), grpc.WithChainUnaryInterceptor( @@ -320,7 +320,7 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error metadata.StreamClientInterceptor, )..., ), - )..., + }, cfg.DialOpts...)..., ) if err != nil { return nil, trace.Wrap(err) From a3e47355a0f3344aba5037bcdf7f0c63d309af68 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Wed, 12 Apr 2023 12:46:07 -0400 Subject: [PATCH 20/27] minor refactor --- api/client/alpn_conn_upgrade.go | 42 +++++++++++++++++---------------- lib/web/conn_upgrade.go | 1 - 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/api/client/alpn_conn_upgrade.go b/api/client/alpn_conn_upgrade.go index 9bf33ce4df224..c96dec445b0cf 100644 --- a/api/client/alpn_conn_upgrade.go +++ b/api/client/alpn_conn_upgrade.go @@ -159,7 +159,7 @@ func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withP } // DialContext implements ContextDialer -func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { +func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { logrus.Debugf("ALPN connection upgrade for %v.", addr) conn, err := d.dialer.DialContext(ctx, network, addr) @@ -184,56 +184,58 @@ func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr st } tlsConn := tls.Client(conn, cfg) - upgradeURL := url.URL{ Host: addr, Scheme: "https", Path: constants.WebAPIConnUpgrade, } - switch { - case d.withPing: - if err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, constants.WebAPIConnUpgradeTypeALPNPing); err != nil { - return nil, trace.NewAggregate(err, tlsConn.Close()) - } - return pingconn.New(tlsConn), nil + conn, err = upgradeConnThroughWebAPI(tlsConn, upgradeURL, d.upgradeType()) + if err != nil { + return nil, trace.NewAggregate(tlsConn.Close(), err) + } + return conn, nil +} - default: - if err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, constants.WebAPIConnUpgradeTypeALPN); err != nil { - return nil, trace.NewAggregate(err, tlsConn.Close()) - } - return tlsConn, nil +func (d *alpnConnUpgradeDialer) upgradeType() string { + if d.withPing { + return constants.WebAPIConnUpgradeTypeALPNPing } + return constants.WebAPIConnUpgradeTypeALPN } -func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) error { +func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (net.Conn, error) { req, err := http.NewRequest(http.MethodGet, api.String(), nil) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } req.Header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) // Send the request and check if upgrade is successful. if err = req.Write(conn); err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } resp, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer resp.Body.Close() if http.StatusSwitchingProtocols != resp.StatusCode { if http.StatusNotFound == resp.StatusCode { - return trace.NotImplemented( + return nil, trace.NotImplemented( "connection upgrade call to %q with upgrade type %v failed with status code %v. Please upgrade the server and try again.", constants.WebAPIConnUpgrade, upgradeType, resp.StatusCode, ) } - return trace.BadParameter("failed to switch Protocols %v", resp.StatusCode) + return nil, trace.BadParameter("failed to switch Protocols %v", resp.StatusCode) + } + + if upgradeType == constants.WebAPIConnUpgradeTypeALPNPing { + return pingconn.New(conn), nil } - return nil + return conn, nil } diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index 8324aabdd5f2b..a2c85034c0f7f 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -104,7 +104,6 @@ func (h *Handler) upgradeALPNWithPing(ctx context.Context, conn net.Conn) error ctx, cancel := context.WithCancel(ctx) defer cancel() go h.startPing(ctx, pingConn) - time.Sleep(time.Millisecond * 100) return h.upgradeALPN(ctx, pingConn) } From e619b945894d1488d121b52cb7d4cc27c22c15f0 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 13 Apr 2023 11:01:57 -0400 Subject: [PATCH 21/27] remove WebProxyAddr --- api/client/client.go | 35 +++++++------------ api/client/contextdialer.go | 18 ++++++++-- api/client/proxy/client.go | 22 ++++++------ integration/proxy/proxy_test.go | 60 ++++++++++++++++++--------------- lib/auth/clt.go | 2 +- lib/client/api.go | 19 ++++++----- lib/client/client.go | 10 +++--- tool/tsh/proxy.go | 2 +- 8 files changed, 87 insertions(+), 81 deletions(-) diff --git a/api/client/client.go b/api/client/client.go index fd8b374d12f90..97f335feeefa1 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -326,11 +326,18 @@ type ( // authConnect connects to the Teleport Auth Server directly or through Proxy. func authConnect(ctx context.Context, params connectParams) (*Client, error) { - dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, + opts := []DialOption{ WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery), - WithALPNConnUpgrade(IsWebProxyAndConnUpgradeRequired(ctx, params.addr, ¶ms.cfg)), - ) + } + + if params.cfg.IsALPNConnUpgradeRequiredFunc != nil { + opts = append(opts, + WithALPNConnUpgrade(params.cfg.IsALPNConnUpgradeRequiredFunc(params.addr, params.cfg.InsecureAddressDiscovery)), + WithALPNConnUpgradePing(true), // Use ping for long-lived connections. + ) + } + dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, opts...) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as an auth server", params.addr) @@ -338,16 +345,6 @@ func authConnect(ctx context.Context, params connectParams) (*Client, error) { return clt, nil } -// IsWebProxyAndConnUpgradeRequired returns true if targetAddr is a Teleport -// web proxy address and ALPN connection upgrade is required for it. If no -// cluster name is provided for ALPN, assume dialer is not trying to connect -// Auth through Proxy using TLS Routing. -func IsWebProxyAndConnUpgradeRequired(ctx context.Context, targetAddr string, cfg *Config) bool { - return cfg.ALPNSNIAuthDialClusterName != "" && - cfg.WebProxyAddr == targetAddr && - cfg.IsALPNConnUpgradeRequired(targetAddr, cfg.InsecureAddressDiscovery) -} - // tunnelConnect connects to the Teleport Auth Server through the proxy's reverse tunnel. func tunnelConnect(ctx context.Context, params connectParams) (*Client, error) { if params.sshConfig == nil { @@ -507,8 +504,6 @@ func (c *Client) waitForConnectionReady(ctx context.Context) error { type Config struct { // Addrs is a list of teleport auth/proxy server addresses to dial. Addrs []string - // WebProxyAddr is the Teleport Proxy web address. - WebProxyAddr string // Credentials are a list of credentials to use when attempting // to connect to the server. Credentials []Credentials @@ -540,9 +535,9 @@ type Config struct { CircuitBreakerConfig breaker.Config // Context is the base context to use for dialing. If not provided context.Background is used Context context.Context - // IsALPNConnUpgradeRequired is a callback function to check whether + // IsALPNConnUpgradeRequiredFunc is a callback function to check whether // connection upgrade is required for TLS Routing. - IsALPNConnUpgradeRequired func(addr string, insecure bool) bool + IsALPNConnUpgradeRequiredFunc func(addr string, insecure bool) bool } // CheckAndSetDefaults checks and sets default config values. @@ -576,12 +571,6 @@ func (c *Config) CheckAndSetDefaults() error { if !c.DialInBackground { c.DialOpts = append(c.DialOpts, grpc.WithBlock()) } - if c.IsALPNConnUpgradeRequired == nil { - c.IsALPNConnUpgradeRequired = IsALPNConnUpgradeRequired - } - if c.WebProxyAddr != "" { - c.Addrs = utils.Deduplicate(append(c.Addrs, c.WebProxyAddr)) - } return nil } diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 7c3b976a42f27..34de5d0cd8098 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -37,8 +37,13 @@ import ( ) type dialConfig struct { - tlsConfig *tls.Config + tlsConfig *tls.Config + // alpnConnUpgradeRequired specifies if ALPN connection upgrade is + // required. alpnConnUpgradeRequired bool + // alpnConnUpgradeWithPing specifies if Ping is required during ALPN + // connection upgrade. This is only effective when alpnConnUpgradeRequired + // is true. alpnConnUpgradeWithPing bool } @@ -58,7 +63,8 @@ func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialOption { } } -// WithALPNConnUpgradePing specifies if Ping is required during ALPN connection upgrade. +// WithALPNConnUpgradePing specifies if Ping is required during ALPN connection +// upgrade. This is only effective when alpnConnUpgradeRequired is true. func WithALPNConnUpgradePing(alpnConnUpgradeWithPing bool) DialOption { return func(cfg *dialProxyConfig) { cfg.alpnConnUpgradeWithPing = alpnConnUpgradeWithPing @@ -223,6 +229,12 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte if err != nil { return nil, trace.Wrap(err) } + + isALPNConnUpgradeRequiredFunc := params.cfg.IsALPNConnUpgradeRequiredFunc + if isALPNConnUpgradeRequiredFunc == nil { + isALPNConnUpgradeRequiredFunc = IsALPNConnUpgradeRequired + } + conn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ DialTimeout: params.cfg.DialTimeout, KeepAlivePeriod: params.cfg.KeepAlivePeriod, @@ -231,7 +243,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte InsecureSkipVerify: insecure, ServerName: host, }, - ALPNConnUpgradeRequired: params.cfg.IsALPNConnUpgradeRequired(tunnelAddr, insecure), + ALPNConnUpgradeRequired: isALPNConnUpgradeRequiredFunc(tunnelAddr, insecure), GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { tlsConfig, err := params.cfg.Credentials[0].TLSConfig() if err != nil { diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 262dee0260a17..10b79b322eb94 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -82,9 +82,9 @@ type ClientConfig struct { DialTimeout time.Duration // DialOpts define options for dialing the client connection. DialOpts []grpc.DialOption - // IsALPNConnUpgradeRequired is a callback function to check whether + // IsALPNConnUpgradeRequiredFunc is a callback function to check whether // connection upgrade is required for TLS Routing. - IsALPNConnUpgradeRequired func(addr string, insecure bool) bool + IsALPNConnUpgradeRequiredFunc func(addr string, insecure bool) bool // InsecureSkipVerify is an option to skip HTTPS cert check InsecureSkipVerify bool @@ -113,8 +113,8 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.DialTimeout <= 0 { c.DialTimeout = defaults.DefaultIOTimeout } - if c.IsALPNConnUpgradeRequired == nil { - c.IsALPNConnUpgradeRequired = client.IsALPNConnUpgradeRequired + if c.IsALPNConnUpgradeRequiredFunc == nil { + c.IsALPNConnUpgradeRequiredFunc = client.IsALPNConnUpgradeRequired } if c.TLSConfig != nil { @@ -352,7 +352,7 @@ func newDialerForGRPCClient(ctx context.Context, cfg *ClientConfig) func(context if cfg.TLSRoutingEnabled { dialOpts = append(dialOpts, - client.WithALPNConnUpgrade(cfg.IsALPNConnUpgradeRequired(cfg.ProxyWebAddress, cfg.InsecureSkipVerify)), + client.WithALPNConnUpgrade(cfg.IsALPNConnUpgradeRequiredFunc(cfg.ProxyWebAddress, cfg.InsecureSkipVerify)), client.WithALPNConnUpgradePing(true), ) } @@ -463,12 +463,12 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config switch { case c.cfg.TLSRoutingEnabled: return client.Config{ - Context: ctx, - WebProxyAddr: c.cfg.ProxyWebAddress, - Credentials: []client.Credentials{c.cfg.clientCreds()}, - ALPNSNIAuthDialClusterName: cluster, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - IsALPNConnUpgradeRequired: c.cfg.IsALPNConnUpgradeRequired, + Context: ctx, + Addrs: []string{c.cfg.ProxyWebAddress}, + Credentials: []client.Credentials{c.cfg.clientCreds()}, + ALPNSNIAuthDialClusterName: cluster, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + IsALPNConnUpgradeRequiredFunc: c.cfg.IsALPNConnUpgradeRequiredFunc, } case c.sshClient != nil: return client.Config{ diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index c4fc82925aa5e..5899718d745a1 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -21,7 +21,6 @@ import ( "context" "crypto/x509" "crypto/x509/pkix" - "fmt" "net" "net/http" "net/http/httptest" @@ -1214,55 +1213,60 @@ func TestALPNProxyAuthClientConnectWithUserIdentity(t *testing.T) { // client can point to this ALB instead. albProxy := mustStartMockALBProxy(t, rc.Web) - connectTypes := []struct { + tests := []struct { name string clientConfig client.Config }{ { - name: "sync", + name: "sync connect to Proxy", clientConfig: client.Config{ + Addrs: []string{rc.Web}, Credentials: []client.Credentials{identity}, InsecureAddressDiscovery: true, }, }, { - name: "background", + name: "sync connect to Proxy behind ALB", clientConfig: client.Config{ + Addrs: []string{albProxy.Addr().String()}, + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + }, + }, + { + name: "background connect to Proxy", + clientConfig: client.Config{ + Addrs: []string{rc.Web}, Credentials: []client.Credentials{identity}, InsecureAddressDiscovery: true, DialInBackground: true, ALPNSNIAuthDialClusterName: cfg.ClusterName, }, }, - } - - targets := []struct { - name string - addr string - }{ { - name: "web", - addr: rc.Web, - }, - { - name: "alb", - addr: albProxy.Addr().String(), + name: "background connect to Proxy behind ALB", + clientConfig: client.Config{ + Addrs: []string{albProxy.Addr().String()}, + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + DialInBackground: true, + ALPNSNIAuthDialClusterName: cfg.ClusterName, + IsALPNConnUpgradeRequiredFunc: func(addr string, insecure bool) bool { + return addr == albProxy.Addr().String() + }, + }, }, } - for _, target := range targets { - for _, connectType := range connectTypes { - t.Run(fmt.Sprintf("%v/%v", target.name, connectType.name), func(t *testing.T) { - connectType.clientConfig.WebProxyAddr = target.addr - - tc, err := client.New(context.Background(), connectType.clientConfig) - require.NoError(t, err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tc, err := client.New(context.Background(), test.clientConfig) + require.NoError(t, err) - resp, err := tc.Ping(context.Background()) - require.NoError(t, err) - require.Equal(t, rc.Secrets.SiteName, resp.ClusterName) - }) - } + resp, err := tc.Ping(context.Background()) + require.NoError(t, err) + require.Equal(t, rc.Secrets.SiteName, resp.ClusterName) + }) } } diff --git a/lib/auth/clt.go b/lib/auth/clt.go index e98dcc2bbe06b..3caecedc9f71b 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -103,7 +103,7 @@ func NewClient(cfg client.Config, params ...roundtrip.ClientParam) (*Client, err for _, addr := range cfg.Addrs { contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithInsecureSkipVerify(httpTLS.InsecureSkipVerify), - client.WithALPNConnUpgrade(client.IsWebProxyAndConnUpgradeRequired(ctx, addr, &cfg)), + client.WithALPNConnUpgrade(cfg.IsALPNConnUpgradeRequiredFunc != nil && cfg.IsALPNConnUpgradeRequiredFunc(addr, httpTLS.InsecureSkipVerify)), ) conn, err = contextDialer.DialContext(ctx, network, addr) if err == nil { diff --git a/lib/client/api.go b/lib/client/api.go index 06f7a893a892f..59087244ef15c 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2790,9 +2790,9 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient, clt, err := makeProxySSHClient(ctx, tc, config) return clt, trace.Wrap(err) }), - SSHConfig: cfg.ClientConfig, - IsALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired, - InsecureSkipVerify: tc.InsecureSkipVerify, + SSHConfig: cfg.ClientConfig, + IsALPNConnUpgradeRequiredFunc: tc.IsALPNConnUpgradeRequiredForWebProxy, + InsecureSkipVerify: tc.InsecureSkipVerify, }) if err != nil { return nil, trace.Wrap(err) @@ -3043,7 +3043,7 @@ func makeProxySSHClientWithTLSWrapper(ctx context.Context, tc *TeleportClient, s tlsConfig.NextProtos = []string{string(alpncommon.ProtocolProxySSH)} dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(client.ALPNDialerConfig{ TLSConfig: tlsConfig, - ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired(proxyAddr, tlsConfig.InsecureSkipVerify), + ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, DialTimeout: sshConfig.Timeout, })) return dialer.Dial(ctx, "tcp", proxyAddr, sshConfig) @@ -4617,15 +4617,16 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste return kubeproto.NewKubeServiceClient(clt.GetConnection()), nil } -// IsALPNConnUpgradeRequired returns true if connection upgrade is required for -// provided addr. -func (tc *TeleportClient) IsALPNConnUpgradeRequired(addr string, insecure bool) bool { +// IsALPNConnUpgradeRequiredForWebProxy returns true if connection upgrade is +// required for provided addr. The provided address must be a web proxy +// address. +func (tc *TeleportClient) IsALPNConnUpgradeRequiredForWebProxy(proxyAddr string, insecure bool) bool { // Use cached value. - if addr == tc.WebProxyAddr { + if proxyAddr == tc.WebProxyAddr { return tc.TLSRoutingConnUpgradeRequired } // Do a test for other addresses. - return client.IsALPNConnUpgradeRequired(addr, insecure) + return client.IsALPNConnUpgradeRequired(proxyAddr, insecure) } // RootClusterCACertPool returns a *x509.CertPool with the root cluster CA. diff --git a/lib/client/client.go b/lib/client/client.go index 24760eb391786..1e5dad9a8edea 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1158,14 +1158,14 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co tlsConfig.InsecureSkipVerify = proxy.teleportClient.InsecureSkipVerify clt, err := auth.NewClient(client.Config{ - Context: ctx, - WebProxyAddr: proxyAddr, + Context: ctx, + Addrs: []string{proxyAddr}, Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, - ALPNSNIAuthDialClusterName: clusterName, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - IsALPNConnUpgradeRequired: proxy.teleportClient.IsALPNConnUpgradeRequired, + ALPNSNIAuthDialClusterName: clusterName, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + IsALPNConnUpgradeRequiredFunc: proxy.teleportClient.IsALPNConnUpgradeRequiredForWebProxy, }) if err != nil { return nil, trace.Wrap(err) diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 0eff8383d1154..abf2c9d2127cd 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -250,7 +250,7 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy InsecureSkipVerify: tc.InsecureSkipVerify, ServerName: sp.proxyHost, }, - ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequired(remoteProxyAddr, tc.InsecureSkipVerify), + ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequiredForWebProxy(remoteProxyAddr, tc.InsecureSkipVerify), }) default: From 9ada770b19208356d937054d53d1109b6547bc84 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 13 Apr 2023 14:48:49 -0400 Subject: [PATCH 22/27] require IsALPNConnUpgradeRequiredFunc --- api/client/client.go | 44 +++++++++++++++++++++++++++----------- api/client/proxy/client.go | 4 ++-- lib/client/client.go | 1 + 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/api/client/client.go b/api/client/client.go index 97f335feeefa1..de302b9820b59 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -150,28 +150,48 @@ func newClient(cfg Config, dialer ContextDialer, tlsConfig *tls.Config) *Client } // connectInBackground connects the client to the server in the background. -// The client will use the first credentials and the given dialer. If -// no dialer is given, the first address will be used. This address must -// be an auth server address. +// +// The client will use the first credentials and the given dialer. +// +// If no dialer is given, the first address will be used. If no +// ALPNSNIAuthDialClusterName is given, this address must be an auth server +// address. +// +// If ALPNSNIAuthDialClusterName is given, the address is expected to be a web +// proxy address and the client will connect auth through the web proxy server +// using TLS Routing. cfg.IsALPNConnUpgradeRequiredFunc must also be provided +// in this case assuming the caller has the context on whether ALPN connection +// upgrade is required. func connectInBackground(ctx context.Context, cfg Config) (*Client, error) { tlsConfig, err := cfg.Credentials[0].TLSConfig() if err != nil { return nil, trace.Wrap(err) } + + // Dialer connect. if cfg.Dialer != nil { - return dialerConnect(ctx, connectParams{ + client, err := dialerConnect(ctx, connectParams{ cfg: cfg, tlsConfig: tlsConfig, dialer: cfg.Dialer, }) - } else if len(cfg.Addrs) != 0 { - return authConnect(ctx, connectParams{ - cfg: cfg, - tlsConfig: tlsConfig, - addr: cfg.Addrs[0], - }) + return client, trace.Wrap(err) + } + + // Auth connect. + if len(cfg.Addrs) == 0 { + return nil, trace.BadParameter("must provide Dialer or Addrs in config") } - return nil, trace.BadParameter("must provide Dialer or Addrs in config") + if cfg.ALPNSNIAuthDialClusterName != "" && cfg.IsALPNConnUpgradeRequiredFunc == nil { + return nil, trace.BadParameter("must provide IsALPNConnUpgradeRequiredFunc for authConnect using TLS Routing") + } + + client, err := authConnect(ctx, connectParams{ + cfg: cfg, + tlsConfig: tlsConfig, + addr: cfg.Addrs[0], + }) + return client, trace.Wrap(err) } // connect connects the client to the server using the Credentials and @@ -333,7 +353,7 @@ func authConnect(ctx context.Context, params connectParams) (*Client, error) { if params.cfg.IsALPNConnUpgradeRequiredFunc != nil { opts = append(opts, WithALPNConnUpgrade(params.cfg.IsALPNConnUpgradeRequiredFunc(params.addr, params.cfg.InsecureAddressDiscovery)), - WithALPNConnUpgradePing(true), // Use ping for long-lived connections. + WithALPNConnUpgradePing(true), // Use Ping protocol for long-lived connections. ) } diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 10b79b322eb94..afaa56ec522d6 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -113,8 +113,8 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.DialTimeout <= 0 { c.DialTimeout = defaults.DefaultIOTimeout } - if c.IsALPNConnUpgradeRequiredFunc == nil { - c.IsALPNConnUpgradeRequiredFunc = client.IsALPNConnUpgradeRequired + if c.TLSRoutingEnabled && c.IsALPNConnUpgradeRequiredFunc == nil { + return trace.BadParameter("missing parameter IsALPNConnUpgradeRequiredFunc when TLS Routing is enabled") } if c.TLSConfig != nil { diff --git a/lib/client/client.go b/lib/client/client.go index 1e5dad9a8edea..7d5d4a4fe8de5 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1271,6 +1271,7 @@ func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName stri case proxy.teleportClient.TLSRoutingEnabled: clientConfig.Addrs = []string{proxy.teleportClient.WebProxyAddr} clientConfig.ALPNSNIAuthDialClusterName = clusterName + clientConfig.IsALPNConnUpgradeRequiredFunc = proxy.teleportClient.IsALPNConnUpgradeRequiredForWebProxy default: clientConfig.Dialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (net.Conn, error) { return proxy.dialAuthServer(ctx, clusterName) From 4e0c056023b25d6c169e0058c725526e040d1c27 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 13 Apr 2023 17:00:48 -0400 Subject: [PATCH 23/27] add tlsRoutingWithConnUpgradeConnect --- api/client/client.go | 70 ++++++++++++++++----------------- api/client/contextdialer.go | 67 +++++++++++++++++++++++++------ api/client/proxy/client.go | 36 ++++++----------- integration/proxy/proxy_test.go | 4 +- lib/auth/clt.go | 2 +- lib/client/api.go | 12 +++--- lib/client/client.go | 8 ++-- tool/tsh/proxy.go | 2 +- 8 files changed, 113 insertions(+), 88 deletions(-) diff --git a/api/client/client.go b/api/client/client.go index de302b9820b59..3fe6400b26b74 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -159,39 +159,26 @@ func newClient(cfg Config, dialer ContextDialer, tlsConfig *tls.Config) *Client // // If ALPNSNIAuthDialClusterName is given, the address is expected to be a web // proxy address and the client will connect auth through the web proxy server -// using TLS Routing. cfg.IsALPNConnUpgradeRequiredFunc must also be provided -// in this case assuming the caller has the context on whether ALPN connection -// upgrade is required. +// using TLS Routing. func connectInBackground(ctx context.Context, cfg Config) (*Client, error) { tlsConfig, err := cfg.Credentials[0].TLSConfig() if err != nil { return nil, trace.Wrap(err) } - - // Dialer connect. if cfg.Dialer != nil { - client, err := dialerConnect(ctx, connectParams{ + return dialerConnect(ctx, connectParams{ cfg: cfg, tlsConfig: tlsConfig, dialer: cfg.Dialer, }) - return client, trace.Wrap(err) - } - - // Auth connect. - if len(cfg.Addrs) == 0 { - return nil, trace.BadParameter("must provide Dialer or Addrs in config") - } - if cfg.ALPNSNIAuthDialClusterName != "" && cfg.IsALPNConnUpgradeRequiredFunc == nil { - return nil, trace.BadParameter("must provide IsALPNConnUpgradeRequiredFunc for authConnect using TLS Routing") + } else if len(cfg.Addrs) != 0 { + return authConnect(ctx, connectParams{ + cfg: cfg, + tlsConfig: tlsConfig, + addr: cfg.Addrs[0], + }) } - - client, err := authConnect(ctx, connectParams{ - cfg: cfg, - tlsConfig: tlsConfig, - addr: cfg.Addrs[0], - }) - return client, trace.Wrap(err) + return nil, trace.BadParameter("must provide Dialer or Addrs in config") } // connect connects the client to the server using the Credentials and @@ -278,7 +265,7 @@ func connect(ctx context.Context, cfg Config) (*Client, error) { addr: addr, }) if sshConfig != nil { - for _, cf := range []connectFunc{proxyConnect, tunnelConnect, tlsRoutingConnect} { + for _, cf := range []connectFunc{proxyConnect, tunnelConnect, tlsRoutingConnect, tlsRoutingWithConnUpgradeConnect} { syncConnect(ctx, cf, connectParams{ cfg: cfg, tlsConfig: tlsConfig, @@ -346,18 +333,12 @@ type ( // authConnect connects to the Teleport Auth Server directly or through Proxy. func authConnect(ctx context.Context, params connectParams) (*Client, error) { - opts := []DialOption{ + dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery), - } - - if params.cfg.IsALPNConnUpgradeRequiredFunc != nil { - opts = append(opts, - WithALPNConnUpgrade(params.cfg.IsALPNConnUpgradeRequiredFunc(params.addr, params.cfg.InsecureAddressDiscovery)), - WithALPNConnUpgradePing(true), // Use Ping protocol for long-lived connections. - ) - } + WithALPNConnUpgrade(params.cfg.ALPNConnUpgradeRequired), + WithALPNConnUpgradePing(true), // Use Ping protocol for long-lived connections. + ) - dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, opts...) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as an auth server", params.addr) @@ -396,7 +377,7 @@ func tlsRoutingConnect(ctx context.Context, params connectParams) (*Client, erro if params.sshConfig == nil { return nil, trace.BadParameter("must provide ssh client config") } - dialer := newTLSRoutingTunnelDialer(*params.sshConfig, params) + dialer := newTLSRoutingTunnelDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, params.addr, params.cfg.InsecureAddressDiscovery) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v with TLS Routing dialer", params.addr) @@ -404,6 +385,20 @@ func tlsRoutingConnect(ctx context.Context, params connectParams) (*Client, erro return clt, nil } +// tlsRoutingWithConnUpgradeConnect connects to the Teleport Auth Server +// through the proxy using TLS Routing with ALPN connection upgrade. +func tlsRoutingWithConnUpgradeConnect(ctx context.Context, params connectParams) (*Client, error) { + if params.sshConfig == nil { + return nil, trace.BadParameter("must provide ssh client config") + } + dialer := newTLSRoutingWithConnUpgradeDialer(*params.sshConfig, params) + clt := newClient(params.cfg, dialer, params.tlsConfig) + if err := clt.dialGRPC(ctx, params.addr); err != nil { + return nil, trace.Wrap(err, "failed to connect to addr %v with TLS Routing with ALPN connection upgrade dialer", params.addr) + } + return clt, nil +} + // dialerConnect connects to the Teleport Auth Server through a custom dialer. // The dialer must provide the address in a custom ContextDialerFunc function. func dialerConnect(ctx context.Context, params connectParams) (*Client, error) { @@ -555,9 +550,10 @@ type Config struct { CircuitBreakerConfig breaker.Config // Context is the base context to use for dialing. If not provided context.Background is used Context context.Context - // IsALPNConnUpgradeRequiredFunc is a callback function to check whether - // connection upgrade is required for TLS Routing. - IsALPNConnUpgradeRequiredFunc func(addr string, insecure bool) bool + // ALPNConnUpgradeRequired indicates that ALPN connection upgrades are + // required for making TLS routing requests. Only used in auth background + // dial. + ALPNConnUpgradeRequired bool } // CheckAndSetDefaults checks and sets default config values. diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 34de5d0cd8098..5b129cbb7feab 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -206,36 +206,78 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur }) } +type isALPNConnUpgradeRequiredFunc func(string, bool) bool + // newTLSRoutingTunnelDialer makes a reverse tunnel TLS Routing dialer to connect to an Auth server // through the SSH reverse tunnel on the proxy. -func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer { - return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { - insecure := params.cfg.InsecureAddressDiscovery - resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: params.addr, Insecure: insecure}) +func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer { + return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure}) if err != nil { return nil, trace.Wrap(err) } - if !resp.Proxy.TLSRoutingEnabled { return nil, trace.NotImplemented("TLS routing is not enabled") } - tunnelAddr, err := resp.Proxy.TunnelAddr() if err != nil { return nil, trace.Wrap(err) } + dialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: keepAlivePeriod, + } + conn, err = dialer.DialContext(ctx, network, tunnelAddr) + if err != nil { + return nil, trace.Wrap(err) + } + host, _, err := webclient.ParseHostPort(tunnelAddr) if err != nil { return nil, trace.Wrap(err) } - isALPNConnUpgradeRequiredFunc := params.cfg.IsALPNConnUpgradeRequiredFunc - if isALPNConnUpgradeRequiredFunc == nil { - isALPNConnUpgradeRequiredFunc = IsALPNConnUpgradeRequired + tlsConn := tls.Client(conn, &tls.Config{ + NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, + InsecureSkipVerify: insecure, + ServerName: host, + }) + if err := tlsConn.Handshake(); err != nil { + return nil, trace.Wrap(err) + } + + sconn, err := sshConnect(ctx, tlsConn, ssh, dialTimeout, tunnelAddr) + if err != nil { + return nil, trace.Wrap(err) + } + + return sconn, nil + }) +} + +// newTLSRoutingWithConnUpgradeDialer makes a reverse tunnel TLS Routing dialer +// through the web proxy with ALPN connection upgrade. +func newTLSRoutingWithConnUpgradeDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer { + return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + insecure := params.cfg.InsecureAddressDiscovery + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: params.addr, + Insecure: insecure, + }) + if err != nil { + return nil, trace.Wrap(err) + } + if !resp.Proxy.TLSRoutingEnabled { + return nil, trace.NotImplemented("TLS routing is not enabled") } - conn, err := DialALPN(ctx, tunnelAddr, ALPNDialerConfig{ + host, _, err := webclient.ParseHostPort(params.addr) + if err != nil { + return nil, trace.Wrap(err) + } + conn, err := DialALPN(ctx, params.addr, ALPNDialerConfig{ DialTimeout: params.cfg.DialTimeout, KeepAlivePeriod: params.cfg.KeepAlivePeriod, TLSConfig: &tls.Config{ @@ -243,7 +285,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte InsecureSkipVerify: insecure, ServerName: host, }, - ALPNConnUpgradeRequired: isALPNConnUpgradeRequiredFunc(tunnelAddr, insecure), + ALPNConnUpgradeRequired: IsALPNConnUpgradeRequired(params.addr, insecure), GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { tlsConfig, err := params.cfg.Credentials[0].TLSConfig() if err != nil { @@ -256,11 +298,10 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, params connectParams) Conte return nil, trace.Wrap(err) } - sconn, err := sshConnect(ctx, conn, ssh, params.cfg.DialTimeout, tunnelAddr) + sconn, err := sshConnect(ctx, conn, ssh, params.cfg.DialTimeout, params.addr) if err != nil { return nil, trace.Wrap(err) } - return sconn, nil }) } diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index afaa56ec522d6..e0f619fb451a0 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -82,9 +82,9 @@ type ClientConfig struct { DialTimeout time.Duration // DialOpts define options for dialing the client connection. DialOpts []grpc.DialOption - // IsALPNConnUpgradeRequiredFunc is a callback function to check whether - // connection upgrade is required for TLS Routing. - IsALPNConnUpgradeRequiredFunc func(addr string, insecure bool) bool + // ALPNConnUpgradeRequired indicates that ALPN connection upgrades are + // required for making TLS routing requests. + ALPNConnUpgradeRequired bool // InsecureSkipVerify is an option to skip HTTPS cert check InsecureSkipVerify bool @@ -113,10 +113,6 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.DialTimeout <= 0 { c.DialTimeout = defaults.DefaultIOTimeout } - if c.TLSRoutingEnabled && c.IsALPNConnUpgradeRequiredFunc == nil { - return trace.BadParameter("missing parameter IsALPNConnUpgradeRequiredFunc when TLS Routing is enabled") - } - if c.TLSConfig != nil { c.clientCreds = func() client.Credentials { return client.LoadTLS(c.TLSConfig.Clone()) @@ -346,17 +342,11 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error } func newDialerForGRPCClient(ctx context.Context, cfg *ClientConfig) func(context.Context, string) (net.Conn, error) { - dialOpts := []client.DialOption{ + return client.GRPCContextDialer(client.NewDialer(ctx, defaults.DefaultIdleTimeout, cfg.DialTimeout, client.WithInsecureSkipVerify(cfg.InsecureSkipVerify), - } - - if cfg.TLSRoutingEnabled { - dialOpts = append(dialOpts, - client.WithALPNConnUpgrade(cfg.IsALPNConnUpgradeRequiredFunc(cfg.ProxyWebAddress, cfg.InsecureSkipVerify)), - client.WithALPNConnUpgradePing(true), - ) - } - return client.GRPCContextDialer(client.NewDialer(ctx, defaults.DefaultIdleTimeout, cfg.DialTimeout, dialOpts...)) + client.WithALPNConnUpgrade(cfg.ALPNConnUpgradeRequired), + client.WithALPNConnUpgradePing(true), // Use Ping protocol for long-lived connections. + )) } // teleportAuthority is the extension set by the server @@ -463,12 +453,12 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config switch { case c.cfg.TLSRoutingEnabled: return client.Config{ - Context: ctx, - Addrs: []string{c.cfg.ProxyWebAddress}, - Credentials: []client.Credentials{c.cfg.clientCreds()}, - ALPNSNIAuthDialClusterName: cluster, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - IsALPNConnUpgradeRequiredFunc: c.cfg.IsALPNConnUpgradeRequiredFunc, + Context: ctx, + Addrs: []string{c.cfg.ProxyWebAddress}, + Credentials: []client.Credentials{c.cfg.clientCreds()}, + ALPNSNIAuthDialClusterName: cluster, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + ALPNConnUpgradeRequired: c.cfg.ALPNConnUpgradeRequired, } case c.sshClient != nil: return client.Config{ diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 5899718d745a1..710cd2c36b893 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -1251,9 +1251,7 @@ func TestALPNProxyAuthClientConnectWithUserIdentity(t *testing.T) { InsecureAddressDiscovery: true, DialInBackground: true, ALPNSNIAuthDialClusterName: cfg.ClusterName, - IsALPNConnUpgradeRequiredFunc: func(addr string, insecure bool) bool { - return addr == albProxy.Addr().String() - }, + ALPNConnUpgradeRequired: true, }, }, } diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 3caecedc9f71b..5b194ec882fd3 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -103,7 +103,7 @@ func NewClient(cfg client.Config, params ...roundtrip.ClientParam) (*Client, err for _, addr := range cfg.Addrs { contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithInsecureSkipVerify(httpTLS.InsecureSkipVerify), - client.WithALPNConnUpgrade(cfg.IsALPNConnUpgradeRequiredFunc != nil && cfg.IsALPNConnUpgradeRequiredFunc(addr, httpTLS.InsecureSkipVerify)), + client.WithALPNConnUpgrade(cfg.ALPNConnUpgradeRequired), ) conn, err = contextDialer.DialContext(ctx, network, addr) if err == nil { diff --git a/lib/client/api.go b/lib/client/api.go index 59087244ef15c..8b591ced4cd37 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2790,9 +2790,9 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient, clt, err := makeProxySSHClient(ctx, tc, config) return clt, trace.Wrap(err) }), - SSHConfig: cfg.ClientConfig, - IsALPNConnUpgradeRequiredFunc: tc.IsALPNConnUpgradeRequiredForWebProxy, - InsecureSkipVerify: tc.InsecureSkipVerify, + SSHConfig: cfg.ClientConfig, + ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, + InsecureSkipVerify: tc.InsecureSkipVerify, }) if err != nil { return nil, trace.Wrap(err) @@ -4620,13 +4620,13 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste // IsALPNConnUpgradeRequiredForWebProxy returns true if connection upgrade is // required for provided addr. The provided address must be a web proxy // address. -func (tc *TeleportClient) IsALPNConnUpgradeRequiredForWebProxy(proxyAddr string, insecure bool) bool { +func (tc *TeleportClient) IsALPNConnUpgradeRequiredForWebProxy(proxyAddr string) bool { // Use cached value. if proxyAddr == tc.WebProxyAddr { return tc.TLSRoutingConnUpgradeRequired } - // Do a test for other addresses. - return client.IsALPNConnUpgradeRequired(proxyAddr, insecure) + // Do a test for other proxy addresses. + return client.IsALPNConnUpgradeRequired(proxyAddr, tc.InsecureSkipVerify) } // RootClusterCACertPool returns a *x509.CertPool with the root cluster CA. diff --git a/lib/client/client.go b/lib/client/client.go index 7d5d4a4fe8de5..c0e2dec6b0a86 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1163,9 +1163,9 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, - ALPNSNIAuthDialClusterName: clusterName, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - IsALPNConnUpgradeRequiredFunc: proxy.teleportClient.IsALPNConnUpgradeRequiredForWebProxy, + ALPNSNIAuthDialClusterName: clusterName, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + ALPNConnUpgradeRequired: proxy.teleportClient.IsALPNConnUpgradeRequiredForWebProxy(proxyAddr), }) if err != nil { return nil, trace.Wrap(err) @@ -1271,7 +1271,7 @@ func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName stri case proxy.teleportClient.TLSRoutingEnabled: clientConfig.Addrs = []string{proxy.teleportClient.WebProxyAddr} clientConfig.ALPNSNIAuthDialClusterName = clusterName - clientConfig.IsALPNConnUpgradeRequiredFunc = proxy.teleportClient.IsALPNConnUpgradeRequiredForWebProxy + clientConfig.ALPNConnUpgradeRequired = proxy.teleportClient.TLSRoutingConnUpgradeRequired default: clientConfig.Dialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (net.Conn, error) { return proxy.dialAuthServer(ctx, clusterName) diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index abf2c9d2127cd..23e368265def5 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -250,7 +250,7 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy InsecureSkipVerify: tc.InsecureSkipVerify, ServerName: sp.proxyHost, }, - ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequiredForWebProxy(remoteProxyAddr, tc.InsecureSkipVerify), + ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequiredForWebProxy(remoteProxyAddr), }) default: From a3d664f3f51b0b1c134d9efe6a98093e95d4d5dc Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 13 Apr 2023 18:08:37 -0400 Subject: [PATCH 24/27] fix lint --- api/client/contextdialer.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 5b129cbb7feab..000f00fa37db0 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -206,8 +206,6 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur }) } -type isALPNConnUpgradeRequiredFunc func(string, bool) bool - // newTLSRoutingTunnelDialer makes a reverse tunnel TLS Routing dialer to connect to an Auth server // through the SSH reverse tunnel on the proxy. func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer { From e6b482514f67e279b12bedaaa08c3cb8a05be369 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Fri, 14 Apr 2023 09:10:45 -0400 Subject: [PATCH 25/27] simplify --- api/client/alpn.go | 41 ++++++++++++++++++-------------- api/client/client.go | 11 +++++++-- api/client/contextdialer.go | 2 ++ lib/client/client.go | 4 ++++ lib/reversetunnel/agentpool.go | 10 +++++++- lib/srv/alpnproxy/local_proxy.go | 11 +++++++-- 6 files changed, 56 insertions(+), 23 deletions(-) diff --git a/api/client/alpn.go b/api/client/alpn.go index 91eb8d0e8e794..91a8d968b64d6 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -25,11 +25,9 @@ import ( "time" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" - "github.com/gravitational/teleport/api/utils/pingconn" ) // GetClusterCAsFunc is a function to fetch cluster CAs. @@ -78,10 +76,22 @@ func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer { func (d *ALPNDialer) shouldUpdateTLSConfig() bool { return d.shouldUpdateServerName() || d.shouldGetClusterCAs() } + +// shouldUpdateServerName returns true if ServerName is not in the provided TLS +// config. It will default to the host of the dialing address. func (d *ALPNDialer) shouldUpdateServerName() bool { return d.cfg.TLSConfig.ServerName == "" } +// shouldGetClusterCAs returns true if RootCAs of the provided TLS config needs +// to be set to the Teleport cluster CAs. +// +// When Teleport Proxy is behind a L7 load balancer, the load balancer +// usually terminates TLS with public certs, and the Proxy is usually in +// private subnets with self-signed web certs. During the connection +// upgrade flow for TLS Routing, instead of serving these self-signed web +// certs, the TLS Routing handler at the Proxy server will present the +// Cluster CAs so clients here can still verify the server. func (d *ALPNDialer) shouldGetClusterCAs() bool { return d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil } @@ -96,20 +106,12 @@ func (d *ALPNDialer) getTLSConfig(ctx context.Context, addr string) (*tls.Config var err error tlsConfig := d.cfg.TLSConfig.Clone() - - // When Teleport Proxy is behind a L7 load balancer, the load balancer - // usually terminates TLS with public certs, and the Proxy is usually in - // private subnets with self-signed web certs. During the connection - // upgrade flow for TLS Routing, instead of serving these self-signed web - // certs, the TLS Routing handler at the Proxy server will present the - // Cluster CAs so clients here can still verify the server. if d.shouldGetClusterCAs() { tlsConfig.RootCAs, err = d.cfg.GetClusterCAs(ctx) if err != nil { return nil, trace.Wrap(err) } } - if d.shouldUpdateServerName() { tlsConfig.ServerName, _, err = webclient.ParseHostPort(addr) if err != nil { @@ -142,18 +144,21 @@ func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net defer tlsConn.Close() return nil, trace.Wrap(err) } - - if IsALPNPingProtocol(tlsConn.ConnectionState().NegotiatedProtocol) { - logrus.Debugf("Using ping connection for protocol %v.", tlsConn.ConnectionState().NegotiatedProtocol) - return pingconn.NewTLS(tlsConn), nil - } return tlsConn, nil } -// DialALPN a helper to dial using an ALPNDialer. -func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (net.Conn, error) { +// DialALPN a helper to dial using an ALPNDialer and returns a tls.Conn if +// successful. +func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (*tls.Conn, error) { conn, err := NewALPNDialer(cfg).DialContext(ctx, "tcp", addr) - return conn, trace.Wrap(err) + if err != nil { + return nil, trace.Wrap(err) + } + tlsConn, ok := conn.(*tls.Conn) + if !ok { + return nil, trace.BadParameter("failed to convert to tls.Conn") + } + return tlsConn, nil } // IsALPNPingProtocol checks if the provided protocol is suffixed with Ping. diff --git a/api/client/client.go b/api/client/client.go index 3fe6400b26b74..03b2f89d19ca5 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -551,8 +551,15 @@ type Config struct { // Context is the base context to use for dialing. If not provided context.Background is used Context context.Context // ALPNConnUpgradeRequired indicates that ALPN connection upgrades are - // required for making TLS routing requests. Only used in auth background - // dial. + // required for making TLS Routing requests. + // + // In DialInBackground mode without a Dialer, a valid value must be + // provided as it's assumed that the caller knows the context if connection + // upgrades are required for TLS Routing. + // + // In default mode, this value is optional as some of the connect methods + // will perform necessary tests to decide if connection upgrade is + // required. ALPNConnUpgradeRequired bool } diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 000f00fa37db0..693ac9ebd44d3 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -214,9 +214,11 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeou if err != nil { return nil, trace.Wrap(err) } + if !resp.Proxy.TLSRoutingEnabled { return nil, trace.NotImplemented("TLS routing is not enabled") } + tunnelAddr, err := resp.Proxy.TunnelAddr() if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/client.go b/lib/client/client.go index c0e2dec6b0a86..8933a5f335819 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/trace" "github.com/moby/term" + "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" oteltrace "go.opentelemetry.io/otel/trace" @@ -1255,6 +1256,7 @@ func (proxy *ProxyClient) ConnectToCluster(ctx context.Context, clusterName stri // It returns a connected and authenticated tracing.Client that will export spans // to the auth server, where they will be forwarded onto the configured exporter. func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName string) (*tracing.Client, error) { + logrus.Debugf("=== new tracing client") tlsConfig, err := proxy.loadTLS(clusterName) if err != nil { return nil, trace.Wrap(err) @@ -1269,6 +1271,7 @@ func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName stri switch { case proxy.teleportClient.TLSRoutingEnabled: + logrus.Debugf("=== new tracing client tls") clientConfig.Addrs = []string{proxy.teleportClient.WebProxyAddr} clientConfig.ALPNSNIAuthDialClusterName = clusterName clientConfig.ALPNConnUpgradeRequired = proxy.teleportClient.TLSRoutingConnUpgradeRequired @@ -1279,6 +1282,7 @@ func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName stri } clt, err := client.NewTracingClient(ctx, clientConfig) + logrus.Debugf("=== new tracing client err %v", err) return clt, trace.Wrap(err) } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index f2f56ff3965cb..b8f5be2f5b5fd 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -631,13 +631,21 @@ func (c *agentPoolRuntimeConfig) restrictConnectionCount() bool { return c.tunnelStrategyType == types.ProxyPeering } +// useReverseTunnelV2Locked returns true if reverse tunnel should be used. +func (c *agentPoolRuntimeConfig) useReverseTunnelV2Locked() bool { + if c.isRemoteCluster { + return false + } + return c.tunnelStrategyType == types.ProxyPeering +} + // alpnDialerConfig creates a config for ALPN dialer. func (c *agentPoolRuntimeConfig) alpnDialerConfig(getClusterCAs client.GetClusterCAsFunc) client.ALPNDialerConfig { c.mu.RLock() defer c.mu.RUnlock() protocols := []alpncommon.Protocol{alpncommon.ProtocolReverseTunnel} - if !c.isRemoteCluster && c.tunnelStrategyType == types.ProxyPeering { + if c.useReverseTunnelV2Locked() { protocols = []alpncommon.Protocol{alpncommon.ProtocolReverseTunnelV2, alpncommon.ProtocolReverseTunnel} } diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index be519eb337af3..813ae59e7c0a8 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -35,6 +35,7 @@ import ( "golang.org/x/exp/slices" "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" commonApp "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" @@ -219,11 +220,17 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC return trace.Wrap(err) } - upstreamConn, err := client.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) + tlsConn, err := client.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(certs)) if err != nil { return trace.Wrap(err) } - defer upstreamConn.Close() + defer tlsConn.Close() + + var upstreamConn net.Conn = tlsConn + if common.IsPingProtocol(common.Protocol(tlsConn.ConnectionState().NegotiatedProtocol)) { + l.cfg.Log.Debug("Using ping connection") + upstreamConn = pingconn.NewTLS(tlsConn) + } return trace.Wrap(utils.ProxyConn(ctx, downstreamConn, upstreamConn)) } From fd185ad7ef69932f56bf3bfa4dc3e46c2bc8b53f Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Fri, 14 Apr 2023 10:44:51 -0400 Subject: [PATCH 26/27] remove debug log and change unknown upgrade type to 404 --- lib/client/client.go | 4 ---- lib/web/conn_upgrade.go | 2 +- lib/web/conn_upgrade_test.go | 2 +- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/lib/client/client.go b/lib/client/client.go index 8933a5f335819..c0e2dec6b0a86 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -33,7 +33,6 @@ import ( "github.com/gravitational/trace" "github.com/moby/term" - "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" oteltrace "go.opentelemetry.io/otel/trace" @@ -1256,7 +1255,6 @@ func (proxy *ProxyClient) ConnectToCluster(ctx context.Context, clusterName stri // It returns a connected and authenticated tracing.Client that will export spans // to the auth server, where they will be forwarded onto the configured exporter. func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName string) (*tracing.Client, error) { - logrus.Debugf("=== new tracing client") tlsConfig, err := proxy.loadTLS(clusterName) if err != nil { return nil, trace.Wrap(err) @@ -1271,7 +1269,6 @@ func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName stri switch { case proxy.teleportClient.TLSRoutingEnabled: - logrus.Debugf("=== new tracing client tls") clientConfig.Addrs = []string{proxy.teleportClient.WebProxyAddr} clientConfig.ALPNSNIAuthDialClusterName = clusterName clientConfig.ALPNConnUpgradeRequired = proxy.teleportClient.TLSRoutingConnUpgradeRequired @@ -1282,7 +1279,6 @@ func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName stri } clt, err := client.NewTracingClient(ctx, clientConfig) - logrus.Debugf("=== new tracing client err %v", err) return clt, trace.Wrap(err) } diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index a2c85034c0f7f..fdb03332b041a 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -45,7 +45,7 @@ func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHa } } - return "", nil, trace.BadParameter("unsupported upgrade types: %v", upgrades) + return "", nil, trace.NotFound("unsupported upgrade types: %v", upgrades) } // connectionUpgrade handles connection upgrades. diff --git a/lib/web/conn_upgrade_test.go b/lib/web/conn_upgrade_test.go index 3705d5e23138c..76b566beeeace 100644 --- a/lib/web/conn_upgrade_test.go +++ b/lib/web/conn_upgrade_test.go @@ -81,7 +81,7 @@ func TestHandlerConnectionUpgrade(t *testing.T) { r.Header.Add("Upgrade", "unsupported-protocol") _, err = h.connectionUpgrade(httptest.NewRecorder(), r, nil) - require.True(t, trace.IsBadParameter(err)) + require.True(t, trace.IsNotFound(err)) }) t.Run("upgraded to ALPN", func(t *testing.T) { From 9165f121022caf427d40852a9842833e737024c8 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Fri, 14 Apr 2023 11:04:50 -0400 Subject: [PATCH 27/27] Force new proxy client to use web proxy when TLS routing is enabled --- api/client/proxy/client.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index e0f619fb451a0..678e4a441264c 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -297,10 +297,16 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error dialCtx, cancel := context.WithTimeout(ctx, cfg.DialTimeout) defer cancel() + // Dial web proxy with TLS Routing. + addr := cfg.ProxySSHAddress + if cfg.TLSRoutingEnabled { + addr = cfg.ProxyWebAddress + } + c := &clusterName{} conn, err := grpc.DialContext( dialCtx, - cfg.ProxySSHAddress, + addr, append([]grpc.DialOption{ grpc.WithContextDialer(newDialerForGRPCClient(ctx, cfg)), grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}),