Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/srv/alpnproxy/local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ func (l *LocalProxy) StartHTTPAccessProxy(ctx context.Context) error {
// localhost. Set appropriate header to keep this information.
if addr, err := utils.ParseAddr(req.Host); err == nil && !addr.IsLocal() {
req.Header.Set("X-Forwarded-Host", req.Host)
} else { // ensure that there is no client provided X-Forwarded-Host
req.Header.Del("X-Forwarded-Host")
}

proxy, err := l.getHTTPReverseProxyForReq(req, defaultProxy)
Expand Down
4 changes: 3 additions & 1 deletion lib/srv/app/aws/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
"github.com/gravitational/trace"

awsapiutils "github.com/gravitational/teleport/api/utils/aws"
libutils "github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)

Expand All @@ -61,7 +62,8 @@ import (
// endpoint.
func resolveEndpoint(r *http.Request) (*endpoints.ResolvedEndpoint, error) {
// Use X-Forwarded-Host header if it is a valid AWS endpoint.
if awsapiutils.IsAWSEndpoint(r.Header.Get("X-Forwarded-Host")) {
forwardedHost, headErr := libutils.GetSingleHeader(r.Header, "X-Forwarded-Host")
if headErr == nil && awsapiutils.IsAWSEndpoint(forwardedHost) {
re, err := resolveEndpointByXForwardedHost(r, awsutils.AuthorizationHeader)
return re, trace.Wrap(err)
}
Expand Down
6 changes: 4 additions & 2 deletions lib/srv/app/azure/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ func (s *handler) formatForwardResponseError(rw http.ResponseWriter, r *http.Req

// prepareForwardRequest prepares a request for forwarding, updating headers and target host. Several checks are made along the way.
func (s *handler) prepareForwardRequest(r *http.Request, sessionCtx *common.SessionContext) (*http.Request, error) {
forwardedHost := r.Header.Get("X-Forwarded-Host")
if !azure.IsAzureEndpoint(forwardedHost) {
forwardedHost, err := utils.GetSingleHeader(r.Header, "X-Forwarded-Host")
if err != nil {
return nil, trace.AccessDenied(err.Error())
} else if !azure.IsAzureEndpoint(forwardedHost) {
return nil, trace.AccessDenied("%q is not an Azure endpoint", forwardedHost)
}

Expand Down
6 changes: 4 additions & 2 deletions lib/srv/app/gcp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,10 @@ func (s *handler) formatForwardResponseError(rw http.ResponseWriter, r *http.Req

// prepareForwardRequest prepares a request for forwarding, updating headers and target host. Several checks are made along the way.
func (s *handler) prepareForwardRequest(r *http.Request, sessionCtx *common.SessionContext) (*http.Request, error) {
forwardedHost := r.Header.Get("X-Forwarded-Host")
if !gcp.IsGCPEndpoint(forwardedHost) {
forwardedHost, err := utils.GetSingleHeader(r.Header, "X-Forwarded-Host")
if err != nil {
return nil, trace.AccessDenied(err.Error())
} else if !gcp.IsGCPEndpoint(forwardedHost) {
return nil, trace.AccessDenied("%q is not a GCP endpoint", forwardedHost)
}

Expand Down
13 changes: 13 additions & 0 deletions lib/utils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ func GetAnyHeader(header http.Header, keys ...string) string {
return ""
}

// GetSingleHeader will return the header value for the key if there is exactly one value present. If the header is
// missing or specified multiple times, an error will be returned.
func GetSingleHeader(headers http.Header, key string) (string, error) {
values := headers.Values(key)
if len(values) > 1 {
return "", trace.BadParameter("multiple %q headers", key)
} else if len(values) == 0 {
return "", trace.NotFound("missing %q headers", key)
} else {
return values[0], nil
}
}

// HTTPDoClient is an interface that defines the Do function of http.Client.
type HTTPDoClient interface {
Do(req *http.Request) (*http.Response, error)
Expand Down
47 changes: 47 additions & 0 deletions lib/utils/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -54,6 +55,52 @@ func TestGetAnyHeader(t *testing.T) {
require.Equal(t, "b1", GetAnyHeader(header, "bbb", "aaa"))
}

func TestGetSingleHeader(t *testing.T) {
t.Run("NoValue", func(t *testing.T) {
t.Parallel()
headers := make(http.Header)

result, err := GetSingleHeader(headers, "key")
require.Empty(t, result)
require.Error(t, err)
})
t.Run("SingleValue", func(t *testing.T) {
t.Parallel()
headers := make(http.Header)
key := "key"
value := "value"
headers.Set(key, value)

result, err := GetSingleHeader(headers, key)
require.NoError(t, err)
require.Equal(t, value, result)
})
t.Run("DuplicateValue", func(t *testing.T) {
t.Parallel()
headers := make(http.Header)
key := "key"
value := "value1"
headers.Add(key, value)
headers.Add(key, "value2")

result, err := GetSingleHeader(headers, key)
require.Empty(t, result)
require.Error(t, err)
})
t.Run("DuplicateCaseValue", func(t *testing.T) {
t.Parallel()
headers := make(http.Header)
key := "key"
value := "value1"
headers.Add(key, value)
headers.Add(strings.ToUpper(key), "value2")

result, err := GetSingleHeader(headers, key)
require.Empty(t, result)
require.Error(t, err)
})
}

func TestChainHTTPMiddlewares(t *testing.T) {
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("baseHandler"))
Expand Down