diff --git a/client.go b/client.go index 5817d7384..05f2f5f65 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package getter import ( "context" + "errors" "fmt" "io/ioutil" "os" @@ -13,6 +14,9 @@ import ( safetemp "github.com/hashicorp/go-safetemp" ) +// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled. +var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled") + // Client is a client for downloading things. // // Top-level functions such as Get are shortcuts for interacting with a client. @@ -76,6 +80,9 @@ type Client struct { // This is identical to tls.Config.InsecureSkipVerify. Insecure bool + // Disable symlinks + DisableSymlinks bool + Options []ClientOption } @@ -123,6 +130,17 @@ func (c *Client) Get() error { dst := c.Dst src, subDir := SourceDirSubdir(src) if subDir != "" { + // Check if the subdirectory is attempting to traverse updwards, outside of + // the cloned repository path. + subDir := filepath.Clean(subDir) + if containsDotDot(subDir) { + return fmt.Errorf("subdirectory component contain path traversal out of the repository") + } + // Prevent absolute paths, remove a leading path separator from the subdirectory + if subDir[0] == os.PathSeparator { + subDir = subDir[1:] + } + td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { return err @@ -230,6 +248,10 @@ func (c *Client) Get() error { filename = v } + if containsDotDot(filename) { + return fmt.Errorf("filename query parameter contain path traversal") + } + dst = filepath.Join(dst, filename) } } @@ -318,7 +340,7 @@ func (c *Client) Get() error { return err } - return copyDir(c.Ctx, realDst, subDir, false, c.umask()) + return copyDir(c.Ctx, realDst, subDir, false, c.DisableSymlinks, c.umask()) } return nil diff --git a/client_option.go b/client_option.go index c1ee413b0..b16413753 100644 --- a/client_option.go +++ b/client_option.go @@ -1,46 +1,100 @@ package getter -import "context" +import ( + "context" + "os" +) -// A ClientOption allows to configure a client +// ClientOption is used to configure a client. type ClientOption func(*Client) error -// Configure configures a client with options. +// Configure applies all of the given client options, along with any default +// behavior including context, decompressors, detectors, and getters used by +// the client. func (c *Client) Configure(opts ...ClientOption) error { + // If the context has not been configured use the background context. if c.Ctx == nil { c.Ctx = context.Background() } + + // Store the options used to configure this client. c.Options = opts + + // Apply all of the client options. for _, opt := range opts { err := opt(c) if err != nil { return err } } - // Default decompressor values + + // If the client was not configured with any Decompressors, Detectors, + // or Getters, use the default values for each. if c.Decompressors == nil { c.Decompressors = Decompressors } - // Default detector values if c.Detectors == nil { c.Detectors = Detectors } - // Default getter values if c.Getters == nil { c.Getters = Getters } + // Set the client for each getter, so the top-level client can know + // the getter-specific client functions or progress tracking. for _, getter := range c.Getters { getter.SetClient(c) } + return nil } // WithContext allows to pass a context to operation // in order to be able to cancel a download in progress. -func WithContext(ctx context.Context) func(*Client) error { +func WithContext(ctx context.Context) ClientOption { return func(c *Client) error { c.Ctx = ctx return nil } } + +// WithDecompressors specifies which Decompressor are available. +func WithDecompressors(decompressors map[string]Decompressor) ClientOption { + return func(c *Client) error { + c.Decompressors = decompressors + return nil + } +} + +// WithDecompressors specifies which compressors are available. +func WithDetectors(detectors []Detector) ClientOption { + return func(c *Client) error { + c.Detectors = detectors + return nil + } +} + +// WithGetters specifies which getters are available. +func WithGetters(getters map[string]Getter) ClientOption { + return func(c *Client) error { + c.Getters = getters + return nil + } +} + +// WithMode specifies which client mode the getters should operate in. +func WithMode(mode ClientMode) ClientOption { + return func(c *Client) error { + c.Mode = mode + return nil + } +} + +// WithUmask specifies how to mask file permissions when storing local +// files or decompressing an archive. +func WithUmask(mode os.FileMode) ClientOption { + return func(c *Client) error { + c.Umask = mode + return nil + } +} diff --git a/copy_dir.go b/copy_dir.go index a629306b7..646c283db 100644 --- a/copy_dir.go +++ b/copy_dir.go @@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "os" "path/filepath" "strings" @@ -16,8 +17,11 @@ func mode(mode, umask os.FileMode) os.FileMode { // should already exist. // // If ignoreDot is set to true, then dot-prefixed files/folders are ignored. -func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error { - src, err := filepath.EvalSymlinks(src) +func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error { + // We can safely evaluate the symlinks here, even if disabled, because they + // will be checked before actual use in walkFn and copyFile + var err error + src, err = filepath.EvalSymlinks(src) if err != nil { return err } @@ -26,6 +30,20 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask if err != nil { return err } + + if disableSymlinks { + fileInfo, err := os.Lstat(path) + if err != nil { + return fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return ErrSymlinkCopy + } + // if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // return ErrSymlinkCopy + // } + } + if path == src { return nil } @@ -59,7 +77,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } // If we have a file, copy the contents. - _, err = copyFile(ctx, dstPath, path, info.Mode(), umask) + _, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask) return err } diff --git a/get_file_copy.go b/get_file_copy.go index 29abbd1aa..6eeda23ca 100644 --- a/get_file_copy.go +++ b/get_file_copy.go @@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "io" "os" ) @@ -49,7 +50,17 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error { } // copyFile copies a file in chunks from src path to dst path, using umask to create the dst file -func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) { +func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) { + if disableSymlinks { + fileInfo, err := os.Lstat(src) + if err != nil { + return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return 0, ErrSymlinkCopy + } + } + srcF, err := os.Open(src) if err != nil { return 0, err diff --git a/get_file_unix.go b/get_file_unix.go index 40ebc5af2..a14a38263 100644 --- a/get_file_unix.go +++ b/get_file_unix.go @@ -87,7 +87,13 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error { return os.Symlink(path, dst) } + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + // Copy - _, err = copyFile(ctx, dst, path, fi.Mode(), g.client.umask()) + _, err = copyFile(ctx, dst, path, disableSymlinks, fi.Mode(), g.client.umask()) return err } diff --git a/get_file_windows.go b/get_file_windows.go index 909d5b006..31146f575 100644 --- a/get_file_windows.go +++ b/get_file_windows.go @@ -111,8 +111,14 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error { } } + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + // Copy - _, err = copyFile(ctx, dst, path, 0666, g.client.umask()) + _, err = copyFile(ctx, dst, path, disableSymlinks, 0666, g.client.umask()) return err } diff --git a/get_gcs.go b/get_gcs.go index 678f9e685..abf2f1d4f 100644 --- a/get_gcs.go +++ b/get_gcs.go @@ -3,13 +3,15 @@ package getter import ( "context" "fmt" - "golang.org/x/oauth2" - "google.golang.org/api/option" "net/url" "os" "path/filepath" "strconv" "strings" + "time" + + "golang.org/x/oauth2" + "google.golang.org/api/option" "cloud.google.com/go/storage" "google.golang.org/api/iterator" @@ -19,11 +21,21 @@ import ( // a GCS bucket. type GCSGetter struct { getter + + // Timeout sets a deadline which all GCS operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, _, err := g.parseURL(u) if err != nil { @@ -61,6 +73,12 @@ func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { func (g *GCSGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, _, err := g.parseURL(u) if err != nil { @@ -120,6 +138,12 @@ func (g *GCSGetter) Get(dst string, u *url.URL) error { func (g *GCSGetter) GetFile(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, fragment, err := g.parseURL(u) if err != nil { diff --git a/get_git.go b/get_git.go index 119fa6a3b..db89edef8 100644 --- a/get_git.go +++ b/get_git.go @@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "time" urlhelper "github.com/hashicorp/go-getter/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -24,6 +25,10 @@ import ( // a git repository. type GitGetter struct { getter + + // Timeout sets a deadline which all git CLI operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } var defaultBranchRegexp = regexp.MustCompile(`\s->\sorigin/(.*)`) @@ -35,6 +40,13 @@ func (g *GitGetter) ClientMode(_ *url.URL) (ClientMode, error) { func (g *GitGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if _, err := exec.LookPath("git"); err != nil { return fmt.Errorf("git must be available and on the PATH") } @@ -76,7 +88,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { var sshKeyFile string if sshKey != "" { // Check that the git version is sufficiently new. - if err := checkGitVersion("2.3"); err != nil { + if err := checkGitVersion(ctx, "2.3"); err != nil { return fmt.Errorf("Error using ssh key: %v", err) } @@ -123,7 +135,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { // Next: check out the proper tag/branch if it is specified, and checkout if ref != "" { - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } } @@ -161,8 +173,8 @@ func (g *GitGetter) GetFile(dst string, u *url.URL) error { return fg.GetFile(dst, u) } -func (g *GitGetter) checkout(dst string, ref string) error { - cmd := exec.Command("git", "checkout", ref) +func (g *GitGetter) checkout(ctx context.Context, dst string, ref string) error { + cmd := exec.CommandContext(ctx, "git", "checkout", ref) cmd.Dir = dst return getRunCommand(cmd) } @@ -182,7 +194,7 @@ func (g *GitGetter) clone(ctx context.Context, dst, sshKeyFile string, u *url.UR originalRef := ref // we handle an unspecified ref differently than explicitly selecting the default branch below if ref == "" { - ref = findRemoteDefaultBranch(u) + ref = findRemoteDefaultBranch(ctx, u) } if depth > 0 { args = append(args, "--depth", strconv.Itoa(depth)) @@ -211,7 +223,7 @@ func (g *GitGetter) clone(ctx context.Context, dst, sshKeyFile string, u *url.UR // If we didn't add --depth and --branch above then we will now be // on the remote repository's default branch, rather than the selected // ref, so we'll need to fix that before we return. - return g.checkout(dst, originalRef) + return g.checkout(ctx, dst, originalRef) } return nil } @@ -226,18 +238,18 @@ func (g *GitGetter) update(ctx context.Context, dst, sshKeyFile, ref string, dep // Not a branch, switch to default branch. This will also catch // non-existent branches, in which case we want to switch to default // and then checkout the proper branch later. - ref = findDefaultBranch(dst) + ref = findDefaultBranch(ctx, dst) } // We have to be on a branch to pull - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } if depth > 0 { - cmd = exec.Command("git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") } else { - cmd = exec.Command("git", "pull", "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--ff-only") } cmd.Dir = dst @@ -260,9 +272,9 @@ func (g *GitGetter) fetchSubmodules(ctx context.Context, dst, sshKeyFile string, // findDefaultBranch checks the repo's origin remote for its default branch // (generally "master"). "master" is returned if an origin default branch // can't be determined. -func findDefaultBranch(dst string) string { +func findDefaultBranch(ctx context.Context, dst string) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") + cmd := exec.CommandContext(ctx, "git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") cmd.Dir = dst cmd.Stdout = &stdoutbuf err := cmd.Run() @@ -275,9 +287,9 @@ func findDefaultBranch(dst string) string { // findRemoteDefaultBranch checks the remote repo's HEAD symref to return the remote repo's // default branch. "master" is returned if no HEAD symref exists. -func findRemoteDefaultBranch(u *url.URL) string { +func findRemoteDefaultBranch(ctx context.Context, u *url.URL) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "ls-remote", "--symref", u.String(), "HEAD") + cmd := exec.CommandContext(ctx, "git", "ls-remote", "--symref", u.String(), "HEAD") cmd.Stdout = &stdoutbuf err := cmd.Run() matches := lsRemoteSymRefRegexp.FindStringSubmatch(stdoutbuf.String()) @@ -326,13 +338,13 @@ func setupGitEnv(cmd *exec.Cmd, sshKeyFile string) { // checkGitVersion is used to check the version of git installed on the system // against a known minimum version. Returns an error if the installed version // is older than the given minimum. -func checkGitVersion(min string) error { +func checkGitVersion(ctx context.Context, min string) error { want, err := version.NewVersion(min) if err != nil { return err } - out, err := exec.Command("git", "version").Output() + out, err := exec.CommandContext(ctx, "git", "version").Output() if err != nil { return err } diff --git a/get_git_test.go b/get_git_test.go index 8cea9077e..df6ad0390 100644 --- a/get_git_test.go +++ b/get_git_test.go @@ -2,7 +2,10 @@ package getter import ( "bytes" + "context" "encoding/base64" + "errors" + "fmt" "io/ioutil" "net/url" "os" @@ -436,12 +439,12 @@ func TestGitGetter_gitVersion(t *testing.T) { os.Setenv("PATH", dir) // Asking for a higher version throws an error - if err := checkGitVersion("2.3"); err == nil { + if err := checkGitVersion(context.Background(), "2.3"); err == nil { t.Fatal("expect git version error") } // Passes when version is satisfied - if err := checkGitVersion("1.9"); err != nil { + if err := checkGitVersion(context.Background(), "1.9"); err != nil { t.Fatal(err) } } @@ -693,6 +696,120 @@ func TestGitGetter_setupGitEnvWithExisting_sshKey(t *testing.T) { } } +func TestGitGetter_subdirectory_symlink(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + g := new(GitGetter) + dst := tempDir(t) + + target, err := ioutil.TempFile("", "link-target") + if err != nil { + t.Fatal(err) + } + defer os.Remove(target.Name()) + + repo := testGitRepo(t, "repo-with-symlink") + innerDir := filepath.Join(repo.dir, "this-directory-contains-a-symlink") + if err := os.Mkdir(innerDir, 0700); err != nil { + t.Fatal(err) + } + path := filepath.Join(innerDir, "this-is-a-symlink") + if err := os.Symlink(target.Name(), path); err != nil { + t.Fatal(err) + } + + repo.git("add", path) + repo.git("commit", "-m", "Adding "+path) + + u, err := url.Parse(fmt.Sprintf("git::%s//this-directory-contains-a-symlink", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + client := &Client{ + Src: u.String(), + Dst: dst, + Pwd: ".", + Mode: ClientModeDir, + DisableSymlinks: true, + Detectors: []Detector{ + new(GitDetector), + }, + Getters: map[string]Getter{ + "git": g, + }, + } + + err = client.Get() + + if runtime.GOOS == "windows" { + // Windows doesn't handle symlinks as one might expect with git. + // + // https://github.com/git-for-windows/git/wiki/Symbolic-Links + filepath.Walk(dst, func(path string, info os.FileInfo, err error) error { + if strings.Contains(path, "this-is-a-symlink") { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // If you see this test fail in the future, you've probably enabled + // symlinks within git on your Windows system. Our CI/CD system does + // not do this, so this is this is the only way we can make this test + // make any sense. + t.Fatalf("windows git should not have cloned a symlink") + } + } + return nil + }) + } else { + // We can rely on POSIX compliant systems running git to do the right thing. + if err == nil { + t.Fatalf("expected client get to fail") + } + if !errors.Is(err, ErrSymlinkCopy) { + t.Fatalf("unexpected error: %v", err) + } + } + +} + +func TestGitGetter_subdirectory(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + g := new(GitGetter) + dst := tempDir(t) + + repo := testGitRepo(t, "empty-repo") + u, err := url.Parse(fmt.Sprintf("git::%s//../../../../../../etc/passwd", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + client := &Client{ + Src: u.String(), + Dst: dst, + Pwd: ".", + + Mode: ClientModeDir, + + Detectors: []Detector{ + new(GitDetector), + }, + Getters: map[string]Getter{ + "git": g, + }, + } + + err = client.Get() + if err == nil { + t.Fatalf("expected client get to fail") + } + if !strings.Contains(err.Error(), "subdirectory component contain path traversal out of the repository") { + t.Fatalf("unexpected error: %v", err) + } +} + // gitRepo is a helper struct which controls a single temp git repo. type gitRepo struct { t *testing.T diff --git a/get_hg.go b/get_hg.go index 290649c91..afa3bde81 100644 --- a/get_hg.go +++ b/get_hg.go @@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "runtime" + "time" urlhelper "github.com/hashicorp/go-getter/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -17,6 +18,10 @@ import ( // a Mercurial repository. type HgGetter struct { getter + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { @@ -25,6 +30,13 @@ func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { func (g *HgGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if _, err := exec.LookPath("hg"); err != nil { return fmt.Errorf("hg must be available and on the PATH") } @@ -53,12 +65,12 @@ func (g *HgGetter) Get(dst string, u *url.URL) error { return err } if err != nil { - if err := g.clone(dst, newURL); err != nil { + if err := g.clone(ctx, dst, newURL); err != nil { return err } } - if err := g.pull(dst, newURL); err != nil { + if err := g.pull(ctx, dst, newURL); err != nil { return err } @@ -101,13 +113,13 @@ func (g *HgGetter) GetFile(dst string, u *url.URL) error { return fg.GetFile(dst, u) } -func (g *HgGetter) clone(dst string, u *url.URL) error { - cmd := exec.Command("hg", "clone", "-U", u.String(), dst) +func (g *HgGetter) clone(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "clone", "-U", "--", u.String(), dst) return getRunCommand(cmd) } -func (g *HgGetter) pull(dst string, u *url.URL) error { - cmd := exec.Command("hg", "pull") +func (g *HgGetter) pull(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "pull") cmd.Dir = dst return getRunCommand(cmd) } @@ -115,7 +127,7 @@ func (g *HgGetter) pull(dst string, u *url.URL) error { func (g *HgGetter) update(ctx context.Context, dst string, u *url.URL, rev string) error { args := []string{"update"} if rev != "" { - args = append(args, rev) + args = append(args, "--", rev) } cmd := exec.CommandContext(ctx, "hg", args...) diff --git a/get_hg_test.go b/get_hg_test.go index ee1657945..7ed446691 100644 --- a/get_hg_test.go +++ b/get_hg_test.go @@ -1,9 +1,11 @@ package getter import ( + "net/url" "os" "os/exec" "path/filepath" + "strings" "testing" ) @@ -97,3 +99,45 @@ func TestHgGetter_GetFile(t *testing.T) { } assertContents(t, dst, "Hello\n") } + +func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + + g := new(HgGetter) + + // If arguments are allowed in the destination, this Get call will fail + dst := "--config=alias.clone=!false" + defer os.RemoveAll(dst) + err := g.Get(dst, testModuleURL("basic-hg")) + if err != nil { + t.Fatalf("Expected no err, got: %s", err) + } + + dst = tempDir(t) + // Test arguments passed into the `rev` parameter + // This clone call will fail regardless, but an exit code of 1 indicates + // that the `false` command executed + // We are expecting an hg parse error + err = g.Get(dst, testModuleURL("basic-hg?rev=--config=alias.update=!false")) + if err != nil { + if !strings.Contains(err.Error(), "hg: parse error") { + t.Fatalf("Expected no err, got: %s", err) + } + } + + dst = tempDir(t) + // Test arguments passed in the repository URL + // This Get call will fail regardless, but it should fail + // because the repository can't be found. + // Other failures indicate that hg interpretted the argument passed in the URL + err = g.Get(dst, &url.URL{Path: "--config=alias.clone=false"}) + if err != nil { + if !strings.Contains(err.Error(), "repository --config=alias.clone=false not found") { + t.Fatalf("Expected no err, got: %s", err) + } + } + +} diff --git a/get_http.go b/get_http.go index f5a39c5cb..36d38c469 100644 --- a/get_http.go +++ b/get_http.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/hashicorp/go-cleanhttp" safetemp "github.com/hashicorp/go-safetemp" @@ -28,7 +29,9 @@ import ( // wish. The response must be a 2xx. // // First, a header is looked for "X-Terraform-Get" which should contain -// a source URL to download. +// a source URL to download. This source must use one of the configured +// protocols and getters for the client, or "http"/"https" if using +// the HttpGetter directly. // // If the header is not present, then a meta tag is searched for named // "terraform-get" and the content should be a source URL. @@ -52,6 +55,36 @@ type HttpGetter struct { // and as such it needs to be initialized before use, via something like // make(http.Header). Header http.Header + + // DoNotCheckHeadFirst configures the client to NOT check if the server + // supports HEAD requests. + DoNotCheckHeadFirst bool + + // HeadFirstTimeout configures the client to enforce a timeout when + // the server supports HEAD requests. + // + // The zero value means no timeout. + HeadFirstTimeout time.Duration + + // ReadTimeout configures the client to enforce a timeout when + // making a request to an HTTP server and reading its response body. + // + // The zero value means no timeout. + ReadTimeout time.Duration + + // MaxBytes limits the number of bytes that will be ready from an HTTP + // response body returned from a server. The zero value means no limit. + MaxBytes int64 + + // XTerraformGetLimit configures how many times the client with follow + // the " X-Terraform-Get" header value. + // + // The zero value means no limit. + XTerraformGetLimit int + + // XTerraformGetDisabled disables the client's usage of the "X-Terraform-Get" + // header value. + XTerraformGetDisabled bool } func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { @@ -61,8 +94,115 @@ func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { return ClientModeFile, nil } +type contextKey int + +const ( + xTerraformGetDisable contextKey = 0 + xTerraformGetLimit contextKey = 1 + xTerraformGetLimitCurrentValue contextKey = 2 + httpClientValue contextKey = 3 + httpMaxBytesValue contextKey = 4 +) + +func xTerraformGetDisabled(ctx context.Context) bool { + value, ok := ctx.Value(xTerraformGetDisable).(bool) + if !ok { + return false + } + return value +} + +func xTerraformGetLimitCurrentValueFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimitCurrentValue).(int) + if !ok { + return 1 + } + return value +} + +func xTerraformGetLimiConfiguredtFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimit).(int) + if !ok { + return 0 + } + return value +} + +func httpClientFromContext(ctx context.Context) *http.Client { + value, ok := ctx.Value(httpClientValue).(*http.Client) + if !ok { + return nil + } + return value +} + +func httpMaxBytesFromContext(ctx context.Context) int64 { + value, ok := ctx.Value(httpMaxBytesValue).(int64) + if !ok { + return 0 // no limit + } + return value +} + +type limitedWrappedReaderCloser struct { + underlying io.Reader + closeFn func() error +} + +func (l *limitedWrappedReaderCloser) Read(p []byte) (n int, err error) { + return l.underlying.Read(p) +} + +func (l *limitedWrappedReaderCloser) Close() (err error) { + return l.closeFn() +} + +func newLimitedWrappedReaderCloser(r io.ReadCloser, limit int64) io.ReadCloser { + return &limitedWrappedReaderCloser{ + underlying: io.LimitReader(r, limit), + closeFn: r.Close, + } +} + func (g *HttpGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + // Optionally disable any X-Terraform-Get redirects. This is reccomended for usage of + // this client outside of Terraform's. This feature is likely not required if the + // source server can provider normal HTTP redirects. + if g.XTerraformGetDisabled { + ctx = context.WithValue(ctx, xTerraformGetDisable, g.XTerraformGetDisabled) + } + + // Optionally enforce a limit on X-Terraform-Get redirects. We check this for every + // invocation of this function, because the value is not passed down to subsequent + // client Get function invocations. + if g.XTerraformGetLimit > 0 { + ctx = context.WithValue(ctx, xTerraformGetLimit, g.XTerraformGetLimit) + } + + // If there was a limit on X-Terraform-Get redirects, check what the current count value. + // + // If the value is greater than the limit, return an error. Otherwise, increment the value, + // and include it in the the context to be passed along in all the subsequent client + // Get function invocations. + if limit := xTerraformGetLimiConfiguredtFromContext(ctx); limit > 0 { + currentValue := xTerraformGetLimitCurrentValueFromContext(ctx) + + if currentValue > limit { + return fmt.Errorf("too many X-Terraform-Get redirects: %d", currentValue) + } + + currentValue++ + + ctx = context.WithValue(ctx, xTerraformGetLimitCurrentValue, currentValue) + } + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + // Copy the URL so we can modify it var newU url.URL = *u u = &newU @@ -74,22 +214,40 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { } } + // If the HTTP client is nil, check if there is one available in the context, + // otherwise create one using cleanhttp's default transport. if g.Client == nil { - g.Client = httpClient - if g.client != nil && g.client.Insecure { - insecureTransport := cleanhttp.DefaultTransport() - insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - g.Client.Transport = insecureTransport + if client := httpClientFromContext(ctx); client != nil { + g.Client = client + } else { + client := httpClient + if g.client != nil && g.client.Insecure { + insecureTransport := cleanhttp.DefaultTransport() + insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + client.Transport = insecureTransport + } + g.Client = client } } + // Pass along the configured HTTP client in the context for usage with the X-Terraform-Get feature. + ctx = context.WithValue(ctx, httpClientValue, g.Client) + // Add terraform-get to the parameter. q := u.Query() q.Add("terraform-get", "1") u.RawQuery = q.Encode() + readCtx := ctx + + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + // Get the URL - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req, err := http.NewRequestWithContext(readCtx, "GET", u.String(), nil) if err != nil { return err } @@ -102,18 +260,28 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { if err != nil { return err } - defer resp.Body.Close() + + body := resp.Body + + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) } - // Extract the source URL + if disabled := xTerraformGetDisabled(ctx); disabled { + return nil + } + + // Extract the source URL, var source string if v := resp.Header.Get("X-Terraform-Get"); v != "" { source = v } else { - source, err = g.parseMeta(resp.Body) + source, err = g.parseMeta(readCtx, body) if err != nil { return err } @@ -127,9 +295,43 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { source, subDir := SourceDirSubdir(source) if subDir == "" { var opts []ClientOption + + // Check if the protocol was switched to one which was not configured. + // + // Otherwise, all default getters are allowed. + if g.client != nil && g.client.Getters != nil { + protocol := strings.Split(source, ":")[0] + _, allowed := g.client.Getters[protocol] + if !allowed { + return fmt.Errorf("no getter available for X-Terraform-Get source protocol: %q", protocol) + } + } + + // Add any getter client options. if g.client != nil { opts = g.client.Options } + + // If the client is nil, we know we're using the HttpGetter directly. In this case, + // we don't know exactly which protocols are configued, but we can make a good guess. + // + // This prevents all default getters from being allowed when only using the + // HttpGetter directly. To enable protocol switching, a client "wrapper" must + // be used. + if g.client == nil { + opts = append(opts, WithGetters(map[string]Getter{ + "http": g, + "https": g, + })) + } + + // Ensure we pass along the context we constructed in this function. + // + // This is especially important to enforce a limit on X-Terraform-Get redirects + // which could be setup, if configured, at the top of this function. + opts = append(opts, WithContext(ctx)) + + // Note: this allows the protocol to be switched to another configured getters. return Get(dst, source, opts...) } @@ -145,6 +347,12 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { // appended. func (g *HttpGetter) GetFile(dst string, src *url.URL) error { ctx := g.Context() + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + if g.Netrc { // Add auth from netrc if we can if err := addAuthFromNetrc(src); err != nil { @@ -171,31 +379,45 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { } } - var currentFileSize int64 + var ( + currentFileSize int64 + req *http.Request + ) - // We first make a HEAD request so we can check - // if the server supports range queries. If the server/URL doesn't - // support HEAD requests, we just fall back to GET. - req, err := http.NewRequestWithContext(ctx, "HEAD", src.String(), nil) - if err != nil { - return err - } - if g.Header != nil { - req.Header = g.Header.Clone() - } - headResp, err := g.Client.Do(req) - if err == nil { - headResp.Body.Close() - if headResp.StatusCode == 200 { - // If the HEAD request succeeded, then attempt to set the range - // query if we can. - if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { - if fi, err := f.Stat(); err == nil { - if _, err = f.Seek(0, io.SeekEnd); err == nil { - currentFileSize = fi.Size() - if currentFileSize >= headResp.ContentLength { - // file already present - return nil + if !g.DoNotCheckHeadFirst { + headCtx := ctx + + if g.HeadFirstTimeout > 0 { + var cancel context.CancelFunc + + headCtx, cancel = context.WithTimeout(ctx, g.HeadFirstTimeout) + defer cancel() + } + + // We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. + req, err = http.NewRequestWithContext(headCtx, "HEAD", src.String(), nil) + if err != nil { + return err + } + if g.Header != nil { + req.Header = g.Header.Clone() + } + headResp, err := g.Client.Do(req) + if err == nil { + headResp.Body.Close() + if headResp.StatusCode == 200 { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { + if fi, err := f.Stat(); err == nil { + if _, err = f.Seek(0, io.SeekEnd); err == nil { + currentFileSize = fi.Size() + if currentFileSize >= headResp.ContentLength { + // file already present + return nil + } } } } @@ -203,7 +425,15 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { } } - req, err = http.NewRequestWithContext(ctx, "GET", src.String(), nil) + readCtx := ctx + + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + + req, err = http.NewRequestWithContext(readCtx, "GET", src.String(), nil) if err != nil { return err } @@ -228,6 +458,10 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { body := resp.Body + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if g.client != nil && g.client.ProgressListener != nil { // track download fn := filepath.Base(src.EscapedPath()) @@ -235,7 +469,7 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { } defer body.Close() - n, err := Copy(ctx, f, body) + n, err := Copy(readCtx, f, body) if err == nil && n < resp.ContentLength { err = io.ErrShortWrite } @@ -284,18 +518,28 @@ func (g *HttpGetter) getSubdir(ctx context.Context, dst, source, subDir string) return err } - return copyDir(ctx, dst, sourcePath, false, g.client.umask()) + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + + return copyDir(ctx, dst, sourcePath, false, disableSymlinks, g.client.umask()) } // parseMeta looks for the first meta tag in the given reader that // will give us the source URL. -func (g *HttpGetter) parseMeta(r io.Reader) (string, error) { +func (g *HttpGetter) parseMeta(ctx context.Context, r io.Reader) (string, error) { d := xml.NewDecoder(r) d.CharsetReader = charsetReader d.Strict = false var err error var t xml.Token for { + if ctx.Err() != nil { + return "", fmt.Errorf("context error while parsing meta tag: %w", ctx.Err()) + } + t, err = d.Token() if err != nil { if err == io.EOF { diff --git a/get_http_test.go b/get_http_test.go index e7f59539f..fa51648e3 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -9,12 +9,15 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "net/url" "os" "path/filepath" "strconv" "strings" "testing" + + "github.com/hashicorp/go-cleanhttp" ) func TestHttpGetter_impl(t *testing.T) { @@ -34,8 +37,27 @@ func TestHttpGetter_header(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -44,6 +66,7 @@ func TestHttpGetter_header(t *testing.T) { if _, err := os.Stat(mainPath); err != nil { t.Fatalf("err: %s", err) } + } func TestHttpGetter_requestHeader(t *testing.T) { @@ -87,8 +110,27 @@ func TestHttpGetter_meta(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/meta" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -330,14 +372,27 @@ func TestHttpGetter_auth(t *testing.T) { u.Path = "/meta-auth" u.User = url.UserPassword("foo", "bar") - // Get it! - if err := g.Get(dst, &u); err != nil { - t.Fatalf("err: %s", err) + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) } - // Verify the main file exists - mainPath := filepath.Join(dst, "main.tf") - if _, err := os.Stat(mainPath); err != nil { + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } } @@ -360,8 +415,27 @@ func TestHttpGetter_authNetrc(t *testing.T) { defer closer() defer tempEnv(t, "NETRC", path)() - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -399,8 +473,27 @@ func TestHttpGetter_cleanhttp(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } } @@ -441,6 +534,326 @@ func TestHttpGetter__RespectsContextCanceled(t *testing.T) { } } +func TestHttpGetter__XTerraformGetLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := tempDir(t) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.client = &Client{ + Ctx: ctx, + } + g.Client = &http.Client{} + + err := g.Get(dst, &u) + if !strings.Contains(err.Error(), "too many X-Terraform-Get redirects") { + t.Fatalf("too many X-Terraform-Get redirects, got: %v", err) + } +} + +func TestHttpGetter__XTerraformGetDisabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := tempDir(t) + + g := new(HttpGetter) + g.XTerraformGetDisabled = true + g.client = &Client{ + Ctx: ctx, + } + g.Client = &http.Client{} + + err := g.Get(dst, &u) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetProxyBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetProxyBypass(t) + + proxyLn := testHttpServerProxy(t, ln.Addr().String()) + + t.Logf("starting malicious server on: %v", ln.Addr().String()) + t.Logf("starting proxy on: %v", proxyLn.Addr().String()) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := tempDir(t) + + proxy, err := url.Parse(fmt.Sprintf("http://%s/", proxyLn.Addr().String())) + if err != nil { + t.Fatalf("failed to parse proxy URL: %v", err) + } + + transport := cleanhttp.DefaultTransport() + transport.Proxy = http.ProxyURL(proxy) + + httpGetter := new(HttpGetter) + httpGetter.XTerraformGetLimit = 10 + httpGetter.Client = &http.Client{ + Transport: transport, + } + + client := &Client{ + Ctx: ctx, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + client.Src = u.String() + client.Dst = dst + + err = client.Get() + if err != nil { + t.Logf("client get error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := tempDir(t) + + rt := hookableHTTPRoundTripper{ + before: func(req *http.Request) { + t.Logf("making request") + }, + RoundTripper: http.DefaultTransport, + } + + httpGetter := new(HttpGetter) + httpGetter.XTerraformGetLimit = 10 + httpGetter.Client = &http.Client{ + Transport: &rt, + } + + client := &Client{ + Ctx: ctx, + Mode: ClientModeDir, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + t.Logf("%v", u.String()) + + client.Src = u.String() + client.Dst = dst + + err := client.Get() + if err != nil { + if !strings.Contains(err.Error(), "no getter available for X-Terraform-Get source protocol") { + t.Fatalf("expected no getter available for X-Terraform-Get source protocol, got: %v", err) + } + } +} + +func TestHttpGetter__endless_body(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithEndlessBody(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/" + dst := tempDir(t) + + httpGetter := new(HttpGetter) + httpGetter.MaxBytes = 10 + httpGetter.DoNotCheckHeadFirst = true + + client := &Client{ + Ctx: ctx, + Mode: ClientModeFile, + // Mode: ClientModeDir, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + t.Logf("%v", u.String()) + + client.Src = u.String() + client.Dst = dst + + err := client.Get() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter_subdirLink(t *testing.T) { + ln := testHttpServerSubDir(t) + defer ln.Close() + + httpGetter := new(HttpGetter) + dst, err := ioutil.TempDir("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + + t.Logf("dst: %q", dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/regular-subdir//meta-subdir" + + t.Logf("url: %q", u.String()) + + client := &Client{ + Src: u.String(), + Dst: dst, + Mode: ClientModeAny, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + err = client.Get() + if err != nil { + t.Fatalf("get err: %v", err) + } +} + +func testHttpServerWithXTerraformGetLoop(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v:%v", ln.Addr().String(), "/loop") + + mux := http.NewServeMux() + mux.HandleFunc("/loop", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving loop") + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetProxyBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v/bypass", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/bypass", func(w http.ResponseWriter, r *http.Request) { + t.Fail() + t.Logf("bypassed proxy") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetConfiguredGettersBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("git::http://%v/some/repository.git", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving git HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerProxy(t *testing.T, upstreamHost string) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving proxy: %v: %#+v", r.URL.Path, r.Header) + // create the reverse proxy + proxy := httputil.NewSingleHostReverseProxy(r.URL) + // Note that ServeHttp is non blocking & uses a go routine under the hood + proxy.ServeHTTP(w, r) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpServer(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -504,6 +917,29 @@ func testHttpHandlerMetaAuth(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf(testHttpMetaStr, testModuleURL("basic").String()))) } +func testHttpServerWithEndlessBody(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + for { + w.Write([]byte(".\n")) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpHandlerMetaSubdir(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf(testHttpMetaStr, testModuleURL("basic//subdir").String()))) } @@ -548,6 +984,29 @@ func testHttpHandlerNoRange(w http.ResponseWriter, r *http.Request) { } } +func testHttpServerSubDir(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + t.Logf("serving: %v: %v: %#+[1]v", r.Method, r.URL.String(), r.Header) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + const testHttpMetaStr = ` diff --git a/get_s3.go b/get_s3.go index ec864428e..7e0d853ba 100644 --- a/get_s3.go +++ b/get_s3.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -20,10 +21,22 @@ import ( // a S3 bucket. type S3Getter struct { getter + + // Timeout sets a deadline which all S3 operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { // Parse URL + ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { return 0, err @@ -40,7 +53,7 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { Bucket: aws.String(bucket), Prefix: aws.String(path), } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return 0, err } @@ -65,6 +78,12 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { func (g *S3Getter) Get(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -106,7 +125,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { req.Marker = aws.String(lastMarker) } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return err } @@ -141,6 +160,13 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { func (g *S3Getter) GetFile(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, version, creds, err := g.parseUrl(u) if err != nil { return err @@ -163,7 +189,7 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke req.VersionId = aws.String(version) } - resp, err := client.GetObject(req) + resp, err := client.GetObjectWithContext(ctx, req) if err != nil { return err } diff --git a/get_test.go b/get_test.go index a575b826e..34f9e87b9 100644 --- a/get_test.go +++ b/get_test.go @@ -492,6 +492,21 @@ func TestGetFile_filename(t *testing.T) { } } +func TestGetFile_filename_path_traversal(t *testing.T) { + dst := tempDir(t) + u := testModule("basic-file/foo.txt") + + u += "?filename=../../../../../../../../../../../../../tmp/bar.txt" + + err := GetAny(dst, u) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "filename query parameter contain path traversal") { + t.Fatalf("unexpected err: %s", err) + } +} + func TestGetFile_checksumSkip(t *testing.T) { dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) diff --git a/source.go b/source.go index dab6d400c..48ac9234e 100644 --- a/source.go +++ b/source.go @@ -58,7 +58,9 @@ func SourceDirSubdir(src string) (string, string) { // // The returned path is the full absolute path. func SubdirGlob(dst, subDir string) (string, error) { - matches, err := filepath.Glob(filepath.Join(dst, subDir)) + pattern := filepath.Join(dst, subDir) + + matches, err := filepath.Glob(pattern) if err != nil { return "", err }