@@ -10,14 +10,18 @@ import (
10
10
"net/http"
11
11
"net/http/pprof"
12
12
"os"
13
+ "os/signal"
14
+ "path/filepath"
13
15
"runtime"
14
16
"strings"
17
+ "sync/atomic"
15
18
"syscall"
16
19
"time"
17
20
18
21
gosundheit "github.com/AppsFlyer/go-sundheit"
19
22
"github.com/AppsFlyer/go-sundheit/checks"
20
23
gosundheithttp "github.com/AppsFlyer/go-sundheit/http"
24
+ "github.com/fsnotify/fsnotify"
21
25
"github.com/ghodss/yaml"
22
26
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
23
27
"github.com/oklog/run"
@@ -142,41 +146,26 @@ func runServe(options serveOptions) error {
142
146
}
143
147
144
148
if c .GRPC .TLSCert != "" {
145
- // Parse certificates from certificate file and key file for server.
146
- cert , err := tls .LoadX509KeyPair (c .GRPC .TLSCert , c .GRPC .TLSKey )
147
- if err != nil {
148
- return fmt .Errorf ("invalid config: error parsing gRPC certificate file: %v" , err )
149
- }
150
-
151
- tlsConfig := tls.Config {
152
- Certificates : []tls.Certificate {cert },
149
+ baseTLSConfig := tls.Config {
153
150
MinVersion : tls .VersionTLS12 ,
154
151
CipherSuites : allowedTLSCiphers ,
155
152
PreferServerCipherSuites : true ,
156
153
}
157
154
158
- if c .GRPC .TLSClientCA != "" {
159
- // Parse certificates from client CA file to a new CertPool.
160
- cPool := x509 .NewCertPool ()
161
- clientCert , err := os .ReadFile (c .GRPC .TLSClientCA )
162
- if err != nil {
163
- return fmt .Errorf ("invalid config: reading from client CA file: %v" , err )
164
- }
165
- if ! cPool .AppendCertsFromPEM (clientCert ) {
166
- return errors .New ("invalid config: failed to parse client CA" )
167
- }
168
-
169
- tlsConfig .ClientAuth = tls .RequireAndVerifyClientCert
170
- tlsConfig .ClientCAs = cPool
155
+ tlsConfig , err := newTLSReloader (logger , c .GRPC .TLSCert , c .GRPC .TLSKey , c .GRPC .TLSClientCA , baseTLSConfig )
156
+ if err != nil {
157
+ return fmt .Errorf ("invalid config: get gRPC TLS: %v" , err )
158
+ }
171
159
160
+ if c .GRPC .TLSClientCA != "" {
172
161
// Only add metrics if client auth is enabled
173
162
grpcOptions = append (grpcOptions ,
174
163
grpc .StreamInterceptor (grpcMetrics .StreamServerInterceptor ()),
175
164
grpc .UnaryInterceptor (grpcMetrics .UnaryServerInterceptor ()),
176
165
)
177
166
}
178
167
179
- grpcOptions = append (grpcOptions , grpc .Creds (credentials .NewTLS (& tlsConfig )))
168
+ grpcOptions = append (grpcOptions , grpc .Creds (credentials .NewTLS (tlsConfig )))
180
169
}
181
170
182
171
s , err := c .Storage .Config .Open (logger )
@@ -431,18 +420,25 @@ func runServe(options serveOptions) error {
431
420
return fmt .Errorf ("listening (%s) on %s: %v" , name , c .Web .HTTPS , err )
432
421
}
433
422
423
+ baseTLSConfig := tls.Config {
424
+ MinVersion : tls .VersionTLS12 ,
425
+ CipherSuites : allowedTLSCiphers ,
426
+ PreferServerCipherSuites : true ,
427
+ }
428
+
429
+ tlsConfig , err := newTLSReloader (logger , c .Web .TLSCert , c .Web .TLSKey , "" , baseTLSConfig )
430
+ if err != nil {
431
+ return fmt .Errorf ("invalid config: get HTTP TLS: %v" , err )
432
+ }
433
+
434
434
server := & http.Server {
435
- Handler : serv ,
436
- TLSConfig : & tls.Config {
437
- CipherSuites : allowedTLSCiphers ,
438
- PreferServerCipherSuites : true ,
439
- MinVersion : tls .VersionTLS12 ,
440
- },
435
+ Handler : serv ,
436
+ TLSConfig : tlsConfig ,
441
437
}
442
438
defer server .Close ()
443
439
444
440
group .Add (func () error {
445
- return server .ServeTLS (l , c . Web . TLSCert , c . Web . TLSKey )
441
+ return server .ServeTLS (l , "" , "" )
446
442
}, func (err error ) {
447
443
ctx , cancel := context .WithTimeout (context .Background (), time .Minute )
448
444
defer cancel ()
@@ -563,3 +559,112 @@ func pprofHandler(router *http.ServeMux) {
563
559
router .HandleFunc ("/debug/pprof/symbol" , pprof .Symbol )
564
560
router .HandleFunc ("/debug/pprof/trace" , pprof .Trace )
565
561
}
562
+
563
+ // newTLSReloader returns a [tls.Config] with GetCertificate or GetConfigForClient set
564
+ // to reload certificates from the given paths on SIGHUP or on file creates (atomic update via rename).
565
+ func newTLSReloader (logger log.Logger , certFile , keyFile , caFile string , baseConfig tls.Config ) (* tls.Config , error ) {
566
+ // trigger reload on channel
567
+ sigc := make (chan os.Signal )
568
+ signal .Notify (sigc , syscall .SIGHUP )
569
+
570
+ // files to watch
571
+ watchFiles := []string {certFile , keyFile }
572
+ if caFile != "" {
573
+ watchFiles = append (watchFiles , caFile )
574
+ }
575
+ watchDirs := make (map [string ]struct {}) // dedupe dirs
576
+ for i , f := range watchFiles {
577
+ dir := filepath .Dir (f )
578
+ if ! strings .HasPrefix (f , dir ) {
579
+ // normalize name to have ./ prefix if only a local path was provided
580
+ // can't pass "" to watcher.Add
581
+ watchFiles [i ] = dir + string (filepath .Separator ) + f
582
+ }
583
+ watchDirs [dir ] = struct {}{}
584
+ }
585
+ // trigger reload on file change
586
+ watcher , err := fsnotify .NewWatcher ()
587
+ if err != nil {
588
+ return nil , fmt .Errorf ("create watcher for TLS reloader: %v" , err )
589
+ }
590
+ // recommended by fsnotify: watch the dir to handle renames
591
+ for dir := range watchDirs {
592
+ logger .Debugf ("watching dir: %v" , dir )
593
+ err := watcher .Add (dir )
594
+ if err != nil {
595
+ return nil , fmt .Errorf ("watch dir for TLS reloader: %v" , err )
596
+ }
597
+ }
598
+
599
+ // load once outside the goroutine so we can return an error on misconfig
600
+ initialConfig , err := loadTLSConfig (certFile , keyFile , caFile , baseConfig )
601
+ if err != nil {
602
+ return nil , fmt .Errorf ("load TLS config: %v" , err )
603
+ }
604
+
605
+ // stored version of current tls config
606
+ ptr := & atomic.Pointer [tls.Config ]{}
607
+ ptr .Store (initialConfig )
608
+
609
+ // start background worker to reload certs
610
+ go func () {
611
+ loop:
612
+ for {
613
+ select {
614
+ case sig := <- sigc :
615
+ logger .Debug ("reloading cert from signal: %v" , sig )
616
+ case evt := <- watcher .Events :
617
+ var found bool
618
+ for _ , f := range watchFiles {
619
+ if evt .Name == f {
620
+ found = true
621
+ }
622
+ }
623
+ if ! found || ! evt .Has (fsnotify .Create ) {
624
+ continue loop
625
+ }
626
+ logger .Debug ("reloading cert from fsnotify: %v %v" , evt .Name , evt .Op .String ())
627
+ case err := <- watcher .Errors :
628
+ logger .Errorf ("TLS reloader watch: %v" , err )
629
+ }
630
+
631
+ loaded , err := loadTLSConfig (certFile , keyFile , caFile , baseConfig )
632
+ if err != nil {
633
+ logger .Errorf ("reload TLS config: %v" , err )
634
+ }
635
+ ptr .Store (loaded )
636
+ }
637
+ }()
638
+
639
+ conf := & tls.Config {}
640
+ if caFile != "" {
641
+ conf .GetConfigForClient = func (chi * tls.ClientHelloInfo ) (* tls.Config , error ) { return ptr .Load (), nil }
642
+ } else {
643
+ conf .GetCertificate = func (chi * tls.ClientHelloInfo ) (* tls.Certificate , error ) { return & ptr .Load ().Certificates [0 ], nil }
644
+ }
645
+ return conf , nil
646
+ }
647
+
648
+ // loadTLSConfig loads the given file paths into a [tls.Config]
649
+ func loadTLSConfig (certFile , keyFile , caFile string , baseConfig tls.Config ) (* tls.Config , error ) {
650
+ cert , err := tls .LoadX509KeyPair (certFile , keyFile )
651
+ if err != nil {
652
+ return nil , fmt .Errorf ("loading TLS keypair: %v" , err )
653
+ }
654
+ loadedConfig := baseConfig // copy
655
+ loadedConfig .Certificates = []tls.Certificate {cert }
656
+ if caFile != "" {
657
+ cPool := x509 .NewCertPool ()
658
+ clientCert , err := os .ReadFile (caFile )
659
+ if err != nil {
660
+ return nil , fmt .Errorf ("reading from client CA file: %v" , err )
661
+ }
662
+ if ! cPool .AppendCertsFromPEM (clientCert ) {
663
+ return nil , errors .New ("failed to parse client CA" )
664
+ }
665
+
666
+ loadedConfig .ClientAuth = tls .RequireAndVerifyClientCert
667
+ loadedConfig .ClientCAs = cPool
668
+ }
669
+ return & loadedConfig , nil
670
+ }
0 commit comments