@@ -2,10 +2,15 @@ package josev2
22
33import  (
44	"context" 
5+ 	"encoding/json" 
56	"errors" 
67	"fmt" 
8+ 	"net/http" 
9+ 	"net/url" 
10+ 	"sync" 
711	"time" 
812
13+ 	"github.com/auth0/go-jwt-middleware/internal/oidc" 
914	"gopkg.in/square/go-jose.v2" 
1015	"gopkg.in/square/go-jose.v2/jwt" 
1116)
@@ -115,7 +120,7 @@ func (v *Validator) ValidateToken(ctx context.Context, token string) (interface{
115120	// if jwt.ParseSigned did not error there will always be at least one 
116121	// header in the token 
117122	if  signatureAlgorithm  !=  ""  &&  signatureAlgorithm  !=  tok .Headers [0 ].Algorithm  {
118- 		return  nil , fmt .Errorf ("expected %q signin  algorithm but token specified %q" , signatureAlgorithm , tok .Headers [0 ].Algorithm )
123+ 		return  nil , fmt .Errorf ("expected %q signing  algorithm but token specified %q" , signatureAlgorithm , tok .Headers [0 ].Algorithm )
119124	}
120125
121126	key , err  :=  v .keyFunc (ctx )
@@ -133,7 +138,8 @@ func (v *Validator) ValidateToken(ctx context.Context, token string) (interface{
133138	}
134139
135140	userCtx  :=  & UserContext {
136- 		Claims : * claimDest [0 ].(* jwt.Claims ),
141+ 		CustomClaims : nil ,
142+ 		Claims :       * claimDest [0 ].(* jwt.Claims ),
137143	}
138144
139145	if  err  =  userCtx .Claims .ValidateWithLeeway (v .expectedClaims (), v .allowedClockSkew ); err  !=  nil  {
@@ -149,3 +155,112 @@ func (v *Validator) ValidateToken(ctx context.Context, token string) (interface{
149155
150156	return  userCtx , nil 
151157}
158+ 
159+ // JWKSProvider handles getting JWKS from the specified IssuerURL and exposes 
160+ // KeyFunc which adheres to the keyFunc signature that the Validator requires. 
161+ // Most likely you will want to use the CachingJWKSProvider as it handles 
162+ // getting and caching JWKS which can help reduce request time and potential 
163+ // rate limiting from your provider. 
164+ type  JWKSProvider  struct  {
165+ 	IssuerURL  url.URL 
166+ }
167+ 
168+ // NewJWKSProvider builds and returns a new JWKSProvider. 
169+ func  NewJWKSProvider (issuerURL  url.URL ) * JWKSProvider  {
170+ 	return  & JWKSProvider {IssuerURL : issuerURL }
171+ }
172+ 
173+ // KeyFunc adheres to the keyFunc signature that the Validator requires. While 
174+ // it returns an interface to adhere to keyFunc, as long as the error is nil 
175+ // the type will be *jose.JSONWebKeySet. 
176+ func  (p  * JWKSProvider ) KeyFunc (ctx  context.Context ) (interface {}, error ) {
177+ 	wkEndpoints , err  :=  oidc .GetWellKnownEndpointsFromIssuerURL (ctx , p .IssuerURL )
178+ 	if  err  !=  nil  {
179+ 		return  nil , err 
180+ 	}
181+ 
182+ 	u , err  :=  url .Parse (wkEndpoints .JWKSURI )
183+ 	if  err  !=  nil  {
184+ 		return  nil , fmt .Errorf ("could not parse JWKS URI from well known endpoints: %w" , err )
185+ 	}
186+ 
187+ 	req , err  :=  http .NewRequest (http .MethodGet , u .String (), nil )
188+ 	if  err  !=  nil  {
189+ 		return  nil , fmt .Errorf ("could not build request to get JWKS: %w" , err )
190+ 	}
191+ 	req  =  req .WithContext (ctx )
192+ 
193+ 	resp , err  :=  http .DefaultClient .Do (req )
194+ 	if  err  !=  nil  {
195+ 		return  nil , err 
196+ 	}
197+ 	defer  resp .Body .Close ()
198+ 
199+ 	var  jwks  jose.JSONWebKeySet 
200+ 	if  err  :=  json .NewDecoder (resp .Body ).Decode (& jwks ); err  !=  nil  {
201+ 		return  nil , fmt .Errorf ("could not decode jwks: %w" , err )
202+ 	}
203+ 
204+ 	return  & jwks , nil 
205+ }
206+ 
207+ type  cachedJWKS  struct  {
208+ 	jwks       * jose.JSONWebKeySet 
209+ 	expiresAt  time.Time 
210+ }
211+ 
212+ // CachingJWKSProvider handles getting JWKS from the specified IssuerURL and 
213+ // caching them for CacheTTL time. It exposes KeyFunc which adheres to the 
214+ // keyFunc signature that the Validator requires. 
215+ type  CachingJWKSProvider  struct  {
216+ 	IssuerURL  url.URL 
217+ 	CacheTTL   time.Duration 
218+ 
219+ 	mu     sync.Mutex 
220+ 	cache  map [string ]cachedJWKS 
221+ }
222+ 
223+ // NewCachingJWKSProvider builds and returns a new CachingJWKSProvider. If 
224+ // cacheTTL is zero then a default value of 1 minute will be used. 
225+ func  NewCachingJWKSProvider (issuerURL  url.URL , cacheTTL  time.Duration ) * CachingJWKSProvider  {
226+ 	if  cacheTTL  ==  0  {
227+ 		cacheTTL  =  1  *  time .Minute 
228+ 	}
229+ 
230+ 	return  & CachingJWKSProvider {
231+ 		IssuerURL : issuerURL ,
232+ 		CacheTTL :  cacheTTL ,
233+ 		cache :     map [string ]cachedJWKS {},
234+ 	}
235+ }
236+ 
237+ // KeyFunc adheres to the keyFunc signature that the Validator requires. While 
238+ // it returns an interface to adhere to keyFunc, as long as the error is nil 
239+ // the type will be *jose.JSONWebKeySet. 
240+ func  (c  * CachingJWKSProvider ) KeyFunc (ctx  context.Context ) (interface {}, error ) {
241+ 	issuer  :=  c .IssuerURL .Hostname ()
242+ 
243+ 	c .mu .Lock ()
244+ 	defer  func () {
245+ 		c .mu .Unlock ()
246+ 	}()
247+ 
248+ 	if  cached , ok  :=  c .cache [issuer ]; ok  {
249+ 		if  ! time .Now ().After (cached .expiresAt ) {
250+ 			return  cached .jwks , nil 
251+ 		}
252+ 	}
253+ 
254+ 	p  :=  JWKSProvider {IssuerURL : c .IssuerURL }
255+ 	jwks , err  :=  p .KeyFunc (ctx )
256+ 	if  err  !=  nil  {
257+ 		return  nil , err 
258+ 	}
259+ 
260+ 	c .cache [issuer ] =  cachedJWKS {
261+ 		jwks :      jwks .(* jose.JSONWebKeySet ),
262+ 		expiresAt : time .Now ().Add (c .CacheTTL ),
263+ 	}
264+ 
265+ 	return  jwks , nil 
266+ }
0 commit comments