From dd3e3f4f24b57fabfe68ec864e2551b7d313b860 Mon Sep 17 00:00:00 2001 From: Andrew Burke Date: Tue, 17 Oct 2023 17:00:27 -0700 Subject: [PATCH] Deflake HTTP_PROXY tests This change rewrites a few HTTP_PROXY tests to be less flaky. --- api/client/webclient/webclient_test.go | 83 ++++++++++++------ api/testhelpers/proxy.go | 61 ++++++++++++- integration/helpers/helpers.go | 23 ----- integration/integration_test.go | 3 +- integration/proxy/proxy_test.go | 9 +- lib/client/https_client_test.go | 90 ++++++++++---------- tool/tsh/common/resolve_default_addr_test.go | 47 +++------- 7 files changed, 180 insertions(+), 136 deletions(-) diff --git a/api/client/webclient/webclient_test.go b/api/client/webclient/webclient_test.go index 8acf0ca690eff..567ff44b4f410 100644 --- a/api/client/webclient/webclient_test.go +++ b/api/client/webclient/webclient_test.go @@ -31,6 +31,7 @@ import ( "golang.org/x/exp/slices" "github.com/gravitational/teleport/api/defaults" + apihelpers "github.com/gravitational/teleport/api/testhelpers" ) func newPingHandler(path string) http.Handler { @@ -324,36 +325,62 @@ func TestParse(t *testing.T) { } } -func TestNewWebClientRespectHTTPProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - client, err := newWebClient(&Config{ - Context: context.Background(), - ProxyAddr: "localhost:3080", - }) - require.NoError(t, err) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - // Client should try to proxy through nonexistent server at localhost. - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.Contains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakeproxy.example.com") - require.Contains(t, err.Error(), "no such host") -} +func TestNewWebClientHTTPProxy(t *testing.T) { + proxyHandler := &apihelpers.ProxyHandler{} + proxyServer := httptest.NewServer(proxyHandler) + t.Cleanup(proxyServer.Close) -func TestNewWebClientNoProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - t.Setenv("NO_PROXY", "fakedomain.example.com") - client, err := newWebClient(&Config{ - Context: context.Background(), - ProxyAddr: "localhost:3080", - }) + localIP, err := apihelpers.GetLocalIP() require.NoError(t, err) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.NotContains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakedomain.example.com") - require.Contains(t, err.Error(), "no such host") + server := apihelpers.MakeTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello")) + }), apihelpers.WithTestServerAddress(localIP)) + _, serverPort, err := net.SplitHostPort(server.Listener.Addr().String()) + require.NoError(t, err) + serverAddr := net.JoinHostPort(localIP, serverPort) + tests := []struct { + name string + env map[string]string + expectedProxyCount int + }{ + { + name: "use http proxy", + env: map[string]string{ + "HTTPS_PROXY": proxyServer.URL, + }, + expectedProxyCount: 1, + }, + { + name: "ignore proxy when no_proxy is set", + env: map[string]string{ + "HTTPS_PROXY": proxyServer.URL, + "NO_PROXY": "*", + }, + expectedProxyCount: 0, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(proxyHandler.Reset) + for k, v := range tc.env { + t.Setenv(k, v) + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + client, err := newWebClient(&Config{ + Context: ctx, + ProxyAddr: "localhost:3080", // addr doesn't matter, it won't be used + Insecure: true, + }) + require.NoError(t, err) + + resp, err := client.Get("https://" + serverAddr) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, tc.expectedProxyCount, proxyHandler.Count()) + }) + } } func TestSSHProxyHostPort(t *testing.T) { diff --git a/api/testhelpers/proxy.go b/api/testhelpers/proxy.go index 4c85a10daf718..3a39d561758a9 100644 --- a/api/testhelpers/proxy.go +++ b/api/testhelpers/proxy.go @@ -18,10 +18,13 @@ import ( "io" "net" "net/http" + "net/http/httptest" "sync" + "testing" "time" "github.com/gravitational/trace" + "github.com/stretchr/testify/require" ) // ProxyHandler is a http.Handler that implements a simple HTTP proxy server. @@ -64,7 +67,7 @@ func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { trace.WriteError(w, trace.AccessDenied("unable to hijack connection")) return } - sconn, _, err := hj.Hijack() + sconn, buf, err := hj.Hijack() if err != nil { trace.WriteError(w, err) return @@ -83,7 +86,7 @@ func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { errc <- err } go replicate(sconn, dconn) - go replicate(dconn, sconn) + go replicate(dconn, io.MultiReader(buf, sconn)) // Wait until done, error, or 10 second. select { @@ -98,3 +101,57 @@ func (p *ProxyHandler) Count() int { defer p.Unlock() return p.count } + +// Reset sets the counter for proxied requests to zero. +func (p *ProxyHandler) Reset() { + p.Lock() + defer p.Unlock() + p.count = 0 +} + +// GetLocalIP gets the non-loopback IP address of this host. +func GetLocalIP() (string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", trace.Wrap(err) + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + if !ip.IsLoopback() && ip.IsPrivate() { + return ip.String(), nil + } + } + return "", trace.NotFound("No non-loopback local IP address found") +} + +type TestServerOption func(*testing.T, *httptest.Server) + +func WithTestServerAddress(ip string) TestServerOption { + return func(t *testing.T, srv *httptest.Server) { + // Replace the test server's address. + _, originalPort, err := net.SplitHostPort(srv.Listener.Addr().String()) + require.NoError(t, err) + require.NoError(t, srv.Listener.Close()) + l, err := net.Listen("tcp", net.JoinHostPort(ip, originalPort)) + require.NoError(t, err) + srv.Listener = l + } +} + +func MakeTestServer(t *testing.T, h http.Handler, opts ...TestServerOption) *httptest.Server { + svr := httptest.NewUnstartedServer(h) + for _, opt := range opts { + opt(t, svr) + } + svr.StartTLS() + t.Cleanup(svr.Close) + return svr +} diff --git a/integration/helpers/helpers.go b/integration/helpers/helpers.go index e519acc593748..ca6e027eba278 100644 --- a/integration/helpers/helpers.go +++ b/integration/helpers/helpers.go @@ -180,29 +180,6 @@ func CloseAgent(teleAgent *teleagent.AgentServer, socketDirPath string) error { return nil } -// GetLocalIP gets the non-loopback IP address of this host. -func GetLocalIP() (string, error) { - addrs, err := net.InterfaceAddrs() - if err != nil { - return "", trace.Wrap(err) - } - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - default: - continue - } - if !ip.IsLoopback() && ip.IsPrivate() { - return ip.String(), nil - } - } - return "", trace.NotFound("No non-loopback local IP address found") -} - func MustCreateUserIdentityFile(t *testing.T, tc *TeleInstance, username string, ttl time.Duration) string { key, err := client.GenerateRSAKey() require.NoError(t, err) diff --git a/integration/integration_test.go b/integration/integration_test.go index 4d20e9131357f..65ca809375096 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -67,6 +67,7 @@ import ( "github.com/gravitational/teleport/api/metadata" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/profile" + apihelpers "github.com/gravitational/teleport/api/testhelpers" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" @@ -2671,7 +2672,7 @@ func testTwoClustersProxy(t *testing.T, suite *integrationTestSuite) { // httpproxy doesn't allow proxying when the target is localhost, so use // this address instead. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) a := suite.newNamedTeleportInstance(t, "site-A", WithNodeName(addr), diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index db8c534536960..ef3d3f85fc9df 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -42,6 +42,7 @@ import ( "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" + apihelpers "github.com/gravitational/teleport/api/testhelpers" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integration/appaccess" dbhelpers "github.com/gravitational/teleport/integration/db" @@ -266,7 +267,7 @@ func TestALPNSNIHTTPSProxy(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) suite := newSuite(t, @@ -307,7 +308,7 @@ func TestMultiPortHTTPSProxy(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) suite := newSuite(t, @@ -1500,7 +1501,7 @@ func TestALPNProxyHTTPProxyNoProxyDial(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - addr, err := helpers.GetLocalIP() + addr, err := apihelpers.GetLocalIP() require.NoError(t, err) instanceCfg := helpers.InstanceConfig{ @@ -1579,7 +1580,7 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) { // We need to use the non-loopback address for our Teleport cluster, as the // Go HTTP library will recognize requests to the loopback address and // refuse to use the HTTP proxy, which will invalidate the test. - rcAddr, err := helpers.GetLocalIP() + rcAddr, err := apihelpers.GetLocalIP() require.NoError(t, err) log.Info("Creating Teleport instance...") diff --git a/lib/client/https_client_test.go b/lib/client/https_client_test.go index c37ad5207610a..37da8addc0681 100644 --- a/lib/client/https_client_test.go +++ b/lib/client/https_client_test.go @@ -17,55 +17,55 @@ limitations under the License. package client import ( + "net/http" + "net/url" "testing" "github.com/stretchr/testify/require" ) -func TestNewInsecureWebClientHTTPProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - client := NewInsecureWebClient() - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - // Client should try to proxy through nonexistent server at localhost. - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.Contains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakeproxy.example.com") - require.Contains(t, err.Error(), "no such host") -} - -func TestNewInsecureWebClientNoProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - t.Setenv("NO_PROXY", "fakedomain.example.com") - client := NewInsecureWebClient() - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.NotContains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakedomain.example.com") - require.Contains(t, err.Error(), "no such host") -} - -func TestNewSecureWebClientHTTPProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - client := newClient(false, nil, nil) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - // Client should try to proxy through nonexistent server at localhost. - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.Contains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakeproxy.example.com") - require.Contains(t, err.Error(), "no such host") -} +func TestHTTPTransportProxy(t *testing.T) { + proxyURL := "proxy.example.com" + target := "target.example.com" + tests := []struct { + name string + env map[string]string + expectedProxyURL string + }{ + { + name: "use http proxy", + env: map[string]string{ + "HTTPS_PROXY": proxyURL, + }, + expectedProxyURL: "http://" + proxyURL, + }, + { + name: "ignore proxy when no_proxy is set", + env: map[string]string{ + "HTTPS_PROXY": proxyURL, + "NO_PROXY": target, + }, + expectedProxyURL: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + for k, v := range tc.env { + t.Setenv(k, v) + } -func TestNewSecureWebClientNoProxy(t *testing.T) { - t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - t.Setenv("NO_PROXY", "fakedomain.example.com") - client := newClient(false, nil, nil) - //nolint:bodyclose // resp should be nil, so there will be no body to close. - resp, err := client.Get("https://fakedomain.example.com") - require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) - require.NotContains(t, err.Error(), "proxyconnect") - require.Contains(t, err.Error(), "lookup fakedomain.example.com") - require.Contains(t, err.Error(), "no such host") + inputURL, err := url.Parse("https://" + target) + require.NoError(t, err) + outputURL, err := httpTransport(false, nil).Proxy(&http.Request{ + URL: inputURL, + }) + require.NoError(t, err) + if tc.expectedProxyURL != "" { + require.NotNil(t, outputURL) + require.Equal(t, tc.expectedProxyURL, outputURL.String()) + } else { + require.Nil(t, outputURL) + } + }) + } } diff --git a/tool/tsh/common/resolve_default_addr_test.go b/tool/tsh/common/resolve_default_addr_test.go index ed6b118c3b987..166bcaa916038 100644 --- a/tool/tsh/common/resolve_default_addr_test.go +++ b/tool/tsh/common/resolve_default_addr_test.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "io" - "net" "net/http" "net/http/httptest" "net/url" @@ -31,6 +30,7 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" + apihelpers "github.com/gravitational/teleport/api/testhelpers" "github.com/gravitational/teleport/integration/helpers" ) @@ -77,18 +77,6 @@ func mustGetCandidatePorts(servers []*httptest.Server) []int { return result } -type testServerOption func(*httptest.Server) - -func makeTestServer(t *testing.T, h http.Handler, opts ...testServerOption) *httptest.Server { - svr := httptest.NewUnstartedServer(h) - for _, opt := range opts { - opt(svr) - } - svr.StartTLS() - t.Cleanup(func() { svr.Close() }) - return svr -} - func TestResolveDefaultAddr(t *testing.T) { t.Parallel() @@ -105,13 +93,13 @@ func TestResolveDefaultAddr(t *testing.T) { if i == magicServerIndex { handler = respondingHandler } - servers[i] = makeTestServer(t, handler) + servers[i] = apihelpers.MakeTestServer(t, handler) } // NB: We need to defer this channel close such that it happens *before* // the httpstest server shutdowns, or the blocking requests will never // finish and we will deadlock. - defer close(doneCh) + t.Cleanup(func() { close(doneCh) }) ports := mustGetCandidatePorts(servers) expectedAddr := fmt.Sprintf("127.0.0.1:%d", ports[magicServerIndex]) @@ -140,7 +128,7 @@ func TestResolveDefaultAddrSingleCandidate(t *testing.T) { servers := make([]*httptest.Server, 1) for i := 0; i < len(servers); i++ { - servers[i] = makeTestServer(t, respondingHandler) + servers[i] = apihelpers.MakeTestServer(t, respondingHandler) } ports := mustGetCandidatePorts(servers) @@ -162,13 +150,13 @@ func TestResolveDefaultAddrTimeout(t *testing.T) { servers := make([]*httptest.Server, 5) for i := 0; i < 5; i++ { - servers[i] = makeTestServer(t, blockingHandler) + servers[i] = apihelpers.MakeTestServer(t, blockingHandler) } // NB: We need to defer this channel close such that it happens *before* // the httpstest server shutdowns, or the blocking requests will never // finish and we will deadlock. - defer close(doneCh) + t.Cleanup(func() { close(doneCh) }) ports := mustGetCandidatePorts(servers) @@ -188,7 +176,7 @@ func TestResolveNonOKResponseIsAnError(t *testing.T) { // Given a single candidate server configured to respond with a non-OK status // code servers := []*httptest.Server{ - makeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)), + apihelpers.MakeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)), } ports := mustGetCandidatePorts(servers) @@ -229,7 +217,7 @@ func TestResolveUndeliveredBodyDoesNotBlockForever(t *testing.T) { testLog.Debug("Exiting handler") }) - servers := []*httptest.Server{makeTestServer(t, handler)} + servers := []*httptest.Server{apihelpers.MakeTestServer(t, handler)} ports := mustGetCandidatePorts(servers) // When I attempt to resolve a default address @@ -246,15 +234,15 @@ func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) { blockingHandler, doneCh := newWaitForeverHandler() - servers := make([]*httptest.Server, 1000) + servers := make([]*httptest.Server, 100) for i := 0; i < len(servers); i++ { - servers[i] = makeTestServer(t, blockingHandler) + servers[i] = apihelpers.MakeTestServer(t, blockingHandler) } // NB: We need to defer this channel close such that it happens *before* // the httpstest server shutdowns, or the blocking requests will never // finish and we will deadlock. - defer close(doneCh) + t.Cleanup(func() { close(doneCh) }) ports := mustGetCandidatePorts(servers) @@ -275,19 +263,12 @@ func TestResolveDefaultAddrHTTPProxy(t *testing.T) { t.Cleanup(proxyServer.Close) // Go won't proxy to localhost, so use this address instead. - localIP, err := helpers.GetLocalIP() + localIP, err := apihelpers.GetLocalIP() require.NoError(t, err) - var serverAddr net.Addr respondingHandler := newRespondingHandler() - server := makeTestServer(t, respondingHandler, func(srv *httptest.Server) { - // Replace the test server's address. - l, err := net.Listen("tcp", localIP+":0") - require.NoError(t, err) - require.NoError(t, srv.Listener.Close()) - srv.Listener = l - serverAddr = l.Addr() - }) + server := apihelpers.MakeTestServer(t, respondingHandler, apihelpers.WithTestServerAddress(localIP)) + serverAddr := server.Listener.Addr() ports := mustGetCandidatePorts([]*httptest.Server{server})