Skip to content

Commit 104357c

Browse files
Update jwt.go
1 parent 22134a4 commit 104357c

File tree

1 file changed

+120
-98
lines changed

1 file changed

+120
-98
lines changed

jwt/jwt.go

+120-98
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44

55
// Package jwt implements the OAuth 2.0 JSON Web Token flow, commonly
66
// known as "two-legged OAuth 2.0".
7-
//
87
// See: https://tools.ietf.org/html/draft-ietf-oauth-jwt-bearer-12
98
package jwt
109

1110
import (
1211
"context"
1312
"encoding/json"
13+
"errors"
1414
"fmt"
1515
"io"
16-
"io/ioutil"
1716
"net/http"
1817
"net/url"
1918
"strings"
@@ -29,157 +28,180 @@ var (
2928
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
3029
)
3130

32-
// Config is the configuration for using JWT to fetch tokens,
33-
// commonly known as "two-legged OAuth 2.0".
31+
// Config holds the configuration for using JWT to fetch tokens.
3432
type Config struct {
35-
// Email is the OAuth client identifier used when communicating with
36-
// the configured OAuth provider.
37-
Email string
38-
39-
// PrivateKey contains the contents of an RSA private key or the
40-
// contents of a PEM file that contains a private key. The provided
41-
// private key is used to sign JWT payloads.
42-
// PEM containers with a passphrase are not supported.
43-
// Use the following command to convert a PKCS 12 file into a PEM.
44-
//
45-
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
46-
//
47-
PrivateKey []byte
48-
49-
// PrivateKeyID contains an optional hint indicating which key is being
50-
// used.
51-
PrivateKeyID string
52-
53-
// Subject is the optional user to impersonate.
54-
Subject string
55-
56-
// Scopes optionally specifies a list of requested permission scopes.
57-
Scopes []string
58-
59-
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
60-
TokenURL string
61-
62-
// Expires optionally specifies how long the token is valid for.
63-
Expires time.Duration
64-
65-
// Audience optionally specifies the intended audience of the
66-
// request. If empty, the value of TokenURL is used as the
67-
// intended audience.
68-
Audience string
69-
70-
// PrivateClaims optionally specifies custom private claims in the JWT.
71-
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
72-
PrivateClaims map[string]interface{}
73-
74-
// UseIDToken optionally specifies whether ID token should be used instead
75-
// of access token when the server returns both.
76-
UseIDToken bool
33+
Email string
34+
PrivateKey []byte
35+
PrivateKeyID string
36+
Subject string
37+
Scopes []string
38+
TokenURL string
39+
Expires time.Duration
40+
Audience string
41+
PrivateClaims map[string]interface{}
42+
UseIDToken bool
7743
}
7844

79-
// TokenSource returns a JWT TokenSource using the configuration
80-
// in c and the HTTP client from the provided context.
45+
// TokenSource returns a JWT TokenSource using the configuration in c.
8146
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
82-
return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c})
47+
return oauth2.ReuseTokenSource(nil, jwtSource{ctx: ctx, conf: c})
8348
}
8449

85-
// Client returns an HTTP client wrapping the context's
86-
// HTTP transport and adding Authorization headers with tokens
87-
// obtained from c.
88-
//
89-
// The returned client and its Transport should not be modified.
50+
// Client returns an HTTP client that adds Authorization headers with tokens obtained from c.
9051
func (c *Config) Client(ctx context.Context) *http.Client {
9152
return oauth2.NewClient(ctx, c.TokenSource(ctx))
9253
}
9354

94-
// jwtSource is a source that always does a signed JWT request for a token.
95-
// It should typically be wrapped with a reuseTokenSource.
9655
type jwtSource struct {
9756
ctx context.Context
9857
conf *Config
9958
}
10059

10160
func (js jwtSource) Token() (*oauth2.Token, error) {
61+
// Validate config
62+
if err := js.validateConfig(); err != nil {
63+
return nil, err
64+
}
65+
66+
// Parse private key
10267
pk, err := internal.ParseKey(js.conf.PrivateKey)
68+
if err != nil {
69+
return nil, fmt.Errorf("failed to parse private key: %v", err)
70+
}
71+
72+
// Generate JWT payload
73+
claimSet, err := js.generateClaimSet()
10374
if err != nil {
10475
return nil, err
10576
}
106-
hc := oauth2.NewClient(js.ctx, nil)
77+
78+
h := *defaultHeader
79+
h.KeyID = js.conf.PrivateKeyID
80+
payload, err := jws.Encode(&h, claimSet, pk)
81+
if err != nil {
82+
return nil, fmt.Errorf("failed to encode JWT: %v", err)
83+
}
84+
85+
// Request token
86+
return js.requestToken(payload)
87+
}
88+
89+
func (js jwtSource) validateConfig() error {
90+
if js.conf.Email == "" {
91+
return errors.New("email is required")
92+
}
93+
if len(js.conf.PrivateKey) == 0 {
94+
return errors.New("private key is required")
95+
}
96+
if js.conf.TokenURL == "" {
97+
return errors.New("token URL is required")
98+
}
99+
return nil
100+
}
101+
102+
func (js jwtSource) generateClaimSet() (*jws.ClaimSet, error) {
107103
claimSet := &jws.ClaimSet{
108104
Iss: js.conf.Email,
109105
Scope: strings.Join(js.conf.Scopes, " "),
110106
Aud: js.conf.TokenURL,
111107
PrivateClaims: js.conf.PrivateClaims,
112108
}
113-
if subject := js.conf.Subject; subject != "" {
114-
claimSet.Sub = subject
115-
// prn is the old name of sub. Keep setting it
116-
// to be compatible with legacy OAuth 2.0 providers.
117-
claimSet.Prn = subject
109+
110+
if js.conf.Subject != "" {
111+
claimSet.Sub = js.conf.Subject
112+
claimSet.Prn = js.conf.Subject
118113
}
119-
if t := js.conf.Expires; t > 0 {
120-
claimSet.Exp = time.Now().Add(t).Unix()
114+
115+
if js.conf.Expires > 0 {
116+
claimSet.Exp = time.Now().Add(js.conf.Expires).Unix()
121117
}
122-
if aud := js.conf.Audience; aud != "" {
123-
claimSet.Aud = aud
118+
119+
if js.conf.Audience != "" {
120+
claimSet.Aud = js.conf.Audience
124121
}
125-
h := *defaultHeader
126-
h.KeyID = js.conf.PrivateKeyID
127-
payload, err := jws.Encode(&h, claimSet, pk)
128-
if err != nil {
129-
return nil, err
122+
123+
return claimSet, nil
124+
}
125+
126+
func (js jwtSource) requestToken(payload string) (*oauth2.Token, error) {
127+
hc := oauth2.NewClient(js.ctx, nil)
128+
v := url.Values{
129+
"grant_type": {defaultGrantType},
130+
"assertion": {payload},
130131
}
131-
v := url.Values{}
132-
v.Set("grant_type", defaultGrantType)
133-
v.Set("assertion", payload)
132+
134133
resp, err := hc.PostForm(js.conf.TokenURL, v)
135134
if err != nil {
136-
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
135+
return nil, fmt.Errorf("failed to fetch token: %v", err)
137136
}
138137
defer resp.Body.Close()
139-
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
140-
if err != nil {
141-
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
142-
}
143-
if c := resp.StatusCode; c < 200 || c > 299 {
138+
139+
if resp.StatusCode < 200 || resp.StatusCode > 299 {
140+
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
144141
return nil, &oauth2.RetrieveError{
145142
Response: resp,
146143
Body: body,
147144
}
148145
}
149-
// tokenRes is the JSON response body.
146+
147+
return js.parseTokenResponse(resp)
148+
}
149+
150+
func (js jwtSource) parseTokenResponse(resp *http.Response) (*oauth2.Token, error) {
151+
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
152+
if err != nil {
153+
return nil, fmt.Errorf("failed to read token response: %v", err)
154+
}
155+
150156
var tokenRes struct {
151157
AccessToken string `json:"access_token"`
152158
TokenType string `json:"token_type"`
153159
IDToken string `json:"id_token"`
154-
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
160+
ExpiresIn int64 `json:"expires_in"`
155161
}
156162
if err := json.Unmarshal(body, &tokenRes); err != nil {
157-
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
163+
return nil, fmt.Errorf("failed to parse token response: %v", err)
158164
}
165+
159166
token := &oauth2.Token{
160167
AccessToken: tokenRes.AccessToken,
161168
TokenType: tokenRes.TokenType,
169+
Expiry: time.Now().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
162170
}
163-
raw := make(map[string]interface{})
164-
json.Unmarshal(body, &raw) // no error checks for optional fields
165-
token = token.WithExtra(raw)
166171

167-
if secs := tokenRes.ExpiresIn; secs > 0 {
168-
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
169-
}
170-
if v := tokenRes.IDToken; v != "" {
171-
// decode returned id token to get expiry
172-
claimSet, err := jws.Decode(v)
173-
if err != nil {
174-
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
175-
}
176-
token.Expiry = time.Unix(claimSet.Exp, 0)
177-
}
178172
if js.conf.UseIDToken {
179173
if tokenRes.IDToken == "" {
180-
return nil, fmt.Errorf("oauth2: response doesn't have JWT token")
174+
return nil, errors.New("response missing ID token")
181175
}
182176
token.AccessToken = tokenRes.IDToken
183177
}
178+
184179
return token, nil
185180
}
181+
182+
// Helper functions for better debugging
183+
func debugLog(msg string) {
184+
fmt.Println("DEBUG:", msg)
185+
}
186+
187+
func infoLog(msg string) {
188+
fmt.Println("INFO:", msg)
189+
}
190+
191+
func warnLog(msg string) {
192+
fmt.Println("WARNING:", msg)
193+
}
194+
195+
func errorLog(msg string) {
196+
fmt.Println("ERROR:", msg)
197+
}
198+
199+
// Additional notes to ensure code clarity and maintainability:
200+
// 1. Proper documentation should be added to all exported functions.
201+
// 2. Ensure this code adheres to the latest security practices.
202+
// 3. Add more test cases to cover edge scenarios.
203+
// 4. Future improvements could include support for additional JWT algorithms.
204+
205+
// End of file
206+
207+

0 commit comments

Comments
 (0)