Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ require (
golang.org/x/crypto v0.0.0-20160711182412-2c99acdd1e9b
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3
golang.org/x/net v0.0.0-20170413175226-5602c733f70a
golang.org/x/oauth2 v0.0.0-20160718223228-08c8d727d239
golang.org/x/oauth2 v0.0.0-20180619213508-088f8e1d436e
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f // indirect
golang.org/x/sys v0.0.0-20151211033651-833a04a10549 // indirect
golang.org/x/text v0.0.0-20170401064109-f4b4367115ec // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ golang.org/x/net v0.0.0-20170413175226-5602c733f70a h1:U+RBxJXt1cn83eNU5KfO0ABG2
golang.org/x/net v0.0.0-20170413175226-5602c733f70a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20160718223228-08c8d727d239 h1:zW4VTIvN4l/liomF2DkpwzM8vz+Xlp9lO06+Z32c91U=
golang.org/x/oauth2 v0.0.0-20160718223228-08c8d727d239/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20180619213508-088f8e1d436e h1:WQv7JqEW/0BfkoG4CTH/MQrEguuNmtDPy9ekfAQmF38=
golang.org/x/oauth2 v0.0.0-20180619213508-088f8e1d436e/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20151211033651-833a04a10549 h1:imXIGlmpdV8HlMP9DTrSVaxjoffgGbwFZdJl0Ous5dc=
Expand Down
47 changes: 38 additions & 9 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -566,15 +568,17 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
switch responseType {
case responseTypeCode:
code = storage.AuthCode{
ID: storage.NewID(),
ClientID: authReq.ClientID,
ConnectorID: authReq.ConnectorID,
Nonce: authReq.Nonce,
Scopes: authReq.Scopes,
Claims: authReq.Claims,
Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI,
ConnectorData: authReq.ConnectorData,
ID: storage.NewID(),
ClientID: authReq.ClientID,
ConnectorID: authReq.ConnectorID,
Nonce: authReq.Nonce,
Scopes: authReq.Scopes,
Claims: authReq.Claims,
Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI,
ConnectorData: authReq.ConnectorData,
CodeChallenge: authReq.CodeChallenge,
CodeChallengeMethod: authReq.CodeChallengeMethod,
}
if err := s.storage.CreateAuthCode(code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err)
Expand Down Expand Up @@ -699,6 +703,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("code")
redirectURI := r.PostFormValue("redirect_uri")
verifier := r.PostFormValue("code_verifier")

authCode, err := s.storage.GetAuthCode(code)
if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
Expand All @@ -716,6 +721,30 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}

// Check for code challenge validity
if authCode.CodeChallenge != "" {
if verifier == "" {
s.tokenErrHelper(w, errInvalidRequest, "Require code_verifier", http.StatusUnauthorized)
return
}
var challenge string
switch authCode.CodeChallengeMethod {
case codeChallengeSHA256:
sum := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
// default to plain: insecure
case codeChallengePlain:
challenge = verifier
default:
s.tokenErrHelper(w, errServerError, "Unsupported code challenge method", http.StatusInternalServerError)
return
}
if challenge != authCode.CodeChallenge {
s.tokenErrHelper(w, errInvalidGrant, "code_verifier doesn't have the same hash as code_challenge", http.StatusUnauthorized)
return
}
}

accessToken := storage.NewID()
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ const (
responseTypeIDToken = "id_token" // ID Token in url fragment
)

const (
codeChallengeSHA256 = "S256"
codeChallengePlain = "plain"
)

func parseScopes(scopes []string) connector.Scopes {
var s connector.Scopes
for _, scope := range scopes {
Expand Down Expand Up @@ -503,6 +508,8 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
Scopes: scopes,
RedirectURI: redirectURI,
ResponseTypes: responseTypes,
CodeChallenge: q.Get("code_challenge"),
CodeChallengeMethod: q.Get("code_challenge_method"),
}, nil
}

Expand Down
50 changes: 48 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package server
import (
"context"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
Expand Down Expand Up @@ -187,6 +189,10 @@ func TestOAuth2CodeFlow(t *testing.T) {
scopes []string
// handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error
// If specified this code challenge will be used during the test case.
codeChallenge string
codeChallengeMethod string
codeVerifier string
}{
{
name: "verify ID Token",
Expand Down Expand Up @@ -412,6 +418,35 @@ func TestOAuth2CodeFlow(t *testing.T) {
return nil
},
},
{
name: "code challenge plain",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
_, err := config.TokenSource(ctx, token).Token()
if err != nil {
return fmt.Errorf("failed to get token: %v", err)
}
return nil
},
codeChallenge: "test-code-challenge",
codeChallengeMethod: codeChallengePlain,
codeVerifier: "test-code-challenge",
},
{
name: "code challenge sha256",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
_, err := config.TokenSource(ctx, token).Token()
if err != nil {
return fmt.Errorf("failed to get token: %v", err)
}
return nil
},
codeChallenge: func() string {
sum := sha256.Sum256([]byte("test-code-challenge"))
return base64.RawURLEncoding.EncodeToString(sum[:])
}(),
codeChallengeMethod: codeChallengeSHA256,
codeVerifier: "test-code-challenge",
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -453,8 +488,15 @@ func TestOAuth2CodeFlow(t *testing.T) {
var oauth2Config *oauth2.Config
oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
options := make([]oauth2.AuthCodeOption, 0, 2)
if tc.codeChallenge != "" {
options = append(options, oauth2.SetAuthURLParam("code_challenge", tc.codeChallenge))
}
if tc.codeChallengeMethod != "" {
options = append(options, oauth2.SetAuthURLParam("code_challenge_method", tc.codeChallengeMethod))
}
// User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
http.Redirect(w, r, oauth2Config.AuthCodeURL(state, options...), http.StatusSeeOther)
return
}

Expand All @@ -475,7 +517,11 @@ func TestOAuth2CodeFlow(t *testing.T) {
// Grab code, exchange for token.
if code := q.Get("code"); code != "" {
gotCode = true
token, err := oauth2Config.Exchange(ctx, code)
options := make([]oauth2.AuthCodeOption, 0, 1)
if tc.codeVerifier != "" {
options = append(options, oauth2.SetAuthURLParam("code_verifier", tc.codeVerifier))
}
token, err := oauth2Config.Exchange(ctx, code, options...)
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
Expand Down
30 changes: 21 additions & 9 deletions storage/etcd/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,24 @@ type AuthCode struct {
Claims Claims `json:"claims,omitempty"`

Expiry time.Time `json:"expiry"`

CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}

func fromStorageAuthCode(a storage.AuthCode) AuthCode {
return AuthCode{
ID: a.ID,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
ID: a.ID,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
}
}

Expand All @@ -58,6 +63,9 @@ type AuthRequest struct {

ConnectorID string `json:"connector_id"`
ConnectorData []byte `json:"connector_data"`

CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}

func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
Expand All @@ -75,6 +83,8 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
Claims: fromStorageClaims(a.Claims),
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
}
}

Expand All @@ -93,6 +103,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
ConnectorData: a.ConnectorData,
Expiry: a.Expiry,
Claims: toStorageClaims(a.Claims),
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
}
}

Expand Down
48 changes: 31 additions & 17 deletions storage/kubernetes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ type AuthRequest struct {
ConnectorData []byte `json:"connectorData,omitempty"`

Expiry time.Time `json:"expiry"`

CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}

// AuthRequestList is a list of AuthRequests.
Expand All @@ -364,6 +367,8 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest {
ConnectorData: req.ConnectorData,
Expiry: req.Expiry,
Claims: toStorageClaims(req.Claims),
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
}
return a
}
Expand All @@ -390,6 +395,8 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
ConnectorData: a.ConnectorData,
Expiry: a.Expiry,
Claims: fromStorageClaims(a.Claims),
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
}
return req
}
Expand Down Expand Up @@ -463,6 +470,9 @@ type AuthCode struct {
ConnectorData []byte `json:"connectorData,omitempty"`

Expiry time.Time `json:"expiry"`

CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}

// AuthCodeList is a list of AuthCodes.
Expand All @@ -482,28 +492,32 @@ func (cli *client) fromStorageAuthCode(a storage.AuthCode) AuthCode {
Name: a.ID,
Namespace: cli.namespace,
},
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry,
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
}
}

func toStorageAuthCode(a AuthCode) storage.AuthCode {
return storage.AuthCode{
ID: a.ObjectMeta.Name,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: toStorageClaims(a.Claims),
Expiry: a.Expiry,
ID: a.ObjectMeta.Name,
ClientID: a.ClientID,
RedirectURI: a.RedirectURI,
ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData,
Nonce: a.Nonce,
Scopes: a.Scopes,
Claims: toStorageClaims(a.Claims),
Expiry: a.Expiry,
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
}
}

Expand Down
Loading