From 8aedaf9635de9a84561563d6623ea312c197b07a Mon Sep 17 00:00:00 2001 From: Anton Miniailo Date: Tue, 30 Aug 2022 19:02:13 -0400 Subject: [PATCH] Add serialization of writes to `known_hosts` file. --- go.mod | 3 +- go.sum | 5 +- lib/client/keystore.go | 35 ++++++++++- lib/client/keystore_test.go | 120 ++++++++++++++++++++++++------------ 4 files changed, 120 insertions(+), 43 deletions(-) diff --git a/go.mod b/go.mod index 2bdfb93fa5b53..7787e2938414d 100644 --- a/go.mod +++ b/go.mod @@ -108,7 +108,7 @@ require ( golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f - golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 + golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 golang.org/x/text v0.3.7 golang.org/x/tools v0.1.6-0.20210820212750-d4cc65f0b2ff @@ -160,6 +160,7 @@ require ( github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect github.com/go-logr/logr v1.2.3 // indirect github.com/go-stack/stack v1.8.0 // indirect + github.com/gofrs/flock v0.8.1 github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 // indirect diff --git a/go.sum b/go.sum index e1890c8602de6..f72d2a501558e 100644 --- a/go.sum +++ b/go.sum @@ -373,6 +373,8 @@ github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGt github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= +github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= @@ -1380,8 +1382,9 @@ golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 1e042fac35d25..0edc9d35cd081 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -18,6 +18,7 @@ package client import ( "bufio" + "context" "encoding/pem" "fmt" "io" @@ -26,9 +27,12 @@ import ( "os" "path/filepath" "strings" + "time" "golang.org/x/crypto/ssh" + "github.com/gofrs/flock" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/utils/keypaths" @@ -536,7 +540,7 @@ func (fs *fsLocalNonSessionKeyStore) sshCAsPath(idx KeyIndex) string { return keypaths.SSHCAsPath(fs.KeyDir, idx.ProxyHost, idx.Username) } -// appCertPath returns the TLS certificate path for the given KeyIndex and app name. +// appCertPath returns the TLS certificate path for the given KeyIndex and app name. func (fs *fsLocalNonSessionKeyStore) appCertPath(idx KeyIndex, appname string) string { return keypaths.AppCertPath(fs.KeyDir, idx.ProxyHost, idx.Username, idx.ClusterName, appname) } @@ -551,8 +555,29 @@ func (fs *fsLocalNonSessionKeyStore) kubeCertPath(idx KeyIndex, kubename string) return keypaths.KubeCertPath(fs.KeyDir, idx.ProxyHost, idx.Username, idx.ClusterName, kubename) } +// acquireFileLock is trying to lock the file, until it's successful or timeout is exceeded. +// File will be created if it doesn't exist. +func acquireFileLock(filePath string, timeout time.Duration) (func() error, error) { + fileLock := flock.New(filePath) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + if _, err := fileLock.TryLockContext(ctx, 10*time.Millisecond); err != nil { + return nil, err + } + + return fileLock.Unlock, nil +} + // AddKnownHostKeys adds a new entry to `known_hosts` file. func (fs *fsLocalNonSessionKeyStore) AddKnownHostKeys(hostname, proxyHost string, hostKeys []ssh.PublicKey) (retErr error) { + // We're trying to serialize our writes to the 'known_hosts' file to avoid corruption, since there + // are cases when multiple tsh instances will try to write to it. + unlock, err := acquireFileLock(fs.knownHostsPath(), 5*time.Second) + if err != nil { + return trace.WrapWithMessage(err, "could not acquire lock for the `known_hosts` file") + } + defer utils.StoreErrorOf(unlock, &retErr) + fp, err := os.OpenFile(fs.knownHostsPath(), os.O_CREATE|os.O_RDWR, 0640) if err != nil { return trace.ConvertSystemError(err) @@ -642,7 +667,13 @@ func matchesWildcard(hostname, pattern string) bool { } // GetKnownHostKeys returns all known public keys from `known_hosts`. -func (fs *fsLocalNonSessionKeyStore) GetKnownHostKeys(hostname string) ([]ssh.PublicKey, error) { +func (fs *fsLocalNonSessionKeyStore) GetKnownHostKeys(hostname string) (keys []ssh.PublicKey, retErr error) { + unlock, err := acquireFileLock(fs.knownHostsPath(), 5*time.Second) + if err != nil { + return nil, trace.WrapWithMessage(err, "could not acquire lock for the `known_hosts` file") + } + defer utils.StoreErrorOf(unlock, &retErr) + bytes, err := ioutil.ReadFile(fs.knownHostsPath()) if err != nil { if os.IsNotExist(err) { diff --git a/lib/client/keystore_test.go b/lib/client/keystore_test.go index 0a9fe29197b80..4a5d2c05801df 100644 --- a/lib/client/keystore_test.go +++ b/lib/client/keystore_test.go @@ -24,9 +24,12 @@ import ( "io/ioutil" "os" "path/filepath" + "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keypaths" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" @@ -154,52 +157,91 @@ func TestDeleteAll(t *testing.T) { } func TestKnownHosts(t *testing.T) { - s, cleanup := newTest(t) - defer cleanup() + t.Parallel() + t.Run("can successfully write/read keys", func(t *testing.T) { + s, cleanup := newTest(t) + t.Cleanup(cleanup) - err := os.MkdirAll(s.store.KeyDir, 0777) - require.NoError(t, err) - pub, _, _, _, err := ssh.ParseAuthorizedKey(CAPub) - require.NoError(t, err) + err := os.MkdirAll(s.store.KeyDir, 0777) + require.NoError(t, err) + pub, _, _, _, err := ssh.ParseAuthorizedKey(CAPub) + require.NoError(t, err) - _, p2, _ := s.keygen.GenerateKeyPair("") - pub2, _, _, _, _ := ssh.ParseAuthorizedKey(p2) + _, p2, _ := s.keygen.GenerateKeyPair("") + pub2, _, _, _, _ := ssh.ParseAuthorizedKey(p2) - err = s.store.AddKnownHostKeys("example.com", "proxy.example.com", []ssh.PublicKey{pub}) - require.NoError(t, err) - err = s.store.AddKnownHostKeys("example.com", "proxy.example.com", []ssh.PublicKey{pub2}) - require.NoError(t, err) - err = s.store.AddKnownHostKeys("example.org", "proxy.example.org", []ssh.PublicKey{pub2}) - require.NoError(t, err) + err = s.store.AddKnownHostKeys("example.com", "proxy.example.com", []ssh.PublicKey{pub}) + require.NoError(t, err) + err = s.store.AddKnownHostKeys("example.com", "proxy.example.com", []ssh.PublicKey{pub2}) + require.NoError(t, err) + err = s.store.AddKnownHostKeys("example.org", "proxy.example.org", []ssh.PublicKey{pub2}) + require.NoError(t, err) - keys, err := s.store.GetKnownHostKeys("") - require.NoError(t, err) - require.Len(t, keys, 3) - require.Equal(t, keys, []ssh.PublicKey{pub, pub2, pub2}) + keys, err := s.store.GetKnownHostKeys("") + require.NoError(t, err) + require.Len(t, keys, 3) + require.Equal(t, keys, []ssh.PublicKey{pub, pub2, pub2}) - // check against dupes: - before, _ := s.store.GetKnownHostKeys("") - err = s.store.AddKnownHostKeys("example.org", "proxy.example.org", []ssh.PublicKey{pub2}) - require.NoError(t, err) - err = s.store.AddKnownHostKeys("example.org", "proxy.example.org", []ssh.PublicKey{pub2}) - require.NoError(t, err) - after, _ := s.store.GetKnownHostKeys("") - require.Equal(t, len(before), len(after)) + // check against dupes: + before, _ := s.store.GetKnownHostKeys("") + err = s.store.AddKnownHostKeys("example.org", "proxy.example.org", []ssh.PublicKey{pub2}) + require.NoError(t, err) + err = s.store.AddKnownHostKeys("example.org", "proxy.example.org", []ssh.PublicKey{pub2}) + require.NoError(t, err) + after, _ := s.store.GetKnownHostKeys("") + require.Equal(t, len(before), len(after)) + + // check by hostname: + keys, _ = s.store.GetKnownHostKeys("badhost") + require.Equal(t, len(keys), 0) + keys, _ = s.store.GetKnownHostKeys("example.org") + require.Equal(t, len(keys), 1) + require.True(t, apisshutils.KeysEqual(keys[0], pub2)) + + // check for proxy and wildcard as well: + keys, _ = s.store.GetKnownHostKeys("proxy.example.org") + require.Equal(t, 1, len(keys)) + require.True(t, apisshutils.KeysEqual(keys[0], pub2)) + keys, _ = s.store.GetKnownHostKeys("*.example.org") + require.Equal(t, 1, len(keys)) + require.True(t, apisshutils.KeysEqual(keys[0], pub2)) + }) + t.Run("can write keys in parallel without corrupting content of the file", func(t *testing.T) { + s, cleanup := newTest(t) + t.Cleanup(cleanup) - // check by hostname: - keys, _ = s.store.GetKnownHostKeys("badhost") - require.Equal(t, len(keys), 0) - keys, _ = s.store.GetKnownHostKeys("example.org") - require.Equal(t, len(keys), 1) - require.True(t, apisshutils.KeysEqual(keys[0], pub2)) + err := os.MkdirAll(s.store.KeyDir, 0777) + require.NoError(t, err) + pub, _, _, _, err := ssh.ParseAuthorizedKey(CAPub) + require.NoError(t, err) + + err = s.store.AddKnownHostKeys("example1.com", "proxy.example1.com", []ssh.PublicKey{pub}) + require.NoError(t, err) - // check for proxy and wildcard as well: - keys, _ = s.store.GetKnownHostKeys("proxy.example.org") - require.Equal(t, 1, len(keys)) - require.True(t, apisshutils.KeysEqual(keys[0], pub2)) - keys, _ = s.store.GetKnownHostKeys("*.example.org") - require.Equal(t, 1, len(keys)) - require.True(t, apisshutils.KeysEqual(keys[0], pub2)) + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + _, p2, _ := s.keygen.GenerateKeyPair("") + tmpPub, _, _, _, _ := ssh.ParseAuthorizedKey(p2) + + err := s.store.AddKnownHostKeys("example2.com", "proxy.example2.com", []ssh.PublicKey{tmpPub}) + assert.NoError(t, err) + + _, err = s.store.GetKnownHostKeys("") + assert.NoError(t, err) + + wg.Done() + }() + } + + wg.Wait() + + keys, err := s.store.GetKnownHostKeys("") + require.NoError(t, err) + require.NotNil(t, keys) + require.True(t, len(keys) > 1) + }) } // TestCheckKey makes sure Teleport clients can load non-RSA algorithms in