Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions lib/web/device_trust.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,38 @@ func (h *Handler) deviceWebConfirm(w http.ResponseWriter, r *http.Request, _ htt
// Always redirect back to the dashboard, regardless of outcome.
app.SetRedirectPageHeaders(w.Header(), "" /* nonce */)

redirectTo, err := h.getRedirectPath(unsafeRedirectURI)
redirectTo, err := h.getRedirectURL(r.Host, unsafeRedirectURI)
if err != nil {
h.log.
WithError(err).
WithField("redirect_uri", unsafeRedirectURI).
Debug("Unable to parse redirectURI")
http.Error(w, http.StatusText(trace.ErrorToCode(err)), trace.ErrorToCode(err))
return nil, nil
}
http.Redirect(w, r, redirectTo, http.StatusSeeOther)

return nil, nil
}

// getRedirectPath tries to parse the given redirectURI. It will always return a redirect url
// even if the parse fails (in case of failture, the returned string is "/web")
func (h *Handler) getRedirectPath(redirectURI string) (string, error) {
const basePath = "/web"

if redirectURI == "" {
// getRedirectPath tries to parse the given unsafeRedirectURI.
// It returns a full URL if the unsafeRedirectURI points to SAML IdP SSO endpoint.
// In any other case, as long as the redirect URL is parsable, it returns
// a path ensuring its prefixed with "/web".
func (h *Handler) getRedirectURL(host, unsafeRedirectURI string) (string, error) {
const (
basePath = "/web"
samlSPInitiatedSSOPath = "/enterprise/saml-idp/sso"
samlIDPInitiatedSSOPath = "/enterprise/saml-idp/login"
)

if unsafeRedirectURI == "" {
return basePath, nil
}

parsedURL, err := url.Parse(redirectURI)
parsedURL, err := url.Parse(unsafeRedirectURI)
if err != nil {
return basePath, trace.Wrap(err)
return basePath, trace.BadParameter("invalid redirect URL")
}

cleanPath := path.Clean(parsedURL.Path)
Expand All @@ -116,6 +124,25 @@ func (h *Handler) getRedirectPath(redirectURI string) (string, error) {
cleanPath = "/" + cleanPath
}

// IDP initiated SSO path format: "/enterprise/saml-idp/login/<service provider name>"
isIdpInitiatedSSOPath := strings.HasPrefix(cleanPath, samlIDPInitiatedSSOPath) && len(strings.Split(cleanPath, "/")) == 5
if cleanPath == samlSPInitiatedSSOPath || isIdpInitiatedSSOPath {
if parsedURL.Host != host {
return "", trace.BadParameter("host mismatch")
}
path := samlSPInitiatedSSOPath
if isIdpInitiatedSSOPath {
path = cleanPath
}
ensuredURL := &url.URL{
Scheme: "https",
Host: host,
Path: path,
RawQuery: parsedURL.RawQuery,
}
return ensuredURL.String(), nil
}

// Prepend "/web" only if it's not already present
if !strings.HasPrefix(cleanPath, basePath) {
return path.Join(basePath, cleanPath), nil
Expand Down
84 changes: 60 additions & 24 deletions lib/web/device_trust_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package web

import (
"context"
"fmt"
"io"
"net/http"
"net/url"
Expand All @@ -34,70 +35,102 @@ import (
)

func TestHandler_DeviceWebConfirm(t *testing.T) {
t.Parallel()
ctx := context.Background()
fakeDevices := &fakeDevicesClient{}
wPack := newWebPack(
t,
1, /* numProxies */
withDevicesClientOverride(fakeDevices),
)
proxy := wPack.proxies[0]

aPack := proxy.authPack(t, "llama", nil /* roles */)
webClient := aPack.clt

tests := []struct {
name string
redirectURI string
expectedRedirectTo string
redirectsToFullURL bool
statusCode int
}{
{
name: "no redirect_uri",
redirectURI: "",
expectedRedirectTo: "/web",
statusCode: http.StatusSeeOther,
},
{
name: "with redirect_uri",
redirectURI: "https://example.com/web/custom/path",
expectedRedirectTo: "/web/custom/path",
statusCode: http.StatusSeeOther,
},
{
name: "with app access redirect_uri",
redirectURI: "https://example.com/web/launch/myapp.example.com",
expectedRedirectTo: "/web/launch/myapp.example.com",
statusCode: http.StatusSeeOther,
},
{
name: "with invalid redirect_uri",
redirectURI: "://invalid",
expectedRedirectTo: "/web",
name: "with invalid redirect_uri",
redirectURI: "://invalid",
statusCode: http.StatusBadRequest,
},
{
name: "with external redirect_uri",
redirectURI: "https://example.com/path",
expectedRedirectTo: "/web/path",
statusCode: http.StatusSeeOther,
},
{
name: "with empty path redirect_uri",
redirectURI: "https://example.com",
expectedRedirectTo: "/web",
statusCode: http.StatusSeeOther,
},
{
name: "with relative path",
redirectURI: "/custom/path",
expectedRedirectTo: "/web/custom/path",
statusCode: http.StatusSeeOther,
},
{
name: "with web prefix already",
redirectURI: "/web/existing/path",
expectedRedirectTo: "/web/existing/path",
statusCode: http.StatusSeeOther,
},
{
name: "saml idp service provider initiated sso endpoint",
redirectURI: fmt.Sprintf("https://%s/enterprise/saml-idp/sso?SAMLRequest=example-authn-request", proxy.webURL.Host),
expectedRedirectTo: fmt.Sprintf("https://%s/enterprise/saml-idp/sso?SAMLRequest=example-authn-request", proxy.webURL.Host),
redirectsToFullURL: true,
statusCode: http.StatusSeeOther,
},
{
name: "saml idp identity provider initiated sso endpoint",
redirectURI: fmt.Sprintf("https://%s/enterprise/saml-idp/login/example-app", proxy.webURL.Host),
expectedRedirectTo: fmt.Sprintf("https://%s/enterprise/saml-idp/login/example-app", proxy.webURL.Host),
redirectsToFullURL: true,
statusCode: http.StatusSeeOther,
},
{
name: "saml idp sso endpoint with redirect_uri pointing to a different host",
redirectURI: "https://example.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request",
redirectsToFullURL: true,
statusCode: http.StatusBadRequest,
},
{
name: "saml idp sso endpoint with redirect_uri pointing to a malformed URL",
redirectURI: "https://%s.//example.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request",
redirectsToFullURL: true,
statusCode: http.StatusBadRequest,
},
}

for _, test := range tests {
fakeDevices := &fakeDevicesClient{}
wPack := newWebPack(
t,
1, /* numProxies */
withDevicesClientOverride(fakeDevices),
)
proxy := wPack.proxies[0]
aPack := proxy.authPack(t, "llama", nil /* roles */)
webClient := aPack.clt

t.Run(test.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()

query := make(url.Values)
query.Set("id", "my-token-id")
query.Set("token", "my-token-token")
Expand All @@ -109,11 +142,12 @@ func TestHandler_DeviceWebConfirm(t *testing.T) {
var actualRedirectTo string
httpClient := webClient.HTTPClient()
httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if !redirected {
redirected = true
actualRedirectTo = req.URL.Path
redirected = true
actualRedirectTo = req.URL.Path
if test.redirectsToFullURL {
actualRedirectTo = req.URL.String()
}
return nil
return http.ErrUseLastResponse
}

req, err := http.NewRequestWithContext(ctx, "GET", webClient.Endpoint("webapi", "devices", "webconfirm"), nil)
Expand All @@ -125,9 +159,11 @@ func TestHandler_DeviceWebConfirm(t *testing.T) {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()

assert.True(t, redirected, "GET /webapi/devices/webconfirm didn't cause a redirect")
assert.Equal(t, http.StatusOK, resp.StatusCode, "GET /webapi/devices/webconfirm code mismatch")
assert.Equal(t, test.expectedRedirectTo, actualRedirectTo, "Redirect destination mismatch")
assert.Equal(t, test.statusCode, resp.StatusCode, "GET /webapi/devices/webconfirm status code mismatch")
if test.expectedRedirectTo != "" {
assert.True(t, redirected, "GET /webapi/devices/webconfirm didn't cause a redirect")
assert.Equal(t, test.expectedRedirectTo, actualRedirectTo, "Redirect destination mismatch")
}

got := fakeDevices.resetConfirmRequests()
want := []*devicepb.ConfirmDeviceWebAuthenticationRequest{
Expand Down
17 changes: 17 additions & 0 deletions web/packages/shared/redirects/processRedirectUri.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ describe('processRedirectURI', () => {
input: '/web/existing/path',
expected: '/web/existing/path',
},
{
name: 'saml idp service provider initiated SSO URL',
input:
'https://example.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request',
expected: '/enterprise/saml-idp/sso?SAMLRequest=example-authn-request',
},
{
name: 'malformed URL',
input:
'https://example.//attacker.com/enterprise/saml-idp/sso?SAMLRequest=example-authn-request',
expected: '/web//attacker.com/enterprise/saml-idp/sso',
},
{
name: 'saml idp identity provider initiated SSO URL',
input: 'https://example.com/enterprise/saml-idp/login/example-app',
expected: '/enterprise/saml-idp/login/example-app',
},
];

test.each(tests)('$name', ({ input, expected }) => {
Expand Down
20 changes: 14 additions & 6 deletions web/packages/shared/redirects/processRedirectUri.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/

const BASE_PATH = '/web';
const SAML_SP_INITIATED_SSO_PATH = '/enterprise/saml-idp/sso';
const SAML_IDP_INITIATED_SSO_PATH = '/enterprise/saml-idp/login';

/**
* Processes a redirect URI to ensure it's valid and follows the expected format.
Expand All @@ -33,12 +35,11 @@ const BASE_PATH = '/web';
* @returns A processed URI string that always starts with the basePath, unless it's an invalid input.
*
* @example
* processRedirectURI('/web', null) // returns '/web'
* processRedirectURI('/web', 'https://example.com/path') // returns '/web/path'
* processRedirectURI('/web', '/custom/path') // returns '/web/custom/path'
* processRedirectURI('/web', '/web/existing/path') // returns '/web/existing/path'
* processRedirectURI('/web', 'invalid://url') // returns '/web'
* processRedirectURI('/app', 'https://example.com/path') // returns '/app/path'
* processRedirectURI(null) // returns '/web'
* processRedirectURI('https://example.com/path') // returns '/web/path'
* processRedirectURI('/custom/path') // returns '/web/custom/path'
* processRedirectURI('/web/existing/path') // returns '/web/existing/path'
* processRedirectURI('invalid://url') // returns '/web'
*/
export function processRedirectUri(redirectUri: string | null): string {
// should be equal to cfg.routes.root
Expand All @@ -53,6 +54,13 @@ export function processRedirectUri(redirectUri: string | null): string {
return path;
}

if (
path.startsWith(SAML_IDP_INITIATED_SSO_PATH) ||
path.startsWith(SAML_SP_INITIATED_SSO_PATH)
) {
return path + url.search;
}

if (path === '/') {
return BASE_PATH;
}
Expand Down
Loading