Skip to content

Commit 209fd8b

Browse files
committed
tls cert reloader
Signed-off-by: Sean Liao <[email protected]>
1 parent fda87ac commit 209fd8b

File tree

3 files changed

+138
-29
lines changed

3 files changed

+138
-29
lines changed

Diff for: cmd/dex/serve.go

+134-29
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@ import (
1010
"net/http"
1111
"net/http/pprof"
1212
"os"
13+
"os/signal"
14+
"path/filepath"
1315
"runtime"
1416
"strings"
17+
"sync/atomic"
1518
"syscall"
1619
"time"
1720

1821
gosundheit "github.com/AppsFlyer/go-sundheit"
1922
"github.com/AppsFlyer/go-sundheit/checks"
2023
gosundheithttp "github.com/AppsFlyer/go-sundheit/http"
24+
"github.com/fsnotify/fsnotify"
2125
"github.com/ghodss/yaml"
2226
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
2327
"github.com/oklog/run"
@@ -142,41 +146,26 @@ func runServe(options serveOptions) error {
142146
}
143147

144148
if c.GRPC.TLSCert != "" {
145-
// Parse certificates from certificate file and key file for server.
146-
cert, err := tls.LoadX509KeyPair(c.GRPC.TLSCert, c.GRPC.TLSKey)
147-
if err != nil {
148-
return fmt.Errorf("invalid config: error parsing gRPC certificate file: %v", err)
149-
}
150-
151-
tlsConfig := tls.Config{
152-
Certificates: []tls.Certificate{cert},
149+
baseTLSConfig := tls.Config{
153150
MinVersion: tls.VersionTLS12,
154151
CipherSuites: allowedTLSCiphers,
155152
PreferServerCipherSuites: true,
156153
}
157154

158-
if c.GRPC.TLSClientCA != "" {
159-
// Parse certificates from client CA file to a new CertPool.
160-
cPool := x509.NewCertPool()
161-
clientCert, err := os.ReadFile(c.GRPC.TLSClientCA)
162-
if err != nil {
163-
return fmt.Errorf("invalid config: reading from client CA file: %v", err)
164-
}
165-
if !cPool.AppendCertsFromPEM(clientCert) {
166-
return errors.New("invalid config: failed to parse client CA")
167-
}
168-
169-
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
170-
tlsConfig.ClientCAs = cPool
155+
tlsConfig, err := newTLSReloader(logger, c.GRPC.TLSCert, c.GRPC.TLSKey, c.GRPC.TLSClientCA, baseTLSConfig)
156+
if err != nil {
157+
return fmt.Errorf("invalid config: get gRPC TLS: %v", err)
158+
}
171159

160+
if c.GRPC.TLSClientCA != "" {
172161
// Only add metrics if client auth is enabled
173162
grpcOptions = append(grpcOptions,
174163
grpc.StreamInterceptor(grpcMetrics.StreamServerInterceptor()),
175164
grpc.UnaryInterceptor(grpcMetrics.UnaryServerInterceptor()),
176165
)
177166
}
178167

179-
grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(&tlsConfig)))
168+
grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))
180169
}
181170

182171
s, err := c.Storage.Config.Open(logger)
@@ -431,18 +420,25 @@ func runServe(options serveOptions) error {
431420
return fmt.Errorf("listening (%s) on %s: %v", name, c.Web.HTTPS, err)
432421
}
433422

423+
baseTLSConfig := tls.Config{
424+
MinVersion: tls.VersionTLS12,
425+
CipherSuites: allowedTLSCiphers,
426+
PreferServerCipherSuites: true,
427+
}
428+
429+
tlsConfig, err := newTLSReloader(logger, c.Web.TLSCert, c.Web.TLSKey, "", baseTLSConfig)
430+
if err != nil {
431+
return fmt.Errorf("invalid config: get HTTP TLS: %v", err)
432+
}
433+
434434
server := &http.Server{
435-
Handler: serv,
436-
TLSConfig: &tls.Config{
437-
CipherSuites: allowedTLSCiphers,
438-
PreferServerCipherSuites: true,
439-
MinVersion: tls.VersionTLS12,
440-
},
435+
Handler: serv,
436+
TLSConfig: tlsConfig,
441437
}
442438
defer server.Close()
443439

444440
group.Add(func() error {
445-
return server.ServeTLS(l, c.Web.TLSCert, c.Web.TLSKey)
441+
return server.ServeTLS(l, "", "")
446442
}, func(err error) {
447443
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
448444
defer cancel()
@@ -563,3 +559,112 @@ func pprofHandler(router *http.ServeMux) {
563559
router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
564560
router.HandleFunc("/debug/pprof/trace", pprof.Trace)
565561
}
562+
563+
// newTLSReloader returns a [tls.Config] with GetCertificate or GetConfigForClient set
564+
// to reload certificates from the given paths on SIGHUP or on file creates (atomic update via rename).
565+
func newTLSReloader(logger log.Logger, certFile, keyFile, caFile string, baseConfig tls.Config) (*tls.Config, error) {
566+
// trigger reload on channel
567+
sigc := make(chan os.Signal)
568+
signal.Notify(sigc, syscall.SIGHUP)
569+
570+
// files to watch
571+
watchFiles := []string{certFile, keyFile}
572+
if caFile != "" {
573+
watchFiles = append(watchFiles, caFile)
574+
}
575+
watchDirs := make(map[string]struct{}) // dedupe dirs
576+
for i, f := range watchFiles {
577+
dir := filepath.Dir(f)
578+
if !strings.HasPrefix(f, dir) {
579+
// normalize name to have ./ prefix if only a local path was provided
580+
// can't pass "" to watcher.Add
581+
watchFiles[i] = dir + string(filepath.Separator) + f
582+
}
583+
watchDirs[dir] = struct{}{}
584+
}
585+
// trigger reload on file change
586+
watcher, err := fsnotify.NewWatcher()
587+
if err != nil {
588+
return nil, fmt.Errorf("create watcher for TLS reloader: %v", err)
589+
}
590+
// recommended by fsnotify: watch the dir to handle renames
591+
for dir := range watchDirs {
592+
logger.Debugf("watching dir: %v", dir)
593+
err := watcher.Add(dir)
594+
if err != nil {
595+
return nil, fmt.Errorf("watch dir for TLS reloader: %v", err)
596+
}
597+
}
598+
599+
// load once outside the goroutine so we can return an error on misconfig
600+
initialConfig, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig)
601+
if err != nil {
602+
return nil, fmt.Errorf("load TLS config: %v", err)
603+
}
604+
605+
// stored version of current tls config
606+
ptr := &atomic.Pointer[tls.Config]{}
607+
ptr.Store(initialConfig)
608+
609+
// start background worker to reload certs
610+
go func() {
611+
loop:
612+
for {
613+
select {
614+
case sig := <-sigc:
615+
logger.Debug("reloading cert from signal: %v", sig)
616+
case evt := <-watcher.Events:
617+
var found bool
618+
for _, f := range watchFiles {
619+
if evt.Name == f {
620+
found = true
621+
}
622+
}
623+
if !found || !evt.Has(fsnotify.Create) {
624+
continue loop
625+
}
626+
logger.Debug("reloading cert from fsnotify: %v %v", evt.Name, evt.Op.String())
627+
case err := <-watcher.Errors:
628+
logger.Errorf("TLS reloader watch: %v", err)
629+
}
630+
631+
loaded, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig)
632+
if err != nil {
633+
logger.Errorf("reload TLS config: %v", err)
634+
}
635+
ptr.Store(loaded)
636+
}
637+
}()
638+
639+
conf := &tls.Config{}
640+
if caFile != "" {
641+
conf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { return ptr.Load(), nil }
642+
} else {
643+
conf.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { return &ptr.Load().Certificates[0], nil }
644+
}
645+
return conf, nil
646+
}
647+
648+
// loadTLSConfig loads the given file paths into a [tls.Config]
649+
func loadTLSConfig(certFile, keyFile, caFile string, baseConfig tls.Config) (*tls.Config, error) {
650+
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
651+
if err != nil {
652+
return nil, fmt.Errorf("loading TLS keypair: %v", err)
653+
}
654+
loadedConfig := baseConfig // copy
655+
loadedConfig.Certificates = []tls.Certificate{cert}
656+
if caFile != "" {
657+
cPool := x509.NewCertPool()
658+
clientCert, err := os.ReadFile(caFile)
659+
if err != nil {
660+
return nil, fmt.Errorf("reading from client CA file: %v", err)
661+
}
662+
if !cPool.AppendCertsFromPEM(clientCert) {
663+
return nil, errors.New("failed to parse client CA")
664+
}
665+
666+
loadedConfig.ClientAuth = tls.RequireAndVerifyClientCert
667+
loadedConfig.ClientCAs = cPool
668+
}
669+
return &loadedConfig, nil
670+
}

Diff for: go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ require (
1111
github.com/coreos/go-oidc/v3 v3.5.0
1212
github.com/dexidp/dex/api/v2 v2.1.0
1313
github.com/felixge/httpsnoop v1.0.3
14+
github.com/fsnotify/fsnotify v1.6.0
1415
github.com/ghodss/yaml v1.0.0
1516
github.com/go-ldap/ldap/v3 v3.4.4
1617
github.com/go-sql-driver/mysql v1.7.0

Diff for: go.sum

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBd
5757
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
5858
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
5959
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
60+
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
61+
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
6062
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
6163
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
6264
github.com/go-asn1-ber/asn1-ber v1.5.4 h1:vXT6d/FNDiELJnLb6hGNa309LMsrCoYFvpwHDF0+Y1A=
@@ -286,6 +288,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
286288
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
287289
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
288290
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
291+
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
289292
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
290293
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
291294
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=

0 commit comments

Comments
 (0)