Skip to content

Commit

Permalink
Multiple fixes for go-getter (#359)
Browse files Browse the repository at this point in the history
* Multiple fixes for go-getter

Co-authored-by: Kent 'picat' Gruber <[email protected]>
  • Loading branch information
eastebry and picatz authored May 18, 2022
1 parent 4553965 commit a2ebce9
Show file tree
Hide file tree
Showing 16 changed files with 1,169 additions and 97 deletions.
24 changes: 23 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package getter

import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
Expand All @@ -13,6 +14,9 @@ import (
safetemp "github.com/hashicorp/go-safetemp"
)

// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled.
var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled")

// Client is a client for downloading things.
//
// Top-level functions such as Get are shortcuts for interacting with a client.
Expand Down Expand Up @@ -76,6 +80,9 @@ type Client struct {
// This is identical to tls.Config.InsecureSkipVerify.
Insecure bool

// Disable symlinks
DisableSymlinks bool

Options []ClientOption
}

Expand Down Expand Up @@ -123,6 +130,17 @@ func (c *Client) Get() error {
dst := c.Dst
src, subDir := SourceDirSubdir(src)
if subDir != "" {
// Check if the subdirectory is attempting to traverse updwards, outside of
// the cloned repository path.
subDir := filepath.Clean(subDir)
if containsDotDot(subDir) {
return fmt.Errorf("subdirectory component contain path traversal out of the repository")
}
// Prevent absolute paths, remove a leading path separator from the subdirectory
if subDir[0] == os.PathSeparator {
subDir = subDir[1:]
}

td, tdcloser, err := safetemp.Dir("", "getter")
if err != nil {
return err
Expand Down Expand Up @@ -230,6 +248,10 @@ func (c *Client) Get() error {
filename = v
}

if containsDotDot(filename) {
return fmt.Errorf("filename query parameter contain path traversal")
}

dst = filepath.Join(dst, filename)
}
}
Expand Down Expand Up @@ -318,7 +340,7 @@ func (c *Client) Get() error {
return err
}

return copyDir(c.Ctx, realDst, subDir, false, c.umask())
return copyDir(c.Ctx, realDst, subDir, false, c.DisableSymlinks, c.umask())
}

return nil
Expand Down
68 changes: 61 additions & 7 deletions client_option.go
Original file line number Diff line number Diff line change
@@ -1,46 +1,100 @@
package getter

import "context"
import (
"context"
"os"
)

// A ClientOption allows to configure a client
// ClientOption is used to configure a client.
type ClientOption func(*Client) error

// Configure configures a client with options.
// Configure applies all of the given client options, along with any default
// behavior including context, decompressors, detectors, and getters used by
// the client.
func (c *Client) Configure(opts ...ClientOption) error {
// If the context has not been configured use the background context.
if c.Ctx == nil {
c.Ctx = context.Background()
}

// Store the options used to configure this client.
c.Options = opts

// Apply all of the client options.
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
// Default decompressor values

// If the client was not configured with any Decompressors, Detectors,
// or Getters, use the default values for each.
if c.Decompressors == nil {
c.Decompressors = Decompressors
}
// Default detector values
if c.Detectors == nil {
c.Detectors = Detectors
}
// Default getter values
if c.Getters == nil {
c.Getters = Getters
}

// Set the client for each getter, so the top-level client can know
// the getter-specific client functions or progress tracking.
for _, getter := range c.Getters {
getter.SetClient(c)
}

return nil
}

// WithContext allows to pass a context to operation
// in order to be able to cancel a download in progress.
func WithContext(ctx context.Context) func(*Client) error {
func WithContext(ctx context.Context) ClientOption {
return func(c *Client) error {
c.Ctx = ctx
return nil
}
}

// WithDecompressors specifies which Decompressor are available.
func WithDecompressors(decompressors map[string]Decompressor) ClientOption {
return func(c *Client) error {
c.Decompressors = decompressors
return nil
}
}

// WithDecompressors specifies which compressors are available.
func WithDetectors(detectors []Detector) ClientOption {
return func(c *Client) error {
c.Detectors = detectors
return nil
}
}

// WithGetters specifies which getters are available.
func WithGetters(getters map[string]Getter) ClientOption {
return func(c *Client) error {
c.Getters = getters
return nil
}
}

// WithMode specifies which client mode the getters should operate in.
func WithMode(mode ClientMode) ClientOption {
return func(c *Client) error {
c.Mode = mode
return nil
}
}

// WithUmask specifies how to mask file permissions when storing local
// files or decompressing an archive.
func WithUmask(mode os.FileMode) ClientOption {
return func(c *Client) error {
c.Umask = mode
return nil
}
}
24 changes: 21 additions & 3 deletions copy_dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package getter

import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
Expand All @@ -16,8 +17,11 @@ func mode(mode, umask os.FileMode) os.FileMode {
// should already exist.
//
// If ignoreDot is set to true, then dot-prefixed files/folders are ignored.
func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error {
src, err := filepath.EvalSymlinks(src)
func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error {
// We can safely evaluate the symlinks here, even if disabled, because they
// will be checked before actual use in walkFn and copyFile
var err error
src, err = filepath.EvalSymlinks(src)
if err != nil {
return err
}
Expand All @@ -26,6 +30,20 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask
if err != nil {
return err
}

if disableSymlinks {
fileInfo, err := os.Lstat(path)
if err != nil {
return fmt.Errorf("failed to check copy file source for symlinks: %w", err)
}
if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink {
return ErrSymlinkCopy
}
// if info.Mode()&os.ModeSymlink == os.ModeSymlink {
// return ErrSymlinkCopy
// }
}

if path == src {
return nil
}
Expand Down Expand Up @@ -59,7 +77,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask
}

// If we have a file, copy the contents.
_, err = copyFile(ctx, dstPath, path, info.Mode(), umask)
_, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask)
return err
}

Expand Down
13 changes: 12 additions & 1 deletion get_file_copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package getter

import (
"context"
"fmt"
"io"
"os"
)
Expand Down Expand Up @@ -49,7 +50,17 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error {
}

// copyFile copies a file in chunks from src path to dst path, using umask to create the dst file
func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) {
func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) {
if disableSymlinks {
fileInfo, err := os.Lstat(src)
if err != nil {
return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err)
}
if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink {
return 0, ErrSymlinkCopy
}
}

srcF, err := os.Open(src)
if err != nil {
return 0, err
Expand Down
8 changes: 7 additions & 1 deletion get_file_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error {
return os.Symlink(path, dst)
}

var disableSymlinks bool

if g.client != nil && g.client.DisableSymlinks {
disableSymlinks = true
}

// Copy
_, err = copyFile(ctx, dst, path, fi.Mode(), g.client.umask())
_, err = copyFile(ctx, dst, path, disableSymlinks, fi.Mode(), g.client.umask())
return err
}
8 changes: 7 additions & 1 deletion get_file_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,14 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error {
}
}

var disableSymlinks bool

if g.client != nil && g.client.DisableSymlinks {
disableSymlinks = true
}

// Copy
_, err = copyFile(ctx, dst, path, 0666, g.client.umask())
_, err = copyFile(ctx, dst, path, disableSymlinks, 0666, g.client.umask())
return err
}

Expand Down
28 changes: 26 additions & 2 deletions get_gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package getter
import (
"context"
"fmt"
"golang.org/x/oauth2"
"google.golang.org/api/option"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"golang.org/x/oauth2"
"google.golang.org/api/option"

"cloud.google.com/go/storage"
"google.golang.org/api/iterator"
Expand All @@ -19,11 +21,21 @@ import (
// a GCS bucket.
type GCSGetter struct {
getter

// Timeout sets a deadline which all GCS operations should
// complete within. Zero value means no timeout.
Timeout time.Duration
}

func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
bucket, object, _, err := g.parseURL(u)
if err != nil {
Expand Down Expand Up @@ -61,6 +73,12 @@ func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) {
func (g *GCSGetter) Get(dst string, u *url.URL) error {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
bucket, object, _, err := g.parseURL(u)
if err != nil {
Expand Down Expand Up @@ -120,6 +138,12 @@ func (g *GCSGetter) Get(dst string, u *url.URL) error {
func (g *GCSGetter) GetFile(dst string, u *url.URL) error {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
bucket, object, fragment, err := g.parseURL(u)
if err != nil {
Expand Down
Loading

0 comments on commit a2ebce9

Please sign in to comment.