diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 813ae59e7c0a8..26157fab19389 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -314,6 +314,8 @@ func (l *LocalProxy) StartHTTPAccessProxy(ctx context.Context) error { // localhost. Set appropriate header to keep this information. if addr, err := utils.ParseAddr(req.Host); err == nil && !addr.IsLocal() { req.Header.Set("X-Forwarded-Host", req.Host) + } else { // ensure that there is no client provided X-Forwarded-Host + req.Header.Del("X-Forwarded-Host") } proxy, err := l.getHTTPReverseProxyForReq(req, defaultProxy) diff --git a/lib/srv/app/aws/endpoints.go b/lib/srv/app/aws/endpoints.go index e1032df3e011b..b6f4122780c25 100644 --- a/lib/srv/app/aws/endpoints.go +++ b/lib/srv/app/aws/endpoints.go @@ -53,6 +53,7 @@ import ( "github.com/gravitational/trace" awsapiutils "github.com/gravitational/teleport/api/utils/aws" + libutils "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) @@ -61,7 +62,8 @@ import ( // endpoint. func resolveEndpoint(r *http.Request) (*endpoints.ResolvedEndpoint, error) { // Use X-Forwarded-Host header if it is a valid AWS endpoint. - if awsapiutils.IsAWSEndpoint(r.Header.Get("X-Forwarded-Host")) { + forwardedHost, headErr := libutils.GetSingleHeader(r.Header, "X-Forwarded-Host") + if headErr == nil && awsapiutils.IsAWSEndpoint(forwardedHost) { re, err := resolveEndpointByXForwardedHost(r, awsutils.AuthorizationHeader) return re, trace.Wrap(err) } diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 7811106a73b87..159d0d864d9e4 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -167,8 +167,10 @@ func (s *handler) formatForwardResponseError(rw http.ResponseWriter, r *http.Req // prepareForwardRequest prepares a request for forwarding, updating headers and target host. Several checks are made along the way. func (s *handler) prepareForwardRequest(r *http.Request, sessionCtx *common.SessionContext) (*http.Request, error) { - forwardedHost := r.Header.Get("X-Forwarded-Host") - if !azure.IsAzureEndpoint(forwardedHost) { + forwardedHost, err := utils.GetSingleHeader(r.Header, "X-Forwarded-Host") + if err != nil { + return nil, trace.AccessDenied(err.Error()) + } else if !azure.IsAzureEndpoint(forwardedHost) { return nil, trace.AccessDenied("%q is not an Azure endpoint", forwardedHost) } diff --git a/lib/srv/app/gcp/handler.go b/lib/srv/app/gcp/handler.go index 3ca604556ce74..4b687ca8d3917 100644 --- a/lib/srv/app/gcp/handler.go +++ b/lib/srv/app/gcp/handler.go @@ -191,8 +191,10 @@ func (s *handler) formatForwardResponseError(rw http.ResponseWriter, r *http.Req // prepareForwardRequest prepares a request for forwarding, updating headers and target host. Several checks are made along the way. func (s *handler) prepareForwardRequest(r *http.Request, sessionCtx *common.SessionContext) (*http.Request, error) { - forwardedHost := r.Header.Get("X-Forwarded-Host") - if !gcp.IsGCPEndpoint(forwardedHost) { + forwardedHost, err := utils.GetSingleHeader(r.Header, "X-Forwarded-Host") + if err != nil { + return nil, trace.AccessDenied(err.Error()) + } else if !gcp.IsGCPEndpoint(forwardedHost) { return nil, trace.AccessDenied("%q is not a GCP endpoint", forwardedHost) } diff --git a/lib/utils/http.go b/lib/utils/http.go index fae574b37e503..e0717bf170c54 100644 --- a/lib/utils/http.go +++ b/lib/utils/http.go @@ -117,6 +117,19 @@ func GetAnyHeader(header http.Header, keys ...string) string { return "" } +// GetSingleHeader will return the header value for the key if there is exactly one value present. If the header is +// missing or specified multiple times, an error will be returned. +func GetSingleHeader(headers http.Header, key string) (string, error) { + values := headers.Values(key) + if len(values) > 1 { + return "", trace.BadParameter("multiple %q headers", key) + } else if len(values) == 0 { + return "", trace.NotFound("missing %q headers", key) + } else { + return values[0], nil + } +} + // HTTPDoClient is an interface that defines the Do function of http.Client. type HTTPDoClient interface { Do(req *http.Request) (*http.Response, error) diff --git a/lib/utils/http_test.go b/lib/utils/http_test.go index df9b0b9cb1c0b..bf5d57a328468 100644 --- a/lib/utils/http_test.go +++ b/lib/utils/http_test.go @@ -20,6 +20,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/require" @@ -54,6 +55,52 @@ func TestGetAnyHeader(t *testing.T) { require.Equal(t, "b1", GetAnyHeader(header, "bbb", "aaa")) } +func TestGetSingleHeader(t *testing.T) { + t.Run("NoValue", func(t *testing.T) { + t.Parallel() + headers := make(http.Header) + + result, err := GetSingleHeader(headers, "key") + require.Empty(t, result) + require.Error(t, err) + }) + t.Run("SingleValue", func(t *testing.T) { + t.Parallel() + headers := make(http.Header) + key := "key" + value := "value" + headers.Set(key, value) + + result, err := GetSingleHeader(headers, key) + require.NoError(t, err) + require.Equal(t, value, result) + }) + t.Run("DuplicateValue", func(t *testing.T) { + t.Parallel() + headers := make(http.Header) + key := "key" + value := "value1" + headers.Add(key, value) + headers.Add(key, "value2") + + result, err := GetSingleHeader(headers, key) + require.Empty(t, result) + require.Error(t, err) + }) + t.Run("DuplicateCaseValue", func(t *testing.T) { + t.Parallel() + headers := make(http.Header) + key := "key" + value := "value1" + headers.Add(key, value) + headers.Add(strings.ToUpper(key), "value2") + + result, err := GetSingleHeader(headers, key) + require.Empty(t, result) + require.Error(t, err) + }) +} + func TestChainHTTPMiddlewares(t *testing.T) { baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("baseHandler"))