diff --git a/go/fileutil/join.go b/go/fileutil/join.go new file mode 100644 index 00000000000..3b282ad9dca --- /dev/null +++ b/go/fileutil/join.go @@ -0,0 +1,48 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fileutil + +import ( + "errors" + "os" + "path/filepath" + "strings" +) + +var ErrInvalidJoinedPath = errors.New("invalid joined path") + +// SafePathJoin joins file paths using a rootPath and one or many other paths, +// returning a single absolute path. An error is returned if the joined path +// causes a directory traversal to a path outside of the provided rootPath. +func SafePathJoin(rootPath string, joinPaths ...string) (string, error) { + allPaths := make([]string, 0, len(joinPaths)+1) + allPaths = append(allPaths, rootPath) + allPaths = append(allPaths, joinPaths...) + p := filepath.Join(allPaths...) + absPath, err := filepath.Abs(p) + if err != nil { + return p, err + } + absRootPath, err := filepath.Abs(rootPath) + if err != nil { + return absPath, err + } + if absPath != absRootPath && !strings.HasPrefix(absPath, absRootPath+string(os.PathSeparator)) { + return absPath, ErrInvalidJoinedPath + } + return absPath, nil +} diff --git a/go/fileutil/join_test.go b/go/fileutil/join_test.go new file mode 100644 index 00000000000..6d1240fd0d8 --- /dev/null +++ b/go/fileutil/join_test.go @@ -0,0 +1,40 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fileutil + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSafePathJoin(t *testing.T) { + rootDir := t.TempDir() + + t.Run("success", func(t *testing.T) { + path, err := SafePathJoin(rootDir, "good/path") + require.NoError(t, err) + require.True(t, filepath.IsAbs(path)) + require.Equal(t, filepath.Join(rootDir, "good/path"), path) + }) + + t.Run("dir-traversal", func(t *testing.T) { + _, err := SafePathJoin(rootDir, "../../..") + require.ErrorIs(t, err, ErrInvalidJoinedPath) + }) +} diff --git a/go/vt/mysqlctl/filebackupstorage/file.go b/go/vt/mysqlctl/filebackupstorage/file.go index a2e4175d11d..bff054692dc 100644 --- a/go/vt/mysqlctl/filebackupstorage/file.go +++ b/go/vt/mysqlctl/filebackupstorage/file.go @@ -27,12 +27,12 @@ import ( "github.com/spf13/pflag" - "vitess.io/vitess/go/os2" - "vitess.io/vitess/go/vt/mysqlctl/errors" - + "vitess.io/vitess/go/fileutil" "vitess.io/vitess/go/ioutil" + "vitess.io/vitess/go/os2" stats "vitess.io/vitess/go/vt/mysqlctl/backupstats" "vitess.io/vitess/go/vt/mysqlctl/backupstorage" + "vitess.io/vitess/go/vt/mysqlctl/errors" "vitess.io/vitess/go/vt/servenv" ) @@ -126,7 +126,10 @@ func (fbh *FileBackupHandle) ReadFile(ctx context.Context, filename string) (io. if !fbh.readOnly { return nil, fmt.Errorf("ReadFile cannot be called on read-write backup") } - p := path.Join(FileBackupStorageRoot, fbh.dir, fbh.name, filename) + p, err := fileutil.SafePathJoin(FileBackupStorageRoot, fbh.dir, fbh.name, filename) + if err != nil { + return nil, err + } f, err := os.Open(p) if err != nil { return nil, err @@ -146,9 +149,13 @@ func newFileBackupStorage(params backupstorage.Params) *FileBackupStorage { // ListBackups is part of the BackupStorage interface func (fbs *FileBackupStorage) ListBackups(ctx context.Context, dir string) ([]backupstorage.BackupHandle, error) { - // ReadDir already sorts the results - p := path.Join(FileBackupStorageRoot, dir) - fi, err := os.ReadDir(p) + // Check dir is not a directory traversal. + path, err := fileutil.SafePathJoin(FileBackupStorageRoot, dir) + if err != nil { + return nil, fmt.Errorf("failed to parse backup path %q: %w", path, err) + } + + fi, err := os.ReadDir(path) if err != nil { if os.IsNotExist(err) { return nil, nil @@ -172,14 +179,17 @@ func (fbs *FileBackupStorage) ListBackups(ctx context.Context, dir string) ([]ba // StartBackup is part of the BackupStorage interface func (fbs *FileBackupStorage) StartBackup(ctx context.Context, dir, name string) (backupstorage.BackupHandle, error) { // Make sure the directory exists. - p := path.Join(FileBackupStorageRoot, dir) - if err := os2.MkdirAll(p); err != nil { + p, err := fileutil.SafePathJoin(FileBackupStorageRoot, dir) + if err != nil { + return nil, err + } + if err = os2.MkdirAll(p); err != nil { return nil, err } // Create the subdirectory for this named backup. p = path.Join(p, name) - if err := os2.Mkdir(p); err != nil { + if err = os2.Mkdir(p); err != nil { return nil, err } @@ -188,7 +198,10 @@ func (fbs *FileBackupStorage) StartBackup(ctx context.Context, dir, name string) // RemoveBackup is part of the BackupStorage interface func (fbs *FileBackupStorage) RemoveBackup(ctx context.Context, dir, name string) error { - p := path.Join(FileBackupStorageRoot, dir, name) + p, err := fileutil.SafePathJoin(FileBackupStorageRoot, dir, name) + if err != nil { + return err + } return os.RemoveAll(p) }