Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hint support #242

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions cmd/spiffe-helper/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Config struct {
IncludeFederatedDomains bool `hcl:"include_federated_domains"`
RenewSignal string `hcl:"renew_signal"`
DaemonMode *bool `hcl:"daemon_mode"`
Hint string `hcl:"hint"`

// x509 configuration
SVIDFileName string `hcl:"svid_file_name"`
Expand Down Expand Up @@ -195,6 +196,7 @@ func NewSidecarConfig(config *Config, log logrus.FieldLogger) *sidecar.Config {
SVIDFileName: config.SVIDFileName,
SVIDKeyFileName: config.SVIDKeyFileName,
SVIDBundleFileName: config.SVIDBundleFileName,
Hint: config.Hint,
}

for _, jwtSVID := range config.JWTSVIDs {
Expand Down
19 changes: 18 additions & 1 deletion pkg/disk/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,25 @@ func WriteJWTBundleSet(jwkSet *jwtbundle.Set, dir string, jwtBundleFilename stri
}

// WriteJWTBundle write the given JWT SVID to disk
func WriteJWTSVID(jwtSVID *jwtsvid.SVID, dir, jwtSVIDFilename string, jwtSVIDFileMode fs.FileMode) error {
func WriteJWTSVID(jwtSVIDs []*jwtsvid.SVID, dir, jwtSVIDFilename string, jwtSVIDFileMode fs.FileMode, Hint string) error {
kfox1111 marked this conversation as resolved.
Show resolved Hide resolved
filePath := path.Join(dir, jwtSVIDFilename)
var jwtSVID *jwtsvid.SVID

notFound := true
kfox1111 marked this conversation as resolved.
Show resolved Hide resolved
if Hint == "" {
kfox1111 marked this conversation as resolved.
Show resolved Hide resolved
jwtSVID = jwtSVIDs[0]
} else {
for id := range jwtSVIDs {
jwtSVID := jwtSVIDs[id]
if jwtSVID.Hint == Hint {
notFound = false
break
}
}
if notFound {
return fmt.Errorf("failed to find the hinted svid")
}
}

return os.WriteFile(filePath, []byte(jwtSVID.Marshal()), jwtSVIDFileMode)
}
Expand Down
16 changes: 14 additions & 2 deletions pkg/disk/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,25 @@ import (
// the Workload API, and calls writeCerts and writeKey to write to disk
// the svid, key and bundle of certificates.
// It is possible to change output setting `addIntermediatesToBundle` as true.
func WriteX509Context(x509Context *workloadapi.X509Context, addIntermediatesToBundle, includeFederatedDomains bool, certDir, svidFilename, svidKeyFilename, svidBundleFilename string, certFileMode, keyFileMode fs.FileMode) error {
func WriteX509Context(x509Context *workloadapi.X509Context, addIntermediatesToBundle, includeFederatedDomains bool, certDir, svidFilename, svidKeyFilename, svidBundleFilename string, certFileMode, keyFileMode fs.FileMode, Hint string) error {
kfox1111 marked this conversation as resolved.
Show resolved Hide resolved
svidFile := path.Join(certDir, svidFilename)
svidKeyFile := path.Join(certDir, svidKeyFilename)
svidBundleFile := path.Join(certDir, svidBundleFilename)

// There may be more than one certificate, but we're only interested in the default one
svid := x509Context.DefaultSVID()
if Hint != "" {
notFound := true
for id := range x509Context.SVIDs {
svid = x509Context.SVIDs[id]
if svid.Hint == Hint {
notFound = false
break
}
}
if notFound {
return fmt.Errorf("failed to find the hinted svid")
}
}

// Extract bundle for the default SVID
bundleSet, found := x509Context.Bundles.Get(svid.ID.TrustDomain())
Expand Down
3 changes: 3 additions & 0 deletions pkg/sidecar/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type Config struct {

// TODO: is there a reason for this to be exposed? and inside of config?
ReloadExternalProcess func() error

// Hint: The hint to pass to the spiffe endpoint to help select SPIFFE IDs
Hint string
kfox1111 marked this conversation as resolved.
Show resolved Hide resolved
}

type JWTConfig struct {
Expand Down
36 changes: 19 additions & 17 deletions pkg/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (s *Sidecar) setupClients(ctx context.Context) error {
// updateCertificates Updates the certificates stored in disk and signal the Process to restart
func (s *Sidecar) updateCertificates(svidResponse *workloadapi.X509Context) {
s.config.Log.Debug("Updating X.509 certificates")
if err := disk.WriteX509Context(svidResponse, s.config.AddIntermediatesToBundle, s.config.IncludeFederatedDomains, s.config.CertDir, s.config.SVIDFileName, s.config.SVIDKeyFileName, s.config.SVIDBundleFileName, s.config.CertFileMode, s.config.KeyFileMode); err != nil {
if err := disk.WriteX509Context(svidResponse, s.config.AddIntermediatesToBundle, s.config.IncludeFederatedDomains, s.config.CertDir, s.config.SVIDFileName, s.config.SVIDKeyFileName, s.config.SVIDBundleFileName, s.config.CertFileMode, s.config.KeyFileMode, s.config.Hint); err != nil {
kfox1111 marked this conversation as resolved.
Show resolved Hide resolved
s.config.Log.WithError(err).Error("Unable to dump bundle")
return
}
Expand Down Expand Up @@ -253,20 +253,22 @@ func (s *Sidecar) checkProcessExit() {
atomic.StoreInt32(&s.processRunning, 0)
}

func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string, jwtExtraAudiences []string) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience, ExtraAudiences: jwtExtraAudiences})
func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string, jwtExtraAudiences []string) ([]*jwtsvid.SVID, error) {
jwtSVIDs, err := s.jwtSource.FetchJWTSVIDs(ctx, jwtsvid.Params{Audience: jwtAudience, ExtraAudiences: jwtExtraAudiences})
if err != nil {
s.config.Log.Errorf("Unable to fetch JWT SVID: %v", err)
return nil, err
}

_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to parse or validate token: %v", err)
return nil, err
for id := range jwtSVIDs {
kfox1111 marked this conversation as resolved.
Show resolved Hide resolved
jwtSVID := jwtSVIDs[id]
_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to parse or validate token: %v", err)
return nil, err
}
}

return jwtSVID, nil
return jwtSVIDs, nil
}

func createRetryIntervalFunc() func() time.Duration {
Expand All @@ -291,34 +293,34 @@ func getRefreshInterval(svid *jwtsvid.SVID) time.Duration {
return time.Until(svid.Expiry)/2 + time.Second
}

func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtExtraAudiences []string, jwtSVIDFilename string) (*jwtsvid.SVID, error) {
func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtExtraAudiences []string, jwtSVIDFilename string) ([]*jwtsvid.SVID, error) {
s.config.Log.Debug("Updating JWT SVID")

jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience, jwtExtraAudiences)
jwtSVIDs, err := s.fetchJWTSVIDs(ctx, jwtAudience, jwtExtraAudiences)
if err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
}

if err = disk.WriteJWTSVID(jwtSVID, s.config.CertDir, jwtSVIDFilename, s.config.JWTSVIDFileMode); err != nil {
if err = disk.WriteJWTSVID(jwtSVIDs, s.config.CertDir, jwtSVIDFilename, s.config.JWTSVIDFileMode, s.config.Hint); err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
}

s.config.Log.Info("JWT SVID updated")
return jwtSVID, nil
return jwtSVIDs, nil
}

func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtExtraAudiences []string, jwtSVIDFilename string) {
retryInterval := createRetryIntervalFunc()
var initialInterval time.Duration
jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtExtraAudiences, jwtSVIDFilename)
jwtSVIDs, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtExtraAudiences, jwtSVIDFilename)
if err != nil {
// If the first update fails, use the retry interval
initialInterval = retryInterval()
} else {
// If the update succeeds, use the refresh interval
initialInterval = getRefreshInterval(jwtSVID)
initialInterval = getRefreshInterval(jwtSVIDs[0])
}
ticker := time.NewTicker(initialInterval)
defer ticker.Stop()
Expand All @@ -328,10 +330,10 @@ func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtExtr
case <-ctx.Done():
return
case <-ticker.C:
jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtExtraAudiences, jwtSVIDFilename)
jwtSVIDs, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtExtraAudiences, jwtSVIDFilename)
if err == nil {
retryInterval = createRetryIntervalFunc()
ticker.Reset(getRefreshInterval(jwtSVID))
ticker.Reset(getRefreshInterval(jwtSVIDs[0]))
} else {
ticker.Reset(retryInterval())
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/sidecar/workloadapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (s *Sidecar) fetchAndWriteX509Context(ctx context.Context) error {
return err
}

return disk.WriteX509Context(x509Context, s.config.AddIntermediatesToBundle, s.config.IncludeFederatedDomains, s.config.CertDir, s.config.SVIDFileName, s.config.SVIDKeyFileName, s.config.SVIDBundleFileName, s.config.CertFileMode, s.config.KeyFileMode)
return disk.WriteX509Context(x509Context, s.config.AddIntermediatesToBundle, s.config.IncludeFederatedDomains, s.config.CertDir, s.config.SVIDFileName, s.config.SVIDKeyFileName, s.config.SVIDBundleFileName, s.config.CertFileMode, s.config.KeyFileMode, s.config.Hint)
}

func (s *Sidecar) fetchAndWriteJWTBundle(ctx context.Context) error {
Expand Down Expand Up @@ -71,18 +71,18 @@ func (s *Sidecar) fetchAndWriteJWTSVIDs(ctx context.Context) error {
}

func (s *Sidecar) fetchAndWriteJWTSVID(ctx context.Context, audience, jwtSVIDFilename string) error {
var jwtSVID *jwtsvid.SVID
var jwtSVIDs []*jwtsvid.SVID

// Retry PermissionDenied errors. We may get a few of these before the cert is minted
err := retry.OnError(backoff, func(err error) bool {
return status.Code(err) == codes.PermissionDenied
}, func() (err error) {
jwtSVID, err = s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: audience})
jwtSVIDs, err = s.jwtSource.FetchJWTSVIDs(ctx, jwtsvid.Params{Audience: audience})
return err
})
if err != nil {
return err
}

return disk.WriteJWTSVID(jwtSVID, s.config.CertDir, jwtSVIDFilename, s.config.JWTSVIDFileMode)
return disk.WriteJWTSVID(jwtSVIDs, s.config.CertDir, jwtSVIDFilename, s.config.JWTSVIDFileMode, s.config.Hint)
}
Loading