diff --git a/archiver.go b/archiver.go index ecf315e0..b09699d6 100644 --- a/archiver.go +++ b/archiver.go @@ -72,6 +72,11 @@ type ExtensionChecker interface { CheckExt(name string) error } +// FilenameChecker validates filenames to prevent path traversal attacks +type FilenameChecker interface { + CheckPath(to, filename string) error +} + // Unarchiver is a type that can extract archive files // into a folder. type Unarchiver interface { diff --git a/rar.go b/rar.go index 8fdf223c..4677b7b5 100644 --- a/rar.go +++ b/rar.go @@ -60,6 +60,17 @@ func (*Rar) CheckExt(filename string) error { return nil } +// CheckPath ensures that the filename has not been crafted to perform path traversal attacks +func (*Rar) CheckPath(to, filename string) error { + to, _ = filepath.Abs(to) //explicit the destination folder to prevent that 'string.HasPrefix' check can be 'bypassed' when no destination folder is supplied in input + dest := filepath.Join(to, filename) + //prevent path traversal attacks + if !strings.HasPrefix(dest, to) { + return fmt.Errorf("illegal file path: %s", filename) + } + return nil +} + // Unarchive unpacks the .rar file at source to destination. // Destination will be treated as a folder name. It supports // multi-volume archives. @@ -145,10 +156,18 @@ func (r *Rar) unrarNext(to string) error { if err != nil { return err // don't wrap error; calling loop must break on io.EOF } + defer f.Close() + header, ok := f.Header.(*rardecode.FileHeader) if !ok { return fmt.Errorf("expected header to be *rardecode.FileHeader but was %T", f.Header) } + + errPath := r.CheckPath(to, header.Name) + if errPath != nil { + return fmt.Errorf("checking path traversal attempt: %v", errPath) + } + return r.unrarFile(f, filepath.Join(to, header.Name)) } @@ -404,6 +423,7 @@ var ( _ = Extractor(new(Rar)) _ = Matcher(new(Rar)) _ = ExtensionChecker(new(Rar)) + _ = FilenameChecker(new(Rar)) _ = os.FileInfo(rarFileInfo{}) ) diff --git a/tar.go b/tar.go index dbac6920..6942aedf 100644 --- a/tar.go +++ b/tar.go @@ -61,6 +61,17 @@ func (*Tar) CheckExt(filename string) error { return nil } +// CheckPath ensures that the filename has not been crafted to perform path traversal attacks +func (*Tar) CheckPath(to, filename string) error { + to, _ = filepath.Abs(to) //explicit the destination folder to prevent that 'string.HasPrefix' check can be 'bypassed' when no destination folder is supplied in input + dest := filepath.Join(to, filename) + //prevent path traversal attacks + if !strings.HasPrefix(dest, to) { + return fmt.Errorf("illegal file path: %s", filename) + } + return nil +} + // Archive creates a tarball file at destination containing // the files listed in sources. The destination must end with // ".tar". File paths can be those of regular files or @@ -211,10 +222,17 @@ func (t *Tar) untarNext(destination string) error { if err != nil { return err // don't wrap error; calling loop must break on io.EOF } + defer f.Close() + header, ok := f.Header.(*tar.Header) if !ok { return fmt.Errorf("expected header to be *tar.Header but was %T", f.Header) } + + errPath := t.CheckPath(destination, header.Name) + if errPath != nil { + return fmt.Errorf("checking path traversal attempt: %v", errPath) + } return t.untarFile(f, destination, header) } @@ -614,6 +632,7 @@ var ( _ = Extractor(new(Tar)) _ = Matcher(new(Tar)) _ = ExtensionChecker(new(Tar)) + _ = FilenameChecker(new(Tar)) ) // DefaultTar is a default instance that is conveniently ready to use. diff --git a/zip.go b/zip.go index 0fa08b7d..2b6f03c2 100644 --- a/zip.go +++ b/zip.go @@ -117,6 +117,17 @@ func registerDecompressor(zr *zip.Reader) { }) } +// CheckPath ensures the file extension matches the format. +func (*Zip) CheckPath(to, filename string) error { + to, _ = filepath.Abs(to) //explicit the destination folder to prevent that 'string.HasPrefix' check can be 'bypassed' when no destination folder is supplied in input + dest := filepath.Join(to, filename) + //prevent path traversal attacks + if !strings.HasPrefix(dest, to) { + return fmt.Errorf("illegal file path: %s", filename) + } + return nil +} + // Archive creates a .zip file at destination containing // the files listed in sources. The destination must end // with ".zip". File paths can be those of regular files @@ -231,6 +242,11 @@ func (z *Zip) extractNext(to string) error { return err // don't wrap error; calling loop must break on io.EOF } defer f.Close() + + errPath := z.CheckPath(to, f.Header.(zip.FileHeader).Name) + if errPath != nil { + return fmt.Errorf("checking path traversal attempt: %v", errPath) + } return z.extractFile(f, to) } @@ -629,6 +645,7 @@ var ( _ = Extractor(new(Zip)) _ = Matcher(new(Zip)) _ = ExtensionChecker(new(Zip)) + _ = FilenameChecker(new(Zip)) ) // compressedFormats is a (non-exhaustive) set of lowercased