From 8491d15c70d53548ff310b330962d93a40fbfd30 Mon Sep 17 00:00:00 2001 From: Danny Randall <10566468+dannyrandall@users.noreply.github.com> Date: Mon, 6 Nov 2023 21:20:49 -0800 Subject: [PATCH] chore: setup proxy connections for `run local --proxy` (#5439) By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the Apache 2.0 License. --- internal/pkg/cli/flag.go | 4 +- internal/pkg/cli/interfaces.go | 2 + internal/pkg/cli/mocks/mock_interfaces.go | 34 ++ internal/pkg/cli/run_local.go | 80 +++- internal/pkg/cli/run_local_test.go | 183 +++++++-- .../pkg/docker/dockerengine/dockerengine.go | 30 +- .../docker/dockerengine/dockerengine_test.go | 47 ++- .../dockerenginetest/dockerenginetest.go | 9 + .../pkg/docker/orchestrator/orchestrator.go | 169 +++++++- .../docker/orchestrator/orchestrator_test.go | 375 ++++++++++++++---- .../orchestratortest/orchestratortest.go | 6 +- 11 files changed, 803 insertions(+), 136 deletions(-) diff --git a/internal/pkg/cli/flag.go b/internal/pkg/cli/flag.go index 060c04dd0c8..5161be9e27f 100644 --- a/internal/pkg/cli/flag.go +++ b/internal/pkg/cli/flag.go @@ -70,6 +70,7 @@ const ( portOverrideFlag = "port-override" envVarOverrideFlag = "env-var-override" proxyFlag = "proxy" + proxyNetworkFlag = "proxy-network" // Flags for CI/CD. githubURLFlag = "github-url" @@ -321,7 +322,8 @@ Defaults to all logs. Only one of end-time / follow may be used.` Format: [container]:KEY=VALUE. Omit container name to apply to all containers.` portOverridesFlagDescription = `Optional. Override ports exposed by service. Format: :. Example: --port-override 5000:80 binds localhost:5000 to the service's port 80.` - proxyFlagDescription = `Optional. Proxy outbound requests to your environment's VPC.` + proxyFlagDescription = `Optional. Proxy outbound requests to your environment's VPC.` + proxyNetworkFlagDescription = `Optional. Set the IP Network used by --proxy.` svcManifestFlagDescription = `Optional. Name of the environment in which the service was deployed; output the manifest file used for that deployment.` diff --git a/internal/pkg/cli/interfaces.go b/internal/pkg/cli/interfaces.go index dc2a80c0dbd..2d72598b6ea 100644 --- a/internal/pkg/cli/interfaces.go +++ b/internal/pkg/cli/interfaces.go @@ -186,6 +186,7 @@ type repositoryService interface { type ecsClient interface { TaskDefinition(app, env, svc string) (*awsecs.TaskDefinition, error) ServiceConnectServices(app, env, svc string) ([]*awsecs.Service, error) + DescribeService(app, env, svc string) (*ecs.ServiceDesc, error) } type logEventsWriter interface { @@ -710,6 +711,7 @@ type dockerEngineRunner interface { Stop(context.Context, string) error Rm(string) error Build(context.Context, *dockerengine.BuildArguments, io.Writer) error + Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error } type workloadStackGenerator interface { diff --git a/internal/pkg/cli/mocks/mock_interfaces.go b/internal/pkg/cli/mocks/mock_interfaces.go index 92bc031cc20..1e0ce2d6fc6 100644 --- a/internal/pkg/cli/mocks/mock_interfaces.go +++ b/internal/pkg/cli/mocks/mock_interfaces.go @@ -1738,6 +1738,21 @@ func (m *MockecsClient) EXPECT() *MockecsClientMockRecorder { return m.recorder } +// DescribeService mocks base method. +func (m *MockecsClient) DescribeService(app, env, svc string) (*ecs0.ServiceDesc, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DescribeService", app, env, svc) + ret0, _ := ret[0].(*ecs0.ServiceDesc) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeService indicates an expected call of DescribeService. +func (mr *MockecsClientMockRecorder) DescribeService(app, env, svc interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeService", reflect.TypeOf((*MockecsClient)(nil).DescribeService), app, env, svc) +} + // ServiceConnectServices mocks base method. func (m *MockecsClient) ServiceConnectServices(app, env, svc string) ([]*ecs.Service, error) { m.ctrl.T.Helper() @@ -7693,6 +7708,25 @@ func (mr *MockdockerEngineRunnerMockRecorder) CheckDockerEngineRunning() *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckDockerEngineRunning", reflect.TypeOf((*MockdockerEngineRunner)(nil).CheckDockerEngineRunning)) } +// Exec mocks base method. +func (m *MockdockerEngineRunner) Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, container, out, cmd} + for _, a := range args { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockdockerEngineRunnerMockRecorder) Exec(ctx, container, out, cmd interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, container, out, cmd}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockdockerEngineRunner)(nil).Exec), varargs...) +} + // IsContainerRunning mocks base method. func (m *MockdockerEngineRunner) IsContainerRunning(arg0 context.Context, arg1 string) (bool, error) { m.ctrl.T.Helper() diff --git a/internal/pkg/cli/run_local.go b/internal/pkg/cli/run_local.go index 68f188946a1..6f7b6d1bcd7 100644 --- a/internal/pkg/cli/run_local.go +++ b/internal/pkg/cli/run_local.go @@ -5,7 +5,9 @@ package cli import ( "context" + "errors" "fmt" + "net" "os" "os/signal" "slices" @@ -58,12 +60,12 @@ const ( type containerOrchestrator interface { Start() <-chan error - RunTask(orchestrator.Task) + RunTask(orchestrator.Task, ...orchestrator.RunTaskOption) Stop() } type hostFinder interface { - Hosts(context.Context) ([]host, error) + Hosts(context.Context) ([]orchestrator.Host, error) } type runLocalVars struct { @@ -74,6 +76,7 @@ type runLocalVars struct { envOverrides map[string]string portOverrides portOverrides proxy bool + proxyNetwork net.IPNet } type runLocalOpts struct { @@ -287,18 +290,22 @@ func (o *runLocalOpts) Execute() error { return fmt.Errorf("get task: %w", err) } + var hosts []orchestrator.Host + var ssmTarget string if o.proxy { if err := validateMinEnvVersion(o.ws, o.envChecker, o.appName, o.envName, template.RunLocalProxyMinEnvVersion, "run local --proxy"); err != nil { return err } - hosts, err := o.hostFinder.Hosts(ctx) + hosts, err = o.hostFinder.Hosts(ctx) if err != nil { return fmt.Errorf("find hosts to connect to: %w", err) } - // TODO(dannyrandall): inject into orchestrator and use in pause container - fmt.Printf("hosts: %+v\n", hosts) + ssmTarget, err = o.getSSMTarget(ctx) + if err != nil { + return fmt.Errorf("get proxy target container: %w", err) + } } mft, _, err := workloadManifest(&workloadManifestInput{ @@ -334,7 +341,11 @@ func (o *runLocalOpts) Execute() error { signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) errCh := o.orchestrator.Start() - o.orchestrator.RunTask(task) + var runTaskOpts []orchestrator.RunTaskOption + if o.proxy { + runTaskOpts = append(runTaskOpts, orchestrator.RunTaskWithProxy(ssmTarget, o.proxyNetwork, hosts...)) + } + o.orchestrator.RunTask(task, runTaskOpts...) for { select { @@ -354,6 +365,41 @@ func (o *runLocalOpts) Execute() error { } } +// getSSMTarget returns a AWS SSM target for a running container +// that supports ECS Service Exec. +func (o *runLocalOpts) getSSMTarget(ctx context.Context) (string, error) { + svc, err := o.ecsClient.DescribeService(o.appName, o.envName, o.wkldName) + if err != nil { + return "", fmt.Errorf("describe service: %w", err) + } + + for _, task := range svc.Tasks { + // TaskArn should have the format: arn:aws:ecs:us-west-2:123456789:task/clusterName/taskName + taskARN, err := arn.Parse(aws.StringValue(task.TaskArn)) + if err != nil { + return "", fmt.Errorf("parse task arn: %w", err) + } + + split := strings.Split(taskARN.Resource, "/") + if len(split) != 3 { + return "", fmt.Errorf("task ARN in unexpected format: %q", taskARN) + } + taskName := split[2] + + for _, ctr := range task.Containers { + id := aws.StringValue(ctr.RuntimeId) + hasECSExec := slices.ContainsFunc(ctr.ManagedAgents, func(a *sdkecs.ManagedAgent) bool { + return aws.StringValue(a.Name) == "ExecuteCommandAgent" && aws.StringValue(a.LastStatus) == "RUNNING" + }) + if id != "" && hasECSExec && aws.StringValue(ctr.LastStatus) == "RUNNING" { + return fmt.Sprintf("ecs:%s_%s_%s", svc.ClusterName, taskName, aws.StringValue(ctr.RuntimeId)), nil + } + } + } + + return "", errors.New("no running tasks have running containers with ecs exec enabled") +} + func (o *runLocalOpts) getTask(ctx context.Context) (orchestrator.Task, error) { td, err := o.ecsClient.TaskDefinition(o.appName, o.envName, o.wkldName) if err != nil { @@ -617,11 +663,6 @@ func (o *runLocalOpts) getSecret(ctx context.Context, valueFrom string) (string, return getter.GetSecretValue(ctx, valueFrom) } -type host struct { - host string - port string -} - type hostDiscoverer struct { ecs ecsClient app string @@ -629,13 +670,13 @@ type hostDiscoverer struct { wkld string } -func (h *hostDiscoverer) Hosts(ctx context.Context) ([]host, error) { +func (h *hostDiscoverer) Hosts(ctx context.Context) ([]orchestrator.Host, error) { svcs, err := h.ecs.ServiceConnectServices(h.app, h.env, h.wkld) if err != nil { return nil, fmt.Errorf("get service connect services: %w", err) } - var hosts []host + var hosts []orchestrator.Host for _, svc := range svcs { // find the primary deployment with service connect enabled idx := slices.IndexFunc(svc.Deployments, func(dep *sdkecs.Deployment) bool { @@ -647,9 +688,9 @@ func (h *hostDiscoverer) Hosts(ctx context.Context) ([]host, error) { for _, sc := range svc.Deployments[idx].ServiceConnectConfiguration.Services { for _, alias := range sc.ClientAliases { - hosts = append(hosts, host{ - host: aws.StringValue(alias.DnsName), - port: strconv.Itoa(int(aws.Int64Value(alias.Port))), + hosts = append(hosts, orchestrator.Host{ + Name: aws.StringValue(alias.DnsName), + Port: strconv.Itoa(int(aws.Int64Value(alias.Port))), }) } } @@ -684,5 +725,12 @@ func BuildRunLocalCmd() *cobra.Command { cmd.Flags().Var(&vars.portOverrides, portOverrideFlag, portOverridesFlagDescription) cmd.Flags().StringToStringVar(&vars.envOverrides, envVarOverrideFlag, nil, envVarOverrideFlagDescription) cmd.Flags().BoolVar(&vars.proxy, proxyFlag, false, proxyFlagDescription) + cmd.Flags().IPNetVar(&vars.proxyNetwork, proxyNetworkFlag, net.IPNet{ + // docker uses 172.17.0.0/16 for networking by default + // so we'll default to different /16 from the 172.16.0.0/12 + // private network defined by RFC 1918. + IP: net.IPv4(172, 20, 0, 0), + Mask: net.CIDRMask(16, 32), + }, proxyNetworkFlag) return cmd } diff --git a/internal/pkg/cli/run_local_test.go b/internal/pkg/cli/run_local_test.go index 5e094a823e3..112383a5977 100644 --- a/internal/pkg/cli/run_local_test.go +++ b/internal/pkg/cli/run_local_test.go @@ -14,12 +14,12 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" sdkecs "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/copilot-cli/internal/pkg/aws/ecs" awsecs "github.com/aws/copilot-cli/internal/pkg/aws/ecs" "github.com/aws/copilot-cli/internal/pkg/cli/mocks" "github.com/aws/copilot-cli/internal/pkg/config" "github.com/aws/copilot-cli/internal/pkg/docker/orchestrator" "github.com/aws/copilot-cli/internal/pkg/docker/orchestrator/orchestratortest" + "github.com/aws/copilot-cli/internal/pkg/ecs" "github.com/aws/copilot-cli/internal/pkg/manifest" "github.com/aws/copilot-cli/internal/pkg/term/selector" "github.com/golang/mock/gomock" @@ -224,10 +224,10 @@ func (m *mockProvider) IsExpired() bool { } type hostFinderDouble struct { - HostsFn func(context.Context) ([]host, error) + HostsFn func(context.Context) ([]orchestrator.Host, error) } -func (d *hostFinderDouble) Hosts(ctx context.Context) ([]host, error) { +func (d *hostFinderDouble) Hosts(ctx context.Context) ([]orchestrator.Host, error) { if d.HostsFn == nil { return nil, nil } @@ -260,7 +260,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { "bar": "image2", } - taskDef := &ecs.TaskDefinition{ + taskDef := &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -426,7 +426,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) m.envChecker.EXPECT().Version().Return("v1.32.0", nil) - m.hostFinder.HostsFn = func(ctx context.Context) ([]host, error) { + m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { return nil, fmt.Errorf("some error") } }, @@ -468,6 +468,114 @@ func TestRunLocalOpts_Execute(t *testing.T) { }, wantedError: errors.New(`build images: some error`), }, + "error, proxy, describe service": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputProxy: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.envChecker.EXPECT().Version().Return("v1.32.0", nil) + m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { + return []orchestrator.Host{ + { + Name: "a-different-service", + Port: "80", + }, + }, nil + } + m.ecsClient.EXPECT().DescribeService(testAppName, testEnvName, testWkldName).Return(nil, errors.New("some error")) + }, + wantedError: errors.New("get proxy target container: describe service: some error"), + }, + "error, proxy, parse arn": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputProxy: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.envChecker.EXPECT().Version().Return("v1.32.0", nil) + m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { + return []orchestrator.Host{ + { + Name: "a-different-service", + Port: "80", + }, + }, nil + } + m.ecsClient.EXPECT().DescribeService(testAppName, testEnvName, testWkldName).Return(&ecs.ServiceDesc{ + Tasks: []*awsecs.Task{ + { + TaskArn: aws.String("asdf"), + }, + }, + }, nil) + }, + wantedError: errors.New(`get proxy target container: parse task arn: arn: invalid prefix`), + }, + "error, proxy, process task": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputProxy: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.envChecker.EXPECT().Version().Return("v1.32.0", nil) + m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { + return []orchestrator.Host{ + { + Name: "a-different-service", + Port: "80", + }, + }, nil + } + m.ecsClient.EXPECT().DescribeService(testAppName, testEnvName, testWkldName).Return(&ecs.ServiceDesc{ + Tasks: []*awsecs.Task{ + { + TaskArn: aws.String("arn:aws:ecs:us-west-2:123456789:task/asdf"), + }, + }, + }, nil) + }, + wantedError: errors.New(`get proxy target container: task ARN in unexpected format: "arn:aws:ecs:us-west-2:123456789:task/asdf"`), + }, + "error, proxy, no valid containers": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputProxy: true, + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.envChecker.EXPECT().Version().Return("v1.32.0", nil) + m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { + return []orchestrator.Host{ + { + Name: "a-different-service", + Port: "80", + }, + }, nil + } + m.ecsClient.EXPECT().DescribeService(testAppName, testEnvName, testWkldName).Return(&ecs.ServiceDesc{ + Tasks: []*awsecs.Task{ + { + TaskArn: aws.String("arn:aws:ecs:us-west-2:123456789:task/clusterName/taskName"), + Containers: []*sdkecs.Container{ + { + RuntimeId: aws.String("runtime-id"), + LastStatus: aws.String("RUNNING"), + }, + }, + }, + }, + }, nil) + }, + wantedError: errors.New(`get proxy target container: no running tasks have running containers with ecs exec enabled`), + }, "success, one run task call": { inputAppName: testAppName, inputWkldName: testWkldName, @@ -483,7 +591,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { errCh <- errors.New("some error") return errCh } - m.orchestrator.RunTaskFn = func(task orchestrator.Task) { + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { require.Equal(t, expectedTask, task) } m.orchestrator.StopFn = func() { @@ -501,14 +609,33 @@ func TestRunLocalOpts_Execute(t *testing.T) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) m.envChecker.EXPECT().Version().Return("v1.32.0", nil) - m.hostFinder.HostsFn = func(ctx context.Context) ([]host, error) { - return []host{ + m.hostFinder.HostsFn = func(ctx context.Context) ([]orchestrator.Host, error) { + return []orchestrator.Host{ { - host: "a-different-service", - port: "80", + Name: "a-different-service", + Port: "80", }, }, nil } + m.ecsClient.EXPECT().DescribeService(testAppName, testEnvName, testWkldName).Return(&ecs.ServiceDesc{ + Tasks: []*awsecs.Task{ + { + TaskArn: aws.String("arn:aws:ecs:us-west-2:123456789:task/clusterName/taskName"), + Containers: []*sdkecs.Container{ + { + RuntimeId: aws.String("runtime-id"), + LastStatus: aws.String("RUNNING"), + ManagedAgents: []*sdkecs.ManagedAgent{ + { + Name: aws.String("ExecuteCommandAgent"), + LastStatus: aws.String("RUNNING"), + }, + }, + }, + }, + }, + }, + }, nil) m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) m.interpolator.EXPECT().Interpolate("").Return("", nil) @@ -517,7 +644,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { errCh <- errors.New("some error") return errCh } - m.orchestrator.RunTaskFn = func(task orchestrator.Task) { + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { require.Equal(t, expectedProxyTask, task) } m.orchestrator.StopFn = func() { @@ -540,7 +667,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { m.orchestrator.StartFn = func() <-chan error { return errCh } - m.orchestrator.RunTaskFn = func(task orchestrator.Task) { + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { require.Equal(t, expectedTask, task) syscall.Kill(syscall.Getpid(), syscall.SIGINT) } @@ -656,7 +783,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { } tests := map[string]struct { - taskDef *ecs.TaskDefinition + taskDef *awsecs.TaskDefinition envOverrides map[string]string setupMocks func(m *runLocalExecuteMocks) credsError error @@ -666,14 +793,14 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { wantError string }{ "invalid container in env override": { - taskDef: &ecs.TaskDefinition{}, + taskDef: &awsecs.TaskDefinition{}, envOverrides: map[string]string{ "bad:OVERRIDE": "bad", }, wantError: `parse env overrides: "bad:OVERRIDE" targets invalid container`, }, "overrides parsed and applied correctly": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -706,7 +833,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { }, }, "overrides merged with existing env vars correctly": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -769,7 +896,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { }, }, "error getting secret": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -788,7 +915,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { wantError: `get secrets: get secret "defaultSSM": some error`, }, "error getting secret if invalid arn": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -804,7 +931,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { wantError: `get secrets: get secret "arn:aws:ecs:us-west-2:123456789:service/mycluster/myservice": invalid ARN; not a SSM or Secrets Manager ARN`, }, "error if secret redefines a var": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -826,7 +953,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { wantError: `get secrets: secret names must be unique, but an environment variable "SHOULD_BE_A_VAR" already exists`, }, "correct service used based on arn": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -864,7 +991,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { }, }, "only unique secrets pulled": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -917,7 +1044,7 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { }, }, "secrets set via overrides not pulled": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -974,12 +1101,12 @@ func TestRunLocalOpts_getEnvVars(t *testing.T) { }, }, "error getting creds": { - taskDef: &ecs.TaskDefinition{}, + taskDef: &awsecs.TaskDefinition{}, credsError: errors.New("some error"), wantError: `get IAM credentials: some error`, }, "region env vars set": { - taskDef: &ecs.TaskDefinition{ + taskDef: &awsecs.TaskDefinition{ ContainerDefinitions: []*sdkecs.ContainerDefinition{ { Name: aws.String("foo"), @@ -1053,7 +1180,7 @@ func TestRunLocal_HostDiscovery(t *testing.T) { tests := map[string]struct { setupMocks func(t *testing.T, m *testMocks) - wantHosts []host + wantHosts []orchestrator.Host wantError string }{ "error getting services": { @@ -1123,10 +1250,10 @@ func TestRunLocal_HostDiscovery(t *testing.T) { }, }, nil) }, - wantHosts: []host{ + wantHosts: []orchestrator.Host{ { - host: "primary", - port: "80", + Name: "primary", + Port: "80", }, }, }, diff --git a/internal/pkg/docker/dockerengine/dockerengine.go b/internal/pkg/docker/dockerengine/dockerengine.go index 784287d6851..62d4805316c 100644 --- a/internal/pkg/docker/dockerengine/dockerengine.go +++ b/internal/pkg/docker/dockerengine/dockerengine.go @@ -78,14 +78,15 @@ type BuildArguments struct { // RunOptions holds the options for running a Docker container. type RunOptions struct { - ImageURI string // Required. The image name to run. - Secrets map[string]string // Optional. Secrets to pass to the container as environment variables. - EnvVars map[string]string // Optional. Environment variables to pass to the container. - ContainerName string // Optional. The name for the container. - ContainerPorts map[string]string // Optional. Contains host and container ports. - Command []string // Optional. The command to run in the container. - ContainerNetwork string // Optional. Network mode for the container. - LogOptions RunLogOptions + ImageURI string // Required. The image name to run. + Secrets map[string]string // Optional. Secrets to pass to the container as environment variables. + EnvVars map[string]string // Optional. Environment variables to pass to the container. + ContainerName string // Optional. The name for the container. + ContainerPorts map[string]string // Optional. Contains host and container ports. + Command []string // Optional. The command to run in the container. + ContainerNetwork string // Optional. Network mode for the container. + LogOptions RunLogOptions // Optional. Configure logging for output from the container + AddLinuxCapabilities []string // Optional. Adds linux capabilities to the container. } // RunLogOptions holds the logging configuration for Run(). @@ -204,6 +205,15 @@ func (c DockerCmdClient) Login(uri, username, password string) error { return nil } +// Exec runs cmd in container with args and writes stderr/stdout to out. +func (c DockerCmdClient) Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error { + return c.runner.RunWithContext(ctx, "docker", append([]string{ + "exec", + container, + cmd, + }, args...), exec.Stdout(out), exec.Stderr(out)) +} + // Push pushes the images with the specified tags and ecr repository URI, and returns the image digest on success. func (c DockerCmdClient) Push(ctx context.Context, uri string, w io.Writer, tags ...string) (digest string, err error) { images := []string{} @@ -260,6 +270,10 @@ func (in *RunOptions) generateRunArguments() []string { args = append(args, "--env", fmt.Sprintf("%s=%s", key, value)) } + for _, cap := range in.AddLinuxCapabilities { + args = append(args, "--cap-add", cap) + } + args = append(args, in.ImageURI) if in.Command != nil && len(in.Command) > 0 { diff --git a/internal/pkg/docker/dockerengine/dockerengine_test.go b/internal/pkg/docker/dockerengine/dockerengine_test.go index dd655226381..ec01f8f6751 100644 --- a/internal/pkg/docker/dockerengine/dockerengine_test.go +++ b/internal/pkg/docker/dockerengine/dockerengine_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "io" osexec "os/exec" "path/filepath" "strings" @@ -253,7 +254,7 @@ func TestDockerCommand_Login(t *testing.T) { want error }{ - "wrap error returned from Run()": { + "wrap error returned from Login()": { setupMocks: func(controller *gomock.Controller) { mockCmd = NewMockCmd(controller) @@ -845,3 +846,47 @@ func TestDockerCommand_IsContainerRunning(t *testing.T) { }) } } + +func TestDockerCommand_Exec(t *testing.T) { + tests := map[string]struct { + setupMocks func(controller *gomock.Controller) *MockCmd + + wantErr string + }{ + "return error": { + setupMocks: func(ctrl *gomock.Controller) *MockCmd { + mockCmd := NewMockCmd(ctrl) + mockCmd.EXPECT().RunWithContext(gomock.Any(), "docker", + []string{"exec", "ctr", "sleep", "infinity"}, + gomock.Any(), gomock.Any()).Return(errors.New("some error")) + return mockCmd + }, + wantErr: "some error", + }, + "happy path": { + setupMocks: func(ctrl *gomock.Controller) *MockCmd { + mockCmd := NewMockCmd(ctrl) + mockCmd.EXPECT().RunWithContext(gomock.Any(), "docker", + []string{"exec", "ctr", "sleep", "infinity"}, + gomock.Any(), gomock.Any()).Return(nil) + return mockCmd + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ctrl := gomock.NewController(t) + s := DockerCmdClient{ + runner: tc.setupMocks(ctrl), + } + + err := s.Exec(context.Background(), "ctr", io.Discard, "sleep", "infinity") + if tc.wantErr != "" { + require.EqualError(t, err, tc.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go b/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go index ba3ed2efe23..39da68ee741 100644 --- a/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go +++ b/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go @@ -16,6 +16,7 @@ type Double struct { IsContainerRunningFn func(context.Context, string) (bool, error) RunFn func(context.Context, *dockerengine.RunOptions) error BuildFn func(context.Context, *dockerengine.BuildArguments, io.Writer) error + ExecFn func(context.Context, string, io.Writer, string, ...string) error } // Stop calls the stubbed function. @@ -49,3 +50,11 @@ func (d *Double) Build(ctx context.Context, in *dockerengine.BuildArguments, w i } return d.BuildFn(ctx, in, w) } + +// Exec calls the stubbed function. +func (d *Double) Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error { + if d.ExecFn == nil { + return nil + } + return d.ExecFn(ctx, container, out, cmd, args...) +} diff --git a/internal/pkg/docker/orchestrator/orchestrator.go b/internal/pkg/docker/orchestrator/orchestrator.go index 6bffce4336c..53dd4ff8543 100644 --- a/internal/pkg/docker/orchestrator/orchestrator.go +++ b/internal/pkg/docker/orchestrator/orchestrator.go @@ -10,7 +10,9 @@ import ( "fmt" "io" "maps" + "net" "os" + "strconv" "sync" "sync/atomic" "time" @@ -47,6 +49,7 @@ type DockerEngine interface { IsContainerRunning(context.Context, string) (bool, error) Stop(context.Context, string) error Build(ctx context.Context, args *dockerengine.BuildArguments, w io.Writer) error + Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error } const ( @@ -59,6 +62,10 @@ const ( pauseCtrTag = "latest" ) +const ( + proxyPortStart = uint16(50000) +) + //go:embed Pause-Dockerfile var pauseDockerfile string @@ -114,22 +121,50 @@ func (o *Orchestrator) Start() <-chan error { } // RunTask stops the current running task and starts task. -func (o *Orchestrator) RunTask(task Task) { +func (o *Orchestrator) RunTask(task Task, opts ...RunTaskOption) { + r := &runTaskAction{ + task: task, + } + for _, opt := range opts { + opt(r) + } + // this guarantees the following: - // - if runTaskAction{} is pulled by the Orchestrator, any errors + // - if r is pulled by the Orchestrator, any errors // returned by it are reported by the Orchestrator. // - if Stop() is called _before_ the Orchestrator picks up this // action, then this action is skipped. select { case <-o.stopped: - case o.actions <- &runTaskAction{ - task: task, - }: + case o.actions <- r: } } type runTaskAction struct { task Task + + // optional vars for proxy + hosts []Host + ssmTarget string + network *net.IPNet +} + +// RunTaskOption adds optional data to RunTask. +type RunTaskOption func(*runTaskAction) + +// Host represents a service reachable via the network. +type Host struct { + Name string + Port string +} + +// RunTaskWithProxy returns a RunTaskOption that sets up a proxy connection to hosts. +func RunTaskWithProxy(ssmTarget string, network net.IPNet, hosts ...Host) RunTaskOption { + return func(r *runTaskAction) { + r.ssmTarget = ssmTarget + r.hosts = hosts + r.network = &network + } } func (a *runTaskAction) Do(o *Orchestrator) error { @@ -161,6 +196,12 @@ func (a *runTaskAction) Do(o *Orchestrator) error { if err := o.waitForContainerToStart(ctx, opts.ContainerName); err != nil { return fmt.Errorf("wait for pause container to start: %w", err) } + + if len(a.hosts) > 0 { + if err := o.setupProxyConnections(ctx, opts.ContainerName, a); err != nil { + return fmt.Errorf("setup proxy connections: %w", err) + } + } } else { // ensure no pause container changes curOpts := o.pauseRunOptions(o.curTask) @@ -185,6 +226,113 @@ func (a *runTaskAction) Do(o *Orchestrator) error { return nil } +// setupProxyConnections creates proxy connections to a.hosts in pauseContainer. +// It assumes that pauseContainer is already running. A unique proxy connection +// is created for each host (in parallel) using AWS SSM Port Forwarding through +// a.ssmTarget. Then, each connection is assigned an IP from a.network, +// starting at the bottom of the IP range. Using iptables, TCP packets destined +// for the connection's assigned IP are redirected to the connection. Finally, +// the host's name is mapped to its assigned IP in /etc/hosts. +func (o *Orchestrator) setupProxyConnections(ctx context.Context, pauseContainer string, a *runTaskAction) error { + fmt.Printf("\nSetting up proxy connections...\n") + + ports := make(map[Host]uint16) + port := proxyPortStart + for i := range a.hosts { + ports[a.hosts[i]] = port + port++ + } + + for _, host := range a.hosts { + host := host + portForHost := ports[host] + + o.wg.Add(1) + go func() { + defer o.wg.Done() + + err := o.docker.Exec(context.Background(), pauseContainer, io.Discard, "aws", "ssm", "start-session", + "--target", a.ssmTarget, + "--document-name", "AWS-StartPortForwardingSessionToRemoteHost", + "--parameters", fmt.Sprintf(`{"host":["%s"],"portNumber":["%s"],"localPortNumber":["%d"]}`, host.Name, host.Port, portForHost)) + if err != nil { + // report err as a runtime error from the pause container + if o.curTaskID.Load() != orchestratorStoppedTaskID { + o.runErrs <- fmt.Errorf("proxy to %v:%v: %w", host.Name, host.Port, err) + } + } + }() + } + + ip := a.network.IP + for host, port := range ports { + err := o.docker.Exec(ctx, pauseContainer, io.Discard, "iptables", + "--table", "nat", + "--append", "OUTPUT", + "--destination", ip.String(), + "--protocol", "tcp", + "--match", "tcp", + "--dport", host.Port, + "--jump", "REDIRECT", + "--to-ports", strconv.Itoa(int(port))) + if err != nil { + return fmt.Errorf("modify iptables: %w", err) + } + + err = o.docker.Exec(ctx, pauseContainer, io.Discard, "iptables-save") + if err != nil { + return fmt.Errorf("save iptables: %w", err) + } + + err = o.docker.Exec(ctx, pauseContainer, io.Discard, "/bin/bash", + "-c", fmt.Sprintf(`echo %s %s >> /etc/hosts`, ip.String(), host.Name)) + if err != nil { + return fmt.Errorf("update /etc/hosts: %w", err) + } + + ip, err = ipv4Increment(ip, a.network) + if err != nil { + return fmt.Errorf("increment ip: %w", err) + } + + fmt.Printf("Created connection to %v:%v\n", host.Name, host.Port) + } + + fmt.Printf("Finished setting up proxy connections\n\n") + return nil +} + +// ipv4Increment returns a copy of ip that has been incremented. +func ipv4Increment(ip net.IP, network *net.IPNet) (net.IP, error) { + // make a copy of the previous ip + cpy := make(net.IP, len(ip)) + copy(cpy, ip) + + ipv4 := cpy.To4() + + var inc func(idx int) error + inc = func(idx int) error { + if idx == -1 { + return errors.New("max ipv4 address") + } + + ipv4[idx]++ + if ipv4[idx] == 0 { // overflow occured + return inc(idx - 1) + } + return nil + } + + err := inc(len(ipv4) - 1) + if err != nil { + return nil, err + } + if !network.Contains(ipv4) { + return nil, fmt.Errorf("no more addresses in network") + } + return ipv4, nil +} + func (o *Orchestrator) buildPauseContainer(ctx context.Context) error { return o.docker.Build(ctx, &dockerengine.BuildArguments{ URI: pauseCtrURI, @@ -298,11 +446,12 @@ type ContainerDefinition struct { // among all of the containers in the task. func (o *Orchestrator) pauseRunOptions(t Task) dockerengine.RunOptions { opts := dockerengine.RunOptions{ - ImageURI: fmt.Sprintf("%s:%s", pauseCtrURI, pauseCtrTag), - ContainerName: o.containerID("pause"), - Command: []string{"sleep", "infinity"}, - ContainerPorts: make(map[string]string), - Secrets: t.PauseSecrets, + ImageURI: fmt.Sprintf("%s:%s", pauseCtrURI, pauseCtrTag), + ContainerName: o.containerID("pause"), + Command: []string{"sleep", "infinity"}, + ContainerPorts: make(map[string]string), + Secrets: t.PauseSecrets, + AddLinuxCapabilities: []string{"NET_ADMIN"}, } for _, ctr := range t.Containers { diff --git a/internal/pkg/docker/orchestrator/orchestrator_test.go b/internal/pkg/docker/orchestrator/orchestrator_test.go index 49b56dcdc94..6f3a82f717e 100644 --- a/internal/pkg/docker/orchestrator/orchestrator_test.go +++ b/internal/pkg/docker/orchestrator/orchestrator_test.go @@ -6,7 +6,10 @@ package orchestrator import ( "context" "errors" + "fmt" "io" + "net" + "strconv" "strings" "sync" "testing" @@ -22,53 +25,63 @@ func TestOrchestrator(t *testing.T) { Output: io.Discard, } } + generateHosts := func(n int) []Host { + hosts := make([]Host, n) + for i := 0; i < n; i++ { + hosts[i].Name = strconv.Itoa(i) + hosts[i].Port = strconv.Itoa(i) + } + return hosts + } + + type test func(*testing.T, *Orchestrator) tests := map[string]struct { - dockerEngine func(t *testing.T, sync chan struct{}) DockerEngine - logOptions logOptionsFunc - test func(t *testing.T, o *Orchestrator, sync chan struct{}) + logOptions logOptionsFunc + test func(t *testing.T) (test, DockerEngine) + stopAfterNErrs int errs []string }{ "stop and start": { - dockerEngine: func(t *testing.T, sync chan struct{}) DockerEngine { - return &dockerenginetest.Double{} + test: func(t *testing.T) (test, DockerEngine) { + return func(t *testing.T, o *Orchestrator) {}, &dockerenginetest.Double{} }, - test: func(t *testing.T, o *Orchestrator, sync chan struct{}) {}, }, "error if fail to build pause container": { - dockerEngine: func(t *testing.T, sync chan struct{}) DockerEngine { - return &dockerenginetest.Double{ + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ BuildFn: func(ctx context.Context, ba *dockerengine.BuildArguments, w io.Writer) error { return errors.New("some error") }, } - }, - test: func(t *testing.T, o *Orchestrator, sync chan struct{}) { - o.RunTask(Task{}) + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{}) + }, de }, errs: []string{ `build pause container: some error`, }, }, "error if unable to check if pause container is running": { - dockerEngine: func(t *testing.T, sync chan struct{}) DockerEngine { - return &dockerenginetest.Double{ + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return false, errors.New("some error") }, } - }, - test: func(t *testing.T, o *Orchestrator, sync chan struct{}) { - o.RunTask(Task{}) + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{}) + }, de }, errs: []string{ `wait for pause container to start: check if "prefix-pause" is running: some error`, }, }, "error stopping task": { - dockerEngine: func(t *testing.T, sync chan struct{}) DockerEngine { - return &dockerenginetest.Double{ + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, @@ -79,16 +92,15 @@ func TestOrchestrator(t *testing.T) { return errors.New("some error") }, } - }, - logOptions: noLogs, - test: func(t *testing.T, o *Orchestrator, sync chan struct{}) { - o.RunTask(Task{ - Containers: map[string]ContainerDefinition{ - "foo": {}, - "bar": {}, - "success": {}, - }, - }) + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": {}, + "bar": {}, + "success": {}, + }, + }) + }, de }, errs: []string{ `stop "pause": some error`, @@ -97,41 +109,42 @@ func TestOrchestrator(t *testing.T) { }, }, "error restarting new task due to pause changes": { - dockerEngine: func(t *testing.T, sync chan struct{}) DockerEngine { - return &dockerenginetest.Double{ + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, } - }, - logOptions: noLogs, - test: func(t *testing.T, o *Orchestrator, sync chan struct{}) { - o.RunTask(Task{ - Containers: map[string]ContainerDefinition{ - "foo": { - Ports: map[string]string{ - "8080": "80", + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + Ports: map[string]string{ + "8080": "80", + }, }, }, - }, - }) - o.RunTask(Task{ - Containers: map[string]ContainerDefinition{ - "foo": { - Ports: map[string]string{ - "10000": "80", + }) + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + Ports: map[string]string{ + "10000": "80", + }, }, }, - }, - }) + }) + }, de }, errs: []string{ "new task requires recreating pause container", }, }, "success with a task": { - dockerEngine: func(t *testing.T, sync chan struct{}) DockerEngine { - return &dockerenginetest.Double{ + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, @@ -149,47 +162,229 @@ func TestOrchestrator(t *testing.T) { return nil }, } - }, - logOptions: noLogs, - test: func(t *testing.T, o *Orchestrator, sync chan struct{}) { - o.RunTask(Task{ - PauseSecrets: map[string]string{ - "A_SECRET": "very secret", - }, - Containers: map[string]ContainerDefinition{ - "foo": { - Ports: map[string]string{ - "8080": "80", - }, + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + PauseSecrets: map[string]string{ + "A_SECRET": "very secret", }, - "bar": { - Ports: map[string]string{ - "9000": "90", + Containers: map[string]ContainerDefinition{ + "foo": { + Ports: map[string]string{ + "8080": "80", + }, + }, + "bar": { + Ports: map[string]string{ + "9000": "90", + }, }, }, - }, - }) + }) + }, de }, errs: []string{}, }, + "proxy setup, connection returns error": { + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + ExecFn: func(ctx context.Context, ctr string, w io.Writer, cmd string, args ...string) error { + if cmd == "aws" { + return errors.New("some error") + } + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + _, ipNet, err := net.ParseCIDR("172.20.0.0/16") + require.NoError(t, err) + + o.RunTask(Task{}, RunTaskWithProxy("ecs:cluster_task_ctr", *ipNet, Host{ + Name: "remote-foo", + Port: "80", + })) + }, de + }, + stopAfterNErrs: 1, + errs: []string{`proxy to remote-foo:80: some error`}, + }, + "proxy setup, ip increment error": { + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + ExecFn: func(ctx context.Context, ctr string, w io.Writer, cmd string, args ...string) error { + if cmd == "aws" { + fmt.Fprintf(w, "Port 61972 opened for sessionId mySessionId\n") + } + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + _, ipNet, err := net.ParseCIDR("255.255.255.254/31") // 255.255.255.254 - 255.255.255.255 + require.NoError(t, err) + + o.RunTask(Task{}, RunTaskWithProxy("ecs:cluster_task_ctr", *ipNet, generateHosts(3)...)) + }, de + }, + stopAfterNErrs: 1, + errs: []string{`setup proxy connections: increment ip: max ipv4 address`}, + }, + "proxy setup, ip tables error": { + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + ExecFn: func(ctx context.Context, ctr string, w io.Writer, cmd string, args ...string) error { + if cmd == "aws" { + fmt.Fprintf(w, "Port 61972 opened for sessionId mySessionId\n") + } else if cmd == "iptables" { + return errors.New("some error") + } + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + _, ipNet, err := net.ParseCIDR("172.20.0.0/16") + require.NoError(t, err) + + o.RunTask(Task{}, RunTaskWithProxy("ecs:cluster_task_ctr", *ipNet, Host{ + Name: "remote-foo", + Port: "80", + })) + }, de + }, + stopAfterNErrs: 1, + errs: []string{`setup proxy connections: modify iptables: some error`}, + }, + "proxy setup, ip tables save error": { + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + ExecFn: func(ctx context.Context, ctr string, w io.Writer, cmd string, args ...string) error { + if cmd == "aws" { + fmt.Fprintf(w, "Port 61972 opened for sessionId mySessionId\n") + } else if cmd == "iptables-save" { + return errors.New("some error") + } + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + _, ipNet, err := net.ParseCIDR("172.20.0.0/16") + require.NoError(t, err) + + o.RunTask(Task{}, RunTaskWithProxy("ecs:cluster_task_ctr", *ipNet, Host{ + Name: "remote-foo", + Port: "80", + })) + }, de + }, + stopAfterNErrs: 1, + errs: []string{`setup proxy connections: save iptables: some error`}, + }, + "proxy setup, /etc/hosts error": { + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + ExecFn: func(ctx context.Context, ctr string, w io.Writer, cmd string, args ...string) error { + if cmd == "aws" { + fmt.Fprintf(w, "Port 61972 opened for sessionId mySessionId\n") + } else if cmd == "/bin/bash" { + return errors.New("some error") + } + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + _, ipNet, err := net.ParseCIDR("172.20.0.0/16") + require.NoError(t, err) + + o.RunTask(Task{}, RunTaskWithProxy("ecs:cluster_task_ctr", *ipNet, Host{ + Name: "remote-foo", + Port: "80", + })) + }, de + }, + stopAfterNErrs: 1, + errs: []string{`setup proxy connections: update /etc/hosts: some error`}, + }, + "proxy success": { + logOptions: noLogs, + test: func(t *testing.T) (test, DockerEngine) { + waitUntilRun := make(chan struct{}) + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + ExecFn: func(ctx context.Context, ctr string, w io.Writer, cmd string, args ...string) error { + if cmd == "aws" { + fmt.Fprintf(w, "Port 61972 opened for sessionId mySessionId\n") + } + return nil + }, + RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { + if opts.ContainerName == "prefix-foo" { + close(waitUntilRun) + } + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + _, ipNet, err := net.ParseCIDR("172.20.0.0/16") + require.NoError(t, err) + + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": {}, + }, + }, RunTaskWithProxy("ecs:cluster_task_ctr", *ipNet, Host{ + Name: "remote-foo", + Port: "80", + })) + + <-waitUntilRun + }, de + }, + }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { - syncCh := make(chan struct{}) - o := New(tc.dockerEngine(t, syncCh), "prefix-", tc.logOptions) + test, dockerEngine := tc.test(t) + o := New(dockerEngine, "prefix-", tc.logOptions) wg := &sync.WaitGroup{} wg.Add(2) + stopCh := make(chan struct{}) errs := o.Start() go func() { defer wg.Done() + if tc.stopAfterNErrs == 0 { + close(stopCh) + } var actualErrs []string for err := range errs { actualErrs = append(actualErrs, strings.Split(err.Error(), "\n")...) + if len(actualErrs) == tc.stopAfterNErrs { + close(stopCh) + } } require.ElementsMatch(t, tc.errs, actualErrs) @@ -198,7 +393,8 @@ func TestOrchestrator(t *testing.T) { go func() { defer wg.Done() - tc.test(t, o, syncCh) + test(t, o) + <-stopCh o.Stop() }() @@ -206,3 +402,44 @@ func TestOrchestrator(t *testing.T) { }) } } + +func TestIPv4Increment(t *testing.T) { + tests := map[string]struct { + cidrIP string + + wantErr string + wantIP string + }{ + "increment": { + cidrIP: "10.0.0.15/24", + wantIP: "10.0.0.16", + }, + "overflows to next octet": { + cidrIP: "10.0.0.255/16", + wantIP: "10.0.1.0", + }, + "error if no more ipv4 addresses": { + cidrIP: "255.255.255.255/16", + wantErr: "max ipv4 address", + }, + "error if out of network": { + cidrIP: "10.0.0.255/24", + wantErr: "no more addresses in network", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ip, network, err := net.ParseCIDR(tc.cidrIP) + require.NoError(t, err) + + gotIP, gotErr := ipv4Increment(ip, network) + if tc.wantErr != "" { + require.EqualError(t, gotErr, tc.wantErr) + } else { + require.NoError(t, gotErr) + require.Equal(t, tc.wantIP, gotIP.String()) + } + }) + } +} diff --git a/internal/pkg/docker/orchestrator/orchestratortest/orchestratortest.go b/internal/pkg/docker/orchestrator/orchestratortest/orchestratortest.go index 59d4dc40779..8c08a0b167d 100644 --- a/internal/pkg/docker/orchestrator/orchestratortest/orchestratortest.go +++ b/internal/pkg/docker/orchestrator/orchestratortest/orchestratortest.go @@ -8,7 +8,7 @@ import "github.com/aws/copilot-cli/internal/pkg/docker/orchestrator" // Double is a test double for orchestrator.Orchestrator type Double struct { StartFn func() <-chan error - RunTaskFn func(task orchestrator.Task) + RunTaskFn func(orchestrator.Task, ...orchestrator.RunTaskOption) StopFn func() } @@ -21,11 +21,11 @@ func (d *Double) Start() <-chan error { } // RunTask calls the stubbed function. -func (d *Double) RunTask(task orchestrator.Task) { +func (d *Double) RunTask(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { if d.RunTaskFn == nil { return } - d.RunTaskFn(task) + d.RunTaskFn(task, opts...) } // Stop calls the stubbed function.