Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions lib/client/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ import (

"golang.org/x/crypto/ssh"

"github.com/gofrs/flock"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/profile"
Expand Down Expand Up @@ -582,24 +580,11 @@ func (fs *fsLocalNonSessionKeyStore) kubeCertPath(idx KeyIndex, kubename string)
return keypaths.KubeCertPath(fs.KeyDir, idx.ProxyHost, idx.Username, idx.ClusterName, kubename)
}

// acquireFileLock is trying to lock the file, until it's successful or timeout is exceeded.
// File will be created if it doesn't exist.
func acquireFileLock(filePath string, timeout time.Duration) (func() error, error) {
fileLock := flock.New(filePath)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if _, err := fileLock.TryLockContext(ctx, 10*time.Millisecond); err != nil {
return nil, err
}

return fileLock.Unlock, nil
}

// AddKnownHostKeys adds a new entry to `known_hosts` file.
func (fs *fsLocalNonSessionKeyStore) AddKnownHostKeys(hostname, proxyHost string, hostKeys []ssh.PublicKey) (retErr error) {
// We're trying to serialize our writes to the 'known_hosts' file to avoid corruption, since there
// are cases when multiple tsh instances will try to write to it.
unlock, err := acquireFileLock(fs.knownHostsPath(), 5*time.Second)
unlock, err := utils.FSTryWriteLockTimeout(context.Background(), fs.knownHostsPath(), 5*time.Second)
if err != nil {
return trace.WrapWithMessage(err, "could not acquire lock for the `known_hosts` file")
}
Expand Down Expand Up @@ -695,7 +680,7 @@ func matchesWildcard(hostname, pattern string) bool {

// GetKnownHostKeys returns all known public keys from `known_hosts`.
func (fs *fsLocalNonSessionKeyStore) GetKnownHostKeys(hostname string) (keys []ssh.PublicKey, retErr error) {
unlock, err := acquireFileLock(fs.knownHostsPath(), 5*time.Second)
unlock, err := utils.FSTryReadLockTimeout(context.Background(), fs.knownHostsPath(), 5*time.Second)
if err != nil {
return nil, trace.WrapWithMessage(err, "could not acquire lock for the `known_hosts` file")
}
Expand Down
18 changes: 10 additions & 8 deletions lib/events/filesessions/fileasync.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ func NewUploader(cfg UploaderConfig) (*Uploader, error) {
// the upload that have been aborted.
//
// It marks corrupted session files to skip their processing.
//
type Uploader struct {
semaphore chan struct{}

Expand Down Expand Up @@ -241,7 +240,7 @@ func (u *Uploader) Scan(ctx context.Context) (*ScanStats, error) {
}
stats.Scanned++
if err := u.startUpload(ctx, fi.Name()); err != nil {
if trace.IsCompareFailed(err) {
if errors.Is(err, utils.ErrUnsuccessfulLockTry) {
u.log.Debugf("Scan is skipping recording %v that is locked by another process.", fi.Name())
continue
}
Expand Down Expand Up @@ -277,6 +276,7 @@ type upload struct {
sessionID session.ID
reader *events.ProtoReader
file *os.File
fileUnlockFn func() error
checkpointFile *os.File
}

Expand Down Expand Up @@ -322,7 +322,7 @@ func (u *upload) writeStatus(status apievents.StreamStatus) error {
func (u *upload) Close() error {
return trace.NewAggregate(
u.reader.Close(),
utils.FSUnlock(u.file),
u.fileUnlockFn(),
u.file.Close(),
utils.NilCloser(u.checkpointFile).Close(),
)
Expand Down Expand Up @@ -366,17 +366,19 @@ func (u *Uploader) startUpload(ctx context.Context, fileName string) error {
if err != nil {
return trace.ConvertSystemError(err)
}
if err := utils.FSTryWriteLock(sessionFile); err != nil {
unlock, err := utils.FSTryWriteLock(sessionFilePath)
if err != nil {
if e := sessionFile.Close(); e != nil {
u.log.WithError(e).Warningf("Failed to close %v.", fileName)
}
return trace.Wrap(err)
return trace.WrapWithMessage(err, "could not acquire file lock for %q", sessionFilePath)
}

upload := &upload{
sessionID: sessionID,
reader: events.NewProtoReader(sessionFile),
file: sessionFile,
sessionID: sessionID,
reader: events.NewProtoReader(sessionFile),
file: sessionFile,
fileUnlockFn: unlock,
}
upload.checkpointFile, err = os.OpenFile(u.checkpointFilePath(sessionID), os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions lib/events/filesessions/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload
if err != nil {
return trace.ConvertSystemError(err)
}
if err := utils.FSTryWriteLock(f); err != nil {
return trace.Wrap(err)
unlock, err := utils.FSTryWriteLock(uploadPath)
if err != nil {
return trace.WrapWithMessage(err, "could not acquire file lock for %q", uploadPath)
}
defer func() {
if err := utils.FSUnlock(f); err != nil {
if err := unlock(); err != nil {
h.WithError(err).Errorf("Failed to unlock filesystem lock.")
}
if err := f.Close(); err != nil {
Expand Down
65 changes: 65 additions & 0 deletions lib/utils/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@ limitations under the License.
package utils

import (
"context"
"errors"
"os"
"path/filepath"
"time"

"github.com/gofrs/flock"

"github.com/gravitational/teleport"
"github.com/gravitational/trace"
)

// ErrUnsuccessfulLockTry designates an error when we temporarily couldn't acquire lock
// (most probably it was already locked by someone else), another try might succeed.
var ErrUnsuccessfulLockTry = errors.New("could not acquire lock on the file at this time")
Comment thread
AntonAM marked this conversation as resolved.
Outdated

// OpenFileWithFlagsFunc defines a function used to open files providing options.
type OpenFileWithFlagsFunc func(name string, flag int, perm os.FileMode) (*os.File, error)

Expand Down Expand Up @@ -144,3 +153,59 @@ func StatDir(path string) (os.FileInfo, error) {
}
return fi, nil
}

// FSTryWriteLock tries to grab write lock, returns ErrUnsuccessfulLockTry
// if lock is already acquired by someone else
func FSTryWriteLock(filePath string) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
locked, err := fileLock.TryLock()
if err != nil {
return nil, trace.ConvertSystemError(err)
}
if !locked {
return nil, trace.Retry(ErrUnsuccessfulLockTry, "")
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}

// FSTryWriteLockTimeout tries to grab write lock, it's doing it until locks is acquired, or timeout is expired,
// or context is expired.
func FSTryWriteLockTimeout(ctx context.Context, filePath string, timeout time.Duration) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
timedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if _, err := fileLock.TryLockContext(timedCtx, 10*time.Millisecond); err != nil {
return nil, trace.ConvertSystemError(err)
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}

// FSTryReadLock tries to grab write lock, returns ErrUnsuccessfulLockTry
// if lock is already acquired by someone else
func FSTryReadLock(filePath string) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
locked, err := fileLock.TryRLock()
if err != nil {
return nil, trace.ConvertSystemError(err)
}
if !locked {
return nil, trace.Retry(ErrUnsuccessfulLockTry, "")
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}

// FSTryReadLockTimeout tries to grab read lock, it's doing it until locks is acquired, or timeout is expired,
// or context is expired.
func FSTryReadLockTimeout(ctx context.Context, filePath string, timeout time.Duration) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
timedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if _, err := fileLock.TryRLockContext(timedCtx, 10*time.Millisecond); err != nil {
return nil, trace.ConvertSystemError(err)
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}
92 changes: 92 additions & 0 deletions lib/utils/fs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
Copyright 2022 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package utils

import (
"context"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestLocks(t *testing.T) {
t.Parallel()

tmpFile, err := os.CreateTemp("", "teleport-lock-test")
fp := tmpFile.Name()
t.Cleanup(func() {
_ = os.Remove(fp)
})
require.NoError(t, err)

// Can take read lock
unlock, err := FSTryReadLock(fp)
require.NoError(t, err)

require.NoError(t, unlock())

// Can take write lock
unlock, err = FSTryWriteLock(fp)
require.NoError(t, err)

// Can't take read lock while write lock is held.
unlock2, err := FSTryReadLock(fp)
require.ErrorIs(t, err, ErrUnsuccessfulLockTry)
require.Nil(t, unlock2)

// Can't take write lock while another write lock is held.
unlock2, err = FSTryWriteLock(fp)
require.ErrorIs(t, err, ErrUnsuccessfulLockTry)
require.Nil(t, unlock2)

require.NoError(t, unlock())

unlock, err = FSTryReadLock(fp)
require.NoError(t, err)

// Can take second read lock on the same file.
unlock2, err = FSTryReadLock(fp)
require.NoError(t, err)

require.NoError(t, unlock())
require.NoError(t, unlock2())

// Can take read lock with timeout
unlock, err = FSTryReadLockTimeout(context.Background(), fp, time.Second)
require.NoError(t, err)
require.NoError(t, unlock())

// Can take write lock with timeout
unlock, err = FSTryWriteLockTimeout(context.Background(), fp, time.Second)
require.NoError(t, err)

// Fails because timeout is exceeded, since file is already locked.
unlock2, err = FSTryWriteLockTimeout(context.Background(), fp, time.Millisecond)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, unlock2)

// Fails because context is expired while waiting for timeout.
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
defer cancel()
unlock2, err = FSTryWriteLockTimeout(ctx, fp, time.Hour*1000)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, unlock2)

require.NoError(t, unlock())
}
48 changes: 8 additions & 40 deletions lib/utils/fs_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,16 @@ limitations under the License.

package utils

import (
"os"
"syscall"

"github.com/gravitational/trace"
)

// FSWriteLock grabs Flock-style filesystem lock on an open file
// in exclusive mode.
func FSWriteLock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
return trace.ConvertSystemError(err)
}
return nil
// On non-windows we just lock the target file itself.
func getPlatformLockFilePath(path string) string {
return path
}

// FSTryWriteLock tries to grab write lock, returns CompareFailed
// if lock is already acquired
func FSTryWriteLock(f *os.File) error {
err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err != nil {
if err == syscall.EWOULDBLOCK {
return trace.CompareFailed("lock %v is acquired by another process", f.Name())
func unlockWrapper(unlockFn func() error, path string) func() error {
return func() error {
if unlockFn == nil {
return nil
}
return trace.ConvertSystemError(err)
}
return nil
}

// FSReadLock grabs Flock-style filesystem lock on an open file
// in read (shared) mode
func FSReadLock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_SH); err != nil {
return trace.ConvertSystemError(err)
}
return nil
}

// FSUnlock unlcocks Flock-style filesystem lock
func FSUnlock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
return trace.ConvertSystemError(err)
return unlockFn()
}
return nil
}
34 changes: 17 additions & 17 deletions lib/utils/fs_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,26 @@ limitations under the License.

import (
"os"

"github.com/gravitational/trace"
)

// FSWriteLock not supported on Windows.
func FSWriteLock(f *os.File) error {
return trace.BadParameter("file locking not supported on Windows")
}

// FSTryWriteLock not supported on Windows.
func FSTryWriteLock(f *os.File) error {
return trace.BadParameter("file locking not supported on Windows")
}
// On Windows we use auxiliary .lock files to acquire locks, so we can still read/write target files
// themselves. On unlock we delete the .lock file.
const lockPostfix = ".lock"

// FSReadLock not supported on Windows.
func FSReadLock(f *os.File) error {
return trace.BadParameter("file locking not supported on Windows")
func getPlatformLockFilePath(path string) string {
return path + lockPostfix
}

// FSUnlock not supported on Windows.
func FSUnlock(f *os.File) error {
return trace.BadParameter("file locking not supported on Windows")
func unlockWrapper(unlockFn func() error, path string) func() error {
return func() error {
if unlockFn == nil {
return nil
}
err := unlockFn()

// At this point file can be locked again, and we can get an error, so we do our best effort
// to remove .lock file, but can't guarantee it. Last locker should be able to successfully clean it.
_ = os.Remove(path)
return err
}
}