Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use auth_servers when proxying an auth connection #13310

Merged
merged 9 commits into from
Jun 10, 2022
33 changes: 13 additions & 20 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import (
"golang.org/x/crypto/ssh"
)

func newlocalSite(srv *server, domainName string, client auth.ClientI, peerClient *proxy.Client) (*localSite, error) {
func newlocalSite(srv *server, domainName string, authServers []string, client auth.ClientI, peerClient *proxy.Client) (*localSite, error) {
err := utils.RegisterPrometheusCollectors(localClusterCollectors...)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -68,6 +68,7 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI, peerClien
accessPoint: accessPoint,
certificateCache: certificateCache,
domainName: domainName,
authServers: authServers,
remoteConns: make(map[connKey][]*remoteConn),
clock: srv.Clock,
log: log.WithFields(log.Fields{
Expand All @@ -91,9 +92,10 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI, peerClien
//
// it implements RemoteSite interface
type localSite struct {
log log.FieldLogger
domainName string
srv *server
log log.FieldLogger
domainName string
authServers []string
srv *server

// client provides access to the Auth Server API of the local cluster.
client auth.ClientI
Expand Down Expand Up @@ -164,27 +166,18 @@ func (s *localSite) GetLastConnected() time.Time {
return s.clock.Now()
}

func (s *localSite) DialAuthServer() (conn net.Conn, err error) {
// get list of local auth servers
authServers, err := s.client.GetAuthServers()
if err != nil {
return nil, trace.Wrap(err)
}

if len(authServers) < 1 {
func (s *localSite) DialAuthServer() (net.Conn, error) {
if len(s.authServers) == 0 {
return nil, trace.ConnectionProblem(nil, "no auth servers available")
}

// try and dial to one of them, as soon as we are successful, return the net.Conn
for _, authServer := range authServers {
espadolini marked this conversation as resolved.
Show resolved Hide resolved
conn, err = net.DialTimeout("tcp", authServer.GetAddr(), apidefaults.DefaultDialTimeout)
if err == nil {
return conn, nil
}
addr := utils.ChooseRandomString(s.authServers)
conn, err := net.DialTimeout("tcp", addr, apidefaults.DefaultDialTimeout)
if err != nil {
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
}

// return the last error
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
return conn, nil
}

func (s *localSite) Dial(params DialParams) (net.Conn, error) {
Expand Down
2 changes: 1 addition & 1 deletion lib/reversetunnel/localsite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestLocalSiteOverlap(t *testing.T) {
},
}

site, err := newlocalSite(srv, "clustername", &mockLocalSiteClient{}, nil)
site, err := newlocalSite(srv, "clustername", nil /* authServers */, &mockLocalSiteClient{}, nil /* peerClient */)
require.NoError(t, err)

nodeID := uuid.NewString()
Expand Down
24 changes: 5 additions & 19 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,6 @@ type server struct {
offlineThreshold time.Duration
}

// DirectCluster is used to access cluster directly
type DirectCluster struct {
// Name is a cluster name
Name string
// Client is a client to the cluster
Client auth.ClientI
}

// Config is a reverse tunnel server configuration
type Config struct {
// ID is the ID of this server proxy
Expand All @@ -151,8 +143,6 @@ type Config struct {
// NewCachingAccessPoint returns new caching access points
// per remote cluster
NewCachingAccessPoint auth.NewRemoteProxyCachingAccessPoint
// DirectClusters is a list of clusters accessed directly
DirectClusters []DirectCluster
// Context is a signalling context
Context context.Context
// Clock is a clock used in the server, set up to
Expand Down Expand Up @@ -315,8 +305,6 @@ func NewServer(cfg Config) (Server, error) {

srv := &server{
Config: cfg,
localSites: []*localSite{},
remoteSites: []*remoteSite{},
localAuthClient: cfg.LocalAuthClient,
localAccessPoint: cfg.LocalAccessPoint,
newAccessPoint: cfg.NewCachingAccessPoint,
Expand All @@ -329,15 +317,13 @@ func NewServer(cfg Config) (Server, error) {
offlineThreshold: offlineThreshold,
}

for _, clusterInfo := range cfg.DirectClusters {
cluster, err := newlocalSite(srv, clusterInfo.Name, clusterInfo.Client, srv.PeerClient)
if err != nil {
return nil, trace.Wrap(err)
}

srv.localSites = append(srv.localSites, cluster)
localSite, err := newlocalSite(srv, cfg.ClusterName, cfg.LocalAuthAddresses, cfg.LocalAuthClient, srv.PeerClient)
if err != nil {
return nil, trace.Wrap(err)
}

srv.localSites = append(srv.localSites, localSite)

s, err := sshutils.NewServer(
teleport.ComponentReverseTunnelServer,
// TODO(klizhentas): improve interface, use struct instead of parameter list
Expand Down
53 changes: 25 additions & 28 deletions lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ func (p *transport) start() {
return
}

var servers []string

// Parse and extract the dial request from the client.
dreq := parseDialReq(req.Payload)
if err := dreq.CheckAndSetDefaults(); err != nil {
Expand All @@ -207,18 +205,22 @@ func (p *transport) start() {
}
p.log.Debugf("Received out-of-band proxy transport request for %v [%v].", dreq.Address, dreq.ServerID)

// directAddress will hold the address of the node to dial to, if we don't
// have a tunnel for it.
var directAddress string

// Handle special non-resolvable addresses first.
switch dreq.Address {
// Connect to an Auth Server.
case RemoteAuthServer:
if len(p.authServers) <= 0 {
if len(p.authServers) == 0 {
p.log.Errorf("connection rejected: no auth servers configured")
p.reply(req, false, []byte("no auth servers configured"))

return
}

servers = p.authServers
directAddress = utils.ChooseRandomString(p.authServers)
// Connect to the Kubernetes proxy.
case LocalKubernetes:
switch p.component {
Expand Down Expand Up @@ -252,7 +254,7 @@ func (p *transport) start() {
return
}
p.log.Debugf("Forwarding connection to %q", p.kubeDialAddr.Addr)
servers = append(servers, p.kubeDialAddr.Addr)
directAddress = p.kubeDialAddr.Addr
}

// LocalNode requests are for the single server running in the agent pool.
Expand Down Expand Up @@ -283,15 +285,16 @@ func (p *transport) start() {
// If this is a proxy and not an SSH node, try finding an inbound
// tunnel from the SSH node by dreq.ServerID. We'll need to forward
// dreq.Address as well.
fallthrough
directAddress = dreq.Address
default:
servers = append(servers, dreq.Address)
// Not a special address; could be empty.
directAddress = dreq.Address
}

// Get a connection to the target address. If a tunnel exists with matching
// search names, connection over the tunnel is returned. Otherwise a direct
// net.Dial is performed.
conn, useTunnel, err := p.getConn(servers, dreq)
conn, useTunnel, err := p.getConn(directAddress, dreq)
if err != nil {
errorMessage := fmt.Sprintf("connection rejected: %v", err)
fmt.Fprint(p.channel.Stderr(), errorMessage)
Expand Down Expand Up @@ -368,7 +371,7 @@ func (p *transport) handleChannelRequests(closeContext context.Context, useTunne
// getConn checks if the local site holds a connection to the target host,
// and if it does, attempts to dial through the tunnel. Otherwise directly
// dials to host.
func (p *transport) getConn(servers []string, r *sshutils.DialReq) (net.Conn, bool, error) {
func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, error) {
// This function doesn't attempt to dial if a host with one of the
// search names is not registered. It's a fast check.
p.log.Debugf("Attempting to dial through tunnel with server ID %q.", r.ServerID)
Expand All @@ -388,13 +391,13 @@ func (p *transport) getConn(servers []string, r *sshutils.DialReq) (net.Conn, bo
}

errTun := err
p.log.Debugf("Attempting to dial directly %v.", servers)
conn, err = p.directDial(servers)
p.log.Debugf("Attempting to dial directly %q.", addr)
conn, err = p.directDial(addr)
if err != nil {
return nil, false, trace.ConnectionProblem(err, "failed dialing through tunnel (%v) or directly (%v)", errTun, err)
}

p.log.Debugf("Returning direct dialed connection to %v.", servers)
p.log.Debugf("Returning direct dialed connection to %q.", addr)
return conn, false, nil
}

Expand Down Expand Up @@ -438,24 +441,18 @@ func (p *transport) reply(req *ssh.Request, ok bool, msg []byte) {
}

// directDial attempts to directly dial to the target host.
func (p *transport) directDial(servers []string) (net.Conn, error) {
if len(servers) <= 0 {
return nil, trace.BadParameter("no servers to dial")
func (p *transport) directDial(addr string) (net.Conn, error) {
if addr == "" {
return nil, trace.BadParameter("no address to dial")
}

var errors []error
for _, addr := range servers {
dialer := net.Dialer{
Timeout: apidefaults.DefaultDialTimeout,
}

conn, err := dialer.DialContext(p.closeContext, "tcp", addr)
if err == nil {
return conn, nil
}

errors = append(errors, err)
d := net.Dialer{
Timeout: apidefaults.DefaultDialTimeout,
}
conn, err := d.DialContext(p.closeContext, "tcp", addr)
if err != nil {
return nil, trace.Wrap(err)
}

return nil, trace.NewAggregate(errors...)
return conn, nil
}
40 changes: 17 additions & 23 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3085,27 +3085,21 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
NewCachingAccessPoint: process.newLocalCacheForRemoteProxy,
NewCachingAccessPointOldProxy: process.newLocalCacheForOldRemoteProxy,
Limiter: reverseTunnelLimiter,
DirectClusters: []reversetunnel.DirectCluster{
{
Name: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority],
Client: conn.Client,
},
},
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
PeerClient: peerClient,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
CircuitBreakerConfig: process.Config.CircuitBreakerConfig,
LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServers),
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
PeerClient: peerClient,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
CircuitBreakerConfig: process.Config.CircuitBreakerConfig,
LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServers),
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -3297,7 +3291,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
AuthClient: conn.Client,
AccessPoint: accessPoint,
HostSigner: conn.ServerIdentity.KeySigner,
LocalCluster: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority],
LocalCluster: clusterName,
KubeDialAddr: utils.DialAddrFromListenAddr(kubeDialAddr(cfg.Proxy, clusterNetworkConfig.GetProxyListenerMode())),
ReverseTunnelServer: tsrv,
FIPS: process.Config.FIPS,
Expand Down Expand Up @@ -3498,7 +3492,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

var alpnServer *alpnproxy.Proxy
if !cfg.Proxy.DisableTLS && !cfg.Proxy.DisableALPNSNIListener && listeners.web != nil {
authDialerService := alpnproxyauth.NewAuthProxyDialerService(tsrv, accessPoint)
authDialerService := alpnproxyauth.NewAuthProxyDialerService(tsrv, clusterName, utils.NetAddrsToStrings(process.Config.AuthServers))
alpnRouter.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByALPNPrefix(string(alpncommon.ProtocolAuth)),
HandlerWithConnInfo: authDialerService.HandleConnection,
Expand Down
Loading