diff --git a/lib/web/app/transport.go b/lib/web/app/transport.go index 04d29f4cd6d28..556032704031e 100644 --- a/lib/web/app/transport.go +++ b/lib/web/app/transport.go @@ -233,6 +233,10 @@ func (t *transport) rewriteRequest(r *http.Request) error { // tunnel subsystem. func (t *transport) DialContext(ctx context.Context, _, _ string) (conn net.Conn, err error) { t.mu.Lock() + if len(t.c.servers) == 0 { + defer t.mu.Unlock() + return nil, trace.ConnectionProblem(nil, "no application servers remaining to connect") + } servers := make([]types.AppServer, len(t.c.servers)) copy(servers, t.c.servers) t.mu.Unlock() @@ -256,7 +260,11 @@ func (t *transport) DialContext(ctx context.Context, _, _ string) (conn net.Conn // eliminate any servers from the head of the list that were unreachable t.mu.Lock() - t.c.servers = t.c.servers[i:] + if i < len(servers) { + t.c.servers = t.c.servers[i:] + } else { + t.c.servers = nil + } t.mu.Unlock() if conn != nil || err != nil { diff --git a/lib/web/app/transport_test.go b/lib/web/app/transport_test.go index 73dd7c2a654e7..47bbf92ef49dc 100644 --- a/lib/web/app/transport_test.go +++ b/lib/web/app/transport_test.go @@ -19,17 +19,22 @@ package app import ( + "context" "crypto/x509/pkix" + "errors" "fmt" + "net" "net/http" "testing" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -199,3 +204,56 @@ func Test_transport_rewriteRedirect(t *testing.T) { }) } } + +type fakeTunnel struct { + reversetunnelclient.Tunnel + + fakeSite *reversetunnelclient.FakeRemoteSite + err error +} + +func (f fakeTunnel) GetSite(domainName string) (reversetunnelclient.RemoteSite, error) { + return f.fakeSite, f.err +} + +func TestTransport_DialContextNoServersAvailable(t *testing.T) { + tp := transport{ + c: &transportConfig{ + proxyClient: fakeTunnel{ + err: trace.ConnectionProblem(errors.New(reversetunnelclient.NoApplicationTunnel), ""), + }, + identity: &tlsca.Identity{}, + servers: []types.AppServer{ + &types.AppServerV3{}, + &types.AppServerV3{}, + &types.AppServerV3{}, + }, + log: utils.NewLoggerForTests(), + }, + } + + ctx := context.Background() + type dialRes struct { + conn net.Conn + err error + } + + count := len(tp.c.servers) + 1 + resC := make(chan dialRes, count) + + for i := 0; i < count; i++ { + go func() { + conn, err := tp.DialContext(ctx, "", "") + resC <- dialRes{ + conn: conn, + err: err, + } + }() + } + + for i := 0; i < count; i++ { + res := <-resC + require.Error(t, res.err) + require.Nil(t, res.conn) + } +}