Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions oidc/jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
Expand Down Expand Up @@ -151,11 +152,8 @@ func TestKeyVerifyContextCanceled(t *testing.T) {
t.Fatal(err)
}

ch := make(chan struct{})
defer close(ch)

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-ch
io.WriteString(w, "{}")
}))
defer s.Close()

Expand Down
99 changes: 41 additions & 58 deletions oidc/verify.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
package oidc

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

jose "github.com/go-jose/go-jose/v4"
Expand Down Expand Up @@ -145,18 +141,6 @@ func (p *Provider) newVerifier(keySet KeySet, config *Config) *IDTokenVerifier {
return NewVerifier(p.issuer, keySet, config)
}

func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
}

func contains(sli []string, ele string) bool {
for _, s := range sli {
if s == ele {
Expand Down Expand Up @@ -219,12 +203,49 @@ func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src
//
// token, err := verifier.Verify(ctx, rawIDToken)
func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDToken, error) {
// Throw out tokens with invalid claims before trying to verify the token. This lets
// us do cheap checks before possibly re-syncing keys.
payload, err := parseJWT(rawIDToken)
var supportedSigAlgs []jose.SignatureAlgorithm
for _, alg := range v.config.SupportedSigningAlgs {
supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg))
}
if len(supportedSigAlgs) == 0 {
// If no algorithms were specified by both the config and discovery, default
// to the one mandatory algorithm "RS256".
supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256}
}
if v.config.InsecureSkipSignatureCheck {
// "none" is a required value to even parse a JWT with the "none" algorithm
// using go-jose.
supportedSigAlgs = append(supportedSigAlgs, "none")
}

// Parse and verify the signature first. This at least forces the user to have
// a valid, signed ID token before we do any other processing.
jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
switch len(jws.Signatures) {
case 0:
return nil, fmt.Errorf("oidc: id token not signed")
case 1:
default:
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
}
sig := jws.Signatures[0]

var payload []byte
if v.config.InsecureSkipSignatureCheck {
// Yolo mode.
payload = jws.UnsafePayloadWithoutVerification()
} else {
// The JWT is attached here for the happy path to avoid the verifier from
// having to parse the JWT twice.
ctx = context.WithValue(ctx, parsedJWTKey, jws)
payload, err = v.keySet.VerifySignature(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify signature: %v", err)
}
}
var token idToken
if err := json.Unmarshal(payload, &token); err != nil {
return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err)
Expand Down Expand Up @@ -254,6 +275,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
AccessTokenHash: token.AtHash,
claims: payload,
distributedClaims: distributedClaims,
sigAlgorithm: sig.Header.Algorithm,
}

// Check issuer.
Expand Down Expand Up @@ -306,45 +328,6 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
}
}

if v.config.InsecureSkipSignatureCheck {
return t, nil
}

var supportedSigAlgs []jose.SignatureAlgorithm
for _, alg := range v.config.SupportedSigningAlgs {
supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg))
}
if len(supportedSigAlgs) == 0 {
// If no algorithms were specified by both the config and discovery, default
// to the one mandatory algorithm "RS256".
supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256}
}
jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}

switch len(jws.Signatures) {
case 0:
return nil, fmt.Errorf("oidc: id token not signed")
case 1:
default:
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
}
sig := jws.Signatures[0]
t.sigAlgorithm = sig.Header.Algorithm

ctx = context.WithValue(ctx, parsedJWTKey, jws)
gotPayload, err := v.keySet.VerifySignature(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify signature: %v", err)
}

// Ensure that the payload returned by the square actually matches the payload parsed earlier.
if !bytes.Equal(gotPayload, payload) {
return nil, errors.New("oidc: internal error, payload parsed did not match previous payload")
}

return t, nil
}

Expand Down
7 changes: 6 additions & 1 deletion oidc/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,14 @@ func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) {
if v.signKey != nil {
token = v.signKey.sign(t, []byte(v.idToken))
} else {
token = base64.RawURLEncoding.EncodeToString([]byte(`{alg: "none"}`))
// "none" still uses a second "." character, but "...MUST use the empty octet
// sequence as its JWS Signature value."
//
// https://datatracker.ietf.org/doc/html/rfc7518#section-3.6
token = base64.RawURLEncoding.EncodeToString([]byte(`{"alg": "none"}`))
token += "."
token += base64.RawURLEncoding.EncodeToString([]byte(v.idToken))
token += "."
}

ctx, cancel := context.WithCancel(context.Background())
Expand Down