Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand All @@ -23,6 +25,11 @@ import (
"github.com/dexidp/dex/storage"
)

const (
CodeChallengeMethodPlain = "plain"
CodeChallengeMethodS256 = "S256"
)

// newHealthChecker returns the healthz handler. The handler runs until the
// provided context is canceled.
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
Expand Down Expand Up @@ -633,6 +640,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI,
ConnectorData: authReq.ConnectorData,
CodeChallenge: authReq.CodeChallenge,
}
if err := s.storage.CreateAuthCode(code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err)
Expand Down Expand Up @@ -745,12 +753,17 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
}
return
}
if client.Secret != clientSecret {

grantType := r.PostFormValue("grant_type")
codeVerifier := r.PostFormValue("code_verifier")

if grantType == grantTypeAuthorizationCode && codeVerifier != "" {
// RFC 7636 (PKCE) if code_verifier is received, use PKCE and not the client_secret
} else if client.Secret != clientSecret {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}

grantType := r.PostFormValue("grant_type")
switch grantType {
case grantTypeAuthorizationCode:
s.handleAuthCode(w, r, client)
Expand All @@ -763,6 +776,19 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
}
}

func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string) (string, error) {
switch codeChallengeMethod {
case CodeChallengeMethodPlain:
return codeVerifier, nil
case CodeChallengeMethodS256:
shaSum := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(shaSum[:]), nil
default:
s.logger.Errorf("unknown challenge method (%v)", codeChallengeMethod)
return "", fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod)
}
}

// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("code")
Expand All @@ -779,6 +805,32 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}

// RFC 7636 (PKCE)
codeChallengeFromStorage := authCode.CodeChallenge.CodeChallenge
providedCodeVerifier := r.PostFormValue("code_verifier")

if providedCodeVerifier != "" && codeChallengeFromStorage != "" {
providedCodeVerifier := r.PostFormValue("code_verifier")
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, authCode.CodeChallenge.CodeChallengeMethod)
if err != nil {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

if codeChallengeFromStorage != calculatedCodeChallenge {
s.tokenErrHelper(w, errInvalidGrant, "Invalid code_verifier.", http.StatusBadRequest)
return
}
} else if providedCodeVerifier != "" {
// Received no code_challenge on /auth, but a code_verifier on /token
s.tokenErrHelper(w, errInvalidRequest, "No PKCE flow started. Cannot check code_verifier.", http.StatusBadRequest)
return
} else if codeChallengeFromStorage != "" {
// Received PKCE request on /auth, but no code_verifier on /token
s.tokenErrHelper(w, errInvalidGrant, "Expecting parameter code_verifier in PKCE flow.", http.StatusBadRequest)
return
}

if authCode.RedirectURI != redirectURI {
s.tokenErrHelper(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
return
Expand Down
16 changes: 16 additions & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,18 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
scopes := strings.Fields(q.Get("scope"))
responseTypes := strings.Fields(q.Get("response_type"))

codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")

if codeChallengeMethod == "" {
codeChallengeMethod = CodeChallengeMethodPlain
}

if codeChallengeMethod != CodeChallengeMethodS256 && codeChallengeMethod != CodeChallengeMethodPlain {
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
return nil, &authErr{"", "", errInvalidRequest, description}
}

client, err := s.storage.GetClient(clientID)
if err != nil {
if err == storage.ErrNotFound {
Expand Down Expand Up @@ -526,6 +538,10 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
RedirectURI: redirectURI,
ResponseTypes: responseTypes,
ConnectorID: connectorID,
CodeChallenge: storage.CodeChallenge{
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
},
}, nil
}

Expand Down
190 changes: 165 additions & 25 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
}
s.URL = config.Issuer

connector := storage.Connector{
connector1 := storage.Connector{
ID: "mock",
Type: "mockCallback",
Name: "Mock",
ResourceVersion: "1",
}
if err := config.Storage.CreateConnector(connector); err != nil {
if err := config.Storage.CreateConnector(connector1); err != nil {
t.Fatalf("create connector: %v", err)
}

Expand Down Expand Up @@ -137,7 +137,7 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo
}
s.URL = config.Issuer

connector := storage.Connector{
connector1 := storage.Connector{
ID: "mock",
Type: "mockCallback",
Name: "Mock",
Expand All @@ -149,7 +149,7 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo
Name: "Mock",
ResourceVersion: "1",
}
if err := config.Storage.CreateConnector(connector); err != nil {
if err := config.Storage.CreateConnector(connector1); err != nil {
t.Fatalf("create connector: %v", err)
}
if err := config.Storage.CreateConnector(connector2); err != nil {
Expand Down Expand Up @@ -203,6 +203,37 @@ func TestDiscovery(t *testing.T) {
}
}

// Defines an expected error by HTTP Status Code and
// the OAuth2 error int the response json
type ErrorResponse struct {
Error string
StatusCode int
}

// https://tools.ietf.org/html/rfc6749#section-5.2
type OAuth2ErrorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}

type TestDefinition struct {
name string
// If specified these set of scopes will be used during the test case.
scopes []string
// handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error

// extra parameters to pass when requesting auth_code
authCodeOptions []oauth2.AuthCodeOption

// extra parameters to pass when retrieving id token
retrieveTokenOptions []oauth2.AuthCodeOption

// define an error response, when the test expects an error on the token endpoint
tokenError ErrorResponse
}

// TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server
// which requires no interaction to login, logs in through a test client, then passes the client
// and returned token to the test.
Expand All @@ -228,25 +259,21 @@ func TestOAuth2CodeFlow(t *testing.T) {

oidcConfig := &oidc.Config{SkipClientIDCheck: true}

tests := []struct {
name string
// If specified these set of scopes will be used during the test case.
scopes []string
// handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error
}{
basicIDTokenVerify := func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return fmt.Errorf("no id token found")
}
if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil {
return fmt.Errorf("failed to verify id token: %v", err)
}
return nil
}

tests := []TestDefinition{
{
name: "verify ID Token",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return fmt.Errorf("no id token found")
}
if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil {
return fmt.Errorf("failed to verify id token: %v", err)
}
return nil
},
name: "verify ID Token",
handleToken: basicIDTokenVerify,
},
{
name: "fetch userinfo",
Expand Down Expand Up @@ -404,7 +431,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
v.Add("client_secret", clientSecret)
v.Add("grant_type", "refresh_token")
v.Add("refresh_token", token.RefreshToken)
// Request a scope that wasn't requestd initially.
// Request a scope that wasn't requested initially.
v.Add("scope", "oidc email profile")
resp, err := http.PostForm(p.Endpoint().TokenURL, v)
if err != nil {
Expand Down Expand Up @@ -472,6 +499,91 @@ func TestOAuth2CodeFlow(t *testing.T) {
return nil
},
},
{
// This test ensures that PKCE work in "plain" mode (no code_challenge_method specified)
name: "PKCE with plain",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "challenge123"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
},
{
// This test ensures that PKCE work in "S256" mode
name: "PKCE with S256",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
},
{
// This test ensures that PKCE does fail with wrong code_verifier in "plain" mode
name: "PKCE with plain and wrong code_verifier",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "challenge123"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge124"),
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidGrant,
StatusCode: http.StatusBadRequest,
},
},
{
// This test ensures that PKCE fail with wrong code_verifier in "S256" mode
name: "PKCE with S256 and wrong code_verifier",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge124"),
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidGrant,
StatusCode: http.StatusBadRequest,
},
},
{
// Ensure that when no PKCE flow was started on /auth
// we cannot switch to PKCE on /token
name: "No PKCE flow started",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidGrant,
StatusCode: http.StatusBadRequest,
},
},
{
// Ensure that, when PKCE flow started on /auth
// we stay in PKCE flow on /token
name: "No PKCE flow started",
authCodeOptions: []oauth2.AuthCodeOption{
// No PKCE call on /auth
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidRequest,
StatusCode: http.StatusBadRequest,
},
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -514,7 +626,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
// User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
http.Redirect(w, r, oauth2Config.AuthCodeURL(state, tc.authCodeOptions...), http.StatusSeeOther)
return
}

Expand All @@ -535,7 +647,11 @@ func TestOAuth2CodeFlow(t *testing.T) {
// Grab code, exchange for token.
if code := q.Get("code"); code != "" {
gotCode = true
token, err := oauth2Config.Exchange(ctx, code)
token, err := oauth2Config.Exchange(ctx, code, tc.retrieveTokenOptions...)
if tc.tokenError.StatusCode != 0 {
checkErrorResponse(err, t, tc)
return
}
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
Expand Down Expand Up @@ -602,6 +718,30 @@ func TestOAuth2CodeFlow(t *testing.T) {
}
}

func checkErrorResponse(err error, t *testing.T, tc TestDefinition) {
if err == nil {
t.Errorf("%s: DANGEROUS! got a token when we should not get one!", tc.name)
return
}
if rErr, ok := err.(*oauth2.RetrieveError); ok {
if rErr.Response.StatusCode != tc.tokenError.StatusCode {
t.Errorf("%s: got wrong StatusCode from server %d. expected %d",
tc.name, rErr.Response.StatusCode, tc.tokenError.StatusCode)
}
details := new(OAuth2ErrorResponse)
if err := json.Unmarshal(rErr.Body, details); err != nil {
t.Errorf("%s: could not parse return json: %s", tc.name, err)
return
}
if tc.tokenError.Error != "" && details.Error != tc.tokenError.Error {
t.Errorf("%s: got wrong Error in response: %s (%s). expected %s",
tc.name, details.Error, details.ErrorDescription, tc.tokenError.Error)
}
} else {
t.Errorf("%s: unexpedted error type: %s. expected *oauth2.RetrieveError", tc.name, reflect.TypeOf(err))
}
}

func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
Loading