From b27e88e32620dfd61474dae420ee151ed12e6b69 Mon Sep 17 00:00:00 2001 From: Andrew Burke <31974658+atburke@users.noreply.github.com> Date: Wed, 8 Nov 2023 09:01:05 -0800 Subject: [PATCH] Deflake HTTP_PROXY tests (#33614) This change rewrites a few HTTP_PROXY tests to be less flaky. --- api/client/webclient/webclient_test.go | 83 ++++++++----- api/testhelpers/proxy.go | 157 +++++++++++++++++++++++++ integration/helpers/helpers.go | 23 ---- integration/integration_test.go | 3 +- integration/proxy/proxy_test.go | 9 +- lib/client/https_client_test.go | 90 +++++++------- tool/tsh/resolve_default_addr_test.go | 47 +++----- 7 files changed, 278 insertions(+), 134 deletions(-) create mode 100644 api/testhelpers/proxy.go diff --git a/api/client/webclient/webclient_test.go b/api/client/webclient/webclient_test.go index 8acf0ca690eff..567ff44b4f410 100644 --- a/api/client/webclient/webclient_test.go +++ b/api/client/webclient/webclient_test.go @@ -31,6 +31,7 @@ import ( "golang.org/x/exp/slices" "github.com/gravitational/teleport/api/defaults" + apihelpers "github.com/gravitational/teleport/api/testhelpers" ) func newPingHandler(path string) http.Handler { @@ -324,36 +325,62 @@ func TestParse(t *testing.T) { } } -func TestNewWebClientRespectHTTPProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - client, err := newWebClient(&Config{ - Context: context.Background(), - ProxyAddr: "localhost:3080", - }) - require.NoError(t, err) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - // Client should try to proxy through nonexistent server at localhost. - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.Contains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakeproxy.example.com") - require.Contains(t, err.Error(), "no such host") -} +func TestNewWebClientHTTPProxy(t *testing.T) { + proxyHandler := &apihelpers.ProxyHandler{} + proxyServer := httptest.NewServer(proxyHandler) + t.Cleanup(proxyServer.Close) -func TestNewWebClientNoProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - t.Setenv("NO_PROXY", "fakedomain.example.com") - client, err := newWebClient(&Config{ - Context: context.Background(), - ProxyAddr: "localhost:3080", - }) + localIP, err := apihelpers.GetLocalIP() require.NoError(t, err) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.NotContains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakedomain.example.com") - require.Contains(t, err.Error(), "no such host") + server := apihelpers.MakeTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello")) + }), apihelpers.WithTestServerAddress(localIP)) + _, serverPort, err := net.SplitHostPort(server.Listener.Addr().String()) + require.NoError(t, err) + serverAddr := net.JoinHostPort(localIP, serverPort) + tests := []struct { + name string + env map[string]string + expectedProxyCount int + }{ + { + name: "use http proxy", + env: map[string]string{ + "HTTPS_PROXY": proxyServer.URL, + }, + expectedProxyCount: 1, + }, + { + name: "ignore proxy when no_proxy is set", + env: map[string]string{ + "HTTPS_PROXY": proxyServer.URL, + "NO_PROXY": "*", + }, + expectedProxyCount: 0, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(proxyHandler.Reset) + for k, v := range tc.env { + t.Setenv(k, v) + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + client, err := newWebClient(&Config{ + Context: ctx, + ProxyAddr: "localhost:3080", // addr doesn't matter, it won't be used + Insecure: true, + }) + require.NoError(t, err) + + resp, err := client.Get("https://" + serverAddr) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, tc.expectedProxyCount, proxyHandler.Count()) + }) + } } func TestSSHProxyHostPort(t *testing.T) { diff --git a/api/testhelpers/proxy.go b/api/testhelpers/proxy.go new file mode 100644 index 0000000000000..3a39d561758a9 --- /dev/null +++ b/api/testhelpers/proxy.go @@ -0,0 +1,157 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testhelpers + +import ( + "io" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +// ProxyHandler is a http.Handler that implements a simple HTTP proxy server. +type ProxyHandler struct { + sync.Mutex + count int +} + +// ServeHTTP only accepts the CONNECT verb and will tunnel your connection to +// the specified host. Also tracks the number of connections that it proxies for +// debugging purposes. +func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Validate http connect parameters. + if r.Method != http.MethodConnect { + trace.WriteError(w, trace.BadParameter("%v not supported", r.Method)) + return + } + if r.Host == "" { + trace.WriteError(w, trace.BadParameter("host not set")) + return + } + + // Dial to the target host, this is done before hijacking the connection to + // ensure the target host is accessible. + dialer := net.Dialer{} + dconn, err := dialer.DialContext(r.Context(), "tcp", r.Host) + if err != nil { + trace.WriteError(w, err) + return + } + defer dconn.Close() + + // Once the client receives 200 OK, the rest of the data will no longer be + // http, but whatever protocol is being tunneled. + w.WriteHeader(http.StatusOK) + + // Hijack request so we can get underlying connection. + hj, ok := w.(http.Hijacker) + if !ok { + trace.WriteError(w, trace.AccessDenied("unable to hijack connection")) + return + } + sconn, buf, err := hj.Hijack() + if err != nil { + trace.WriteError(w, err) + return + } + defer sconn.Close() + + // Success, we're proxying data now. + p.Lock() + p.count++ + p.Unlock() + + // Copy from src to dst and dst to src. + errc := make(chan error, 2) + replicate := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + go replicate(sconn, dconn) + go replicate(dconn, io.MultiReader(buf, sconn)) + + // Wait until done, error, or 10 second. + select { + case <-time.After(10 * time.Second): + case <-errc: + } +} + +// Count returns the number of requests that have been proxied. +func (p *ProxyHandler) Count() int { + p.Lock() + defer p.Unlock() + return p.count +} + +// Reset sets the counter for proxied requests to zero. +func (p *ProxyHandler) Reset() { + p.Lock() + defer p.Unlock() + p.count = 0 +} + +// GetLocalIP gets the non-loopback IP address of this host. +func GetLocalIP() (string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", trace.Wrap(err) + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + if !ip.IsLoopback() && ip.IsPrivate() { + return ip.String(), nil + } + } + return "", trace.NotFound("No non-loopback local IP address found") +} + +type TestServerOption func(*testing.T, *httptest.Server) + +func WithTestServerAddress(ip string) TestServerOption { + return func(t *testing.T, srv *httptest.Server) { + // Replace the test server's address. + _, originalPort, err := net.SplitHostPort(srv.Listener.Addr().String()) + require.NoError(t, err) + require.NoError(t, srv.Listener.Close()) + l, err := net.Listen("tcp", net.JoinHostPort(ip, originalPort)) + require.NoError(t, err) + srv.Listener = l + } +} + +func MakeTestServer(t *testing.T, h http.Handler, opts ...TestServerOption) *httptest.Server { + svr := httptest.NewUnstartedServer(h) + for _, opt := range opts { + opt(t, svr) + } + svr.StartTLS() + t.Cleanup(svr.Close) + return svr +} diff --git a/integration/helpers/helpers.go b/integration/helpers/helpers.go index 81f35ad2280ed..3376abddabd37 100644 --- a/integration/helpers/helpers.go +++ b/integration/helpers/helpers.go @@ -177,29 +177,6 @@ func CloseAgent(teleAgent *teleagent.AgentServer, socketDirPath string) error { return nil } -// GetLocalIP gets the non-loopback IP address of this host. -func GetLocalIP() (string, error) { - addrs, err := net.InterfaceAddrs() - if err != nil { - return "", trace.Wrap(err) - } - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - default: - continue - } - if !ip.IsLoopback() && ip.IsPrivate() { - return ip.String(), nil - } - } - return "", trace.NotFound("No non-loopback local IP address found") -} - func MustCreateUserIdentityFile(t *testing.T, tc *TeleInstance, username string, ttl time.Duration) string { key, err := libclient.GenerateRSAKey() require.NoError(t, err) diff --git a/integration/integration_test.go b/integration/integration_test.go index 8f3cd06df1dc1..337f7e48921f8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -62,6 +62,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/profile" + apihelpers "github.com/gravitational/teleport/api/testhelpers" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" @@ -2368,7 +2369,7 @@ func testTwoClustersProxy(t *testing.T, suite *integrationTestSuite) { // httpproxy doesn't allow proxying when the target is localhost, so use // this address instead. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) a := suite.newNamedTeleportInstance(t, "site-A", WithNodeName(addr), diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 34f86ea8b155a..bdb7ea67fb2ca 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client" + apihelpers "github.com/gravitational/teleport/api/testhelpers" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integration/appaccess" dbhelpers "github.com/gravitational/teleport/integration/db" @@ -228,7 +229,7 @@ func TestALPNSNIHTTPSProxy(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) suite := newSuite(t, @@ -269,7 +270,7 @@ func TestMultiPortHTTPSProxy(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) suite := newSuite(t, @@ -1211,7 +1212,7 @@ func TestALPNProxyHTTPProxyNoProxyDial(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) instanceCfg := helpers.InstanceConfig{ @@ -1290,7 +1291,7 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - rcAddr, err := helpers.GetLocalIP() + rcAddr, err := apihelpers.GetLocalIP() require.NoError(t, err) log.Info("Creating Teleport instance...") diff --git a/lib/client/https_client_test.go b/lib/client/https_client_test.go index c37ad5207610a..37da8addc0681 100644 --- a/lib/client/https_client_test.go +++ b/lib/client/https_client_test.go @@ -17,55 +17,55 @@ limitations under the License. package client import ( + "net/http" + "net/url" "testing" "github.com/stretchr/testify/require" ) -func TestNewInsecureWebClientHTTPProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - client := NewInsecureWebClient() - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - // Client should try to proxy through nonexistent server at localhost. - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.Contains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakeproxy.example.com") - require.Contains(t, err.Error(), "no such host") -} - -func TestNewInsecureWebClientNoProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - t.Setenv("NO_PROXY", "fakedomain.example.com") - client := NewInsecureWebClient() - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.NotContains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakedomain.example.com") - require.Contains(t, err.Error(), "no such host") -} - -func TestNewSecureWebClientHTTPProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - client := newClient(false, nil, nil) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - // Client should try to proxy through nonexistent server at localhost. - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.Contains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakeproxy.example.com") - require.Contains(t, err.Error(), "no such host") -} +func TestHTTPTransportProxy(t *testing.T) { + proxyURL := "proxy.example.com" + target := "target.example.com" + tests := []struct { + name string + env map[string]string + expectedProxyURL string + }{ + { + name: "use http proxy", + env: map[string]string{ + "HTTPS_PROXY": proxyURL, + }, + expectedProxyURL: "http://" + proxyURL, + }, + { + name: "ignore proxy when no_proxy is set", + env: map[string]string{ + "HTTPS_PROXY": proxyURL, + "NO_PROXY": target, + }, + expectedProxyURL: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + for k, v := range tc.env { + t.Setenv(k, v) + } -func TestNewSecureWebClientNoProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - t.Setenv("NO_PROXY", "fakedomain.example.com") - client := newClient(false, nil, nil) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.NotContains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakedomain.example.com") - require.Contains(t, err.Error(), "no such host") + inputURL, err := url.Parse("https://" + target) + require.NoError(t, err) + outputURL, err := httpTransport(false, nil).Proxy(&http.Request{ + URL: inputURL, + }) + require.NoError(t, err) + if tc.expectedProxyURL != "" { + require.NotNil(t, outputURL) + require.Equal(t, tc.expectedProxyURL, outputURL.String()) + } else { + require.Nil(t, outputURL) + } + }) + } } diff --git a/tool/tsh/resolve_default_addr_test.go b/tool/tsh/resolve_default_addr_test.go index 8d66374d006cb..6f7cdab528deb 100644 --- a/tool/tsh/resolve_default_addr_test.go +++ b/tool/tsh/resolve_default_addr_test.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "io" - "net" "net/http" "net/http/httptest" "net/url" @@ -31,6 +30,7 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" + apihelpers "github.com/gravitational/teleport/api/testhelpers" "github.com/gravitational/teleport/integration/helpers" ) @@ -77,18 +77,6 @@ func mustGetCandidatePorts(servers []*httptest.Server) []int { return result } -type testServerOption func(*httptest.Server) - -func makeTestServer(t *testing.T, h http.Handler, opts ...testServerOption) *httptest.Server { - svr := httptest.NewUnstartedServer(h) - for _, opt := range opts { - opt(svr) - } - svr.StartTLS() - t.Cleanup(func() { svr.Close() }) - return svr -} - func TestResolveDefaultAddr(t *testing.T) { t.Parallel() @@ -105,13 +93,13 @@ func TestResolveDefaultAddr(t *testing.T) { if i == magicServerIndex { handler = respondingHandler } - servers[i] = makeTestServer(t, handler) + servers[i] = apihelpers.MakeTestServer(t, handler) } // NB: We need to defer this channel close such that it happens *before* // the httpstest server shutdowns, or the blocking requests will never // finish and we will deadlock. - defer close(doneCh) + t.Cleanup(func() { close(doneCh) }) ports := mustGetCandidatePorts(servers) expectedAddr := fmt.Sprintf("127.0.0.1:%d", ports[magicServerIndex]) @@ -140,7 +128,7 @@ func TestResolveDefaultAddrSingleCandidate(t *testing.T) { servers := make([]*httptest.Server, 1) for i := 0; i < len(servers); i++ { - servers[i] = makeTestServer(t, respondingHandler) + servers[i] = apihelpers.MakeTestServer(t, respondingHandler) } ports := mustGetCandidatePorts(servers) @@ -162,13 +150,13 @@ func TestResolveDefaultAddrTimeout(t *testing.T) { servers := make([]*httptest.Server, 5) for i := 0; i < 5; i++ { - servers[i] = makeTestServer(t, blockingHandler) + servers[i] = apihelpers.MakeTestServer(t, blockingHandler) } // NB: We need to defer this channel close such that it happens *before* // the httpstest server shutdowns, or the blocking requests will never // finish and we will deadlock. - defer close(doneCh) + t.Cleanup(func() { close(doneCh) }) ports := mustGetCandidatePorts(servers) @@ -188,7 +176,7 @@ func TestResolveNonOKResponseIsAnError(t *testing.T) { // Given a single candidate server configured to respond with a non-OK status // code servers := []*httptest.Server{ - makeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)), + apihelpers.MakeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)), } ports := mustGetCandidatePorts(servers) @@ -229,7 +217,7 @@ func TestResolveUndeliveredBodyDoesNotBlockForever(t *testing.T) { testLog.Debug("Exiting handler") }) - servers := []*httptest.Server{makeTestServer(t, handler)} + servers := []*httptest.Server{apihelpers.MakeTestServer(t, handler)} ports := mustGetCandidatePorts(servers) // When I attempt to resolve a default address @@ -246,15 +234,15 @@ func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) { blockingHandler, doneCh := newWaitForeverHandler() - servers := make([]*httptest.Server, 1000) + servers := make([]*httptest.Server, 100) for i := 0; i < len(servers); i++ { - servers[i] = makeTestServer(t, blockingHandler) + servers[i] = apihelpers.MakeTestServer(t, blockingHandler) } // NB: We need to defer this channel close such that it happens *before* // the httpstest server shutdowns, or the blocking requests will never // finish and we will deadlock. - defer close(doneCh) + t.Cleanup(func() { close(doneCh) }) ports := mustGetCandidatePorts(servers) @@ -275,19 +263,12 @@ func TestResolveDefaultAddrHTTPProxy(t *testing.T) { t.Cleanup(proxyServer.Close) // Go won't proxy to localhost, so use this address instead. - localIP, err := helpers.GetLocalIP() + localIP, err := apihelpers.GetLocalIP() require.NoError(t, err) - var serverAddr net.Addr respondingHandler := newRespondingHandler() - server := makeTestServer(t, respondingHandler, func(srv *httptest.Server) { - // Replace the test server's address. - l, err := net.Listen("tcp", localIP+":0") - require.NoError(t, err) - require.NoError(t, srv.Listener.Close()) - srv.Listener = l - serverAddr = l.Addr() - }) + server := apihelpers.MakeTestServer(t, respondingHandler, apihelpers.WithTestServerAddress(localIP)) + serverAddr := server.Listener.Addr() ports := mustGetCandidatePorts([]*httptest.Server{server})