Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand All @@ -23,6 +25,11 @@ import (
"github.com/dexidp/dex/storage"
)

const (
CodeChallengeMethodPlain = "plain"
CodeChallengeMethodS256 = "S256"
)

// newHealthChecker returns the healthz handler. The handler runs until the
// provided context is canceled.
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
Expand Down Expand Up @@ -633,6 +640,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI,
ConnectorData: authReq.ConnectorData,
CodeChallenge: authReq.CodeChallenge,
}
if err := s.storage.CreateAuthCode(code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err)
Expand Down Expand Up @@ -761,6 +769,19 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
}
}

func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string) (string, error) {
switch codeChallengeMethod {
case CodeChallengeMethodPlain:
return codeVerifier, nil
case CodeChallengeMethodS256:
shaSum := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(shaSum[:]), nil
default:
s.logger.Errorf("unknown challenge method (%v)", codeChallengeMethod)
return "", fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod)
}
}

// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("code")
Expand All @@ -777,6 +798,22 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}

// RFC 7636 (PKCE)
codeChallengeFromStorage := authCode.CodeChallenge.CodeChallenge
if codeChallengeFromStorage != "" {
providedCodeVerifier := r.PostFormValue("code_verifier")
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, authCode.CodeChallenge.CodeChallengeMethod)
if err != nil {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

if codeChallengeFromStorage != calculatedCodeChallenge {
s.tokenErrHelper(w, errInvalidRequest, "invalid code_verifier.", http.StatusBadRequest)
Copy link
Contributor

@deric deric Jun 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to RFC 7636, errInvalidGrant should be returned.

code_verifier == code_challenge.
If the values are equal, the token endpoint MUST continue processing
as normal (as defined by OAuth 2.0 [RFC6749]). If the values are not equal, an error response indicating "invalid_grant" as described in Section 5.2 of [RFC6749] MUST be returned.

return
}
}

if authCode.RedirectURI != redirectURI {
s.tokenErrHelper(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
return
Expand Down
16 changes: 16 additions & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,18 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
scopes := strings.Fields(q.Get("scope"))
responseTypes := strings.Fields(q.Get("response_type"))

codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")

if codeChallengeMethod == "" {
codeChallengeMethod = CodeChallengeMethodPlain
}

if codeChallengeMethod != CodeChallengeMethodS256 && codeChallengeMethod != CodeChallengeMethodPlain {
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
return nil, &authErr{"", "", errInvalidRequest, description}
}

client, err := s.storage.GetClient(clientID)
if err != nil {
if err == storage.ErrNotFound {
Expand Down Expand Up @@ -525,6 +537,10 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
RedirectURI: redirectURI,
ResponseTypes: responseTypes,
ConnectorID: connectorID,
CodeChallenge: storage.CodeChallenge{
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
},
}, nil
}

Expand Down
57 changes: 44 additions & 13 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,25 +228,33 @@ func TestOAuth2CodeFlow(t *testing.T) {

oidcConfig := &oidc.Config{SkipClientIDCheck: true}

basicIDTokenVerify := func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return fmt.Errorf("no id token found")
}
if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil {
return fmt.Errorf("failed to verify id token: %v", err)
}
return nil
}

tests := []struct {
name string
// If specified these set of scopes will be used during the test case.
scopes []string
// handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error

// extra parameters to pass when requesting auth_code
authCodeOptions []oauth2.AuthCodeOption

// extra parameters to pass when retrieving id token
retreiveTokenOptions []oauth2.AuthCodeOption
}{
{
name: "verify ID Token",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return fmt.Errorf("no id token found")
}
if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil {
return fmt.Errorf("failed to verify id token: %v", err)
}
return nil
},
name: "verify ID Token",
handleToken: basicIDTokenVerify,
},
{
name: "fetch userinfo",
Expand Down Expand Up @@ -472,6 +480,29 @@ func TestOAuth2CodeFlow(t *testing.T) {
return nil
},
},
{
// This test ensures that PKCE work in "plain" mode (no code_challenge_method specified)
name: "PKCE with plain",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "challenge123"),
},
retreiveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
},
{
// This test ensures that PKCE work in "S256" mode
name: "PKCE with S256",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retreiveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -514,7 +545,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
// User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
http.Redirect(w, r, oauth2Config.AuthCodeURL(state, tc.authCodeOptions...), http.StatusSeeOther)
return
}

Expand All @@ -535,7 +566,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
// Grab code, exchange for token.
if code := q.Get("code"); code != "" {
gotCode = true
token, err := oauth2Config.Exchange(ctx, code)
token, err := oauth2Config.Exchange(ctx, code, tc.retreiveTokenOptions...)
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
Expand Down
10 changes: 10 additions & 0 deletions storage/conformance/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
}

func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
codeChallenge := storage.CodeChallenge{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",
}

a1 := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "client1",
Expand All @@ -99,6 +104,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
EmailVerified: true,
Groups: []string{"a", "b"},
},
CodeChallenge: codeChallenge,
}

identity := storage.Claims{Email: "foobar"}
Expand Down Expand Up @@ -153,6 +159,10 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims)
}

if !reflect.DeepEqual(got.CodeChallenge, codeChallenge) {
t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.CodeChallenge)
}

if err := s.DeleteAuthRequest(a1.ID); err != nil {
t.Fatalf("failed to delete auth request: %v", err)
}
Expand Down
32 changes: 23 additions & 9 deletions storage/etcd/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,24 @@ type AuthCode struct {
Claims Claims `json:"claims,omitempty"`

Expiry time.Time `json:"expiry"`

CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
}

func fromStorageAuthCode(a storage.AuthCode) AuthCode {
return AuthCode{
ID: a.ID,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
ID: a.ID,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
CodeChallenge: a.CodeChallenge.CodeChallenge,
CodeChallengeMethod: a.CodeChallenge.CodeChallengeMethod,
}
}

Expand All @@ -58,6 +63,9 @@ type AuthRequest struct {

ConnectorID string `json:"connector_id"`
ConnectorData []byte `json:"connector_data"`

CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
}

func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
Expand All @@ -75,6 +83,8 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
Claims: fromStorageClaims(a.Claims),
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
CodeChallenge: a.CodeChallenge.CodeChallenge,
CodeChallengeMethod: a.CodeChallenge.CodeChallengeMethod,
}
}

Expand All @@ -93,6 +103,10 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
ConnectorData: a.ConnectorData,
Expiry: a.Expiry,
Claims: toStorageClaims(a.Claims),
CodeChallenge: storage.CodeChallenge{
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
},
}
}

Expand Down
34 changes: 26 additions & 8 deletions storage/kubernetes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ type AuthRequest struct {
ConnectorData []byte `json:"connectorData,omitempty"`

Expiry time.Time `json:"expiry"`

CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
}

// AuthRequestList is a list of AuthRequests.
Expand All @@ -293,6 +296,10 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest {
ConnectorData: req.ConnectorData,
Expiry: req.Expiry,
Claims: toStorageClaims(req.Claims),
CodeChallenge: storage.CodeChallenge{
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
},
}
return a
}
Expand All @@ -319,6 +326,8 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
ConnectorData: a.ConnectorData,
Expiry: a.Expiry,
Claims: fromStorageClaims(a.Claims),
CodeChallenge: a.CodeChallenge.CodeChallenge,
CodeChallengeMethod: a.CodeChallenge.CodeChallengeMethod,
}
return req
}
Expand Down Expand Up @@ -392,6 +401,9 @@ type AuthCode struct {
ConnectorData []byte `json:"connectorData,omitempty"`

Expiry time.Time `json:"expiry"`

CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
}

// AuthCodeList is a list of AuthCodes.
Expand All @@ -411,14 +423,16 @@ func (cli *client) fromStorageAuthCode(a storage.AuthCode) AuthCode {
Name: a.ID,
Namespace: cli.namespace,
},
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
CodeChallenge: a.CodeChallenge.CodeChallenge,
CodeChallengeMethod: a.CodeChallenge.CodeChallengeMethod,
}
}

Expand All @@ -433,6 +447,10 @@ func toStorageAuthCode(a AuthCode) storage.AuthCode {
Scopes: a.Scopes,
Claims: toStorageClaims(a.Claims),
Expiry: a.Expiry,
CodeChallenge: storage.CodeChallenge{
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
},
}
}

Expand Down
Loading