Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 78 additions & 56 deletions lib/auth/authclient/authclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,32 @@ type Config struct {
func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) {
cfg.Log.Debugf("Connecting to: %v.", cfg.AuthServers)

directClient, err := connectViaAuthDirect(cfg)
if err == nil {
return directClient, nil
}
directErr := trace.Wrap(err, "failed direct dial to auth server: %v", err)

// If it fails, we now want to try tunneling to the auth server through a
// proxy, we can only do this with SSH credentials.
if cfg.SSH == nil {
return nil, trace.Wrap(directErr)
}
proxyTunnelClient, err := connectViaProxyTunnel(ctx, cfg)
if err == nil {
return proxyTunnelClient, nil
}
proxyTunnelErr := trace.Wrap(err, "failed dial to auth server through reverse tunnel: %v", err)

return nil, trace.NewAggregate(
directErr,
proxyTunnelErr,
)
}

func connectViaAuthDirect(cfg *Config) (auth.ClientI, error) {
// Try connecting to the auth server directly over TLS.
client, err := auth.NewClient(apiclient.Config{
directDialClient, err := auth.NewClient(apiclient.Config{
Addrs: utils.NetAddrsToStrings(cfg.AuthServers),
Credentials: []apiclient.Credentials{
apiclient.LoadTLS(cfg.TLS),
Expand All @@ -67,64 +91,62 @@ func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) {
DialTimeout: cfg.DialTimeout,
})
if err != nil {
return nil, trace.Wrap(err, "failed direct dial to auth server: %v", err)
return nil, trace.Wrap(err)
}

// Check connectivity by calling something on the client.
_, err = client.GetClusterName()
if _, err := directDialClient.GetClusterName(); err != nil {
// This client didn't work for us, so we close it.
_ = directDialClient.Close()
return nil, trace.Wrap(err)
}
return directDialClient, nil
}

func connectViaProxyTunnel(ctx context.Context, cfg *Config) (auth.ClientI, error) {
// If direct dial failed, we may have a proxy address in
// cfg.AuthServers. Try connecting to the reverse tunnel
// endpoint and make a client over that.
//
// TODO(nic): this logic should be implemented once and reused in IoT
// nodes.
resolver := reversetunnel.WebClientResolver(&webclient.Config{
Context: ctx,
ProxyAddr: cfg.AuthServers[0].String(),
Insecure: cfg.TLS.InsecureSkipVerify,
Timeout: cfg.DialTimeout,
})

resolver, err := reversetunnel.CachingResolver(ctx, resolver, nil /* clock */)
if err != nil {
directDialErr := trace.Wrap(err, "failed direct dial to auth server: %v", err)
if cfg.SSH == nil {
// No identity file was provided, don't try dialing via a reverse
// tunnel on the proxy.
return nil, trace.Wrap(directDialErr)
}

// If direct dial failed, we may have a proxy address in
// cfg.AuthServers. Try connecting to the reverse tunnel
// endpoint and make a client over that.
//
// TODO(nic): this logic should be implemented once and reused in IoT
// nodes.

resolver := reversetunnel.WebClientResolver(&webclient.Config{
Context: ctx,
ProxyAddr: cfg.AuthServers[0].String(),
Insecure: cfg.TLS.InsecureSkipVerify,
Timeout: cfg.DialTimeout,
})

resolver, err = reversetunnel.CachingResolver(ctx, resolver, nil /* clock */)
if err != nil {
return nil, trace.Wrap(err)
}

// reversetunnel.TunnelAuthDialer will take care of creating a net.Conn
// within an SSH tunnel.
dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{
Resolver: resolver,
ClientConfig: cfg.SSH,
Log: cfg.Log,
InsecureSkipTLSVerify: cfg.TLS.InsecureSkipVerify,
})
if err != nil {
return nil, trace.Wrap(err)
}
client, err = auth.NewClient(apiclient.Config{
Dialer: dialer,
Credentials: []apiclient.Credentials{
apiclient.LoadTLS(cfg.TLS),
},
})
if err != nil {
tunnelClientErr := trace.Wrap(err, "failed dial to auth server through reverse tunnel: %v", err)
return nil, trace.NewAggregate(directDialErr, tunnelClientErr)
}
// Check connectivity by calling something on the client.
if _, err := client.GetClusterName(); err != nil {
tunnelClientErr := trace.Wrap(err, "failed dial to auth server through reverse tunnel: %v", err)
return nil, trace.NewAggregate(directDialErr, tunnelClientErr)
}
return nil, trace.Wrap(err)
}

// reversetunnel.TunnelAuthDialer will take care of creating a net.Conn
// within an SSH tunnel.
dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{
Resolver: resolver,
ClientConfig: cfg.SSH,
Log: cfg.Log,
InsecureSkipTLSVerify: cfg.TLS.InsecureSkipVerify,
})
if err != nil {
return nil, trace.Wrap(err)
}
tunnelClient, err := auth.NewClient(apiclient.Config{
Dialer: dialer,
Credentials: []apiclient.Credentials{
apiclient.LoadTLS(cfg.TLS),
},
})
if err != nil {
return nil, trace.Wrap(err)
}
// Check connectivity by calling something on the client.
if _, err := tunnelClient.GetClusterName(); err != nil {
// This client didn't work for us, so we close it.
_ = tunnelClient.Close()
return nil, trace.Wrap(err)
}
return client, nil
return tunnelClient, nil
}