diff --git a/lib/client/keystore.go b/lib/client/keystore.go index f4c58e14cc952..fa17306b29089 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/api/utils/keypaths" "github.com/gravitational/teleport/api/utils/keys" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/utils" ) const ( @@ -235,7 +236,7 @@ func (fs *FSKeyStore) DeleteKey(idx KeyIndex) error { fs.tlsCertPath(idx), } for _, fn := range files { - if err := os.Remove(fn); err != nil { + if err := utils.RemoveSecure(fn); err != nil { return trace.ConvertSystemError(err) } } @@ -243,11 +244,11 @@ func (fs *FSKeyStore) DeleteKey(idx KeyIndex) error { // but it may not exist when upgrading from v9 -> v10 and logging into an existing cluster. // as such, deletion should be best-effort and not generate an error if it fails. if runtime.GOOS == constants.WindowsOS { - _ = os.Remove(fs.ppkFilePath(idx)) + _ = utils.RemoveSecure(fs.ppkFilePath(idx)) } // And try to delete kube credentials lockfile in case it exists - err := os.Remove(fs.kubeCredLockfilePath(idx)) + err := utils.RemoveSecure(fs.kubeCredLockfilePath(idx)) if err != nil && !errors.Is(err, iofs.ErrNotExist) { log.Debugf("Could not remove kube credentials file: %v", err) } @@ -266,7 +267,7 @@ func (fs *FSKeyStore) DeleteKey(idx KeyIndex) error { func (fs *FSKeyStore) DeleteUserCerts(idx KeyIndex, opts ...CertOption) error { for _, o := range opts { certPath := o.certPath(fs.KeyDir, idx) - if err := os.RemoveAll(certPath); err != nil { + if err := utils.RemoveAllSecure(certPath); err != nil { return trace.ConvertSystemError(err) } } @@ -289,13 +290,13 @@ func (fs *FSKeyStore) DeleteKeys() error { continue } if file.IsDir() { - err := os.RemoveAll(filepath.Join(fs.KeyDir, file.Name())) + err := utils.RemoveAllSecure(filepath.Join(fs.KeyDir, file.Name())) if err != nil { return trace.ConvertSystemError(err) } continue } - err := os.Remove(filepath.Join(fs.KeyDir, file.Name())) + err := utils.RemoveAllSecure(filepath.Join(fs.KeyDir, file.Name())) if err != nil { return trace.ConvertSystemError(err) } diff --git a/lib/utils/fs.go b/lib/utils/fs.go index 2153bd1eb79a3..8d00de71114d5 100644 --- a/lib/utils/fs.go +++ b/lib/utils/fs.go @@ -26,11 +26,11 @@ import ( "path/filepath" "runtime" "strings" + "syscall" "time" "github.com/gofrs/flock" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport" ) @@ -242,37 +242,102 @@ func FSTryReadLockTimeout(ctx context.Context, filePath string, timeout time.Dur return fileLock.Unlock, nil } +// RemoveAllSecure is similar to [os.RemoveAll] but leverages [RemoveSecure] to delete files so that they are +// overwritten. This helps guard against hardware attacks on magnetic disks. +func RemoveAllSecure(path string) error { + if path == "" { + // match behavior from os.RemoveAll + return nil + } + // Match os.RemoveAll protections in not permitting removal of "." directories + // This check comes directly from https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/os/removeall_at.go;l=24 + if path == "." || (len(path) >= 2 && path[len(path)-1] == '.' && os.IsPathSeparator(path[len(path)-2])) { + return &os.PathError{Op: "RemoveAllSecure", Path: path, Err: syscall.EINVAL} // error type matches os.RemoveAll + } + + info, err := os.Lstat(path) + switch { + case err != nil && os.IsNotExist(err): + return nil + case err != nil: + return trace.ConvertSystemError(err) + case !info.IsDir(): + return removeSecure(path, info) + } + var removeErrors []error + files, err := os.ReadDir(path) + if err != nil { + // Don't fail fast, allow removal at end to be attempted. + removeErrors = append(removeErrors, err) + } + // It's possible for a partial file list to be returned even if an error above was returned. + for _, f := range files { + if err := RemoveAllSecure(filepath.Join(path, f.Name())); err != nil { + removeErrors = append(removeErrors, err) + } + } + if err := os.Remove(path); err != nil { + removeErrors = append(removeErrors, err) + } + switch len(removeErrors) { + case 1: + return trace.ConvertSystemError(removeErrors[0]) + case 0: + return nil + default: + return trace.NewAggregate(removeErrors...) + } +} + // RemoveSecure attempts to securely delete the file by first overwriting the file with random data three times // followed by calling os.Remove(filePath). func RemoveSecure(filePath string) error { - for i := 0; i < 3; i++ { - if err := overwriteFile(filePath); err != nil { - return trace.Wrap(err) - } + info, err := os.Lstat(filePath) + if err != nil && os.IsNotExist(err) { + return err } - return trace.ConvertSystemError(os.Remove(filePath)) + // Don't fast return on other errors, still allow removeSecure to attempt removal. + return removeSecure(filePath, info) } -func overwriteFile(filePath string) (err error) { - f, err := os.OpenFile(filePath, os.O_WRONLY, 0) - if err != nil { - return trace.ConvertSystemError(err) +func removeSecure(filePath string, fi os.FileInfo) error { + if fi.Mode().Type()&os.ModeSymlink != 0 { + return os.Remove(filePath) } - defer func() { - if closeErr := f.Close(); closeErr != nil { - if err == nil { - err = trace.ConvertSystemError(closeErr) - } else { - log.WithError(closeErr).Warningf("Failed to close %v.", f.Name()) + f, openErr := os.OpenFile(filePath, os.O_WRONLY, 0) + switch { + case os.IsNotExist(openErr): + return trace.ConvertSystemError(openErr) + case openErr != nil: + // Attempt delete anyway. + return trace.ConvertSystemError(os.Remove(filePath)) + } + defer f.Close() + + if runtime.GOOS == "windows" { + // Windows can't unlink the file before overwriting. + if f != nil { + for i := 0; i < 3; i++ { + if err := overwriteFile(f, fi); err != nil { + break + } } } - }() - - fi, err := f.Stat() - if err != nil { - return trace.ConvertSystemError(err) + return trace.ConvertSystemError(os.Remove(filePath)) + } else { + removeErr := os.Remove(filePath) + if f != nil { + for i := 0; i < 3; i++ { + if err := overwriteFile(f, fi); err != nil { + break + } + } + } + return trace.ConvertSystemError(removeErr) } +} +func overwriteFile(f *os.File, fi os.FileInfo) error { // Rounding up to 4k to hide the original file size. 4k was chosen because it's a common block size. const block = 4096 size := fi.Size() / block * block @@ -280,8 +345,16 @@ func overwriteFile(filePath string) (err error) { size += block } - _, err = io.CopyN(f, rand.Reader, size) - return trace.Wrap(err) + _, copyErr := io.CopyN(f, rand.Reader, size) + + // Attempt sync regardless of above error + syncErr := f.Sync() // sync to ensure commit to hardware + if copyErr != nil { + return trace.Wrap(copyErr) + } else if syncErr != nil { + return trace.Wrap(syncErr) + } + return nil } // RemoveFileIfExist removes file if exits. diff --git a/lib/utils/fs_test.go b/lib/utils/fs_test.go index 12359aa58adf4..7377e669bfb10 100644 --- a/lib/utils/fs_test.go +++ b/lib/utils/fs_test.go @@ -316,16 +316,49 @@ func TestOverwriteFile(t *testing.T) { fName := filepath.Join(t.TempDir(), "teleport-overwrite-file-test") require.NoError(t, os.WriteFile(fName, have, 0600)) - require.NoError(t, overwriteFile(fName)) + f, err := os.OpenFile(fName, os.O_WRONLY, 0) + require.NoError(t, err) + defer f.Close() + fi, err := os.Stat(fName) + require.NoError(t, err) + require.NoError(t, overwriteFile(f, fi)) contents, err := os.ReadFile(fName) require.NoError(t, err) require.NotContains(t, contents, have, "File contents were not overwritten") } +func TestRemoveAllSecure(t *testing.T) { + tempDir := t.TempDir() + tempFile := filepath.Join(tempDir, "teleport-remove-all-secure-test") + f, err := os.Create(tempFile) + symlink := filepath.Join(tempDir, "teleport-remove-secure-symlink") + require.NoError(t, os.Symlink(tempFile, symlink)) + require.NoError(t, err) + require.NoError(t, f.Close()) + + require.NoError(t, RemoveAllSecure("")) + require.NoError(t, RemoveAllSecure(tempDir)) + _, err = os.Stat(tempDir) + require.True(t, os.IsNotExist(err), "Directory should be removed: %v", err) +} + func TestRemoveSecure(t *testing.T) { - f, err := os.Create(filepath.Join(t.TempDir(), "teleport-remove-secure-test")) + tempFile := filepath.Join(t.TempDir(), "teleport-remove-secure-test") + f, err := os.Create(tempFile) require.NoError(t, err) require.NoError(t, f.Close()) + require.NoError(t, RemoveSecure(f.Name())) + _, err = os.Stat(tempFile) + require.True(t, os.IsNotExist(err), "File should be removed: %v", err) +} + +func TestRemoveSecure_symlink(t *testing.T) { + symlink := filepath.Join(t.TempDir(), "teleport-remove-secure-symlink") + require.NoError(t, os.Symlink("/tmp", symlink)) + + require.NoError(t, RemoveSecure(symlink)) + _, err := os.Stat(symlink) + require.True(t, os.IsNotExist(err), "Symlink should be removed: %v", err) }