diff --git a/lib/web/device_trust.go b/lib/web/device_trust.go index dce54528bdcc7..e3f66eb471875 100644 --- a/lib/web/device_trust.go +++ b/lib/web/device_trust.go @@ -82,30 +82,38 @@ func (h *Handler) deviceWebConfirm(w http.ResponseWriter, r *http.Request, _ htt // Always redirect back to the dashboard, regardless of outcome. app.SetRedirectPageHeaders(w.Header(), "" /* nonce */) - redirectTo, err := h.getRedirectPath(unsafeRedirectURI) + redirectTo, err := h.getRedirectURL(r.Host, unsafeRedirectURI) if err != nil { h.log. WithError(err). WithField("redirect_uri", unsafeRedirectURI). Debug("Unable to parse redirectURI") + http.Error(w, http.StatusText(trace.ErrorToCode(err)), trace.ErrorToCode(err)) + return nil, nil } http.Redirect(w, r, redirectTo, http.StatusSeeOther) return nil, nil } -// getRedirectPath tries to parse the given redirectURI. It will always return a redirect url -// even if the parse fails (in case of failture, the returned string is "/web") -func (h *Handler) getRedirectPath(redirectURI string) (string, error) { - const basePath = "/web" - - if redirectURI == "" { +// getRedirectPath tries to parse the given unsafeRedirectURI. +// It returns a full URL if the unsafeRedirectURI points to SAML IdP SSO endpoint. +// In any other case, as long as the redirect URL is parsable, it returns +// a path ensuring its prefixed with "/web". +func (h *Handler) getRedirectURL(host, unsafeRedirectURI string) (string, error) { + const ( + basePath = "/web" + samlSPInitiatedSSOPath = "/enterprise/saml-idp/sso" + samlIDPInitiatedSSOPath = "/enterprise/saml-idp/login" + ) + + if unsafeRedirectURI == "" { return basePath, nil } - parsedURL, err := url.Parse(redirectURI) + parsedURL, err := url.Parse(unsafeRedirectURI) if err != nil { - return basePath, trace.Wrap(err) + return basePath, trace.BadParameter("invalid redirect URL") } cleanPath := path.Clean(parsedURL.Path) @@ -116,6 +124,25 @@ func (h *Handler) getRedirectPath(redirectURI string) (string, error) { cleanPath = "/" + cleanPath } + // IDP initiated SSO path format: "/enterprise/saml-idp/login/" + isIdpInitiatedSSOPath := strings.HasPrefix(cleanPath, samlIDPInitiatedSSOPath) && len(strings.Split(cleanPath, "/")) == 5 + if cleanPath == samlSPInitiatedSSOPath || isIdpInitiatedSSOPath { + if parsedURL.Host != host { + return "", trace.BadParameter("host mismatch") + } + path := samlSPInitiatedSSOPath + if isIdpInitiatedSSOPath { + path = cleanPath + } + ensuredURL := &url.URL{ + Scheme: "https", + Host: host, + Path: path, + RawQuery: parsedURL.RawQuery, + } + return ensuredURL.String(), nil + } + // Prepend "/web" only if it's not already present if !strings.HasPrefix(cleanPath, basePath) { return path.Join(basePath, cleanPath), nil diff --git a/lib/web/device_trust_test.go b/lib/web/device_trust_test.go index dcce9415fdef8..7a6f21a9887f6 100644 --- a/lib/web/device_trust_test.go +++ b/lib/web/device_trust_test.go @@ -18,6 +18,7 @@ package web import ( "context" + "fmt" "io" "net/http" "net/url" @@ -34,70 +35,102 @@ import ( ) func TestHandler_DeviceWebConfirm(t *testing.T) { - t.Parallel() + ctx := context.Background() + fakeDevices := &fakeDevicesClient{} + wPack := newWebPack( + t, + 1, /* numProxies */ + withDevicesClientOverride(fakeDevices), + ) + proxy := wPack.proxies[0] + + aPack := proxy.authPack(t, "llama", nil /* roles */) + webClient := aPack.clt tests := []struct { name string redirectURI string expectedRedirectTo string + redirectsToFullURL bool + statusCode int }{ { name: "no redirect_uri", redirectURI: "", expectedRedirectTo: "/web", + statusCode: http.StatusSeeOther, }, { name: "with redirect_uri", redirectURI: "https://example.com/web/custom/path", expectedRedirectTo: "/web/custom/path", + statusCode: http.StatusSeeOther, }, { name: "with app access redirect_uri", redirectURI: "https://example.com/web/launch/myapp.example.com", expectedRedirectTo: "/web/launch/myapp.example.com", + statusCode: http.StatusSeeOther, }, { - name: "with invalid redirect_uri", - redirectURI: "://invalid", - expectedRedirectTo: "/web", + name: "with invalid redirect_uri", + redirectURI: "://invalid", + statusCode: http.StatusBadRequest, }, { name: "with external redirect_uri", redirectURI: "https://example.com/path", expectedRedirectTo: "/web/path", + statusCode: http.StatusSeeOther, }, { name: "with empty path redirect_uri", redirectURI: "https://example.com", expectedRedirectTo: "/web", + statusCode: http.StatusSeeOther, }, { name: "with relative path", redirectURI: "/custom/path", expectedRedirectTo: "/web/custom/path", + statusCode: http.StatusSeeOther, }, { name: "with web prefix already", redirectURI: "/web/existing/path", expectedRedirectTo: "/web/existing/path", + statusCode: http.StatusSeeOther, + }, + { + name: "saml idp service provider initiated sso endpoint", + redirectURI: fmt.Sprintf("https://%s/enterprise/saml-idp/sso?SAMLRequest=example-authn-request", proxy.webURL.Host), + expectedRedirectTo: fmt.Sprintf("https://%s/enterprise/saml-idp/sso?SAMLRequest=example-authn-request", proxy.webURL.Host), + redirectsToFullURL: true, + statusCode: http.StatusSeeOther, + }, + { + name: "saml idp identity provider initiated sso endpoint", + redirectURI: fmt.Sprintf("https://%s/enterprise/saml-idp/login/example-app", proxy.webURL.Host), + expectedRedirectTo: fmt.Sprintf("https://%s/enterprise/saml-idp/login/example-app", proxy.webURL.Host), + redirectsToFullURL: true, + statusCode: http.StatusSeeOther, + }, + { + name: "saml idp sso endpoint with redirect_uri pointing to a different host", + redirectURI: "https://example.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request", + redirectsToFullURL: true, + statusCode: http.StatusBadRequest, + }, + { + name: "saml idp sso endpoint with redirect_uri pointing to a malformed URL", + redirectURI: "https://%s.//example.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request", + redirectsToFullURL: true, + statusCode: http.StatusBadRequest, }, } for _, test := range tests { - fakeDevices := &fakeDevicesClient{} - wPack := newWebPack( - t, - 1, /* numProxies */ - withDevicesClientOverride(fakeDevices), - ) - proxy := wPack.proxies[0] - aPack := proxy.authPack(t, "llama", nil /* roles */) - webClient := aPack.clt - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ctx := context.Background() - query := make(url.Values) query.Set("id", "my-token-id") query.Set("token", "my-token-token") @@ -109,11 +142,12 @@ func TestHandler_DeviceWebConfirm(t *testing.T) { var actualRedirectTo string httpClient := webClient.HTTPClient() httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - if !redirected { - redirected = true - actualRedirectTo = req.URL.Path + redirected = true + actualRedirectTo = req.URL.Path + if test.redirectsToFullURL { + actualRedirectTo = req.URL.String() } - return nil + return http.ErrUseLastResponse } req, err := http.NewRequestWithContext(ctx, "GET", webClient.Endpoint("webapi", "devices", "webconfirm"), nil) @@ -125,9 +159,11 @@ func TestHandler_DeviceWebConfirm(t *testing.T) { io.Copy(io.Discard, resp.Body) resp.Body.Close() - assert.True(t, redirected, "GET /webapi/devices/webconfirm didn't cause a redirect") - assert.Equal(t, http.StatusOK, resp.StatusCode, "GET /webapi/devices/webconfirm code mismatch") - assert.Equal(t, test.expectedRedirectTo, actualRedirectTo, "Redirect destination mismatch") + assert.Equal(t, test.statusCode, resp.StatusCode, "GET /webapi/devices/webconfirm status code mismatch") + if test.expectedRedirectTo != "" { + assert.True(t, redirected, "GET /webapi/devices/webconfirm didn't cause a redirect") + assert.Equal(t, test.expectedRedirectTo, actualRedirectTo, "Redirect destination mismatch") + } got := fakeDevices.resetConfirmRequests() want := []*devicepb.ConfirmDeviceWebAuthenticationRequest{ diff --git a/web/packages/shared/redirects/processRedirectUri.test.ts b/web/packages/shared/redirects/processRedirectUri.test.ts index 18c377388030c..b0ebce9c92dcb 100644 --- a/web/packages/shared/redirects/processRedirectUri.test.ts +++ b/web/packages/shared/redirects/processRedirectUri.test.ts @@ -71,6 +71,23 @@ describe('processRedirectURI', () => { input: '/web/existing/path', expected: '/web/existing/path', }, + { + name: 'saml idp service provider initiated SSO URL', + input: + 'https://example.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request', + expected: '/enterprise/saml-idp/sso?SAMLRequest=example-authn-request', + }, + { + name: 'malformed URL', + input: + 'https://example.//attacker.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request', + expected: '/web//attacker.com/enterprise/saml-idp/sso', + }, + { + name: 'saml idp identity provider initiated SSO URL', + input: 'https://example.com/enterprise/saml-idp/login/example-app', + expected: '/enterprise/saml-idp/login/example-app', + }, ]; test.each(tests)('$name', ({ input, expected }) => { diff --git a/web/packages/shared/redirects/processRedirectUri.ts b/web/packages/shared/redirects/processRedirectUri.ts index c0abd0e9ecbd3..e602564b2c00e 100644 --- a/web/packages/shared/redirects/processRedirectUri.ts +++ b/web/packages/shared/redirects/processRedirectUri.ts @@ -17,6 +17,8 @@ */ const BASE_PATH = '/web'; +const SAML_SP_INITIATED_SSO_PATH = '/enterprise/saml-idp/sso'; +const SAML_IDP_INITIATED_SSO_PATH = '/enterprise/saml-idp/login'; /** * Processes a redirect URI to ensure it's valid and follows the expected format. @@ -33,12 +35,11 @@ const BASE_PATH = '/web'; * @returns A processed URI string that always starts with the basePath, unless it's an invalid input. * * @example - * processRedirectURI('/web', null) // returns '/web' - * processRedirectURI('/web', 'https://example.com/path') // returns '/web/path' - * processRedirectURI('/web', '/custom/path') // returns '/web/custom/path' - * processRedirectURI('/web', '/web/existing/path') // returns '/web/existing/path' - * processRedirectURI('/web', 'invalid://url') // returns '/web' - * processRedirectURI('/app', 'https://example.com/path') // returns '/app/path' + * processRedirectURI(null) // returns '/web' + * processRedirectURI('https://example.com/path') // returns '/web/path' + * processRedirectURI('/custom/path') // returns '/web/custom/path' + * processRedirectURI('/web/existing/path') // returns '/web/existing/path' + * processRedirectURI('invalid://url') // returns '/web' */ export function processRedirectUri(redirectUri: string | null): string { // should be equal to cfg.routes.root @@ -53,6 +54,13 @@ export function processRedirectUri(redirectUri: string | null): string { return path; } + if ( + path.startsWith(SAML_IDP_INITIATED_SSO_PATH) || + path.startsWith(SAML_SP_INITIATED_SSO_PATH) + ) { + return path + url.search; + } + if (path === '/') { return BASE_PATH; }