Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions lib/httplib/httplib.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net/url"
"regexp"
"strconv"
"strings"

"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -249,14 +250,24 @@ func RewritePaths(next http.Handler, rewrites ...RewritePair) http.Handler {
})
}

// SafeRedirect performs a relative redirect to the URI part of the provided redirect URL
func SafeRedirect(w http.ResponseWriter, r *http.Request, redirectURL string) error {
// OriginLocalRedirectURI will take an incoming URL including optionally the host and scheme and return the URI
// associated with the URL. Additionally, it will ensure that the URI does not include any techniques potentially
// used to redirect to a different origin.
func OriginLocalRedirectURI(redirectURL string) (string, error) {
parsedURL, err := url.Parse(redirectURL)
if err != nil {
return trace.Wrap(err)
return "", trace.Wrap(err)
} else if parsedURL.IsAbs() && (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
return "", trace.BadParameter("Invalid scheme: %s", parsedURL.Scheme)
}
http.Redirect(w, r, parsedURL.RequestURI(), http.StatusFound)
return nil

resultURI := parsedURL.RequestURI()
if strings.HasPrefix(resultURI, "//") {
return "", trace.BadParameter("Invalid double slash redirect")
} else if strings.Contains(resultURI, "@") {
return "", trace.BadParameter("Basic Auth not allowed in redirect")
}
return resultURI, nil
}

// ResponseStatusRecorder is an http.ResponseWriter that records the response status code.
Expand Down
68 changes: 68 additions & 0 deletions lib/httplib/httplib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,71 @@ func TestSetRedirectPageContentSecurityPolicy(t *testing.T) {
require.Contains(t, actualCsp, expectedCspSubString)
}
}

func TestOriginLocalRedirectURI(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
input string
expected string
errCheck require.ErrorAssertionFunc
}{
{
name: "empty",
input: "",
expected: "/",
errCheck: require.NoError,
},
{
name: "simple path",
input: "/foo",
expected: "/foo",
errCheck: require.NoError,
},
{
name: "host only",
input: "https://localhost",
expected: "/",
errCheck: require.NoError,
},
{
name: "host and simple path",
input: "https://localhost/bar",
expected: "/bar",
errCheck: require.NoError,
},
{
name: "double slash redirect with host",
input: "https://localhost//goteleport.com/",
expected: "",
errCheck: require.Error,
},
{
name: "basic auth redirect with host",
input: "https://localhost/@goteleport.com/",
expected: "",
errCheck: require.Error,
},
{
name: "ftp scheme",
input: "ftp://localhost",
expected: "",
errCheck: require.Error,
},
{
name: "invalid url",
input: "https://foo com",
expected: "",
errCheck: require.Error,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := OriginLocalRedirectURI(tc.input)
require.Equal(t, tc.expected, result)
tc.errCheck(t, err)
})
}
}
5 changes: 2 additions & 3 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4027,12 +4027,11 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp
return trace.Wrap(err)
}

parsedURL, err := url.Parse(response.ClientRedirectURL)
parsedRedirectURL, err := httplib.OriginLocalRedirectURI(response.ClientRedirectURL)
if err != nil {
return trace.Wrap(err)
}

response.ClientRedirectURL = parsedURL.RequestURI()
response.ClientRedirectURL = parsedRedirectURL

return nil
}
Expand Down