Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions go/fileutil/join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
Copyright 2025 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package fileutil

import (
"errors"
"os"
"path/filepath"
"strings"
)

var ErrInvalidJoinedPath = errors.New("invalid joined path")

// SafePathJoin joins file paths using a rootPath and one or many other paths,
// returning a single absolute path. An error is returned if the joined path
// causes a directory traversal to a path outside of the provided rootPath.
func SafePathJoin(rootPath string, joinPaths ...string) (string, error) {
allPaths := make([]string, 0, len(joinPaths)+1)
allPaths = append(allPaths, rootPath)
allPaths = append(allPaths, joinPaths...)
p := filepath.Join(allPaths...)
absPath, err := filepath.Abs(p)
if err != nil {
return p, err
}
absRootPath, err := filepath.Abs(rootPath)
if err != nil {
return absPath, err
}
if absPath != absRootPath && !strings.HasPrefix(absPath, absRootPath+string(os.PathSeparator)) {
return absPath, ErrInvalidJoinedPath
}
return absPath, nil
}
40 changes: 40 additions & 0 deletions go/fileutil/join_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
Copyright 2025 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package fileutil

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestSafePathJoin(t *testing.T) {
rootDir := t.TempDir()

t.Run("success", func(t *testing.T) {
path, err := SafePathJoin(rootDir, "good/path")
require.NoError(t, err)
require.True(t, filepath.IsAbs(path))
require.Equal(t, filepath.Join(rootDir, "good/path"), path)
})

t.Run("dir-traversal", func(t *testing.T) {
_, err := SafePathJoin(rootDir, "../../..")
require.ErrorIs(t, err, ErrInvalidJoinedPath)
})
}
35 changes: 24 additions & 11 deletions go/vt/mysqlctl/filebackupstorage/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ import (

"github.com/spf13/pflag"

"vitess.io/vitess/go/os2"
"vitess.io/vitess/go/vt/mysqlctl/errors"

"vitess.io/vitess/go/fileutil"
"vitess.io/vitess/go/ioutil"
"vitess.io/vitess/go/os2"
stats "vitess.io/vitess/go/vt/mysqlctl/backupstats"
"vitess.io/vitess/go/vt/mysqlctl/backupstorage"
"vitess.io/vitess/go/vt/mysqlctl/errors"
"vitess.io/vitess/go/vt/servenv"
)

Expand Down Expand Up @@ -126,7 +126,10 @@ func (fbh *FileBackupHandle) ReadFile(ctx context.Context, filename string) (io.
if !fbh.readOnly {
return nil, fmt.Errorf("ReadFile cannot be called on read-write backup")
}
p := path.Join(FileBackupStorageRoot, fbh.dir, fbh.name, filename)
p, err := fileutil.SafePathJoin(FileBackupStorageRoot, fbh.dir, fbh.name, filename)
if err != nil {
return nil, err
}
f, err := os.Open(p)
if err != nil {
return nil, err
Expand All @@ -146,9 +149,13 @@ func newFileBackupStorage(params backupstorage.Params) *FileBackupStorage {

// ListBackups is part of the BackupStorage interface
func (fbs *FileBackupStorage) ListBackups(ctx context.Context, dir string) ([]backupstorage.BackupHandle, error) {
// ReadDir already sorts the results
p := path.Join(FileBackupStorageRoot, dir)
fi, err := os.ReadDir(p)
// Check dir is not a directory traversal.
path, err := fileutil.SafePathJoin(FileBackupStorageRoot, dir)
if err != nil {
return nil, fmt.Errorf("failed to parse backup path %q: %w", path, err)
}

fi, err := os.ReadDir(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
Expand All @@ -172,14 +179,17 @@ func (fbs *FileBackupStorage) ListBackups(ctx context.Context, dir string) ([]ba
// StartBackup is part of the BackupStorage interface
func (fbs *FileBackupStorage) StartBackup(ctx context.Context, dir, name string) (backupstorage.BackupHandle, error) {
// Make sure the directory exists.
p := path.Join(FileBackupStorageRoot, dir)
if err := os2.MkdirAll(p); err != nil {
p, err := fileutil.SafePathJoin(FileBackupStorageRoot, dir)
if err != nil {
return nil, err
}
if err = os2.MkdirAll(p); err != nil {
return nil, err
}

// Create the subdirectory for this named backup.
p = path.Join(p, name)
if err := os2.Mkdir(p); err != nil {
if err = os2.Mkdir(p); err != nil {
return nil, err
}

Expand All @@ -188,7 +198,10 @@ func (fbs *FileBackupStorage) StartBackup(ctx context.Context, dir, name string)

// RemoveBackup is part of the BackupStorage interface
func (fbs *FileBackupStorage) RemoveBackup(ctx context.Context, dir, name string) error {
p := path.Join(FileBackupStorageRoot, dir, name)
p, err := fileutil.SafePathJoin(FileBackupStorageRoot, dir, name)
if err != nil {
return err
}
return os.RemoveAll(p)
}

Expand Down
Loading