From f10e47f7281a41a5fa94c2ee74372796869810b4 Mon Sep 17 00:00:00 2001 From: Sean Liao Date: Sat, 20 May 2023 23:26:38 +0100 Subject: [PATCH] tls cert reloader Signed-off-by: Sean Liao --- cmd/dex/serve.go | 164 ++++++++++++++++++++++++++++++++++++++--------- go.mod | 1 + go.sum | 3 + 3 files changed, 139 insertions(+), 29 deletions(-) diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index c8fb95eb16..d7746440c1 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -10,14 +10,18 @@ import ( "net/http" "net/http/pprof" "os" + "os/signal" + "path/filepath" "runtime" "strings" + "sync/atomic" "syscall" "time" gosundheit "github.com/AppsFlyer/go-sundheit" "github.com/AppsFlyer/go-sundheit/checks" gosundheithttp "github.com/AppsFlyer/go-sundheit/http" + "github.com/fsnotify/fsnotify" "github.com/ghodss/yaml" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/oklog/run" @@ -142,33 +146,18 @@ func runServe(options serveOptions) error { } if c.GRPC.TLSCert != "" { - // Parse certificates from certificate file and key file for server. - cert, err := tls.LoadX509KeyPair(c.GRPC.TLSCert, c.GRPC.TLSKey) - if err != nil { - return fmt.Errorf("invalid config: error parsing gRPC certificate file: %v", err) - } - - tlsConfig := tls.Config{ - Certificates: []tls.Certificate{cert}, + baseTLSConfig := &tls.Config{ MinVersion: tls.VersionTLS12, CipherSuites: allowedTLSCiphers, PreferServerCipherSuites: true, } - if c.GRPC.TLSClientCA != "" { - // Parse certificates from client CA file to a new CertPool. - cPool := x509.NewCertPool() - clientCert, err := os.ReadFile(c.GRPC.TLSClientCA) - if err != nil { - return fmt.Errorf("invalid config: reading from client CA file: %v", err) - } - if !cPool.AppendCertsFromPEM(clientCert) { - return errors.New("invalid config: failed to parse client CA") - } - - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - tlsConfig.ClientCAs = cPool + tlsConfig, err := newTLSReloader(logger, c.GRPC.TLSCert, c.GRPC.TLSKey, c.GRPC.TLSClientCA, baseTLSConfig) + if err != nil { + return fmt.Errorf("invalid config: get gRPC TLS: %v", err) + } + if c.GRPC.TLSClientCA != "" { // Only add metrics if client auth is enabled grpcOptions = append(grpcOptions, grpc.StreamInterceptor(grpcMetrics.StreamServerInterceptor()), @@ -176,7 +165,7 @@ func runServe(options serveOptions) error { ) } - grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(&tlsConfig))) + grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(tlsConfig))) } s, err := c.Storage.Config.Open(logger) @@ -431,18 +420,25 @@ func runServe(options serveOptions) error { return fmt.Errorf("listening (%s) on %s: %v", name, c.Web.HTTPS, err) } + baseTLSConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + CipherSuites: allowedTLSCiphers, + PreferServerCipherSuites: true, + } + + tlsConfig, err := newTLSReloader(logger, c.Web.TLSCert, c.Web.TLSKey, "", baseTLSConfig) + if err != nil { + return fmt.Errorf("invalid config: get HTTP TLS: %v", err) + } + server := &http.Server{ - Handler: serv, - TLSConfig: &tls.Config{ - CipherSuites: allowedTLSCiphers, - PreferServerCipherSuites: true, - MinVersion: tls.VersionTLS12, - }, + Handler: serv, + TLSConfig: tlsConfig, } defer server.Close() group.Add(func() error { - return server.ServeTLS(l, c.Web.TLSCert, c.Web.TLSKey) + return server.ServeTLS(l, "", "") }, func(err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() @@ -563,3 +559,113 @@ func pprofHandler(router *http.ServeMux) { router.HandleFunc("/debug/pprof/symbol", pprof.Symbol) router.HandleFunc("/debug/pprof/trace", pprof.Trace) } + +// newTLSReloader returns a [tls.Config] with GetCertificate or GetConfigForClient set +// to reload certificates from the given paths on SIGHUP or on file creates (atomic update via rename). +func newTLSReloader(logger log.Logger, certFile, keyFile, caFile string, baseConfig *tls.Config) (*tls.Config, error) { + // trigger reload on channel + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, syscall.SIGHUP) + + // files to watch + watchFiles := map[string]struct{}{ + certFile: {}, + keyFile: {}, + } + if caFile != "" { + watchFiles[caFile] = struct{}{} + } + watchDirs := make(map[string]struct{}) // dedupe dirs + for f := range watchFiles { + dir := filepath.Dir(f) + if !strings.HasPrefix(f, dir) { + // normalize name to have ./ prefix if only a local path was provided + // can't pass "" to watcher.Add + watchFiles[dir+string(filepath.Separator)+f] = struct{}{} + } + watchDirs[dir] = struct{}{} + } + // trigger reload on file change + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("create watcher for TLS reloader: %v", err) + } + // recommended by fsnotify: watch the dir to handle renames + // https://pkg.go.dev/github.com/fsnotify/fsnotify#hdr-Watching_files + for dir := range watchDirs { + logger.Debugf("watching dir: %v", dir) + err := watcher.Add(dir) + if err != nil { + return nil, fmt.Errorf("watch dir for TLS reloader: %v", err) + } + } + + // load once outside the goroutine so we can return an error on misconfig + initialConfig, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig) + if err != nil { + return nil, fmt.Errorf("load TLS config: %v", err) + } + + // stored version of current tls config + ptr := &atomic.Pointer[tls.Config]{} + ptr.Store(initialConfig) + + // start background worker to reload certs + go func() { + loop: + for { + select { + case sig := <-sigc: + logger.Debug("reloading cert from signal: %v", sig) + case evt := <-watcher.Events: + _, ok := watchFiles[evt.Name] + if !ok || !evt.Has(fsnotify.Create) { + continue loop + } + logger.Debug("reloading cert from fsnotify: %v %v", evt.Name, evt.Op.String()) + case err := <-watcher.Errors: + logger.Errorf("TLS reloader watch: %v", err) + } + + loaded, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig) + if err != nil { + logger.Errorf("reload TLS config: %v", err) + } + ptr.Store(loaded) + } + }() + + conf := &tls.Config{ + // net/http only uses Certificates or GetCertificate + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { return &ptr.Load().Certificates[0], nil }, + } + if caFile != "" { + // grpc will use this via tls.Server for mTLS + conf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { return ptr.Load(), nil } + } + return conf, nil +} + +// loadTLSConfig loads the given file paths into a [tls.Config] +func loadTLSConfig(certFile, keyFile, caFile string, baseConfig *tls.Config) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, fmt.Errorf("loading TLS keypair: %v", err) + } + loadedConfig := baseConfig.Clone() // copy + loadedConfig.Certificates = []tls.Certificate{cert} + if caFile != "" { + cPool := x509.NewCertPool() + clientCert, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("reading from client CA file: %v", err) + } + if !cPool.AppendCertsFromPEM(clientCert) { + return nil, errors.New("failed to parse client CA") + } + + loadedConfig.ClientAuth = tls.RequireAndVerifyClientCert + loadedConfig.ClientCAs = cPool + } + return loadedConfig, nil +} diff --git a/go.mod b/go.mod index efa89aae99..7cb1039c43 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/coreos/go-oidc/v3 v3.5.0 github.com/dexidp/dex/api/v2 v2.1.0 github.com/felixge/httpsnoop v1.0.3 + github.com/fsnotify/fsnotify v1.6.0 github.com/ghodss/yaml v1.0.0 github.com/go-ldap/ldap/v3 v3.4.4 github.com/go-sql-driver/mysql v1.7.0 diff --git a/go.sum b/go.sum index e4bee9b4ae..596367c9aa 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBd github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-asn1-ber/asn1-ber v1.5.4 h1:vXT6d/FNDiELJnLb6hGNa309LMsrCoYFvpwHDF0+Y1A= @@ -286,6 +288,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=