Skip to content

Commit

Permalink
Create remote site cache based on remote auth version
Browse files Browse the repository at this point in the history
The cache policy used for a remote site is determined based on
the response from a version request. However the version response
was only returning the proxy version. If the remote site was not
running the same version for both auth and proxy, then the cache
policy chosen could be invalid.

The reverse tunnel agent now pings its auth server and reports
both the auth and proxy version in response to a version request.
To maintain backward compatability the reverse tunnel server will
fallback to using the proxy version if the response does not
contain an auth version.

Fixes #12010
  • Loading branch information
rosstimothy committed Apr 20, 2022
1 parent 93364cf commit d66b9ea
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 25 deletions.
23 changes: 20 additions & 3 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package reversetunnel
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -354,16 +355,32 @@ func (a *Agent) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh.

switch r.Type {
case versionRequest:
err := r.Reply(true, []byte(teleport.Version))
response := versionResponse{
ProxyVersion: teleport.Version,
}

pong, err := a.Client.Ping(ctx)
if err != nil {
log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.log.WithError(err).Debugf("Failed to ping auth server.")
} else {
response.AuthVersion = pong.ServerVersion
}

payload, err := json.Marshal(response)
if err != nil {
a.log.WithError(err).Debugf("Failed to marshal version response")
payload = []byte(teleport.Version)
}

if err := r.Reply(true, payload); err != nil {
a.log.WithError(err).Debugf("Failed to reply to version request")
continue
}
default:
// This handles keep-alive messages and matches the behaviour of OpenSSH.
err := r.Reply(false, nil)
if err != nil {
log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
continue
}
}
Expand Down
68 changes: 46 additions & 22 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package reversetunnel
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -639,7 +640,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
// nodes it's a node dialing back.
val, ok := sconn.Permissions.Extensions[extCertRole]
if !ok {
log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.rejectRequest(nch, ssh.ConnectionFailed, "unknown role")
return
}
Expand All @@ -665,22 +666,22 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
s.handleNewService(role, conn, sconn, nch, types.WindowsDesktopTunnel)
// Unknown role.
default:
log.Errorf("Unsupported role attempting to connect: %v", val)
s.log.Errorf("Unsupported role attempting to connect: %v", val)
s.rejectRequest(nch, ssh.ConnectionFailed, fmt.Sprintf("unsupported role %v", val))
}
}

func (s *server) handleNewService(role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) {
cluster, rconn, err := s.upsertServiceConn(conn, sconn, connType)
if err != nil {
log.Errorf("Failed to upsert %s: %v.", role, err)
s.log.Errorf("Failed to upsert %s: %v.", role, err)
sconn.Close()
return
}

ch, req, err := nch.Accept()
if err != nil {
log.Errorf("Failed to accept on channel: %v.", err)
s.log.Errorf("Failed to accept on channel: %v.", err)
sconn.Close()
return
}
Expand All @@ -692,14 +693,14 @@ func (s *server) handleNewCluster(conn net.Conn, sshConn *ssh.ServerConn, nch ss
// add the incoming site (cluster) to the list of active connections:
site, remoteConn, err := s.upsertRemoteCluster(conn, sshConn)
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
s.rejectRequest(nch, ssh.ConnectionFailed, "failed to accept incoming cluster connection")
return
}
// accept the request and start the heartbeat on it:
ch, req, err := nch.Accept()
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
sshConn.Close()
return
}
Expand Down Expand Up @@ -1067,12 +1068,12 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
// (RFD 28) because older proxy servers will reject that causing the cache
// to go into a re-sync loop.
var accessPointFunc auth.NewRemoteProxyCachingAccessPoint
ok, err := isPreV8Cluster(closeContext, sconn)
ok, version, err := isPreV8Cluster(closeContext, sconn)
if err != nil {
return nil, trace.Wrap(err)
}
if ok {
log.Debugf("Pre-v8 cluster connecting, loading old cache policy.")
srv.log.Debugf("cluster %q running %q connecting, loading old cache policy.", domainName, version)
accessPointFunc = srv.Config.NewCachingAccessPointOldProxy
} else {
accessPointFunc = srv.newAccessPoint
Expand Down Expand Up @@ -1126,32 +1127,42 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
}

// isPreV8Cluster checks if the cluster is older than 8.0.0.
func isPreV8Cluster(ctx context.Context, conn ssh.Conn) (bool, error) {
version, err := sendVersionRequest(ctx, conn)
func isPreV8Cluster(ctx context.Context, conn ssh.Conn) (bool, string, error) {
response, err := sendVersionRequest(ctx, conn)
if err != nil {
return false, trace.Wrap(err)
return false, "", trace.Wrap(err)
}

// If the AuthVersion wasn't provided it means that the
// remote side is still only reporting proxy version. In
// that case the only thing we can do is fallback to using
// the ProxyVersion and hope that the Auth server is running
// the same version
version := response.AuthVersion
if version == "" {
version = response.ProxyVersion
}

remoteClusterVersion, err := semver.NewVersion(version)
if err != nil {
return false, trace.Wrap(err)
return false, "", trace.Wrap(err)
}
minClusterVersion, err := semver.NewVersion(utils.VersionBeforeAlpha("8.0.0"))
if err != nil {
return false, trace.Wrap(err)
return false, "", trace.Wrap(err)
}
// Return true if the version is older than 8.0.0
if remoteClusterVersion.LessThan(*minClusterVersion) {
return true, nil
return true, version, nil
}

return false, nil
return false, version, nil
}

// sendVersionRequest sends a request for the version remote Teleport cluster.
func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (*versionResponse, error) {
errorCh := make(chan error, 1)
versionCh := make(chan string, 1)
versionCh := make(chan versionResponse, 1)

go func() {
ok, payload, err := sconn.SendRequest(versionRequest, true, nil)
Expand All @@ -1163,18 +1174,26 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
errorCh <- trace.BadParameter("no response to %v request", versionRequest)
return
}
versionCh <- string(payload)
var response versionResponse
if err := json.Unmarshal(payload, &response); err != nil {
// failure means that we are talking to an older version
// that only reports its proxy version
versionCh <- versionResponse{ProxyVersion: string(payload)}
return
}

versionCh <- response
}()

select {
case ver := <-versionCh:
return ver, nil
return &ver, nil
case err := <-errorCh:
return "", trace.Wrap(err)
return nil, trace.Wrap(err)
case <-time.After(defaults.WaitCopyTimeout):
return "", trace.BadParameter("timeout waiting for version")
return nil, trace.BadParameter("timeout waiting for version")
case <-ctx.Done():
return "", ctx.Err()
return nil, trace.Wrap(ctx.Err())
}
}

Expand All @@ -1185,3 +1204,8 @@ const (

versionRequest = "x-teleport-version"
)

type versionResponse struct {
ProxyVersion string `json:"proxy"`
AuthVersion string `json:"auth"`
}

0 comments on commit d66b9ea

Please sign in to comment.