Skip to content

Commit a28b9d2

Browse files
authored
feat(api): handle RSA key rollover (#6154)
* feat(api): handler RSA key rollover Signed-off-by: francois samin <[email protected]>
1 parent 156c3cf commit a28b9d2

File tree

5 files changed

+89
-24
lines changed

5 files changed

+89
-24
lines changed

engine/api/api.go

+19-8
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,13 @@ type Configuration struct {
9393
InsecureSkipVerifyTLS bool `toml:"insecureSkipVerifyTLS" json:"insecureSkipVerifyTLS" default:"false"`
9494
} `toml:"internalServiceMesh" json:"internalServiceMesh"`
9595
Auth struct {
96-
TokenDefaultDuration int64 `toml:"tokenDefaultDuration" default:"30" comment:"The default duration of a token (in days)" json:"tokenDefaultDuration"`
97-
TokenOverlapDefaultDuration string `toml:"tokenOverlapDefaultDuration" default:"24h" comment:"The default overlap duration when a token is regen" json:"tokenOverlapDefaultDuration"`
98-
DefaultGroup string `toml:"defaultGroup" default:"" comment:"The default group is the group in which every new user will be granted at signup" json:"defaultGroup"`
99-
DisableAddUserInDefaultGroup bool `toml:"disableAddUserInDefaultGroup" default:"false" comment:"If false, user are automatically added in the default group" json:"disableAddUserInDefaultGroup"`
100-
RSAPrivateKey string `toml:"rsaPrivateKey" default:"" comment:"The RSA Private Key used to sign and verify the JWT Tokens issued by the API \nThis is mandatory." json:"-"`
101-
AllowedOrganizations sdk.StringSlice `toml:"allowedOrganizations" default:"" comment:"The list of allowed organizations for CDS users, let empty to authorize all organizations." json:"allowedOrganizations"`
96+
TokenDefaultDuration int64 `toml:"tokenDefaultDuration" default:"30" comment:"The default duration of a token (in days)" json:"tokenDefaultDuration"`
97+
TokenOverlapDefaultDuration string `toml:"tokenOverlapDefaultDuration" default:"24h" comment:"The default overlap duration when a token is regen" json:"tokenOverlapDefaultDuration"`
98+
DefaultGroup string `toml:"defaultGroup" default:"" comment:"The default group is the group in which every new user will be granted at signup" json:"defaultGroup"`
99+
DisableAddUserInDefaultGroup bool `toml:"disableAddUserInDefaultGroup" default:"false" comment:"If false, user are automatically added in the default group" json:"disableAddUserInDefaultGroup"`
100+
RSAPrivateKey string `toml:"rsaPrivateKey" default:"" comment:"The RSA Private Key used to sign and verify the JWT Tokens issued by the API \nThis is mandatory." json:"-"`
101+
RSAPrivateKeys []authentication.KeyConfig `toml:"rsaPrivateKeys" default:"" comment:"RSA Private Keys used to sign and verify the JWT Tokens issued by the API \nThis is mandatory." json:"-" mapstructure:"rsaPrivateKeys"`
102+
AllowedOrganizations sdk.StringSlice `toml:"allowedOrganizations" default:"" comment:"The list of allowed organizations for CDS users, let empty to authorize all organizations." json:"allowedOrganizations"`
102103
LDAP struct {
103104
Enabled bool `toml:"enabled" default:"false" json:"enabled"`
104105
SignupDisabled bool `toml:"signupDisabled" default:"false" json:"signupDisabled"`
@@ -363,7 +364,7 @@ func (a *API) CheckConfiguration(config interface{}) error {
363364
return fmt.Errorf("You can't specify just defaultArch without defaultOS in your configuration and vice versa")
364365
}
365366

366-
if aConfig.Auth.RSAPrivateKey == "" {
367+
if aConfig.Auth.RSAPrivateKey == "" && len(aConfig.Auth.RSAPrivateKeys) == 0 {
367368
return errors.New("invalid given authentication rsa private key")
368369
}
369370

@@ -420,7 +421,17 @@ func (a *API) Serve(ctx context.Context) error {
420421
}
421422

422423
// Initialize the jwt layer
423-
if err := authentication.Init(a.ServiceName, []byte(a.Config.Auth.RSAPrivateKey)); err != nil {
424+
var RSAKeyConfigs []authentication.KeyConfig
425+
if a.Config.Auth.RSAPrivateKey != "" {
426+
RSAKeyConfigs = append(RSAKeyConfigs, authentication.KeyConfig{
427+
Key: a.Config.Auth.RSAPrivateKey,
428+
Timestamp: 0,
429+
})
430+
}
431+
if len(a.Config.Auth.RSAPrivateKeys) > 0 {
432+
RSAKeyConfigs = append(RSAKeyConfigs, a.Config.Auth.RSAPrivateKeys...)
433+
}
434+
if err := authentication.Init(ctx, a.ServiceName, RSAKeyConfigs); err != nil {
424435
return sdk.WrapError(err, "unable to initialize the JWT Layer")
425436
}
426437

+47-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package authentication
22

33
import (
4+
"context"
45
"crypto/rsa"
6+
"sort"
57
"time"
68

79
jwt "github.com/golang-jwt/jwt"
@@ -10,24 +12,39 @@ import (
1012
)
1113

1214
var (
13-
signer *authentication.Signer
15+
signers []authentication.Signer
1416
)
1517

18+
type KeyConfig struct {
19+
Timestamp int64 `toml:"timestamp" mapstructure:"timestamp"`
20+
Key string `toml:"key" mapstructure:"key"`
21+
}
22+
1623
// Init the package by passing the signing key
17-
func Init(issuer string, k []byte) error {
18-
s, err := authentication.NewSigner(issuer, k)
19-
if err != nil {
20-
return err
24+
func Init(ctx context.Context, issuer string, keys []KeyConfig) error {
25+
// sort the keys to set the most recent signer at the end
26+
sort.Slice(keys, func(i, j int) bool {
27+
return keys[i].Timestamp < keys[j].Timestamp
28+
})
29+
30+
signers = make([]authentication.Signer, len(keys))
31+
32+
for i := range keys {
33+
s, err := authentication.NewSigner(issuer, []byte(keys[i].Key))
34+
if err != nil {
35+
return err
36+
}
37+
signers[i] = s
2138
}
22-
signer = &s
39+
2340
return nil
2441
}
2542

2643
func getSigner() authentication.Signer {
27-
if signer == nil {
44+
if len(signers) == 0 {
2845
panic("signer is not set")
2946
}
30-
return *signer
47+
return signers[len(signers)-1] // return the most recent signer
3148
}
3249

3350
func GetIssuerName() string {
@@ -43,13 +60,33 @@ func SignJWT(jwtToken *jwt.Token) (string, error) {
4360
}
4461

4562
func VerifyJWT(token *jwt.Token) (interface{}, error) {
46-
return getSigner().VerifyJWT(token)
63+
var lastError error
64+
// Check with the most recent signer first
65+
for i := len(signers) - 1; i >= 0; i-- {
66+
s := signers[i]
67+
res, err := s.VerifyJWT(token)
68+
if err == nil && res != nil {
69+
return res, nil
70+
}
71+
lastError = err
72+
}
73+
return nil, lastError
4774
}
4875

4976
func SignJWS(content interface{}, now time.Time, duration time.Duration) (string, error) {
5077
return getSigner().SignJWS(content, now, duration)
5178
}
5279

5380
func VerifyJWS(signature string, content interface{}) error {
54-
return getSigner().VerifyJWS(signature, content)
81+
var lastError error
82+
// Check with the most recent signer first
83+
for i := len(signers) - 1; i >= 0; i-- {
84+
s := signers[i]
85+
err := s.VerifyJWS(signature, content)
86+
if err == nil {
87+
return nil
88+
}
89+
lastError = err
90+
}
91+
return lastError
5592
}

engine/api/test/test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package test
22

33
import (
4+
"context"
45
"testing"
56

67
"github.com/rockbears/log"
@@ -26,7 +27,7 @@ func SetupPGWithFactory(t *testing.T, bootstrapFunc ...test.Bootstrapf) (*test.F
2627
db, factory, cache, cancel := test.SetupPGToCancel(t, gorpmapping.Mapper, sdk.TypeAPI, bootstrapFunc...)
2728
t.Cleanup(cancel)
2829

29-
err := authentication.Init("cds-api-test", test.SigningKey)
30+
err := authentication.Init(context.TODO(), "cds-api-test", []authentication.KeyConfig{{Key: string(test.SigningKey)}})
3031
require.NoError(t, err, "unable to init authentication layer")
3132

3233
return db, factory, cache

engine/config.go

+20-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"bytes"
5+
"context"
56
"fmt"
67
"io"
78
"os"
@@ -306,12 +307,17 @@ func configSetStartupData(conf *Configuration) (string, error) {
306307
validityPediod := sdk.NewAuthConsumerValidityPeriod(time.Now(), 0)
307308
startupCfg := api.StartupConfig{IAT: validityPediod.Latest().IssuedAt.Unix()}
308309

309-
if err := authentication.Init("cds-api", apiPrivateKeyPEM); err != nil {
310+
if err := authentication.Init(context.TODO(), "cds-api", []authentication.KeyConfig{{Key: string(apiPrivateKeyPEM)}}); err != nil {
310311
return "", err
311312
}
312313

313314
if conf.API != nil {
314-
conf.API.Auth.RSAPrivateKey = string(apiPrivateKeyPEM)
315+
conf.API.Auth.RSAPrivateKeys = []authentication.KeyConfig{
316+
{
317+
Timestamp: time.Now().Unix(),
318+
Key: string(apiPrivateKeyPEM),
319+
},
320+
}
315321

316322
key, _ := keyloader.GenerateKey("hmac", gorpmapper.KeySignIdentifier, false, time.Now())
317323
conf.API.Database.SignatureKey = database.RollingKeyConfig{Cipher: "hmac"}
@@ -658,13 +664,23 @@ func getInitTokenFromExistingConfiguration(conf Configuration) (string, error) {
658664
if conf.API == nil {
659665
return "", fmt.Errorf("cannot load configuration")
660666
}
661-
apiPrivateKeyPEM := []byte(conf.API.Auth.RSAPrivateKey)
662667

663668
now := time.Now()
664669
globalIAT := now.Unix()
665670
startupCfg := api.StartupConfig{}
666671

667-
if err := authentication.Init("cds-api", apiPrivateKeyPEM); err != nil {
672+
var RSAKeyConfigs []authentication.KeyConfig
673+
if conf.API.Auth.RSAPrivateKey != "" {
674+
RSAKeyConfigs = append(RSAKeyConfigs, authentication.KeyConfig{
675+
Key: conf.API.Auth.RSAPrivateKey,
676+
Timestamp: 0,
677+
})
678+
}
679+
if len(conf.API.Auth.RSAPrivateKeys) > 0 {
680+
RSAKeyConfigs = append(RSAKeyConfigs, conf.API.Auth.RSAPrivateKeys...)
681+
}
682+
683+
if err := authentication.Init(context.TODO(), "cds-api", RSAKeyConfigs); err != nil {
668684
return "", err
669685
}
670686

engine/hatchery/hatchery_helper_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func InitMock(t *testing.T, url string) {
7575
privKeyPEM, _ := jws.ExportPrivateKey(privKey)
7676
pubKey, _ := jws.ExportPublicKey(privKey)
7777

78-
require.NoError(t, authentication.Init("cds-api-test", privKeyPEM))
78+
require.NoError(t, authentication.Init(context.TODO(), "cds-api-test", []authentication.KeyConfig{{Key: string(privKeyPEM)}}))
7979
id := sdk.UUID()
8080
consumerID := sdk.UUID()
8181
hatcheryAuthenticationToken, _ := authentication.NewSessionJWT(&sdk.AuthSession{

0 commit comments

Comments
 (0)