diff --git a/lib/httplib/httplib.go b/lib/httplib/httplib.go index 0e7447af7a2c6..8c2b711511b7a 100644 --- a/lib/httplib/httplib.go +++ b/lib/httplib/httplib.go @@ -27,6 +27,7 @@ import ( "net/url" "regexp" "strconv" + "strings" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" @@ -249,14 +250,24 @@ func RewritePaths(next http.Handler, rewrites ...RewritePair) http.Handler { }) } -// SafeRedirect performs a relative redirect to the URI part of the provided redirect URL -func SafeRedirect(w http.ResponseWriter, r *http.Request, redirectURL string) error { +// OriginLocalRedirectURI will take an incoming URL including optionally the host and scheme and return the URI +// associated with the URL. Additionally, it will ensure that the URI does not include any techniques potentially +// used to redirect to a different origin. +func OriginLocalRedirectURI(redirectURL string) (string, error) { parsedURL, err := url.Parse(redirectURL) if err != nil { - return trace.Wrap(err) + return "", trace.Wrap(err) + } else if parsedURL.IsAbs() && (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { + return "", trace.BadParameter("Invalid scheme: %s", parsedURL.Scheme) } - http.Redirect(w, r, parsedURL.RequestURI(), http.StatusFound) - return nil + + resultURI := parsedURL.RequestURI() + if strings.HasPrefix(resultURI, "//") { + return "", trace.BadParameter("Invalid double slash redirect") + } else if strings.Contains(resultURI, "@") { + return "", trace.BadParameter("Basic Auth not allowed in redirect") + } + return resultURI, nil } // ResponseStatusRecorder is an http.ResponseWriter that records the response status code. diff --git a/lib/httplib/httplib_test.go b/lib/httplib/httplib_test.go index 93fd3a1aaa95a..e6404e1980f24 100644 --- a/lib/httplib/httplib_test.go +++ b/lib/httplib/httplib_test.go @@ -407,3 +407,71 @@ func TestSetRedirectPageContentSecurityPolicy(t *testing.T) { require.Contains(t, actualCsp, expectedCspSubString) } } + +func TestOriginLocalRedirectURI(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input string + expected string + errCheck require.ErrorAssertionFunc + }{ + { + name: "empty", + input: "", + expected: "/", + errCheck: require.NoError, + }, + { + name: "simple path", + input: "/foo", + expected: "/foo", + errCheck: require.NoError, + }, + { + name: "host only", + input: "https://localhost", + expected: "/", + errCheck: require.NoError, + }, + { + name: "host and simple path", + input: "https://localhost/bar", + expected: "/bar", + errCheck: require.NoError, + }, + { + name: "double slash redirect with host", + input: "https://localhost//goteleport.com/", + expected: "", + errCheck: require.Error, + }, + { + name: "basic auth redirect with host", + input: "https://localhost/@goteleport.com/", + expected: "", + errCheck: require.Error, + }, + { + name: "ftp scheme", + input: "ftp://localhost", + expected: "", + errCheck: require.Error, + }, + { + name: "invalid url", + input: "https://foo com", + expected: "", + errCheck: require.Error, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := OriginLocalRedirectURI(tc.input) + require.Equal(t, tc.expected, result) + tc.errCheck(t, err) + }) + } +} diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 9b89efbd711de..68a21b4c27ed4 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -4253,12 +4253,11 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp return trace.Wrap(err) } - parsedURL, err := url.Parse(response.ClientRedirectURL) + parsedRedirectURL, err := httplib.OriginLocalRedirectURI(response.ClientRedirectURL) if err != nil { return trace.Wrap(err) } - - response.ClientRedirectURL = parsedURL.RequestURI() + response.ClientRedirectURL = parsedRedirectURL return nil }