Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 23 additions & 36 deletions oidc/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ package oidc
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

jose "github.com/go-jose/go-jose/v4"
Expand Down Expand Up @@ -145,18 +143,6 @@ func (p *Provider) newVerifier(keySet KeySet, config *Config) *IDTokenVerifier {
return NewVerifier(p.issuer, keySet, config)
}

func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
}

func contains(sli []string, ele string) bool {
for _, s := range sli {
if s == ele {
Expand Down Expand Up @@ -221,10 +207,32 @@ func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src
func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDToken, error) {
// Throw out tokens with invalid claims before trying to verify the token. This lets
// us do cheap checks before possibly re-syncing keys.
payload, err := parseJWT(rawIDToken)
var supportedSigAlgs []jose.SignatureAlgorithm
for _, alg := range v.config.SupportedSigningAlgs {
supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg))
}
if len(supportedSigAlgs) == 0 {
// If no algorithms were specified by both the config and discovery, default
// to the one mandatory algorithm "RS256".
supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256}
}
if v.config.InsecureSkipSignatureCheck {
supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm("none"))
}
jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}

switch len(jws.Signatures) {
case 0:
return nil, fmt.Errorf("oidc: id token not signed")
case 1:
default:
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
}

payload := jws.UnsafePayloadWithoutVerification()
var token idToken
if err := json.Unmarshal(payload, &token); err != nil {
return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err)
Expand Down Expand Up @@ -310,27 +318,6 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
return t, nil
}

var supportedSigAlgs []jose.SignatureAlgorithm
for _, alg := range v.config.SupportedSigningAlgs {
supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg))
}
if len(supportedSigAlgs) == 0 {
// If no algorithms were specified by both the config and discovery, default
// to the one mandatory algorithm "RS256".
supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256}
}
jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}

switch len(jws.Signatures) {
case 0:
return nil, fmt.Errorf("oidc: id token not signed")
case 1:
default:
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
}
sig := jws.Signatures[0]
t.sigAlgorithm = sig.Header.Algorithm

Expand Down
3 changes: 2 additions & 1 deletion oidc/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,10 @@ func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) {
if v.signKey != nil {
token = v.signKey.sign(t, []byte(v.idToken))
} else {
token = base64.RawURLEncoding.EncodeToString([]byte(`{alg: "none"}`))
token = base64.RawURLEncoding.EncodeToString([]byte(`{"alg": "none"}`))
token += "."
token += base64.RawURLEncoding.EncodeToString([]byte(v.idToken))
token += "."
}

ctx, cancel := context.WithCancel(context.Background())
Expand Down