From 6cc9959559a1aba56983503a40eca5c53dcb3e5c Mon Sep 17 00:00:00 2001 From: Richa Gangwar Date: Thu, 18 May 2023 12:02:25 -0700 Subject: [PATCH] wsclient: Add new test for context cancellation in client. --- ecs-agent/wsclient/client_test.go | 51 +++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/ecs-agent/wsclient/client_test.go b/ecs-agent/wsclient/client_test.go index ca1c40de242..ce2a313d232 100644 --- a/ecs-agent/wsclient/client_test.go +++ b/ecs-agent/wsclient/client_test.go @@ -57,7 +57,7 @@ func TestClientProxy(t *testing.T) { defer os.Unsetenv("HTTP_PROXY") types := []interface{}{ecsacs.AckRequest{}} - cs := getTestClientServer("http://www.amazon.com", types) + cs := getTestClientServer("http://www.amazon.com", types, 1) err := cs.Connect() assert.Error(t, err) assert.True(t, strings.Contains(err.Error(), proxy_url), "proxy not found: %s", err.Error()) @@ -86,7 +86,7 @@ func TestConcurrentWritesDontPanic(t *testing.T) { req := ecsacs.AckRequest{Cluster: aws.String("test"), ContainerInstance: aws.String("test"), MessageId: aws.String("test")} types := []interface{}{ecsacs.AckRequest{}} - cs := getTestClientServer(mockServer.URL, types) + cs := getTestClientServer(mockServer.URL, types, 1) require.NoError(t, cs.Connect()) executeTenRequests := func() { @@ -104,7 +104,7 @@ func TestConcurrentWritesDontPanic(t *testing.T) { waitForRequests.Wait() } -func getTestClientServer(url string, msgType []interface{}) *ClientServerImpl { +func getTestClientServer(url string, msgType []interface{}, rwTimeout time.Duration) *ClientServerImpl { testCreds := credentials.NewStaticCredentials("test-id", "test-secret", "test-token") return &ClientServerImpl{ @@ -117,7 +117,7 @@ func getTestClientServer(url string, msgType []interface{}) *ClientServerImpl { }, CredentialProvider: testCreds, TypeDecoder: BuildTypeDecoder(msgType), - RWTimeout: time.Second, + RWTimeout: rwTimeout * time.Second, RequestHandlers: make(map[string]RequestHandler), } } @@ -135,7 +135,7 @@ func TestProxyVariableCustomValue(t *testing.T) { testString := "Custom no proxy string" os.Setenv("NO_PROXY", testString) types := []interface{}{ecsacs.AckRequest{}} - require.NoError(t, getTestClientServer(mockServer.URL, types).Connect()) + require.NoError(t, getTestClientServer(mockServer.URL, types, 1).Connect()) assert.Equal(t, os.Getenv("NO_PROXY"), testString, "NO_PROXY should match user-supplied variable") } @@ -152,7 +152,7 @@ func TestProxyVariableDefaultValue(t *testing.T) { os.Unsetenv("NO_PROXY") types := []interface{}{ecsacs.AckRequest{}} - getTestClientServer(mockServer.URL, types).Connect() + getTestClientServer(mockServer.URL, types, 1).Connect() expectedEnvVar := "169.254.169.254,169.254.170.2," + dockerEndpoint @@ -171,7 +171,7 @@ func TestHandleMessagePermissibleCloseCode(t *testing.T) { mockServer.StartTLS() types := []interface{}{ecsacs.AckRequest{}} - cs := getTestClientServer(mockServer.URL, types) + cs := getTestClientServer(mockServer.URL, types, 1) require.NoError(t, cs.Connect()) assert.True(t, cs.IsReady(), "expected websocket connection to be ready") @@ -194,7 +194,7 @@ func TestHandleMessageUnexpectedCloseCode(t *testing.T) { mockServer.StartTLS() types := []interface{}{ecsacs.AckRequest{}} - cs := getTestClientServer(mockServer.URL, types) + cs := getTestClientServer(mockServer.URL, types, 1) require.NoError(t, cs.Connect()) assert.True(t, cs.IsReady(), "expected websocket connection to be ready") @@ -218,7 +218,7 @@ func TestHandleNonHTTPSEndpoint(t *testing.T) { defer mockServer.Close() types := []interface{}{ecsacs.AckRequest{}} - cs := getTestClientServer(mockServer.URL, types) + cs := getTestClientServer(mockServer.URL, types, 1) require.NoError(t, cs.Connect()) assert.True(t, cs.IsReady(), "expected websocket connection to be ready") @@ -243,7 +243,7 @@ func TestHandleIncorrectURLScheme(t *testing.T) { mockServerURL.Scheme = "notaparticularlyrealscheme" types := []interface{}{ecsacs.AckRequest{}} - cs := getTestClientServer(mockServerURL.String(), types) + cs := getTestClientServer(mockServerURL.String(), types, 1) err := cs.Connect() assert.Error(t, err, "Expected error for incorrect URL scheme") @@ -321,7 +321,7 @@ func TestAddRequestPayloadHandler(t *testing.T) { types := []interface{}{ecsacs.PayloadMessage{}} messageError := make(chan error) - cs := getTestClientServer(mockServer.URL, types) + cs := getTestClientServer(mockServer.URL, types, 1) cs.conn = conn defer cs.Close() @@ -364,7 +364,7 @@ func TestMakeUnrecognizedRequest(t *testing.T) { mockServer.StartTLS() types := []interface{}{ecsacs.PayloadMessage{}} - cs := getTestClientServer(mockServer.URL, types) + cs := getTestClientServer(mockServer.URL, types, 1) cs.conn = conn defer cs.Close() @@ -388,7 +388,7 @@ func TestWriteCloseMessage(t *testing.T) { mockServer.StartTLS() types := []interface{}{ecsacs.PayloadMessage{}} - cs := getTestClientServer(mockServer.URL, types) + cs := getTestClientServer(mockServer.URL, types, 1) cs.Connect() defer cs.Close() @@ -397,3 +397,28 @@ func TestWriteCloseMessage(t *testing.T) { assert.NoError(t, err) assert.Error(t, <-errChan) } + +// TestCtxCancel tests if the passed context, on receiving the cancel +// on the created ctx.Done channel, performs the expected behavior of +// closing the connection and returns the ctx error. +func TestCtxCancel(t *testing.T) { + closeWS := make(chan []byte) + defer close(closeWS) + + ctx, cancel := context.WithCancel(context.Background()) + messageError := make(chan error) + mockServer, _, _, _, _ := utils.GetMockServer(closeWS) + mockServer.StartTLS() + + types := []interface{}{ecsacs.AckRequest{}} + cs := getTestClientServer(mockServer.URL, types, 2) + require.NoError(t, cs.Connect()) + assert.True(t, cs.IsReady(), "expected websocket connection to be ready") + + go func() { + messageError <- cs.ConsumeMessages(ctx) + }() + // Cancel the context. + cancel() + assert.EqualError(t, <-messageError, "context canceled") +}