Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple fixes for go-getter #359

Merged
merged 18 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package getter

import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
Expand All @@ -13,6 +14,9 @@ import (
safetemp "github.com/hashicorp/go-safetemp"
)

// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled.
var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled")

// Client is a client for downloading things.
//
// Top-level functions such as Get are shortcuts for interacting with a client.
Expand Down Expand Up @@ -76,6 +80,9 @@ type Client struct {
// This is identical to tls.Config.InsecureSkipVerify.
Insecure bool

// Disable symlinks
DisableSymlinks bool

Options []ClientOption
}

Expand Down Expand Up @@ -123,6 +130,17 @@ func (c *Client) Get() error {
dst := c.Dst
src, subDir := SourceDirSubdir(src)
if subDir != "" {
// Check if the subdirectory is attempting to traverse updwards, outside of
// the cloned repository path.
subDir := filepath.Clean(subDir)
if containsDotDot(subDir) {
return fmt.Errorf("subdirectory component contain path traversal out of the repository")
}
// Prevent absolute paths, remove a leading path separator from the subdirectory
if subDir[0] == os.PathSeparator {
subDir = subDir[1:]
}

td, tdcloser, err := safetemp.Dir("", "getter")
if err != nil {
return err
Expand Down Expand Up @@ -230,6 +248,10 @@ func (c *Client) Get() error {
filename = v
}

if containsDotDot(filename) {
return fmt.Errorf("filename query parameter contain path traversal")
}

dst = filepath.Join(dst, filename)
}
}
Expand Down Expand Up @@ -318,7 +340,7 @@ func (c *Client) Get() error {
return err
}

return copyDir(c.Ctx, realDst, subDir, false, c.umask())
return copyDir(c.Ctx, realDst, subDir, false, c.DisableSymlinks, c.umask())
}

return nil
Expand Down
68 changes: 61 additions & 7 deletions client_option.go
Original file line number Diff line number Diff line change
@@ -1,46 +1,100 @@
package getter

import "context"
import (
"context"
"os"
)

// A ClientOption allows to configure a client
// ClientOption is used to configure a client.
type ClientOption func(*Client) error

// Configure configures a client with options.
// Configure applies all of the given client options, along with any default
// behavior including context, decompressors, detectors, and getters used by
// the client.
func (c *Client) Configure(opts ...ClientOption) error {
// If the context has not been configured use the background context.
if c.Ctx == nil {
c.Ctx = context.Background()
}

// Store the options used to configure this client.
c.Options = opts

// Apply all of the client options.
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
// Default decompressor values

// If the client was not configured with any Decompressors, Detectors,
// or Getters, use the default values for each.
if c.Decompressors == nil {
c.Decompressors = Decompressors
}
// Default detector values
if c.Detectors == nil {
c.Detectors = Detectors
}
// Default getter values
if c.Getters == nil {
c.Getters = Getters
}

// Set the client for each getter, so the top-level client can know
// the getter-specific client functions or progress tracking.
for _, getter := range c.Getters {
getter.SetClient(c)
}

return nil
}

// WithContext allows to pass a context to operation
// in order to be able to cancel a download in progress.
func WithContext(ctx context.Context) func(*Client) error {
func WithContext(ctx context.Context) ClientOption {
return func(c *Client) error {
c.Ctx = ctx
return nil
}
}

// WithDecompressors specifies which Decompressor are available.
func WithDecompressors(decompressors map[string]Decompressor) ClientOption {
return func(c *Client) error {
c.Decompressors = decompressors
return nil
}
}

// WithDecompressors specifies which compressors are available.
func WithDetectors(detectors []Detector) ClientOption {
return func(c *Client) error {
c.Detectors = detectors
return nil
}
}

// WithGetters specifies which getters are available.
func WithGetters(getters map[string]Getter) ClientOption {
return func(c *Client) error {
c.Getters = getters
return nil
}
}

// WithMode specifies which client mode the getters should operate in.
func WithMode(mode ClientMode) ClientOption {
return func(c *Client) error {
c.Mode = mode
return nil
}
}

// WithUmask specifies how to mask file permissions when storing local
// files or decompressing an archive.
func WithUmask(mode os.FileMode) ClientOption {
return func(c *Client) error {
c.Umask = mode
return nil
}
}
24 changes: 21 additions & 3 deletions copy_dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package getter

import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
Expand All @@ -16,8 +17,11 @@ func mode(mode, umask os.FileMode) os.FileMode {
// should already exist.
//
// If ignoreDot is set to true, then dot-prefixed files/folders are ignored.
func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error {
src, err := filepath.EvalSymlinks(src)
func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error {
// We can safely evaluate the symlinks here, even if disabled, because they
// will be checked before actual use in walkFn and copyFile
var err error
src, err = filepath.EvalSymlinks(src)
if err != nil {
return err
}
Expand All @@ -26,6 +30,20 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask
if err != nil {
return err
}

if disableSymlinks {
fileInfo, err := os.Lstat(path)
if err != nil {
return fmt.Errorf("failed to check copy file source for symlinks: %w", err)
}
if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink {
return ErrSymlinkCopy
}
// if info.Mode()&os.ModeSymlink == os.ModeSymlink {
// return ErrSymlinkCopy
// }
}

if path == src {
return nil
}
Expand Down Expand Up @@ -59,7 +77,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask
}

// If we have a file, copy the contents.
_, err = copyFile(ctx, dstPath, path, info.Mode(), umask)
_, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask)
return err
}

Expand Down
13 changes: 12 additions & 1 deletion get_file_copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package getter

import (
"context"
"fmt"
"io"
"os"
)
Expand Down Expand Up @@ -49,7 +50,17 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error {
}

// copyFile copies a file in chunks from src path to dst path, using umask to create the dst file
func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) {
func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) {
if disableSymlinks {
fileInfo, err := os.Lstat(src)
if err != nil {
return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err)
}
if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink {
return 0, ErrSymlinkCopy
}
}

srcF, err := os.Open(src)
if err != nil {
return 0, err
Expand Down
8 changes: 7 additions & 1 deletion get_file_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error {
return os.Symlink(path, dst)
}

var disableSymlinks bool

if g.client != nil && g.client.DisableSymlinks {
disableSymlinks = true
}

// Copy
_, err = copyFile(ctx, dst, path, fi.Mode(), g.client.umask())
_, err = copyFile(ctx, dst, path, disableSymlinks, fi.Mode(), g.client.umask())
return err
}
8 changes: 7 additions & 1 deletion get_file_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,14 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error {
}
}

var disableSymlinks bool

if g.client != nil && g.client.DisableSymlinks {
disableSymlinks = true
}

// Copy
_, err = copyFile(ctx, dst, path, 0666, g.client.umask())
_, err = copyFile(ctx, dst, path, disableSymlinks, 0666, g.client.umask())
return err
}

Expand Down
28 changes: 26 additions & 2 deletions get_gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package getter
import (
"context"
"fmt"
"golang.org/x/oauth2"
"google.golang.org/api/option"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"golang.org/x/oauth2"
"google.golang.org/api/option"

"cloud.google.com/go/storage"
"google.golang.org/api/iterator"
Expand All @@ -19,11 +21,21 @@ import (
// a GCS bucket.
type GCSGetter struct {
getter

// Timeout sets a deadline which all GCS operations should
// complete within. Zero value means no timeout.
Timeout time.Duration
}

func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
bucket, object, _, err := g.parseURL(u)
if err != nil {
Expand Down Expand Up @@ -61,6 +73,12 @@ func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) {
func (g *GCSGetter) Get(dst string, u *url.URL) error {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
bucket, object, _, err := g.parseURL(u)
if err != nil {
Expand Down Expand Up @@ -120,6 +138,12 @@ func (g *GCSGetter) Get(dst string, u *url.URL) error {
func (g *GCSGetter) GetFile(dst string, u *url.URL) error {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
bucket, object, fragment, err := g.parseURL(u)
if err != nil {
Expand Down
Loading