From e385c9bf4670fe765d48082a9e461f78afe8910c Mon Sep 17 00:00:00 2001 From: David Date: Mon, 11 Apr 2022 08:54:18 -0600 Subject: [PATCH 1/5] Prevent blocking forever when transport channel fails to open --- api/utils/sshutils/conn.go | 4 +- api/utils/sshutils/conn_test.go | 156 ++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 api/utils/sshutils/conn_test.go diff --git a/api/utils/sshutils/conn.go b/api/utils/sshutils/conn.go index bee97420e4a1a..a4abec1fbf8ec 100644 --- a/api/utils/sshutils/conn.go +++ b/api/utils/sshutils/conn.go @@ -59,10 +59,12 @@ func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (*ChCon channel, discard, err := sconn.OpenChannel(constants.ChanTransport, nil) if err != nil { - ssh.DiscardRequests(discard) return nil, false, trace.Wrap(err) } + // DiscardRequests will return when the channel or underlying connection is closed. + go ssh.DiscardRequests(discard) + // Send a special SSH out-of-band request called "teleport-transport" // the agent on the other side will create a new TCP/IP connection to // 'addr' on its network and will start proxying that connection over diff --git a/api/utils/sshutils/conn_test.go b/api/utils/sshutils/conn_test.go new file mode 100644 index 0000000000000..eb61f9be5fa9b --- /dev/null +++ b/api/utils/sshutils/conn_test.go @@ -0,0 +1,156 @@ +package sshutils + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "net" + "sync" + "testing" + "time" + + "github.com/gravitational/teleport/api/constants" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +type server struct { + listener net.Listener + config *ssh.ServerConfig + handler func(*ssh.ServerConn) + t *testing.T + mu sync.RWMutex + closed bool + + cSigner ssh.Signer + hSigner ssh.Signer +} + +func (s *server) Run() { + for { + conn, err := s.listener.Accept() + + s.mu.RLock() + if s.closed { + s.mu.RUnlock() + return + } + s.mu.RUnlock() + + require.NoError(s.t, err) + + go func() { + defer conn.Close() + sconn, _, _, err := ssh.NewServerConn(conn, s.config) + require.NoError(s.t, err) + s.handler(sconn) + }() + } +} + +func (s *server) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true + return s.listener.Close() +} + +func generateSigner(t *testing.T) ssh.Signer { + private, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(private), + } + + privatePEM := pem.EncodeToMemory(block) + signer, err := ssh.ParsePrivateKey(privatePEM) + require.NoError(t, err) + + return signer +} + +func (s *server) GetClient() (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) { + conn, err := net.Dial("tcp", s.listener.Addr().String()) + require.NoError(s.t, err) + + sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ + Auth: []ssh.AuthMethod{ssh.PublicKeys(s.cSigner)}, + HostKeyCallback: ssh.FixedHostKey(s.hSigner.PublicKey()), + }) + require.NoError(s.t, err) + + return sconn, nc, r +} + +func newServer(t *testing.T, handler func(*ssh.ServerConn)) *server { + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + cSigner := generateSigner(t) + hSigner := generateSigner(t) + + config := &ssh.ServerConfig{ + NoClientAuth: true, + } + config.AddHostKey(hSigner) + + return &server{ + listener: listener, + config: config, + handler: handler, + t: t, + cSigner: cSigner, + hSigner: hSigner, + } +} + +// TestTransportError ensures ConnectProxyTransport does not block forever +// when an error occurs while opening the transport channel. +func TestTransportError(t *testing.T) { + errC := make(chan error) + + server := newServer(t, func(sconn *ssh.ServerConn) { + _, _, err := ConnectProxyTransport(sconn, &DialReq{ + Address: "test", ServerID: "test", + }, false) + errC <- err + }) + + go server.Run() + defer server.Stop() + + sconn, nc, _ := server.GetClient() + defer sconn.Close() + channel := <-nc + require.Equal(t, channel.ChannelType(), constants.ChanTransport) + + sconn.Close() + err := timeoutErrC(t, errC, time.Second*5) + require.Error(t, err) + + sconn, nc, _ = server.GetClient() + defer sconn.Close() + channel = <-nc + require.Equal(t, channel.ChannelType(), constants.ChanTransport) + + err = channel.Reject(ssh.ConnectionFailed, "test reject") + require.NoError(t, err) + + err = timeoutErrC(t, errC, time.Second*5) + require.Error(t, err) +} + +func timeoutErrC(t *testing.T, errC <-chan error, d time.Duration) error { + timeout := time.NewTimer(d) + select { + case err := <-errC: + return err + case <-timeout.C: + require.FailNow(t, "failed to receive on err channel in time") + } + + return nil +} From 091d25541e77fd2872b05445fc25531d0e0f6d4d Mon Sep 17 00:00:00 2001 From: David Date: Mon, 11 Apr 2022 10:46:52 -0600 Subject: [PATCH 2/5] Add license to conn_test.go --- api/utils/sshutils/conn_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/api/utils/sshutils/conn_test.go b/api/utils/sshutils/conn_test.go index eb61f9be5fa9b..7a7535d8d3c3e 100644 --- a/api/utils/sshutils/conn_test.go +++ b/api/utils/sshutils/conn_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package sshutils import ( From cbcdbc5c4f6e5a5e634ff768d2c856f63a7e6d78 Mon Sep 17 00:00:00 2001 From: David Date: Thu, 14 Apr 2022 14:58:52 -0600 Subject: [PATCH 3/5] Fix tests use assert over require in goroutines` --- api/utils/sshutils/conn_test.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/api/utils/sshutils/conn_test.go b/api/utils/sshutils/conn_test.go index 7a7535d8d3c3e..5cd1abd84f8aa 100644 --- a/api/utils/sshutils/conn_test.go +++ b/api/utils/sshutils/conn_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/gravitational/teleport/api/constants" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) @@ -54,12 +55,12 @@ func (s *server) Run() { } s.mu.RUnlock() - require.NoError(s.t, err) + assert.NoError(s.t, err) go func() { defer conn.Close() sconn, _, _, err := ssh.NewServerConn(conn, s.config) - require.NoError(s.t, err) + assert.NoError(s.t, err) s.handler(sconn) }() } @@ -136,19 +137,19 @@ func TestTransportError(t *testing.T) { }) go server.Run() - defer server.Stop() + t.Cleanup(func() { require.NoError(t, server.Stop()) }) - sconn, nc, _ := server.GetClient() - defer sconn.Close() + sconn1, nc, _ := server.GetClient() + t.Cleanup(func() { require.Error(t, sconn1.Close()) }) channel := <-nc require.Equal(t, channel.ChannelType(), constants.ChanTransport) - sconn.Close() + sconn1.Close() err := timeoutErrC(t, errC, time.Second*5) require.Error(t, err) - sconn, nc, _ = server.GetClient() - defer sconn.Close() + sconn2, nc, _ := server.GetClient() + t.Cleanup(func() { require.NoError(t, sconn2.Close()) }) channel = <-nc require.Equal(t, channel.ChannelType(), constants.ChanTransport) From a59074c1320f721943e2acd3ddc04117825547de Mon Sep 17 00:00:00 2001 From: David Date: Thu, 14 Apr 2022 17:11:06 -0600 Subject: [PATCH 4/5] Remove use of assert and require in goroutines --- api/utils/sshutils/conn_test.go | 54 ++++++++++++++++----------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/api/utils/sshutils/conn_test.go b/api/utils/sshutils/conn_test.go index 5cd1abd84f8aa..7b73e8d7662c2 100644 --- a/api/utils/sshutils/conn_test.go +++ b/api/utils/sshutils/conn_test.go @@ -22,12 +22,10 @@ import ( "crypto/x509" "encoding/pem" "net" - "sync" "testing" "time" "github.com/gravitational/teleport/api/constants" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) @@ -36,40 +34,32 @@ type server struct { listener net.Listener config *ssh.ServerConfig handler func(*ssh.ServerConn) - t *testing.T - mu sync.RWMutex - closed bool cSigner ssh.Signer hSigner ssh.Signer } -func (s *server) Run() { +func (s *server) Run(errC chan error) { for { conn, err := s.listener.Accept() - - s.mu.RLock() - if s.closed { - s.mu.RUnlock() + if err != nil { + errC <- err return } - s.mu.RUnlock() - - assert.NoError(s.t, err) go func() { defer conn.Close() sconn, _, _, err := ssh.NewServerConn(conn, s.config) - assert.NoError(s.t, err) + if err != nil { + errC <- err + return + } s.handler(sconn) }() } } func (s *server) Stop() error { - s.mu.Lock() - defer s.mu.Unlock() - s.closed = true return s.listener.Close() } @@ -89,15 +79,15 @@ func generateSigner(t *testing.T) ssh.Signer { return signer } -func (s *server) GetClient() (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) { +func (s *server) GetClient(t *testing.T) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) { conn, err := net.Dial("tcp", s.listener.Addr().String()) - require.NoError(s.t, err) + require.NoError(t, err) sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ Auth: []ssh.AuthMethod{ssh.PublicKeys(s.cSigner)}, HostKeyCallback: ssh.FixedHostKey(s.hSigner.PublicKey()), }) - require.NoError(s.t, err) + require.NoError(t, err) return sconn, nc, r } @@ -118,7 +108,6 @@ func newServer(t *testing.T, handler func(*ssh.ServerConn)) *server { listener: listener, config: config, handler: handler, - t: t, cSigner: cSigner, hSigner: hSigner, } @@ -127,37 +116,46 @@ func newServer(t *testing.T, handler func(*ssh.ServerConn)) *server { // TestTransportError ensures ConnectProxyTransport does not block forever // when an error occurs while opening the transport channel. func TestTransportError(t *testing.T) { - errC := make(chan error) + handlerErrC := make(chan error, 1) + serverErrC := make(chan error, 1) server := newServer(t, func(sconn *ssh.ServerConn) { _, _, err := ConnectProxyTransport(sconn, &DialReq{ Address: "test", ServerID: "test", }, false) - errC <- err + handlerErrC <- err }) - go server.Run() + go server.Run(serverErrC) t.Cleanup(func() { require.NoError(t, server.Stop()) }) - sconn1, nc, _ := server.GetClient() + sconn1, nc, _ := server.GetClient(t) t.Cleanup(func() { require.Error(t, sconn1.Close()) }) + channel := <-nc require.Equal(t, channel.ChannelType(), constants.ChanTransport) sconn1.Close() - err := timeoutErrC(t, errC, time.Second*5) + err := timeoutErrC(t, handlerErrC, time.Second*5) require.Error(t, err) - sconn2, nc, _ := server.GetClient() + sconn2, nc, _ := server.GetClient(t) t.Cleanup(func() { require.NoError(t, sconn2.Close()) }) + channel = <-nc require.Equal(t, channel.ChannelType(), constants.ChanTransport) err = channel.Reject(ssh.ConnectionFailed, "test reject") require.NoError(t, err) - err = timeoutErrC(t, errC, time.Second*5) + err = timeoutErrC(t, handlerErrC, time.Second*5) require.Error(t, err) + + select { + case err = <-serverErrC: + require.FailNow(t, err.Error()) + default: + } } func timeoutErrC(t *testing.T, errC <-chan error, d time.Duration) error { From 1d114354b4d51d31f6e98fec20680e12ecbddc26 Mon Sep 17 00:00:00 2001 From: David Date: Mon, 18 Apr 2022 14:33:24 -0600 Subject: [PATCH 5/5] Close connections in main goroutine to prevent flaky test --- api/utils/sshutils/conn_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/api/utils/sshutils/conn_test.go b/api/utils/sshutils/conn_test.go index 7b73e8d7662c2..a644d34c53d1c 100644 --- a/api/utils/sshutils/conn_test.go +++ b/api/utils/sshutils/conn_test.go @@ -48,7 +48,6 @@ func (s *server) Run(errC chan error) { } go func() { - defer conn.Close() sconn, _, _, err := ssh.NewServerConn(conn, s.config) if err != nil { errC <- err