diff --git a/agent/dockerclient/dockerapi/docker_client.go b/agent/dockerclient/dockerapi/docker_client.go index aa0ed331c66..306cee28e59 100644 --- a/agent/dockerclient/dockerapi/docker_client.go +++ b/agent/dockerclient/dockerapi/docker_client.go @@ -1441,7 +1441,7 @@ func (dg *dockerGoClient) APIVersion() (dockerclient.DockerVersion, error) { func (dg *dockerGoClient) Stats(ctx context.Context, id string, inactivityTimeout time.Duration) (<-chan *types.StatsJSON, <-chan error) { subCtx, cancelRequest := context.WithCancel(ctx) - errC := make(chan error) + errC := make(chan error, 1) statsC := make(chan *types.StatsJSON) client, err := dg.sdkDockerClient() if err != nil { @@ -1487,7 +1487,12 @@ func (dg *dockerGoClient) Stats(ctx context.Context, id string, inactivityTimeou return } - statsC <- data + select { + case <-ctx.Done(): + return + case statsC <- data: + } + data = new(types.StatsJSON) } }() @@ -1504,8 +1509,11 @@ func (dg *dockerGoClient) Stats(ctx context.Context, id string, inactivityTimeou errC <- err return } - statsC <- stats - + select { + case <-ctx.Done(): + return + case statsC <- stats: + } // sleeping here jitters the time at which the ticker is created, so that // containers do not synchronize on calling the docker stats api. // the max sleep is 80% of the polling interval so that we have a chance to @@ -1519,7 +1527,11 @@ func (dg *dockerGoClient) Stats(ctx context.Context, id string, inactivityTimeou errC <- err return } - statsC <- stats + select { + case <-ctx.Done(): + return + case statsC <- stats: + } } }() } diff --git a/agent/dockerclient/dockerapi/docker_client_test.go b/agent/dockerclient/dockerapi/docker_client_test.go index 7a65cd2c9ab..9bfae2af7b0 100644 --- a/agent/dockerclient/dockerapi/docker_client_test.go +++ b/agent/dockerclient/dockerapi/docker_client_test.go @@ -1336,6 +1336,12 @@ func TestStatsNormalExit(t *testing.T) { assert.Equal(t, uint64(50), newStat.MemoryStats.Usage) assert.Equal(t, uint64(100), newStat.CPUStats.SystemUsage) + + // stop container stats + cancel() + // verify stats chan was closed to avoid goroutine leaks + _, ok := <-stats + assert.False(t, ok, "stats channel was not properly closed") } func TestStatsErrorReading(t *testing.T) { @@ -1349,9 +1355,12 @@ func TestStatsErrorReading(t *testing.T) { }, errors.New("test error")) ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - _, errC := client.Stats(ctx, "foo", dockerclient.StatsInactivityTimeout) + statsC, errC := client.Stats(ctx, "foo", dockerclient.StatsInactivityTimeout) assert.Error(t, <-errC) + // verify stats chan was closed to avoid goroutine leaks + _, ok := <-statsC + assert.False(t, ok, "stats channel was not properly closed") } func TestStatsErrorDecoding(t *testing.T) { @@ -1365,8 +1374,11 @@ func TestStatsErrorDecoding(t *testing.T) { }, nil) ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - _, errC := client.Stats(ctx, "foo", dockerclient.StatsInactivityTimeout) + statsC, errC := client.Stats(ctx, "foo", dockerclient.StatsInactivityTimeout) assert.Error(t, <-errC) + // verify stats chan was closed to avoid goroutine leaks + _, ok := <-statsC + assert.False(t, ok, "stats channel was not properly closed") } func TestStatsClientError(t *testing.T) {