diff --git a/api/client/proxy/proxy.go b/api/client/proxy/proxy.go index 7b95b3ba8a4df..c1e3da4322c46 100644 --- a/api/client/proxy/proxy.go +++ b/api/client/proxy/proxy.go @@ -71,32 +71,39 @@ func parse(addr string) (*url.URL, error) { return addrURL, nil } -// HTTPFallbackRoundTripper is a wrapper for http.Transport that downgrades requests -// to plain HTTP when using a plain HTTP proxy at localhost. -type HTTPFallbackRoundTripper struct { +// HTTPRoundTripper is a wrapper for http.Transport that +// - adds extra HTTP headers to all requests, and +// - downgrades requests to plain HTTP when proxy is at localhost and the wrapped http.Transport has TLSClientConfig.InsecureSkipVerify set to true. +type HTTPRoundTripper struct { *http.Transport + // extraHeaders is a map of extra HTTP headers to be included in requests. + extraHeaders map[string]string + // isProxyHTTPLocalhost indicates that the HTTP_PROXY is at "http://localhost" isProxyHTTPLocalhost bool } -// NewHTTPFallbackRoundTripper creates a new initialized HTTP fallback roundtripper. -func NewHTTPFallbackRoundTripper(transport *http.Transport, insecure bool) *HTTPFallbackRoundTripper { +// NewHTTPRoundTripper creates a new initialized HTTP roundtripper. +func NewHTTPRoundTripper(transport *http.Transport, extraHeaders map[string]string) *HTTPRoundTripper { proxyConfig := httpproxy.FromEnvironment() - rt := HTTPFallbackRoundTripper{ + return &HTTPRoundTripper{ Transport: transport, + extraHeaders: extraHeaders, isProxyHTTPLocalhost: strings.HasPrefix(proxyConfig.HTTPProxy, "http://localhost"), } - if rt.TLSClientConfig != nil { - rt.TLSClientConfig.InsecureSkipVerify = insecure - } - return &rt } // RoundTrip executes a single HTTP transaction. Part of the RoundTripper interface. -func (rt *HTTPFallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - tlsConfig := rt.Transport.TLSClientConfig +func (rt *HTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Add extra HTTP headers. + for header, v := range rt.extraHeaders { + req.Header.Add(header, v) + } + // Use plain HTTP if proxying via http://localhost in insecure mode. + tlsConfig := rt.Transport.TLSClientConfig if rt.isProxyHTTPLocalhost && tlsConfig != nil && tlsConfig.InsecureSkipVerify { req.URL.Scheme = "http" } + return rt.Transport.RoundTrip(req) } diff --git a/api/client/proxy/proxy_test.go b/api/client/proxy/proxy_test.go index 08f05bfbe9f3a..939b530f91865 100644 --- a/api/client/proxy/proxy_test.go +++ b/api/client/proxy/proxy_test.go @@ -19,7 +19,9 @@ package proxy import ( "crypto/tls" "fmt" + "net" "net/http" + "net/http/httptest" "net/url" "strings" "testing" @@ -182,12 +184,14 @@ func buildProxyAddr(addr, user, pass string) (string, error) { func TestProxyAwareRoundTripper(t *testing.T) { t.Setenv("HTTP_PROXY", "http://localhost:8888") transport := &http.Transport{ - TLSClientConfig: &tls.Config{}, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, Proxy: func(req *http.Request) (*url.URL, error) { return httpproxy.FromEnvironment().ProxyFunc()(req.URL) }, } - rt := NewHTTPFallbackRoundTripper(transport, true) + rt := NewHTTPRoundTripper(transport, nil) req, err := http.NewRequest(http.MethodGet, "https://localhost:9999", nil) require.NoError(t, err) // Don't care about response, only if the scheme changed. @@ -197,6 +201,190 @@ func TestProxyAwareRoundTripper(t *testing.T) { require.Equal(t, "http", req.URL.Scheme) } +// TestHttpRoundTripperDowngrade tests that the round tripper downgrades https requests to http +// when HTTP_PROXY is set to "http://localhost:*" (i.e. there's an http proxy running on localhost). +func TestHttpRoundTripperDowngrade(t *testing.T) { + testCases := []struct { + desc string + setHTTPProxy bool + shouldHitProxy bool + }{ + { + desc: "hits http proxy if insecure and localhost http proxy is set", + setHTTPProxy: true, + shouldHitProxy: true, + }, + { + desc: "does not hit http proxy if insecure and localhost http proxy is not set", + setHTTPProxy: false, + shouldHitProxy: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + newHandler := func(runningAtProxy bool, wasHit *bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + *wasHit = true + if tc.shouldHitProxy { + // If the request should hit the proxy, then: + // - this handler is running at the proxy, and + // - the scheme should be http. + require.True(t, runningAtProxy) + require.Equal(t, "http", r.URL.Scheme) + } + w.WriteHeader(http.StatusOK) + } + } + + // Start localhost http proxy. + runningAtProxy := true + loopback := true + https := false + httpProxyWasHit := false + httpProxy, err := newServer(newHandler(runningAtProxy, &httpProxyWasHit), loopback, https) + require.NoError(t, err) + defer httpProxy.Close() + + // Start non-localhost https server. + runningAtProxy = false + loopback = false + https = true + httpsSrvWasHit := false + httpsSrv, err := newServer(newHandler(runningAtProxy, &httpsSrvWasHit), loopback, https) + require.NoError(t, err) + defer httpsSrv.Close() + + if tc.setHTTPProxy { + // url.Parse won't correctly parse an absolute URL without a scheme. + u, err := url.Parse("http://" + httpProxy.Listener.Addr().String()) + require.NoError(t, err) + _, port, err := net.SplitHostPort(u.Host) + require.NoError(t, err) + + // Set HTTP_PROXY to "http://localhost:*". + t.Setenv("HTTP_PROXY", fmt.Sprintf("http://localhost:%s", port)) + } + + clt := newClient(t, nil) + + // Perform any request. + // Set addr to the https server. If HTTP_PROXY was set above, + // the http proxy should be hit regardless. + addr := httpsSrv.Listener.Addr().String() + request(t, clt, addr) + + // Validate that the correct server was hit. + require.Equal(t, tc.shouldHitProxy, httpProxyWasHit) + require.Equal(t, !tc.shouldHitProxy, httpsSrvWasHit) + }) + } +} + +// TestHttpRoundTripperExtraHeaders tests that the round tripper adds the extra headers set. +func TestHttpRoundTripperExtraHeaders(t *testing.T) { + testCases := []struct { + desc string + extraHeaders map[string]string + expectHeaders func(*testing.T, http.Header) + }{ + { + desc: "extra headers are added", + extraHeaders: map[string]string{ + "header1": "value1", + "header2": "value2", + }, + expectHeaders: func(t *testing.T, headers http.Header) { + require.Equal(t, []string{"value1"}, headers.Values("header1")) + require.Equal(t, []string{"value2"}, headers.Values("header2")) + }, + }, + { + desc: "extra headers do not overwrite existing headers", + extraHeaders: map[string]string{ + "header1": "value1", + "Content-Type": "value2", + }, + expectHeaders: func(t *testing.T, headers http.Header) { + require.Equal(t, []string{"value1"}, headers.Values("header1")) + require.Equal(t, []string{"application/json", "value2"}, headers.Values("Content-Type")) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + tc.expectHeaders(t, r.Header) + w.WriteHeader(http.StatusOK) + } + + // Start localhost https server. + loopback := true + tls := true + httpsSrv, err := newServer(handler, loopback, tls) + require.NoError(t, err) + defer httpsSrv.Close() + + clt := newClient(t, tc.extraHeaders) + + // Perform any request. + // Set the address to the localhost https server. + addr := httpsSrv.Listener.Addr().String() + request(t, clt, addr) + }) + } +} + +// newServer starts a new server that: +// - runs TLS if `https` +// - uses a loopback listener if `loopback` +func newServer(handler http.HandlerFunc, loopback bool, https bool) (*httptest.Server, error) { + srv := httptest.NewUnstartedServer(handler) + + if !loopback { + // Replace the test-supplied loopback listener with the first available + // non-loopback address. + srv.Listener.Close() + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return nil, err + } + srv.Listener = l + } + + if https { + srv.StartTLS() + } else { + srv.Start() + } + return srv, nil +} + +// newClient creates a new https roundtrip client. +func newClient(t *testing.T, extraHeaders map[string]string) *http.Client { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + // Setting insecure ensures that https requests succeed. + InsecureSkipVerify: true, + }, + Proxy: func(req *http.Request) (*url.URL, error) { + return httpproxy.FromEnvironment().ProxyFunc()(req.URL) + }, + } + return &http.Client{ + Transport: NewHTTPRoundTripper(transport, extraHeaders), + } +} + +// request perform a POST request. +func request(t *testing.T, clt *http.Client, addr string) { + url := "https://" + addr + "/v1/content" + resp, err := clt.Post(url, "application/json", nil) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) +} + func TestParse(t *testing.T) { successTests := []struct { name, addr, scheme, host, path string diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index 771390111a5fb..505c8648b3372 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -103,7 +103,7 @@ func newWebClient(cfg *Config) (*http.Client, error) { } return &http.Client{ Transport: otelhttp.NewTransport( - proxy.NewHTTPFallbackRoundTripper(&transport, cfg.Insecure), + proxy.NewHTTPRoundTripper(&transport, nil), otelhttp.WithSpanNameFormatter(tracing.HTTPTransportFormatter), ), Timeout: cfg.Timeout, diff --git a/lib/client/api.go b/lib/client/api.go index b2e4495db876b..5002195388abf 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2787,7 +2787,12 @@ func makeProxySSHClient(ctx context.Context, tc *TeleportClient, sshConfig *ssh. if len(tc.JumpHosts) > 0 { sshProxyAddr = tc.JumpHosts[0].Addr.Addr // Check if JumpHost address is a proxy web address. - resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: sshProxyAddr, Insecure: tc.InsecureSkipVerify}) + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: sshProxyAddr, + Insecure: tc.InsecureSkipVerify, + ExtraHeaders: tc.ExtraProxyHeaders, + }) // If JumpHost address is a proxy web port and proxy supports TLSRouting dial proxy with TLSWrapper. if err == nil && resp.Proxy.TLSRoutingEnabled { log.Infof("Connecting to proxy=%v login=%q using TLS Routing JumpHost", sshProxyAddr, sshConfig.User) @@ -3284,6 +3289,7 @@ func (tc *TeleportClient) newSSHLogin(priv *keys.PrivateKey) (SSHLogin, error) { RouteToCluster: tc.SiteName, KubernetesCluster: tc.KubernetesCluster, AttestationStatement: attestationStatement, + ExtraHeaders: tc.ExtraProxyHeaders, }, nil } diff --git a/lib/client/https_client.go b/lib/client/https_client.go index 4e7f4c8e43e08..5c2f0ebf0bddb 100644 --- a/lib/client/https_client.go +++ b/lib/client/https_client.go @@ -37,39 +37,30 @@ import ( ) func NewInsecureWebClient() *http.Client { - // Because Teleport clients can't be configured (yet), they take the default - // list of cipher suites from Go. - tlsConfig := utils.TLSConfig(nil) - transport := http.Transport{ - TLSClientConfig: tlsConfig, - Proxy: func(req *http.Request) (*url.URL, error) { - return httpproxy.FromEnvironment().ProxyFunc()(req.URL) - }, - } + return newClient(true, nil, nil) +} + +func newClient(insecure bool, pool *x509.CertPool, extraHeaders map[string]string) *http.Client { return &http.Client{ Transport: otelhttp.NewTransport( - apiproxy.NewHTTPFallbackRoundTripper(&transport, true /* insecure */), + apiproxy.NewHTTPRoundTripper(httpTransport(insecure, pool), extraHeaders), otelhttp.WithSpanNameFormatter(tracing.HTTPTransportFormatter), ), } } -func newClientWithPool(pool *x509.CertPool) *http.Client { +func httpTransport(insecure bool, pool *x509.CertPool) *http.Transport { // Because Teleport clients can't be configured (yet), they take the default // list of cipher suites from Go. tlsConfig := utils.TLSConfig(nil) + tlsConfig.InsecureSkipVerify = insecure tlsConfig.RootCAs = pool - return &http.Client{ - Transport: otelhttp.NewTransport( - &http.Transport{ - TLSClientConfig: tlsConfig, - Proxy: func(req *http.Request) (*url.URL, error) { - return httpproxy.FromEnvironment().ProxyFunc()(req.URL) - }, - }, - otelhttp.WithSpanNameFormatter(tracing.HTTPTransportFormatter), - ), + return &http.Transport{ + TLSClientConfig: tlsConfig, + Proxy: func(req *http.Request) (*url.URL, error) { + return httpproxy.FromEnvironment().ProxyFunc()(req.URL) + }, } } diff --git a/lib/client/https_client_test.go b/lib/client/https_client_test.go index 7482df7fd0ea8..c37ad5207610a 100644 --- a/lib/client/https_client_test.go +++ b/lib/client/https_client_test.go @@ -46,9 +46,9 @@ func TestNewInsecureWebClientNoProxy(t *testing.T) { require.Contains(t, err.Error(), "no such host") } -func TestNewClientWithPoolHTTPProxy(t *testing.T) { +func TestNewSecureWebClientHTTPProxy(t *testing.T) { t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") - client := newClientWithPool(nil) + client := newClient(false, nil, nil) //nolint:bodyclose // resp should be nil, so there will be no body to close. resp, err := client.Get("https://fakedomain.example.com") // Client should try to proxy through nonexistent server at localhost. @@ -58,10 +58,10 @@ func TestNewClientWithPoolHTTPProxy(t *testing.T) { require.Contains(t, err.Error(), "no such host") } -func TestNewClientWithPoolNoProxy(t *testing.T) { +func TestNewSecureWebClientNoProxy(t *testing.T) { t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999") t.Setenv("NO_PROXY", "fakedomain.example.com") - client := newClientWithPool(nil) + client := newClient(false, nil, nil) //nolint:bodyclose // resp should be nil, so there will be no body to close. resp, err := client.Get("https://fakedomain.example.com") require.Error(t, err, "GET unexpectedly succeeded: %+v", resp) diff --git a/lib/client/redirect.go b/lib/client/redirect.go index 419d69ff0bb2a..cb92e5206e24a 100644 --- a/lib/client/redirect.go +++ b/lib/client/redirect.go @@ -91,7 +91,7 @@ type RedirectorConfig struct { // NewRedirector returns new local web server redirector func NewRedirector(ctx context.Context, login SSHLoginSSO, config *RedirectorConfig) (*Redirector, error) { - clt, proxyURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool) + clt, proxyURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/client/weblogin.go b/lib/client/weblogin.go index 51687be3e5173..aa347516d1af5 100644 --- a/lib/client/weblogin.go +++ b/lib/client/weblogin.go @@ -172,7 +172,7 @@ type SSHLogin struct { TTL time.Duration // Insecure turns off verification for x509 target proxy Insecure bool - // Pool is x509 cert pool to use for server certifcate verification + // Pool is x509 cert pool to use for server certificate verification Pool *x509.CertPool // Compatibility sets compatibility mode for SSH certificates Compatibility string @@ -184,6 +184,8 @@ type SSHLogin struct { KubernetesCluster string // AttestationStatement is an attestation statement. AttestationStatement *keys.AttestationStatement + // ExtraHeaders is a map of extra HTTP headers to be included in requests. + ExtraHeaders map[string]string } // SSHLoginSSO contains SSH login parameters for SSO login. @@ -250,13 +252,13 @@ type SSHLoginPasswordless struct { } // initClient creates a new client to the HTTPS web proxy. -func initClient(proxyAddr string, insecure bool, pool *x509.CertPool) (*WebClient, *url.URL, error) { +func initClient(proxyAddr string, insecure bool, pool *x509.CertPool, extraHeaders map[string]string) (*WebClient, *url.URL, error) { log := logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentClient, }) - log.Debugf("HTTPS client init(proxyAddr=%v, insecure=%v)", proxyAddr, insecure) + log.Debugf("HTTPS client init(proxyAddr=%v, insecure=%v, extraHeaders=%v)", proxyAddr, insecure, extraHeaders) - // validate proxyAddr: + // validate proxy address host, port, err := net.SplitHostPort(proxyAddr) if err != nil || host == "" || port == "" { if err != nil { @@ -270,18 +272,13 @@ func initClient(proxyAddr string, insecure bool, pool *x509.CertPool) (*WebClien return nil, nil, trace.BadParameter("'%v' is not a valid proxy address", proxyAddr) } - var opts []roundtrip.ClientParam - if insecure { - // Skip https cert verification, print a warning that this is insecure. + // Skipping https cert verification, print a warning that this is insecure. fmt.Fprintf(os.Stderr, "WARNING: You are using insecure connection to Teleport proxy %v\n", proxyAddr) - opts = append(opts, roundtrip.HTTPClient(NewInsecureWebClient())) - } else if pool != nil { - // use custom set of trusted CAs - opts = append(opts, roundtrip.HTTPClient(newClientWithPool(pool))) } - clt, err := NewWebClient(proxyAddr, opts...) + opt := roundtrip.HTTPClient(newClient(insecure, pool, extraHeaders)) + clt, err := NewWebClient(proxyAddr, opt) if err != nil { return nil, nil, trace.Wrap(err) } @@ -360,7 +357,7 @@ func SSHAgentSSOLogin(ctx context.Context, login SSHLoginSSO, config *Redirector // SSHAgentLogin is used by tsh to fetch local user credentials. func SSHAgentLogin(ctx context.Context, login SSHLoginDirect) (*auth.SSHLoginResponse, error) { - clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool) + clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders) if err != nil { return nil, trace.Wrap(err) } @@ -395,7 +392,7 @@ func SSHAgentLogin(ctx context.Context, login SSHLoginDirect) (*auth.SSHLoginRes // // Returns the SSH certificate if authn is successful or an error. func SSHAgentPasswordlessLogin(ctx context.Context, login SSHLoginPasswordless) (*auth.SSHLoginResponse, error) { - webClient, webURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool) + webClient, webURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders) if err != nil { return nil, trace.Wrap(err) } @@ -466,7 +463,7 @@ func SSHAgentPasswordlessLogin(ctx context.Context, login SSHLoginPasswordless) // prompt the user to provide 2nd factor and pass the response to the proxy. // If the authentication succeeds, we will get a temporary certificate back. func SSHAgentMFALogin(ctx context.Context, login SSHLoginMFA) (*auth.SSHLoginResponse, error) { - clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool) + clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders) if err != nil { return nil, trace.Wrap(err) } @@ -534,7 +531,7 @@ func SSHAgentMFALogin(ctx context.Context, login SSHLoginMFA) (*auth.SSHLoginRes // HostCredentials is used to fetch host credentials for a node. func HostCredentials(ctx context.Context, proxyAddr string, insecure bool, req types.RegisterUsingTokenRequest) (*proto.Certs, error) { - clt, _, err := initClient(proxyAddr, insecure, nil) + clt, _, err := initClient(proxyAddr, insecure, nil, nil) if err != nil { return nil, trace.Wrap(err) } @@ -554,7 +551,7 @@ func HostCredentials(ctx context.Context, proxyAddr string, insecure bool, req t // GetWebConfig is used by teleterm to fetch webconfig.js from proxies func GetWebConfig(ctx context.Context, proxyAddr string, insecure bool) (*webclient.WebConfig, error) { - clt, _, err := initClient(proxyAddr, insecure, nil) + clt, _, err := initClient(proxyAddr, insecure, nil, nil) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/client/weblogin_test.go b/lib/client/weblogin_test.go index a562966a06e58..071ebcc69f022 100644 --- a/lib/client/weblogin_test.go +++ b/lib/client/weblogin_test.go @@ -36,69 +36,82 @@ import ( "github.com/gravitational/teleport/lib/client" ) -func TestPlainHttpFallback(t *testing.T) { +// TestHostCredentialsHttpFallback tests that HostCredentials requests (/v1/webapi/host/credentials/) +// fall back to HTTP only if the address is a loopback and the insecure mode was set. +func TestHostCredentialsHttpFallback(t *testing.T) { testCases := []struct { - desc string - path string - handler http.HandlerFunc - actionUnderTest func(ctx context.Context, addr string, insecure bool) error + desc string + loopback bool + insecure bool + fallback bool }{ { - desc: "HostCredentials", - path: "/v1/webapi/host/credentials", - handler: func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI != "/v1/webapi/host/credentials" { - w.WriteHeader(http.StatusNotFound) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(proto.Certs{}) - }, - actionUnderTest: func(ctx context.Context, addr string, insecure bool) error { - _, err := client.HostCredentials(ctx, addr, insecure, types.RegisterUsingTokenRequest{}) - return err - }, + desc: "falls back to http if loopback and insecure", + loopback: true, + insecure: true, + fallback: true, + }, + { + desc: "does not fall back to http if loopback and secure", + loopback: true, + insecure: false, + fallback: false, + }, + { + desc: "does not fall back to http if non-loopback and insecure", + loopback: false, + insecure: true, + fallback: false, }, } - for _, testCase := range testCases { - t.Run(testCase.desc, func(t *testing.T) { - ctx := context.Background() - - t.Run("Allowed on insecure & loopback", func(t *testing.T) { - httpSvr := httptest.NewServer(testCase.handler) - defer httpSvr.Close() - - err := testCase.actionUnderTest(ctx, httpSvr.Listener.Addr().String(), true /* insecure */) - require.NoError(t, err) - }) - - t.Run("Denied on secure", func(t *testing.T) { - httpSvr := httptest.NewServer(testCase.handler) - defer httpSvr.Close() - - err := testCase.actionUnderTest(ctx, httpSvr.Listener.Addr().String(), false /* secure */) - require.Error(t, err) - }) + for _, tc := range testCases { + // Start an http server (not https) so that the request only succeeds + // if the fallback occurs. + var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI != "/v1/webapi/host/credentials" { + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(proto.Certs{}) + } + httpSvr, err := newServer(handler, tc.loopback) + require.NoError(t, err) + defer httpSvr.Close() + + // Send the HostCredentials request. + ctx := context.Background() + _, err = client.HostCredentials(ctx, httpSvr.Listener.Addr().String(), tc.insecure, types.RegisterUsingTokenRequest{}) + + // If it should fallback, then no error should occur + // as the request will hit the running http server. + if tc.fallback { + require.NoError(t, err) + } else { + require.Error(t, err) + } + } +} - t.Run("Denied on non-loopback", func(t *testing.T) { - nonLoopbackSvr := httptest.NewUnstartedServer(testCase.handler) +// newServer starts a new server that uses a loopback listener if `loopback`. +func newServer(handler http.HandlerFunc, loopback bool) (*httptest.Server, error) { + srv := httptest.NewUnstartedServer(handler) - // replace the test-supplied loopback listener with the first available - // non-loopback address - nonLoopbackSvr.Listener.Close() - l, err := net.Listen("tcp", "0.0.0.0:0") - require.NoError(t, err) - nonLoopbackSvr.Listener = l - nonLoopbackSvr.Start() - defer nonLoopbackSvr.Close() - - err = testCase.actionUnderTest(ctx, nonLoopbackSvr.Listener.Addr().String(), true /* insecure */) - require.Error(t, err) - }) - }) + if !loopback { + // Replace the test-supplied loopback listener with the first available + // non-loopback address. + srv.Listener.Close() + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return nil, err + } + srv.Listener = l } + + srv.Start() + return srv, nil } func TestSSHAgentPasswordlessLogin(t *testing.T) {