diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 395f762d39a58..555a5d6760c23 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -79,7 +79,8 @@ type LocalProxyConfig struct { // LocalProxyMiddleware provides callback functions for LocalProxy. type LocalProxyMiddleware interface { // OnNewConnection is a callback triggered when a new downstream connection is - // accepted by the local proxy. + // accepted by the local proxy. If an error is returned, the connection will be closed + // by the local proxy. OnNewConnection(ctx context.Context, lp *LocalProxy, conn net.Conn) error // OnStart is a callback triggered when the local proxy starts. OnStart(ctx context.Context, lp *LocalProxy) error @@ -149,7 +150,10 @@ func (l *LocalProxy) Start(ctx context.Context) error { if l.cfg.Middleware != nil { if err := l.cfg.Middleware.OnNewConnection(ctx, l, conn); err != nil { - log.WithError(err).Errorf("Middleware failed to handle new connection.") + log.WithError(err).Error("Middleware failed to handle client connection.") + if err := conn.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) { + log.WithError(err).Debug("Failed to close client connection.") + } continue } } diff --git a/lib/srv/alpnproxy/local_proxy_test.go b/lib/srv/alpnproxy/local_proxy_test.go index c3b38c70f8e85..460f641dda73a 100644 --- a/lib/srv/alpnproxy/local_proxy_test.go +++ b/lib/srv/alpnproxy/local_proxy_test.go @@ -19,6 +19,7 @@ package alpnproxy import ( "bytes" "context" + "io" "net" "net/http" "net/http/httptest" @@ -32,6 +33,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/s3" + "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -246,6 +248,51 @@ func TestMiddleware(t *testing.T) { m.waitForCounts(t, 1, 1) } +type mockMiddlewareConnUnauth struct { +} + +func (m *mockMiddlewareConnUnauth) OnNewConnection(_ context.Context, _ *LocalProxy, _ net.Conn) error { + return trace.AccessDenied("access denied.") +} + +func (m *mockMiddlewareConnUnauth) OnStart(_ context.Context, _ *LocalProxy) error { + return nil +} + +var _ LocalProxyMiddleware = (*mockMiddlewareConnUnauth)(nil) + +func TestLocalProxyClosesConnOnError(t *testing.T) { + hs := httptest.NewTLSServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})) + lp, err := NewLocalProxy(LocalProxyConfig{ + Listener: mustCreateLocalListener(t), + RemoteProxyAddr: hs.Listener.Addr().String(), + Protocols: []common.Protocol{common.ProtocolHTTP}, + ParentContext: context.Background(), + InsecureSkipVerify: true, + Middleware: &mockMiddlewareConnUnauth{}, + }) + require.NoError(t, err) + t.Cleanup(func() { + err := lp.Close() + require.NoError(t, err) + hs.Close() + }) + go func() { + assert.NoError(t, lp.Start(context.Background())) + }() + + conn, err := net.Dial("tcp", lp.GetAddr()) + require.NoError(t, err) + + // set a read deadline so that if the connection is not closed, + // this test will fail quickly instead of hanging. + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + buf := make([]byte, 512) + _, err = conn.Read(buf) + require.Error(t, err) + require.ErrorIs(t, err, io.EOF) +} + func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *LocalProxy { hs := httptest.NewTLSServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}))