diff --git a/lib/auth/authclient/authclient.go b/lib/auth/authclient/authclient.go index c76ca1f3e4b3a..f2731cba6f704 100644 --- a/lib/auth/authclient/authclient.go +++ b/lib/auth/authclient/authclient.go @@ -56,8 +56,32 @@ type Config struct { func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) { cfg.Log.Debugf("Connecting to: %v.", cfg.AuthServers) + directClient, err := connectViaAuthDirect(cfg) + if err == nil { + return directClient, nil + } + directErr := trace.Wrap(err, "failed direct dial to auth server: %v", err) + + // If it fails, we now want to try tunneling to the auth server through a + // proxy, we can only do this with SSH credentials. + if cfg.SSH == nil { + return nil, trace.Wrap(directErr) + } + proxyTunnelClient, err := connectViaProxyTunnel(ctx, cfg) + if err == nil { + return proxyTunnelClient, nil + } + proxyTunnelErr := trace.Wrap(err, "failed dial to auth server through reverse tunnel: %v", err) + + return nil, trace.NewAggregate( + directErr, + proxyTunnelErr, + ) +} + +func connectViaAuthDirect(cfg *Config) (auth.ClientI, error) { // Try connecting to the auth server directly over TLS. - client, err := auth.NewClient(apiclient.Config{ + directDialClient, err := auth.NewClient(apiclient.Config{ Addrs: utils.NetAddrsToStrings(cfg.AuthServers), Credentials: []apiclient.Credentials{ apiclient.LoadTLS(cfg.TLS), @@ -67,64 +91,62 @@ func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) { DialTimeout: cfg.DialTimeout, }) if err != nil { - return nil, trace.Wrap(err, "failed direct dial to auth server: %v", err) + return nil, trace.Wrap(err) } // Check connectivity by calling something on the client. - _, err = client.GetClusterName() + if _, err := directDialClient.GetClusterName(); err != nil { + // This client didn't work for us, so we close it. + _ = directDialClient.Close() + return nil, trace.Wrap(err) + } + return directDialClient, nil +} + +func connectViaProxyTunnel(ctx context.Context, cfg *Config) (auth.ClientI, error) { + // If direct dial failed, we may have a proxy address in + // cfg.AuthServers. Try connecting to the reverse tunnel + // endpoint and make a client over that. + // + // TODO(nic): this logic should be implemented once and reused in IoT + // nodes. + resolver := reversetunnel.WebClientResolver(&webclient.Config{ + Context: ctx, + ProxyAddr: cfg.AuthServers[0].String(), + Insecure: cfg.TLS.InsecureSkipVerify, + Timeout: cfg.DialTimeout, + }) + + resolver, err := reversetunnel.CachingResolver(ctx, resolver, nil /* clock */) if err != nil { - directDialErr := trace.Wrap(err, "failed direct dial to auth server: %v", err) - if cfg.SSH == nil { - // No identity file was provided, don't try dialing via a reverse - // tunnel on the proxy. - return nil, trace.Wrap(directDialErr) - } - - // If direct dial failed, we may have a proxy address in - // cfg.AuthServers. Try connecting to the reverse tunnel - // endpoint and make a client over that. - // - // TODO(nic): this logic should be implemented once and reused in IoT - // nodes. - - resolver := reversetunnel.WebClientResolver(&webclient.Config{ - Context: ctx, - ProxyAddr: cfg.AuthServers[0].String(), - Insecure: cfg.TLS.InsecureSkipVerify, - Timeout: cfg.DialTimeout, - }) - - resolver, err = reversetunnel.CachingResolver(ctx, resolver, nil /* clock */) - if err != nil { - return nil, trace.Wrap(err) - } - - // reversetunnel.TunnelAuthDialer will take care of creating a net.Conn - // within an SSH tunnel. - dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{ - Resolver: resolver, - ClientConfig: cfg.SSH, - Log: cfg.Log, - InsecureSkipTLSVerify: cfg.TLS.InsecureSkipVerify, - }) - if err != nil { - return nil, trace.Wrap(err) - } - client, err = auth.NewClient(apiclient.Config{ - Dialer: dialer, - Credentials: []apiclient.Credentials{ - apiclient.LoadTLS(cfg.TLS), - }, - }) - if err != nil { - tunnelClientErr := trace.Wrap(err, "failed dial to auth server through reverse tunnel: %v", err) - return nil, trace.NewAggregate(directDialErr, tunnelClientErr) - } - // Check connectivity by calling something on the client. - if _, err := client.GetClusterName(); err != nil { - tunnelClientErr := trace.Wrap(err, "failed dial to auth server through reverse tunnel: %v", err) - return nil, trace.NewAggregate(directDialErr, tunnelClientErr) - } + return nil, trace.Wrap(err) + } + + // reversetunnel.TunnelAuthDialer will take care of creating a net.Conn + // within an SSH tunnel. + dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{ + Resolver: resolver, + ClientConfig: cfg.SSH, + Log: cfg.Log, + InsecureSkipTLSVerify: cfg.TLS.InsecureSkipVerify, + }) + if err != nil { + return nil, trace.Wrap(err) + } + tunnelClient, err := auth.NewClient(apiclient.Config{ + Dialer: dialer, + Credentials: []apiclient.Credentials{ + apiclient.LoadTLS(cfg.TLS), + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + // Check connectivity by calling something on the client. + if _, err := tunnelClient.GetClusterName(); err != nil { + // This client didn't work for us, so we close it. + _ = tunnelClient.Close() + return nil, trace.Wrap(err) } - return client, nil + return tunnelClient, nil }