Skip to content

Commit

Permalink
Check if TLS certificate and key file have been modified (#345)
Browse files Browse the repository at this point in the history
* Check hash of cert and key file

Signed-off-by: Levi Harrison <[email protected]>

Signed-off-by: Simon Pasquier <[email protected]>
  • Loading branch information
LeviHarrison authored Nov 3, 2022
1 parent 54e041d commit 1c0fa3e
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 22 deletions.
89 changes: 69 additions & 20 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
return newRT(tlsConfig)
}

return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT)
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT)
}

type authorizationCredentialsRoundTripper struct {
Expand Down Expand Up @@ -709,7 +709,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
if len(rt.config.TLSConfig.CAFile) == 0 {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, tlsTransport)
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -838,12 +838,39 @@ func (c *TLSConfig) SetDirectory(dir string) {
c.KeyFile = JoinDir(dir, c.KeyFile)
}

// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain TLSConfig
return unmarshal((*plain)(c))
}

// readCertAndKey reads the cert and key files from the disk.
func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) {
certData, err := ioutil.ReadFile(certFile)
if err != nil {
return nil, nil, err
}

keyData, err := ioutil.ReadFile(keyFile)
if err != nil {
return nil, nil, err
}

return certData, keyData, nil
}

// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
}

cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
}

return &cert, nil
}

Expand All @@ -869,23 +896,30 @@ func updateRootCA(cfg *tls.Config, b []byte) bool {
// tlsRoundTripper is a RoundTripper that updates automatically its TLS
// configuration whenever the content of the CA file changes.
type tlsRoundTripper struct {
caFile string
caFile string
certFile string
keyFile string

// newRT returns a new RoundTripper.
newRT func(*tls.Config) (http.RoundTripper, error)

mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
tlsConfig *tls.Config
mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
hashCertFile []byte
hashKeyFile []byte
tlsConfig *tls.Config
}

func NewTLSRoundTripper(
cfg *tls.Config,
caFile string,
caFile, certFile, keyFile string,
newRT func(*tls.Config) (http.RoundTripper, error),
) (http.RoundTripper, error) {
t := &tlsRoundTripper{
caFile: caFile,
certFile: certFile,
keyFile: keyFile,
newRT: newRT,
tlsConfig: cfg,
}
Expand All @@ -895,33 +929,44 @@ func NewTLSRoundTripper(
return nil, err
}
t.rt = rt
_, t.hashCAFile, err = t.getCAWithHash()
_, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash()
if err != nil {
return nil, err
}

return t, nil
}

func (t *tlsRoundTripper) getCAWithHash() ([]byte, []byte, error) {
b, err := readCAFile(t.caFile)
func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) {
b1, err := readCAFile(t.caFile)
if err != nil {
return nil, nil, err
return nil, nil, nil, nil, err
}
h1 := sha256.Sum256(b1)

var h2, h3 [32]byte
if t.certFile != "" {
b2, b3, err := readCertAndKey(t.certFile, t.keyFile)
if err != nil {
return nil, nil, nil, nil, err
}
h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3)
}
h := sha256.Sum256(b)
return b, h[:], nil

return b1, h1[:], h2[:], h3[:], nil
}

// RoundTrip implements the http.RoundTrip interface.
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
b, h, err := t.getCAWithHash()
caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash()
if err != nil {
return nil, err
}

t.mtx.RLock()
equal := bytes.Equal(h[:], t.hashCAFile)
equal := bytes.Equal(caHash[:], t.hashCAFile) &&
bytes.Equal(certHash[:], t.hashCertFile) &&
bytes.Equal(keyHash[:], t.hashKeyFile)
rt := t.rt
t.mtx.RUnlock()
if equal {
Expand All @@ -930,8 +975,10 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
}

// Create a new RoundTripper.
// The cert and key files are read separately by the client
// using GetClientCertificate.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, b) {
if !updateRootCA(tlsConfig, caData) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
}
rt, err = t.newRT(tlsConfig)
Expand All @@ -942,7 +989,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {

t.mtx.Lock()
t.rt = rt
t.hashCAFile = h[:]
t.hashCAFile = caHash[:]
t.hashCertFile = certHash[:]
t.hashKeyFile = keyHash[:]
t.mtx.Unlock()

return rt.RoundTrip(req)
Expand Down
117 changes: 115 additions & 2 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,15 +718,15 @@ func TestTLSConfigInvalidCA(t *testing.T) {
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
}, {
configTLSConfig: TLSConfig{
CAFile: "",
CertFile: ClientCertificatePath,
KeyFile: MissingKey,
ServerName: "",
InsecureSkipVerify: false},
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
},
}

Expand Down Expand Up @@ -1532,3 +1532,116 @@ func TestOAuth2Proxy(t *testing.T) {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}

func TestModifyTLSCertificates(t *testing.T) {
bs := getCertificateBlobs(t)

tmpDir, err := ioutil.TempDir("", "modifytlscertificates")
if err != nil {
t.Fatal("Failed to create tmp dir", err)
}
defer os.RemoveAll(tmpDir)
ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key")

handler := func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
}
testServer, err := newTestServer(handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()

tests := []struct {
ca string
cert string
key string

errMsg string

modification func()
}{
{
ca: ClientCertificatePath,
cert: ClientCertificatePath,
key: ClientKeyNoPassPath,

errMsg: "certificate signed by unknown authority",

modification: func() { writeCertificate(bs, TLSCAChainPath, ca) },
},
{
ca: TLSCAChainPath,
cert: WrongClientCertPath,
key: ClientKeyNoPassPath,

errMsg: "private key does not match public key",

modification: func() { writeCertificate(bs, ClientCertificatePath, cert) },
},
{
ca: TLSCAChainPath,
cert: ClientCertificatePath,
key: WrongClientCertPath,

errMsg: "found a certificate rather than a key in the PEM for the private key",

modification: func() { writeCertificate(bs, ClientKeyNoPassPath, key) },
},
}

cfg := HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: ca,
CertFile: cert,
KeyFile: key,
InsecureSkipVerify: false},
}

var c *http.Client
for i, tc := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
writeCertificate(bs, tc.ca, ca)
writeCertificate(bs, tc.cert, cert)
writeCertificate(bs, tc.key, key)
if c == nil {
c, err = NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
}

req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
if err != nil {
t.Fatalf("Error creating HTTP request: %v", err)
}

r, err := c.Do(req)
if err == nil {
r.Body.Close()
t.Fatalf("Could connect to the test server.")
}
if !strings.Contains(err.Error(), tc.errMsg) {
t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err)
}

tc.modification()

r, err = c.Do(req)
if err != nil {
t.Fatalf("Expected no error, got %q", err)
}

b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body")
}

got := strings.TrimSpace(string(b))
if ExpectedMessage != got {
t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got)
}
})
}
}

0 comments on commit 1c0fa3e

Please sign in to comment.