Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 55 additions & 28 deletions api/client/webclient/webclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"golang.org/x/exp/slices"

"github.com/gravitational/teleport/api/defaults"
apihelpers "github.com/gravitational/teleport/api/testhelpers"
)

func newPingHandler(path string) http.Handler {
Expand Down Expand Up @@ -324,36 +325,62 @@ func TestParse(t *testing.T) {
}
}

func TestNewWebClientRespectHTTPProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
client, err := newWebClient(&Config{
Context: context.Background(),
ProxyAddr: "localhost:3080",
})
require.NoError(t, err)
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
// Client should try to proxy through nonexistent server at localhost.
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.Contains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakeproxy.example.com")
require.Contains(t, err.Error(), "no such host")
}
func TestNewWebClientHTTPProxy(t *testing.T) {
proxyHandler := &apihelpers.ProxyHandler{}
proxyServer := httptest.NewServer(proxyHandler)
t.Cleanup(proxyServer.Close)

func TestNewWebClientNoProxy(t *testing.T) {
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
t.Setenv("NO_PROXY", "fakedomain.example.com")
client, err := newWebClient(&Config{
Context: context.Background(),
ProxyAddr: "localhost:3080",
})
localIP, err := apihelpers.GetLocalIP()
require.NoError(t, err)
//nolint:bodyclose // resp should be nil, so there will be no body to close.
resp, err := client.Get("https://fakedomain.example.com")
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
require.NotContains(t, err.Error(), "proxyconnect")
require.Contains(t, err.Error(), "lookup fakedomain.example.com")
require.Contains(t, err.Error(), "no such host")
server := apihelpers.MakeTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello"))
}), apihelpers.WithTestServerAddress(localIP))
_, serverPort, err := net.SplitHostPort(server.Listener.Addr().String())
require.NoError(t, err)
serverAddr := net.JoinHostPort(localIP, serverPort)
tests := []struct {
name string
env map[string]string
expectedProxyCount int
}{
{
name: "use http proxy",
env: map[string]string{
"HTTPS_PROXY": proxyServer.URL,
},
expectedProxyCount: 1,
},
{
name: "ignore proxy when no_proxy is set",
env: map[string]string{
"HTTPS_PROXY": proxyServer.URL,
"NO_PROXY": "*",
},
expectedProxyCount: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(proxyHandler.Reset)
for k, v := range tc.env {
t.Setenv(k, v)
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
client, err := newWebClient(&Config{
Context: ctx,
ProxyAddr: "localhost:3080", // addr doesn't matter, it won't be used
Insecure: true,
})
require.NoError(t, err)

resp, err := client.Get("https://" + serverAddr)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
require.Equal(t, tc.expectedProxyCount, proxyHandler.Count())
})
}
}

func TestSSHProxyHostPort(t *testing.T) {
Expand Down
157 changes: 157 additions & 0 deletions api/testhelpers/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright 2023 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testhelpers

import (
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

// ProxyHandler is a http.Handler that implements a simple HTTP proxy server.
type ProxyHandler struct {
sync.Mutex
count int
}

// ServeHTTP only accepts the CONNECT verb and will tunnel your connection to
// the specified host. Also tracks the number of connections that it proxies for
// debugging purposes.
func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Validate http connect parameters.
if r.Method != http.MethodConnect {
trace.WriteError(w, trace.BadParameter("%v not supported", r.Method))
return
}
if r.Host == "" {
trace.WriteError(w, trace.BadParameter("host not set"))
return
}

// Dial to the target host, this is done before hijacking the connection to
// ensure the target host is accessible.
dialer := net.Dialer{}
dconn, err := dialer.DialContext(r.Context(), "tcp", r.Host)
if err != nil {
trace.WriteError(w, err)
return
}
defer dconn.Close()

// Once the client receives 200 OK, the rest of the data will no longer be
// http, but whatever protocol is being tunneled.
w.WriteHeader(http.StatusOK)

// Hijack request so we can get underlying connection.
hj, ok := w.(http.Hijacker)
if !ok {
trace.WriteError(w, trace.AccessDenied("unable to hijack connection"))
return
}
sconn, buf, err := hj.Hijack()
if err != nil {
trace.WriteError(w, err)
return
}
defer sconn.Close()

// Success, we're proxying data now.
p.Lock()
p.count++
p.Unlock()

// Copy from src to dst and dst to src.
errc := make(chan error, 2)
replicate := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errc <- err
}
go replicate(sconn, dconn)
go replicate(dconn, io.MultiReader(buf, sconn))

// Wait until done, error, or 10 second.
select {
case <-time.After(10 * time.Second):
case <-errc:
}
}

// Count returns the number of requests that have been proxied.
func (p *ProxyHandler) Count() int {
p.Lock()
defer p.Unlock()
return p.count
}

// Reset sets the counter for proxied requests to zero.
func (p *ProxyHandler) Reset() {
p.Lock()
defer p.Unlock()
p.count = 0
}

// GetLocalIP gets the non-loopback IP address of this host.
func GetLocalIP() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", trace.Wrap(err)
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if !ip.IsLoopback() && ip.IsPrivate() {
return ip.String(), nil
}
}
return "", trace.NotFound("No non-loopback local IP address found")
}

type TestServerOption func(*testing.T, *httptest.Server)

func WithTestServerAddress(ip string) TestServerOption {
return func(t *testing.T, srv *httptest.Server) {
// Replace the test server's address.
_, originalPort, err := net.SplitHostPort(srv.Listener.Addr().String())
require.NoError(t, err)
require.NoError(t, srv.Listener.Close())
l, err := net.Listen("tcp", net.JoinHostPort(ip, originalPort))
require.NoError(t, err)
srv.Listener = l
}
}

func MakeTestServer(t *testing.T, h http.Handler, opts ...TestServerOption) *httptest.Server {
svr := httptest.NewUnstartedServer(h)
for _, opt := range opts {
opt(t, svr)
}
svr.StartTLS()
t.Cleanup(svr.Close)
return svr
}
23 changes: 0 additions & 23 deletions integration/helpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,29 +177,6 @@ func CloseAgent(teleAgent *teleagent.AgentServer, socketDirPath string) error {
return nil
}

// GetLocalIP gets the non-loopback IP address of this host.
func GetLocalIP() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", trace.Wrap(err)
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if !ip.IsLoopback() && ip.IsPrivate() {
return ip.String(), nil
}
}
return "", trace.NotFound("No non-loopback local IP address found")
}

func MustCreateUserIdentityFile(t *testing.T, tc *TeleInstance, username string, ttl time.Duration) string {
key, err := libclient.GenerateRSAKey()
require.NoError(t, err)
Expand Down
3 changes: 2 additions & 1 deletion integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import (
apidefaults "github.com/gravitational/teleport/api/defaults"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/api/profile"
apihelpers "github.com/gravitational/teleport/api/testhelpers"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
apiutils "github.com/gravitational/teleport/api/utils"
Expand Down Expand Up @@ -2368,7 +2369,7 @@ func testTwoClustersProxy(t *testing.T, suite *integrationTestSuite) {

// httpproxy doesn't allow proxying when the target is localhost, so use
// this address instead.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)
a := suite.newNamedTeleportInstance(t, "site-A",
WithNodeName(addr),
Expand Down
9 changes: 5 additions & 4 deletions integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (

"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/client"
apihelpers "github.com/gravitational/teleport/api/testhelpers"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integration/appaccess"
dbhelpers "github.com/gravitational/teleport/integration/db"
Expand Down Expand Up @@ -228,7 +229,7 @@ func TestALPNSNIHTTPSProxy(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

suite := newSuite(t,
Expand Down Expand Up @@ -269,7 +270,7 @@ func TestMultiPortHTTPSProxy(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

suite := newSuite(t,
Expand Down Expand Up @@ -1211,7 +1212,7 @@ func TestALPNProxyHTTPProxyNoProxyDial(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
addr, err := helpers.GetLocalIP()
addr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

instanceCfg := helpers.InstanceConfig{
Expand Down Expand Up @@ -1290,7 +1291,7 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) {
// We need to use the non-loopback address for our Teleport cluster, as the
// Go HTTP library will recognize requests to the loopback address and
// refuse to use the HTTP proxy, which will invalidate the test.
rcAddr, err := helpers.GetLocalIP()
rcAddr, err := apihelpers.GetLocalIP()
require.NoError(t, err)

log.Info("Creating Teleport instance...")
Expand Down
Loading