diff --git a/go.mod b/go.mod index b3dd8ae863d..5831367b582 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/dustin/go-humanize v1.0.1 github.com/fatih/color v1.16.0 github.com/fatih/structs v1.1.0 + github.com/fsnotify/fsnotify v1.7.0 github.com/golang/mock v1.6.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.4.0 diff --git a/go.sum b/go.sum index 65d59e9421f..65b92dbf91f 100644 --- a/go.sum +++ b/go.sum @@ -81,6 +81,8 @@ github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= diff --git a/internal/pkg/cli/file/filetest/watchertest.go b/internal/pkg/cli/file/filetest/watchertest.go new file mode 100644 index 00000000000..28775d1830a --- /dev/null +++ b/internal/pkg/cli/file/filetest/watchertest.go @@ -0,0 +1,38 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package filetest + +import "github.com/fsnotify/fsnotify" + +// Double is a test double for file.RecursiveWatcher +type Double struct { + EventsFn func() <-chan fsnotify.Event + ErrorsFn func() <-chan error +} + +// Add is a no-op for Double. +func (d *Double) Add(string) error { + return nil +} + +// Close is a no-op for Double. +func (d *Double) Close() error { + return nil +} + +// Events calls the stubbed function. +func (d *Double) Events() <-chan fsnotify.Event { + if d.EventsFn == nil { + return nil + } + return d.EventsFn() +} + +// Errors calls the stubbed function. +func (d *Double) Errors() <-chan error { + if d.ErrorsFn == nil { + return nil + } + return d.ErrorsFn() +} diff --git a/internal/pkg/cli/file/hidden.go b/internal/pkg/cli/file/hidden.go new file mode 100644 index 00000000000..181a4ae012a --- /dev/null +++ b/internal/pkg/cli/file/hidden.go @@ -0,0 +1,13 @@ +//go:build !windows + +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package file + +import "path/filepath" + +// IsHiddenFile returns true if the file is hidden on non-windows. The filename must be non-empty. +func IsHiddenFile(filename string) (bool, error) { + return filepath.Base(filename)[0] == '.', nil +} diff --git a/internal/pkg/cli/file/hidden_windows.go b/internal/pkg/cli/file/hidden_windows.go new file mode 100644 index 00000000000..b8bec85eee4 --- /dev/null +++ b/internal/pkg/cli/file/hidden_windows.go @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package file + +import ( + "syscall" +) + +// IsHiddenFile returns true if the file is hidden on windows. +func IsHiddenFile(filename string) (bool, error) { + pointer, err := syscall.UTF16PtrFromString(filename) + if err != nil { + return false, err + } + attributes, err := syscall.GetFileAttributes(pointer) + if err != nil { + return false, err + } + return attributes&syscall.FILE_ATTRIBUTE_HIDDEN != 0, nil +} diff --git a/internal/pkg/cli/file/watch.go b/internal/pkg/cli/file/watch.go new file mode 100644 index 00000000000..e4cb9782fb5 --- /dev/null +++ b/internal/pkg/cli/file/watch.go @@ -0,0 +1,100 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package file + +import ( + "io/fs" + "path/filepath" + + "github.com/fsnotify/fsnotify" +) + +// RecursiveWatcher wraps an fsnotify Watcher to recursively watch all files in a directory. +type RecursiveWatcher struct { + fsnotifyWatcher *fsnotify.Watcher + done chan struct{} + closed bool + events chan fsnotify.Event + errors chan error +} + +// NewRecursiveWatcher returns a RecursiveWatcher which notifies when changes are made to files inside a recursive directory tree. +func NewRecursiveWatcher(buffer uint) (*RecursiveWatcher, error) { + watcher, err := fsnotify.NewBufferedWatcher(buffer) + if err != nil { + return nil, err + } + + rw := &RecursiveWatcher{ + events: make(chan fsnotify.Event, buffer), + errors: make(chan error), + fsnotifyWatcher: watcher, + done: make(chan struct{}), + closed: false, + } + + go rw.start() + + return rw, nil +} + +// Add recursively adds a directory tree to the list of watched files. +func (rw *RecursiveWatcher) Add(path string) error { + if rw.closed { + return fsnotify.ErrClosed + } + return filepath.WalkDir(path, func(p string, d fs.DirEntry, err error) error { + if err != nil { + // swallow error from WalkDir, don't attempt to add to watcher. + return nil + } + if d.IsDir() { + return rw.fsnotifyWatcher.Add(p) + } + return nil + }) +} + +// Events returns the events channel. +func (rw *RecursiveWatcher) Events() <-chan fsnotify.Event { + return rw.events +} + +// Errors returns the errors channel. +func (rw *RecursiveWatcher) Errors() <-chan error { + return rw.errors +} + +// Close closes the RecursiveWatcher. +func (rw *RecursiveWatcher) Close() error { + if rw.closed { + return nil + } + rw.closed = true + close(rw.done) + return rw.fsnotifyWatcher.Close() +} + +func (rw *RecursiveWatcher) start() { + for { + select { + case <-rw.done: + close(rw.events) + close(rw.errors) + return + case event := <-rw.fsnotifyWatcher.Events: + // handle recursive watch + switch event.Op { + case fsnotify.Create: + if err := rw.Add(event.Name); err != nil { + rw.errors <- err + } + } + + rw.events <- event + case err := <-rw.fsnotifyWatcher.Errors: + rw.errors <- err + } + } +} diff --git a/internal/pkg/cli/file/watch_integration_test.go b/internal/pkg/cli/file/watch_integration_test.go new file mode 100644 index 00000000000..4d55e468d5a --- /dev/null +++ b/internal/pkg/cli/file/watch_integration_test.go @@ -0,0 +1,141 @@ +//go:build integration || localintegration + +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package file_test + +import ( + "fmt" + "io/fs" + "os" + "testing" + "time" + + "github.com/aws/copilot-cli/internal/pkg/cli/file" + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/require" +) + +func TestRecursiveWatcher(t *testing.T) { + var ( + watcher *file.RecursiveWatcher + tmp string + eventsExpected []fsnotify.Event + eventsActual []fsnotify.Event + ) + + tmp = os.TempDir() + eventsActual = make([]fsnotify.Event, 0) + eventsExpected = []fsnotify.Event{ + { + Name: fmt.Sprintf("%s/watch/subdir/testfile", tmp), + Op: fsnotify.Create, + }, + { + Name: fmt.Sprintf("%s/watch/subdir/testfile", tmp), + Op: fsnotify.Chmod, + }, + { + Name: fmt.Sprintf("%s/watch/subdir/testfile", tmp), + Op: fsnotify.Write, + }, + { + Name: fmt.Sprintf("%s/watch/subdir/testfile", tmp), + Op: fsnotify.Write, + }, + { + Name: fmt.Sprintf("%s/watch/subdir", tmp), + Op: fsnotify.Rename, + }, + { + Name: fmt.Sprintf("%s/watch/subdir2", tmp), + Op: fsnotify.Create, + }, + { + Name: fmt.Sprintf("%s/watch/subdir", tmp), + Op: fsnotify.Rename, + }, + { + Name: fmt.Sprintf("%s/watch/subdir2/testfile", tmp), + Op: fsnotify.Rename, + }, + { + Name: fmt.Sprintf("%s/watch/subdir2/testfile2", tmp), + Op: fsnotify.Create, + }, + { + Name: fmt.Sprintf("%s/watch/subdir2/testfile2", tmp), + Op: fsnotify.Remove, + }, + } + + t.Run("Setup Watcher", func(t *testing.T) { + err := os.MkdirAll(fmt.Sprintf("%s/watch/subdir", tmp), 0755) + require.NoError(t, err) + + watcher, err = file.NewRecursiveWatcher(uint(len(eventsExpected))) + require.NoError(t, err) + }) + + t.Run("Watch", func(t *testing.T) { + // SETUP + err := watcher.Add(fmt.Sprintf("%s/watch", tmp)) + require.NoError(t, err) + + eventsCh := watcher.Events() + errorsCh := watcher.Errors() + + expectEvents := func(t *testing.T, n int) []fsnotify.Event { + receivedEvents := []fsnotify.Event{} + for i := 0; i < n; i++ { + select { + case e := <-eventsCh: + receivedEvents = append(receivedEvents, e) + case <-time.After(time.Second): + } + } + return receivedEvents + } + + // WATCH + file, err := os.Create(fmt.Sprintf("%s/watch/subdir/testfile", tmp)) + require.NoError(t, err) + eventsActual = append(eventsActual, expectEvents(t, 1)...) + + err = os.Chmod(fmt.Sprintf("%s/watch/subdir/testfile", tmp), 0755) + require.NoError(t, err) + eventsActual = append(eventsActual, expectEvents(t, 1)...) + + err = os.WriteFile(fmt.Sprintf("%s/watch/subdir/testfile", tmp), []byte("write to file"), fs.ModeAppend) + require.NoError(t, err) + eventsActual = append(eventsActual, expectEvents(t, 2)...) + + err = file.Close() + require.NoError(t, err) + + err = os.Rename(fmt.Sprintf("%s/watch/subdir", tmp), fmt.Sprintf("%s/watch/subdir2", tmp)) + require.NoError(t, err) + eventsActual = append(eventsActual, expectEvents(t, 3)...) + + err = os.Rename(fmt.Sprintf("%s/watch/subdir2/testfile", tmp), fmt.Sprintf("%s/watch/subdir2/testfile2", tmp)) + require.NoError(t, err) + eventsActual = append(eventsActual, expectEvents(t, 2)...) + + err = os.Remove(fmt.Sprintf("%s/watch/subdir2/testfile2", tmp)) + require.NoError(t, err) + eventsActual = append(eventsActual, expectEvents(t, 1)...) + + // CLOSE + err = watcher.Close() + require.NoError(t, err) + require.Empty(t, errorsCh) + + require.Equal(t, eventsExpected, eventsActual) + }) + + t.Run("Clean", func(t *testing.T) { + err := os.RemoveAll(fmt.Sprintf("%s/watch", tmp)) + require.NoError(t, err) + }) +} diff --git a/internal/pkg/cli/flag.go b/internal/pkg/cli/flag.go index 5161be9e27f..126c91d9ab6 100644 --- a/internal/pkg/cli/flag.go +++ b/internal/pkg/cli/flag.go @@ -71,6 +71,7 @@ const ( envVarOverrideFlag = "env-var-override" proxyFlag = "proxy" proxyNetworkFlag = "proxy-network" + watchFlag = "watch" // Flags for CI/CD. githubURLFlag = "github-url" @@ -324,6 +325,7 @@ Format: [container]:KEY=VALUE. Omit container name to apply to all containers.` Example: --port-override 5000:80 binds localhost:5000 to the service's port 80.` proxyFlagDescription = `Optional. Proxy outbound requests to your environment's VPC.` proxyNetworkFlagDescription = `Optional. Set the IP Network used by --proxy.` + watchFlagDescription = `Optional. Watch changes to local files and restart containers when updated.` svcManifestFlagDescription = `Optional. Name of the environment in which the service was deployed; output the manifest file used for that deployment.` diff --git a/internal/pkg/cli/run_local.go b/internal/pkg/cli/run_local.go index 6f7b6d1bcd7..eddbc9746ab 100644 --- a/internal/pkg/cli/run_local.go +++ b/internal/pkg/cli/run_local.go @@ -10,11 +10,13 @@ import ( "net" "os" "os/signal" + "path/filepath" "slices" "strconv" "strings" "sync" "syscall" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/arn" @@ -30,6 +32,7 @@ import ( "github.com/aws/copilot-cli/internal/pkg/aws/sessions" "github.com/aws/copilot-cli/internal/pkg/aws/ssm" clideploy "github.com/aws/copilot-cli/internal/pkg/cli/deploy" + "github.com/aws/copilot-cli/internal/pkg/cli/file" "github.com/aws/copilot-cli/internal/pkg/cli/group" "github.com/aws/copilot-cli/internal/pkg/config" "github.com/aws/copilot-cli/internal/pkg/deploy" @@ -49,6 +52,7 @@ import ( "github.com/aws/copilot-cli/internal/pkg/term/selector" "github.com/aws/copilot-cli/internal/pkg/term/syncbuffer" "github.com/aws/copilot-cli/internal/pkg/workspace" + "github.com/fsnotify/fsnotify" "github.com/spf13/afero" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" @@ -68,12 +72,20 @@ type hostFinder interface { Hosts(context.Context) ([]orchestrator.Host, error) } +type recursiveWatcher interface { + Add(path string) error + Close() error + Events() <-chan fsnotify.Event + Errors() <-chan error +} + type runLocalVars struct { wkldName string wkldType string appName string envName string envOverrides map[string]string + watch bool portOverrides portOverrides proxy bool proxyNetwork net.IPNet @@ -82,24 +94,26 @@ type runLocalVars struct { type runLocalOpts struct { runLocalVars - sel deploySelector - ecsClient ecsClient - ssm secretGetter - secretsManager secretGetter - sessProvider sessionProvider - sess *session.Session - envManagerSess *session.Session - targetEnv *config.Environment - targetApp *config.Application - store store - ws wsWlDirReader - cmd execRunner - dockerEngine dockerEngineRunner - repository repositoryService - prog progress - orchestrator containerOrchestrator - hostFinder hostFinder - envChecker versionCompatibilityChecker + sel deploySelector + ecsClient ecsClient + ssm secretGetter + secretsManager secretGetter + sessProvider sessionProvider + sess *session.Session + envManagerSess *session.Session + targetEnv *config.Environment + targetApp *config.Application + store store + ws wsWlDirReader + cmd execRunner + dockerEngine dockerEngineRunner + repository repositoryService + prog progress + orchestrator containerOrchestrator + hostFinder hostFinder + envChecker versionCompatibilityChecker + debounceTime time.Duration + newRecursiveWatcher func() (recursiveWatcher, error) buildContainerImages func(mft manifest.DynamicWorkload) (map[string]string, error) configureClients func() error @@ -223,6 +237,10 @@ func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) { } return containerURIs, nil } + o.debounceTime = 5 * time.Second + o.newRecursiveWatcher = func() (recursiveWatcher, error) { + return file.NewRecursiveWatcher(0) + } return o, nil } @@ -285,9 +303,9 @@ func (o *runLocalOpts) Execute() error { ctx := context.Background() - task, err := o.getTask(ctx) + task, err := o.prepareTask(ctx) if err != nil { - return fmt.Errorf("get task: %w", err) + return err } var hosts []orchestrator.Host @@ -308,35 +326,6 @@ func (o *runLocalOpts) Execute() error { } } - mft, _, err := workloadManifest(&workloadManifestInput{ - name: o.wkldName, - appName: o.appName, - envName: o.envName, - ws: o.ws, - interpolator: o.newInterpolator(o.appName, o.envName), - unmarshal: o.unmarshal, - sess: o.envManagerSess, - }) - if err != nil { - return err - } - - containerURIs, err := o.buildContainerImages(mft) - if err != nil { - return fmt.Errorf("build images: %w", err) - } - - // replace built images with the local built URI - for name, uri := range containerURIs { - ctr, ok := task.Containers[name] - if !ok { - return fmt.Errorf("built an image for %q, which doesn't exist in the task", name) - } - - ctr.ImageURI = uri - task.Containers[name] = ctr - } - sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) @@ -347,20 +336,42 @@ func (o *runLocalOpts) Execute() error { } o.orchestrator.RunTask(task, runTaskOpts...) + var watchCh <-chan interface{} + var watchErrCh <-chan error + stopCh := make(chan struct{}) + if o.watch { + watchCh, watchErrCh, err = o.watchLocalFiles(stopCh) + if err != nil { + return fmt.Errorf("setup watch: %s", err) + } + } + for { select { case err, ok := <-errCh: // we loop until errCh closes, since Start() // closes errCh when the orchestrator is completely done. if !ok { + close(stopCh) return nil } - fmt.Printf("error: %s\n", err) + log.Errorf("error: %s\n", err) o.orchestrator.Stop() case <-sigCh: signal.Stop(sigCh) o.orchestrator.Stop() + case <-watchErrCh: + log.Errorf("watch: %s\n", err) + o.orchestrator.Stop() + case <-watchCh: + task, err = o.prepareTask(ctx) + if err != nil { + log.Errorf("rerun task: %s\n", err) + o.orchestrator.Stop() + break + } + o.orchestrator.RunTask(task) } } } @@ -455,6 +466,120 @@ func (o *runLocalOpts) getTask(ctx context.Context) (orchestrator.Task, error) { return task, nil } +func (o *runLocalOpts) prepareTask(ctx context.Context) (orchestrator.Task, error) { + task, err := o.getTask(ctx) + if err != nil { + return orchestrator.Task{}, fmt.Errorf("get task: %w", err) + } + + mft, _, err := workloadManifest(&workloadManifestInput{ + name: o.wkldName, + appName: o.appName, + envName: o.envName, + ws: o.ws, + interpolator: o.newInterpolator(o.appName, o.envName), + unmarshal: o.unmarshal, + sess: o.envManagerSess, + }) + if err != nil { + return orchestrator.Task{}, err + } + + containerURIs, err := o.buildContainerImages(mft) + if err != nil { + return orchestrator.Task{}, fmt.Errorf("build images: %w", err) + } + + // replace built images with the local built URI + for name, uri := range containerURIs { + ctr, ok := task.Containers[name] + if !ok { + return orchestrator.Task{}, fmt.Errorf("built an image for %q, which doesn't exist in the task", name) + } + + ctr.ImageURI = uri + task.Containers[name] = ctr + } + + return task, nil +} + +func (o *runLocalOpts) watchLocalFiles(stopCh <-chan struct{}) (<-chan interface{}, <-chan error, error) { + workspacePath := o.ws.Path() + + watchCh := make(chan interface{}) + watchErrCh := make(chan error) + + watcher, err := o.newRecursiveWatcher() + if err != nil { + return nil, nil, fmt.Errorf("file: %w", err) + } + + if err = watcher.Add(workspacePath); err != nil { + return nil, nil, err + } + + watcherEvents := watcher.Events() + watcherErrors := watcher.Errors() + + debounceTimer := time.NewTimer(o.debounceTime) + if !debounceTimer.Stop() { + // flush the timer in case stop is called after the timer finishes + <-debounceTimer.C + } + + go func() { + for { + select { + case <-stopCh: + watcher.Close() + return + case err, ok := <-watcherErrors: + watchErrCh <- err + if !ok { + watcher.Close() + return + } + case event, ok := <-watcherEvents: + if !ok { + watcher.Close() + return + } + + // skip chmod events + if event.Has(fsnotify.Chmod) { + break + } + + // check if any subdirectories within copilot directory are hidden + isHidden := false + parent := workspacePath + suffix, _ := strings.CutPrefix(event.Name, parent+"/") + // fsnotify events are always of form /a/b/c, don't use filepath.Split as that's OS dependent + for _, child := range strings.Split(suffix, "/") { + parent = filepath.Join(parent, child) + subdirHidden, err := file.IsHiddenFile(child) + if err != nil { + break + } + if subdirHidden { + isHidden = true + } + } + + // TODO(Aiden): implement dockerignore blacklist for update + if !isHidden { + debounceTimer.Reset(o.debounceTime) + } + case <-debounceTimer.C: + watchCh <- nil + } + } + }() + + return watchCh, watchErrCh, nil +} + func sessionEnvVars(ctx context.Context, sess *session.Session) (map[string]string, error) { creds, err := sess.Config.Credentials.GetWithContext(ctx) if err != nil { @@ -722,6 +847,7 @@ func BuildRunLocalCmd() *cobra.Command { cmd.Flags().StringVarP(&vars.wkldName, nameFlag, nameFlagShort, "", workloadFlagDescription) cmd.Flags().StringVarP(&vars.envName, envFlag, envFlagShort, "", envFlagDescription) cmd.Flags().StringVarP(&vars.appName, appFlag, appFlagShort, tryReadingAppName(), appFlagDescription) + cmd.Flags().BoolVar(&vars.watch, watchFlag, false, watchFlagDescription) cmd.Flags().Var(&vars.portOverrides, portOverrideFlag, portOverridesFlagDescription) cmd.Flags().StringToStringVar(&vars.envOverrides, envVarOverrideFlag, nil, envVarOverrideFlagDescription) cmd.Flags().BoolVar(&vars.proxy, proxyFlag, false, proxyFlagDescription) diff --git a/internal/pkg/cli/run_local_test.go b/internal/pkg/cli/run_local_test.go index 112383a5977..edc73a2e8f1 100644 --- a/internal/pkg/cli/run_local_test.go +++ b/internal/pkg/cli/run_local_test.go @@ -15,6 +15,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" sdkecs "github.com/aws/aws-sdk-go/service/ecs" awsecs "github.com/aws/copilot-cli/internal/pkg/aws/ecs" + "github.com/aws/copilot-cli/internal/pkg/cli/file/filetest" "github.com/aws/copilot-cli/internal/pkg/cli/mocks" "github.com/aws/copilot-cli/internal/pkg/config" "github.com/aws/copilot-cli/internal/pkg/docker/orchestrator" @@ -22,6 +23,7 @@ import ( "github.com/aws/copilot-cli/internal/pkg/ecs" "github.com/aws/copilot-cli/internal/pkg/manifest" "github.com/aws/copilot-cli/internal/pkg/term/selector" + "github.com/fsnotify/fsnotify" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) @@ -207,6 +209,7 @@ type runLocalExecuteMocks struct { secretsManager *mocks.MocksecretGetter prog *mocks.Mockprogress orchestrator *orchestratortest.Double + watcher *filetest.Double hostFinder *hostFinderDouble envChecker *mocks.MockversionCompatibilityChecker } @@ -312,6 +315,58 @@ func TestRunLocalOpts_Execute(t *testing.T) { }, }, } + alteredTaskDef := &awsecs.TaskDefinition{ + ContainerDefinitions: []*sdkecs.ContainerDefinition{ + { + Name: aws.String("foo"), + Environment: []*sdkecs.KeyValuePair{ + { + Name: aws.String("FOO_VAR"), + Value: aws.String("foo-value"), + }, + }, + Secrets: []*sdkecs.Secret{ + { + Name: aws.String("SHARED_SECRET"), + ValueFrom: aws.String("mysecret"), + }, + }, + PortMappings: []*sdkecs.PortMapping{ + { + HostPort: aws.Int64(80), + ContainerPort: aws.Int64(8081), + }, + { + HostPort: aws.Int64(9999), + }, + }, + }, + { + Name: aws.String("bar"), + Environment: []*sdkecs.KeyValuePair{ + { + Name: aws.String("BAR_VAR"), + Value: aws.String("bar-value"), + }, + }, + Secrets: []*sdkecs.Secret{ + { + Name: aws.String("SHARED_SECRET"), + ValueFrom: aws.String("mysecret"), + }, + }, + PortMappings: []*sdkecs.PortMapping{ + { + HostPort: aws.Int64(10000), + }, + { + HostPort: aws.Int64(77), + ContainerPort: aws.Int64(7777), + }, + }, + }, + }, + } expectedTask := orchestrator.Task{ Containers: map[string]orchestrator.ContainerDefinition{ "foo": { @@ -363,6 +418,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { inputWkldName string inputEnvOverrides map[string]string inputPortOverrides []string + inputWatch bool inputProxy bool buildImagesError error @@ -393,7 +449,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { }, wantedError: errors.New(`get task: get env vars: parse env overrides: "bad:OVERRIDE" targets invalid container`), }, - "error getting env version": { + "error reading workload manifest": { inputAppName: testAppName, inputWkldName: testWkldName, inputEnvName: testEnvName, @@ -401,11 +457,11 @@ func TestRunLocalOpts_Execute(t *testing.T) { setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) - m.envChecker.EXPECT().Version().Return("", fmt.Errorf("some error")) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return(nil, errors.New("some error")) }, - wantedError: errors.New(`retrieve version of environment stack "testEnv" in application "testApp": some error`), + wantedError: errors.New(`read manifest file for testWkld: some error`), }, - "error due to old env version": { + "error interpolating workload manifest": { inputAppName: testAppName, inputWkldName: testWkldName, inputEnvName: testEnvName, @@ -413,60 +469,68 @@ func TestRunLocalOpts_Execute(t *testing.T) { setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) - m.envChecker.EXPECT().Version().Return("v1.31.0", nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", errors.New("some error")) }, - wantedError: errors.New(`environment "testEnv" is on version "v1.31.0" which does not support the "run local --proxy" feature`), + wantedError: errors.New(`interpolate environment variables for testWkld manifest: some error`), }, - "error getting hosts to proxy to": { - inputAppName: testAppName, - inputWkldName: testWkldName, - inputEnvName: testEnvName, - inputProxy: true, + "error building container images": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + buildImagesError: errors.New("some error"), setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) - m.envChecker.EXPECT().Version().Return("v1.32.0", nil) - m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { - return nil, fmt.Errorf("some error") - } + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) }, - wantedError: errors.New(`find hosts to connect to: some error`), + wantedError: errors.New(`build images: some error`), }, - "error reading workload manifest": { + "error getting env version": { inputAppName: testAppName, inputWkldName: testWkldName, inputEnvName: testEnvName, + inputProxy: true, setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) - m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return(nil, errors.New("some error")) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) + m.envChecker.EXPECT().Version().Return("", fmt.Errorf("some error")) }, - wantedError: errors.New(`read manifest file for testWkld: some error`), + wantedError: errors.New(`retrieve version of environment stack "testEnv" in application "testApp": some error`), }, - "error interpolating workload manifest": { + "error due to old env version": { inputAppName: testAppName, inputWkldName: testWkldName, inputEnvName: testEnvName, + inputProxy: true, setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) - m.interpolator.EXPECT().Interpolate("").Return("", errors.New("some error")) + m.interpolator.EXPECT().Interpolate("").Return("", nil) + m.envChecker.EXPECT().Version().Return("v1.31.0", nil) }, - wantedError: errors.New(`interpolate environment variables for testWkld manifest: some error`), + wantedError: errors.New(`environment "testEnv" is on version "v1.31.0" which does not support the "run local --proxy" feature`), }, - "error building container images": { - inputAppName: testAppName, - inputWkldName: testWkldName, - inputEnvName: testEnvName, - buildImagesError: errors.New("some error"), + "error getting hosts to proxy to": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputProxy: true, setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) m.interpolator.EXPECT().Interpolate("").Return("", nil) + m.envChecker.EXPECT().Version().Return("v1.32.0", nil) + m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { + return nil, fmt.Errorf("some error") + } }, - wantedError: errors.New(`build images: some error`), + wantedError: errors.New(`find hosts to connect to: some error`), }, "error, proxy, describe service": { inputAppName: testAppName, @@ -476,6 +540,8 @@ func TestRunLocalOpts_Execute(t *testing.T) { setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) m.envChecker.EXPECT().Version().Return("v1.32.0", nil) m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { return []orchestrator.Host{ @@ -497,6 +563,8 @@ func TestRunLocalOpts_Execute(t *testing.T) { setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) m.envChecker.EXPECT().Version().Return("v1.32.0", nil) m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { return []orchestrator.Host{ @@ -524,6 +592,8 @@ func TestRunLocalOpts_Execute(t *testing.T) { setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) m.envChecker.EXPECT().Version().Return("v1.32.0", nil) m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { return []orchestrator.Host{ @@ -551,6 +621,8 @@ func TestRunLocalOpts_Execute(t *testing.T) { setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) m.envChecker.EXPECT().Version().Return("v1.32.0", nil) m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { return []orchestrator.Host{ @@ -688,6 +760,176 @@ func TestRunLocalOpts_Execute(t *testing.T) { } }, }, + "watch flag receives hidden file update, doesn't restart": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputWatch: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) + m.ws.EXPECT().Path().Return("") + + eventCh := make(chan fsnotify.Event, 1) + m.watcher.EventsFn = func() <-chan fsnotify.Event { + eventCh <- fsnotify.Event{ + Name: ".hiddensubdir/mockFilename", + Op: fsnotify.Write, + } + return eventCh + } + + watcherErrCh := make(chan error, 1) + m.watcher.ErrorsFn = func() <-chan error { + return watcherErrCh + } + + errCh := make(chan error, 1) + m.orchestrator.StartFn = func() <-chan error { + return errCh + } + + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + } + + m.orchestrator.StopFn = func() { + close(errCh) + } + }, + }, + "watch flag restarts, error for pause container definition update": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputWatch: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil).Times(2) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil).Times(2) + m.interpolator.EXPECT().Interpolate("").Return("", nil).Times(2) + m.ws.EXPECT().Path().Return("") + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(alteredTaskDef, nil) + + eventCh := make(chan fsnotify.Event, 1) + m.watcher.EventsFn = func() <-chan fsnotify.Event { + eventCh <- fsnotify.Event{ + Name: "mockFilename", + Op: fsnotify.Write, + } + return eventCh + } + + watcherErrCh := make(chan error, 1) + m.watcher.ErrorsFn = func() <-chan error { + return watcherErrCh + } + + errCh := make(chan error, 1) + m.orchestrator.StartFn = func() <-chan error { + return errCh + } + + count := 1 + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { + switch count { + case 1: + require.Equal(t, expectedTask, task) + case 2: + require.NotEqual(t, expectedTask, task) + errCh <- errors.New("new task requires recreating pause container") + } + count++ + } + + m.orchestrator.StopFn = func() { + close(errCh) + } + }, + }, + "watcher error succesfully stops all goroutines": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputWatch: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) + m.ws.EXPECT().Path().Return("") + + eventCh := make(chan fsnotify.Event, 1) + m.watcher.EventsFn = func() <-chan fsnotify.Event { + return eventCh + } + + watcherErrCh := make(chan error, 1) + m.watcher.ErrorsFn = func() <-chan error { + watcherErrCh <- errors.New("some error") + return watcherErrCh + } + + errCh := make(chan error, 1) + m.orchestrator.StartFn = func() <-chan error { + return errCh + } + + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { + require.Equal(t, expectedTask, task) + } + + m.orchestrator.StopFn = func() { + close(errCh) + } + }, + }, + "watch flag restarts and finishes successfully": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputWatch: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil).Times(2) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil).Times(2) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil).Times(2) + m.interpolator.EXPECT().Interpolate("").Return("", nil).Times(2) + m.ws.EXPECT().Path().Return("") + + eventCh := make(chan fsnotify.Event, 1) + m.watcher.EventsFn = func() <-chan fsnotify.Event { + eventCh <- fsnotify.Event{ + Name: "mockFilename", + Op: fsnotify.Write, + } + return eventCh + } + + watcherErrCh := make(chan error, 1) + m.watcher.ErrorsFn = func() <-chan error { + return watcherErrCh + } + + errCh := make(chan error, 1) + m.orchestrator.StartFn = func() <-chan error { + return errCh + } + runCount := 1 + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { + require.Equal(t, expectedTask, task) + if runCount > 1 { + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + } + runCount++ + } + + m.orchestrator.StopFn = func() { + close(errCh) + } + }, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { @@ -706,6 +948,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { repository: mocks.NewMockrepositoryService(ctrl), prog: mocks.NewMockprogress(ctrl), orchestrator: &orchestratortest.Double{}, + watcher: &filetest.Double{}, hostFinder: &hostFinderDouble{}, envChecker: mocks.NewMockversionCompatibilityChecker(ctrl), } @@ -716,6 +959,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { wkldName: tc.inputWkldName, envName: tc.inputEnvName, envOverrides: tc.inputEnvOverrides, + watch: tc.inputWatch, portOverrides: portOverrides{ { host: "777", @@ -759,6 +1003,10 @@ func TestRunLocalOpts_Execute(t *testing.T) { orchestrator: m.orchestrator, hostFinder: m.hostFinder, envChecker: m.envChecker, + debounceTime: 0, // disable debounce during testing + newRecursiveWatcher: func() (recursiveWatcher, error) { + return m.watcher, nil + }, } // WHEN err := opts.Execute() diff --git a/internal/pkg/docker/orchestrator/orchestrator.go b/internal/pkg/docker/orchestrator/orchestrator.go index 53dd4ff8543..155bc57cce6 100644 --- a/internal/pkg/docker/orchestrator/orchestrator.go +++ b/internal/pkg/docker/orchestrator/orchestrator.go @@ -215,6 +215,10 @@ func (a *runTaskAction) Do(o *Orchestrator) error { if err := o.stopTask(ctx, o.curTask); err != nil { return fmt.Errorf("stop existing task: %w", err) } + + // ensure that containers are fully stopped after o.stopTask finishes blocking + // TODO(Aiden): Implement a container ID system or use `docker ps` to ensure containers are stopped + time.Sleep(1 * time.Second) } for name, ctr := range a.task.Containers {