diff --git a/agent/api/container.go b/agent/api/container.go index 433bc436a42..2a3901ff135 100644 --- a/agent/api/container.go +++ b/agent/api/container.go @@ -493,10 +493,10 @@ func (c *Container) GetLabels() map[string]string { return c.labels } -// SetKnownPortBindings gets the ports for a container +// SetKnownPortBindings sets the ports for a container func (c *Container) SetKnownPortBindings(ports []PortBinding) { - c.lock.RLock() - defer c.lock.RUnlock() + c.lock.Lock() + defer c.lock.Unlock() c.KnownPortBindingsUnsafe = ports } diff --git a/agent/handlers/v1_handlers.go b/agent/handlers/v1_handlers.go index 7d627e586d7..7450396e895 100644 --- a/agent/handlers/v1_handlers.go +++ b/agent/handlers/v1_handlers.go @@ -104,7 +104,33 @@ func newContainerResponse(dockerContainer *api.DockerContainer, eni *api.ENI) v1 DockerName: dockerContainer.DockerName, } - for _, binding := range container.GetKnownPortBindings() { + resp.Ports = newPortBindingsResponse(dockerContainer, eni) + + if eni != nil { + resp.Networks = []containermetadata.Network{ + { + NetworkMode: networkModeAwsvpc, + IPv4Addresses: eni.GetIPV4Addresses(), + IPv6Addresses: eni.GetIPV6Addresses(), + }, + } + } + return resp +} + +func newPortBindingsResponse(dockerContainer *api.DockerContainer, eni *api.ENI) []v2.PortResponse { + container := dockerContainer.Container + resp := []v2.PortResponse{} + + bindings := container.GetKnownPortBindings() + + // if KnownPortBindings list is empty, then we use the port mapping + // information that was passed down from ACS. + if len(bindings) == 0 { + bindings = container.Ports + } + + for _, binding := range bindings { port := v2.PortResponse{ ContainerPort: binding.ContainerPort, Protocol: binding.Protocol.String(), @@ -116,17 +142,7 @@ func newContainerResponse(dockerContainer *api.DockerContainer, eni *api.ENI) v1 port.HostPort = port.ContainerPort } - resp.Ports = append(resp.Ports, port) - } - - if eni != nil { - resp.Networks = []containermetadata.Network{ - { - NetworkMode: networkModeAwsvpc, - IPv4Addresses: eni.GetIPV4Addresses(), - IPv6Addresses: eni.GetIPV6Addresses(), - }, - } + resp = append(resp, port) } return resp } diff --git a/agent/handlers/v1_handlers_test.go b/agent/handlers/v1_handlers_test.go index 48fb4c9aaef..11b83709ba8 100644 --- a/agent/handlers/v1_handlers_test.go +++ b/agent/handlers/v1_handlers_test.go @@ -138,12 +138,45 @@ func TestGetAWSVPCTaskByTaskArn(t *testing.T) { resp := v1.TasksResponse{Tasks: []*v1.TaskResponse{&taskResponse}} assert.Equal(t, eniIPV4Address, resp.Tasks[0].Containers[0].Networks[0].IPv4Addresses[0]) - assert.Equal(t, uint16(80), resp.Tasks[0].Containers[0].Ports[0].ContainerPort) assert.Equal(t, "tcp", resp.Tasks[0].Containers[0].Ports[0].Protocol) taskDiffHelper(t, []*api.Task{testTasks[3]}, resp) } +func TestGetHostNeworkingTaskByTaskArn(t *testing.T) { + recorder := performMockRequest(t, "/v1/tasks?taskarn=hostModeNetworkingTask") + + var taskResponse v1.TaskResponse + err := json.Unmarshal(recorder.Body.Bytes(), &taskResponse) + if err != nil { + t.Fatal(err) + } + + resp := v1.TasksResponse{Tasks: []*v1.TaskResponse{&taskResponse}} + + assert.Equal(t, uint16(80), resp.Tasks[0].Containers[0].Ports[0].ContainerPort) + assert.Equal(t, "tcp", resp.Tasks[0].Containers[0].Ports[0].Protocol) + + taskDiffHelper(t, []*api.Task{testTasks[4]}, resp) +} + +func TestGetBridgeNeworkingTaskByTaskArn(t *testing.T) { + recorder := performMockRequest(t, "/v1/tasks?taskarn=bridgeModeNetworkingTask") + + var taskResponse v1.TaskResponse + err := json.Unmarshal(recorder.Body.Bytes(), &taskResponse) + if err != nil { + t.Fatal(err) + } + + resp := v1.TasksResponse{Tasks: []*v1.TaskResponse{&taskResponse}} + + assert.Equal(t, uint16(80), resp.Tasks[0].Containers[0].Ports[0].ContainerPort) + assert.Equal(t, "tcp", resp.Tasks[0].Containers[0].Ports[0].Protocol) + + taskDiffHelper(t, []*api.Task{testTasks[5]}, resp) +} + func TestGetTaskByTaskArnNotFound(t *testing.T) { recorder := performMockRequest(t, "/v1/tasks?taskarn=doesnotexist") @@ -331,18 +364,50 @@ var testTasks = []*api.Task{ Containers: []*api.Container{ { Name: "awsvpc", - KnownPortBindingsUnsafe: []api.PortBinding{ + }, + }, + ENI: &api.ENI{ + IPV4Addresses: []*api.ENIIPV4Address{ + { + Address: eniIPV4Address, + }, + }, + }, + }, + { + Arn: "hostModeNetworkingTask", + DesiredStatusUnsafe: api.TaskRunning, + KnownStatusUnsafe: api.TaskRunning, + Family: "test", + Version: "1", + Containers: []*api.Container{ + { + Name: "awsvpc", + Ports: []api.PortBinding{ { ContainerPort: 80, + HostPort: 80, Protocol: api.TransportProtocolTCP, }, }, }, }, - ENI: &api.ENI{ - IPV4Addresses: []*api.ENIIPV4Address{ - { - Address: eniIPV4Address, + }, + { + Arn: "bridgeModeNetworkingTask", + DesiredStatusUnsafe: api.TaskRunning, + KnownStatusUnsafe: api.TaskRunning, + Family: "test", + Version: "1", + Containers: []*api.Container{ + { + Name: "awsvpc", + KnownPortBindingsUnsafe: []api.PortBinding{ + { + ContainerPort: 80, + HostPort: 80, + Protocol: api.TransportProtocolTCP, + }, }, }, },