Skip to content

Commit

Permalink
Protect existing solutions from being overwritten by 'download' (exer…
Browse files Browse the repository at this point in the history
…cism#979)

Add an optional `--force` flag to the download command to overwrite
an existing exercise directory.
  • Loading branch information
haguro authored Feb 9, 2021
1 parent aa9dcfa commit ce8f497
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 2 deletions.
14 changes: 12 additions & 2 deletions cmd/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func runDownload(cfg config.Config, flags *pflag.FlagSet, args []string) error {
metadata := download.payload.metadata()
dir := metadata.Exercise(usrCfg.GetString("workspace")).MetadataDir()

if _, err = os.Stat(dir); !download.forceoverwrite && err == nil {
return fmt.Errorf("directory '%s' already exists, use --force to overwrite", dir)
}

if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return err
}
Expand Down Expand Up @@ -103,7 +107,6 @@ func runDownload(cfg config.Config, flags *pflag.FlagSet, args []string) error {
continue
}

// TODO: handle collisions
path := sf.relativePath()
dir := filepath.Join(metadata.Dir, filepath.Dir(path))
if err = os.MkdirAll(dir, os.FileMode(0755)); err != nil {
Expand Down Expand Up @@ -133,7 +136,8 @@ type download struct {
token, apibaseurl, workspace string

// optional
track, team string
track, team string
forceoverwrite bool

payload *downloadPayload
}
Expand All @@ -158,6 +162,11 @@ func newDownload(flags *pflag.FlagSet, usrCfg *viper.Viper) (*download, error) {
return nil, err
}

d.forceoverwrite, err = flags.GetBool("force")
if err != nil {
return nil, err
}

d.token = usrCfg.GetString("token")
d.apibaseurl = usrCfg.GetString("apibaseurl")
d.workspace = usrCfg.GetString("workspace")
Expand Down Expand Up @@ -354,6 +363,7 @@ func setupDownloadFlags(flags *pflag.FlagSet) {
flags.StringP("track", "t", "", "the track ID")
flags.StringP("exercise", "e", "", "the exercise slug")
flags.StringP("team", "T", "", "the team slug")
flags.BoolP("force", "F", false, "overwrite existing exercise directory")
}

func init() {
Expand Down
102 changes: 102 additions & 0 deletions cmd/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,108 @@ func TestDownload(t *testing.T) {
}
}

func TestDownloadToExistingDirectory(t *testing.T) {
co := newCapturedOutput()
co.override()
defer co.reset()

testCases := []struct {
exerciseDir string
flags map[string]string
}{
{
exerciseDir: filepath.Join("bogus-track", "bogus-exercise"),
flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track"},
},
{
exerciseDir: filepath.Join("teams", "bogus-team", "bogus-track", "bogus-exercise"),
flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track", "team": "bogus-team"},
},
}

for _, tc := range testCases {
tmpDir, err := ioutil.TempDir("", "download-cmd")
defer os.RemoveAll(tmpDir)
assert.NoError(t, err)

err = os.MkdirAll(filepath.Join(tmpDir, tc.exerciseDir), os.FileMode(0755))
assert.NoError(t, err)

ts := fakeDownloadServer("true", "")
defer ts.Close()

v := viper.New()
v.Set("workspace", tmpDir)
v.Set("apibaseurl", ts.URL)
v.Set("token", "abc123")

cfg := config.Config{
UserViperConfig: v,
}
flags := pflag.NewFlagSet("fake", pflag.PanicOnError)
setupDownloadFlags(flags)
for name, value := range tc.flags {
flags.Set(name, value)
}

err = runDownload(cfg, flags, []string{})

if assert.Error(t, err) {
assert.Regexp(t, "directory '.+' already exists", err.Error())
}
}
}

func TestDownloadToExistingDirectoryWithForce(t *testing.T) {
co := newCapturedOutput()
co.override()
defer co.reset()

testCases := []struct {
exerciseDir string
flags map[string]string
}{
{
exerciseDir: filepath.Join("bogus-track", "bogus-exercise"),
flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track"},
},
{
exerciseDir: filepath.Join("teams", "bogus-team", "bogus-track", "bogus-exercise"),
flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track", "team": "bogus-team"},
},
}

for _, tc := range testCases {
tmpDir, err := ioutil.TempDir("", "download-cmd")
defer os.RemoveAll(tmpDir)
assert.NoError(t, err)

err = os.MkdirAll(filepath.Join(tmpDir, tc.exerciseDir), os.FileMode(0755))
assert.NoError(t, err)

ts := fakeDownloadServer("true", "")
defer ts.Close()

v := viper.New()
v.Set("workspace", tmpDir)
v.Set("apibaseurl", ts.URL)
v.Set("token", "abc123")

cfg := config.Config{
UserViperConfig: v,
}
flags := pflag.NewFlagSet("fake", pflag.PanicOnError)
setupDownloadFlags(flags)
for name, value := range tc.flags {
flags.Set(name, value)
}
flags.Set("force", "true")

err = runDownload(cfg, flags, []string{})
assert.NoError(t, err)
}
}

func fakeDownloadServer(requestor, teamSlug string) *httptest.Server {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
Expand Down

0 comments on commit ce8f497

Please sign in to comment.