Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 55 additions & 28 deletions api/client/webclient/webclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"golang.org/x/exp/slices"

"github.com/gravitational/teleport/api/defaults"
apihelpers "github.com/gravitational/teleport/api/testhelpers"
)

func newPingHandler(path string) http.Handler {
Expand Down Expand Up @@ -324,36 +325,62 @@ func TestParse(t *testing.T) {
}
}

func TestNewWebClientRespectHTTPProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
client, err := newWebClient(&Config{
Context: context.Background(),
ProxyAddr: "localhost:3080",
})
require.NoError(t, err)
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
// Client should try to proxy through nonexistent server at localhost.
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.Contains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakeproxy.example.com")
require.Contains(t, err.Error(), "no such host")
}
func TestNewWebClientHTTPProxy(t *testing.T) {
proxyHandler := &apihelpers.ProxyHandler{}
proxyServer := httptest.NewServer(proxyHandler)
t.Cleanup(proxyServer.Close)

func TestNewWebClientNoProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
t.Setenv("NO_PROXY", "fakedomain.example.com")
client, err := newWebClient(&Config{
Context: context.Background(),
ProxyAddr: "localhost:3080",
})
localIP, err := apihelpers.GetLocalIP()
require.NoError(t, err)
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.NotContains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakedomain.example.com")
require.Contains(t, err.Error(), "no such host")
server := apihelpers.MakeTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello"))
}), apihelpers.WithTestServerAddress(localIP))
_, serverPort, err := net.SplitHostPort(server.Listener.Addr().String())
require.NoError(t, err)
serverAddr := net.JoinHostPort(localIP, serverPort)
tests := []struct {
name string
env map[string]string
expectedProxyCount int
}{
{
name: "use http proxy",
env: map[string]string{
"HTTPS_PROXY": proxyServer.URL,
},
expectedProxyCount: 1,
},
{
name: "ignore proxy when no_proxy is set",
env: map[string]string{
"HTTPS_PROXY": proxyServer.URL,
"NO_PROXY": "*",
},
Comment thread
atburke marked this conversation as resolved.
expectedProxyCount: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid variable name hiding here, and say, use tt? I was surprised seeing the proxyHandler.Reset passed to the cleanup function and was wondering how the hell we expect to get back to between loop iterations.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to t? Shadowing t in subtests is quite common and is actually the preferred approach.

If you have tt and t in scope, it's possible to mistakenly use the wrong one and (for example) fail the top-level test instead of the subtest. Shadowing t makes that impossible because the outer t is no longer accessible to the closure.

It's even written this way in the official docs for the testing package: https://pkg.go.dev/testing#hdr-Subtests_and_Sub_benchmarks

t.Cleanup(proxyHandler.Reset)
for k, v := range tc.env {
t.Setenv(k, v)
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
client, err := newWebClient(&Config{
Context: ctx,
ProxyAddr: "localhost:3080", // addr doesn't matter, it won't be used
Insecure: true,
})
require.NoError(t, err)

resp, err := client.Get("https://" + serverAddr)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
require.Equal(t, tc.expectedProxyCount, proxyHandler.Count())
})
}
}

func TestSSHProxyHostPort(t *testing.T) {
Expand Down
61 changes: 59 additions & 2 deletions api/testhelpers/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ import (
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

// ProxyHandler is a http.Handler that implements a simple HTTP proxy server.
Expand Down Expand Up @@ -64,7 +67,7 @@ func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
trace.WriteError(w, trace.AccessDenied("unable to hijack connection"))
return
}
sconn, _, err := hj.Hijack()
sconn, buf, err := hj.Hijack()
if err != nil {
trace.WriteError(w, err)
return
Expand All @@ -83,7 +86,7 @@ func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
errc <- err
}
go replicate(sconn, dconn)
go replicate(dconn, sconn)
go replicate(dconn, io.MultiReader(buf, sconn))

// Wait until done, error, or 10 second.
select {
Expand All @@ -98,3 +101,57 @@ func (p *ProxyHandler) Count() int {
defer p.Unlock()
return p.count
}

// Reset sets the counter for proxied requests to zero.
func (p *ProxyHandler) Reset() {
p.Lock()
defer p.Unlock()
p.count = 0
}

// GetLocalIP gets the non-loopback IP address of this host.
func GetLocalIP() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", trace.Wrap(err)
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if !ip.IsLoopback() && ip.IsPrivate() {
return ip.String(), nil
}
}
return "", trace.NotFound("No non-loopback local IP address found")
}

type TestServerOption func(*testing.T, *httptest.Server)

func WithTestServerAddress(ip string) TestServerOption {
return func(t *testing.T, srv *httptest.Server) {
// Replace the test server's address.
_, originalPort, err := net.SplitHostPort(srv.Listener.Addr().String())
require.NoError(t, err)
require.NoError(t, srv.Listener.Close())
l, err := net.Listen("tcp", net.JoinHostPort(ip, originalPort))
require.NoError(t, err)
srv.Listener = l
}
}

func MakeTestServer(t *testing.T, h http.Handler, opts ...TestServerOption) *httptest.Server {
svr := httptest.NewUnstartedServer(h)
for _, opt := range opts {
opt(t, svr)
}
svr.StartTLS()
t.Cleanup(svr.Close)
return svr
}
23 changes: 0 additions & 23 deletions integration/helpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,29 +180,6 @@ func CloseAgent(teleAgent *teleagent.AgentServer, socketDirPath string) error {
return nil
}

// GetLocalIP gets the non-loopback IP address of this host.
func GetLocalIP() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", trace.Wrap(err)
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if !ip.IsLoopback() && ip.IsPrivate() {
return ip.String(), nil
}
}
return "", trace.NotFound("No non-loopback local IP address found")
}

func MustCreateUserIdentityFile(t *testing.T, tc *TeleInstance, username string, ttl time.Duration) string {
key, err := client.GenerateRSAKey()
require.NoError(t, err)
Expand Down
3 changes: 2 additions & 1 deletion integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import (
"github.com/gravitational/teleport/api/metadata"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/api/profile"
apihelpers "github.com/gravitational/teleport/api/testhelpers"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
apiutils "github.com/gravitational/teleport/api/utils"
Expand Down Expand Up @@ -2671,7 +2672,7 @@ func testTwoClustersProxy(t *testing.T, suite *integrationTestSuite) {

// httpproxy doesn't allow proxying when the target is localhost, so use
// this address instead.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)
a := suite.newNamedTeleportInstance(t, "site-A",
WithNodeName(addr),
Expand Down
9 changes: 5 additions & 4 deletions integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/constants"
apihelpers "github.com/gravitational/teleport/api/testhelpers"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integration/appaccess"
dbhelpers "github.com/gravitational/teleport/integration/db"
Expand Down Expand Up @@ -266,7 +267,7 @@ func TestALPNSNIHTTPSProxy(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

suite := newSuite(t,
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestMultiPortHTTPSProxy(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

suite := newSuite(t,
Expand Down Expand Up @@ -1500,7 +1501,7 @@ func TestALPNProxyHTTPProxyNoProxyDial(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

instanceCfg := helpers.InstanceConfig{
Expand Down Expand Up @@ -1579,7 +1580,7 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
rcAddr, err := helpers.GetLocalIP()
rcAddr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

log.Info("Creating Teleport instance...")
Expand Down
90 changes: 45 additions & 45 deletions lib/client/https_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,55 @@ limitations under the License.
package client

import (
"net/http"
"net/url"
"testing"

"github.com/stretchr/testify/require"
)

func TestNewInsecureWebClientHTTPProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
client := NewInsecureWebClient()
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
// Client should try to proxy through nonexistent server at localhost.
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.Contains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakeproxy.example.com")
require.Contains(t, err.Error(), "no such host")
}

func TestNewInsecureWebClientNoProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
t.Setenv("NO_PROXY", "fakedomain.example.com")
client := NewInsecureWebClient()
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.NotContains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakedomain.example.com")
require.Contains(t, err.Error(), "no such host")
}

func TestNewSecureWebClientHTTPProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
client := newClient(false, nil, nil)
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
// Client should try to proxy through nonexistent server at localhost.
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.Contains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakeproxy.example.com")
require.Contains(t, err.Error(), "no such host")
}
func TestHTTPTransportProxy(t *testing.T) {
proxyURL := "proxy.example.com"
target := "target.example.com"
tests := []struct {
name string
env map[string]string
expectedProxyURL string
}{
{
name: "use http proxy",
env: map[string]string{
"HTTPS_PROXY": proxyURL,
},
expectedProxyURL: "http://" + proxyURL,
},
{
name: "ignore proxy when no_proxy is set",
env: map[string]string{
"HTTPS_PROXY": proxyURL,
"NO_PROXY": target,
},
expectedProxyURL: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
for k, v := range tc.env {
t.Setenv(k, v)
}

func TestNewSecureWebClientNoProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
t.Setenv("NO_PROXY", "fakedomain.example.com")
client := newClient(false, nil, nil)
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.NotContains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakedomain.example.com")
require.Contains(t, err.Error(), "no such host")
inputURL, err := url.Parse("https://" + target)
require.NoError(t, err)
outputURL, err := httpTransport(false, nil).Proxy(&http.Request{
URL: inputURL,
})
require.NoError(t, err)
if tc.expectedProxyURL != "" {
require.NotNil(t, outputURL)
require.Equal(t, tc.expectedProxyURL, outputURL.String())
} else {
require.Nil(t, outputURL)
}
})
}
}
Loading