diff --git a/sdk/messaging/azservicebus/admin/admin_client.go b/sdk/messaging/azservicebus/admin/admin_client.go index 38f1c51d3a7b..0cfe8ddbc270 100644 --- a/sdk/messaging/azservicebus/admin/admin_client.go +++ b/sdk/messaging/azservicebus/admin/admin_client.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/atom" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" ) // Client allows you to administer resources in a Service Bus Namespace. @@ -21,8 +22,13 @@ type Client struct { em atom.EntityManager } +// RetryOptions represent the options for retries. +type RetryOptions = utils.RetryOptions + +// ClientOptions allows you to set optional configuration for `Client`. type ClientOptions struct { - // for future expansion + // RetryOptions controls how often operations are retried from this client. + RetryOptions *RetryOptions } // NewClientFromConnectionString creates a Client authenticating using a connection string. @@ -38,7 +44,13 @@ func NewClientFromConnectionString(connectionString string, options *ClientOptio // NewClient creates a Client authenticating using a TokenCredential. func NewClient(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential, options *ClientOptions) (*Client, error) { - em, err := atom.NewEntityManager(fullyQualifiedNamespace, tokenCredential, internal.Version) + var retryOptions utils.RetryOptions + + if options != nil && options.RetryOptions != nil { + retryOptions = *options.RetryOptions + } + + em, err := atom.NewEntityManager(fullyQualifiedNamespace, tokenCredential, internal.Version, retryOptions) if err != nil { return nil, err diff --git a/sdk/messaging/azservicebus/admin/admin_client_queue.go b/sdk/messaging/azservicebus/admin/admin_client_queue.go index 72f48b13d4a6..c79d4ef49c86 100644 --- a/sdk/messaging/azservicebus/admin/admin_client_queue.go +++ b/sdk/messaging/azservicebus/admin/admin_client_queue.go @@ -196,7 +196,7 @@ func (ac *Client) GetQueue(ctx context.Context, queueName string, options *GetQu props, err := newQueueProperties(&atomResp.Content.QueueDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } return &GetQueueResponse{ @@ -234,7 +234,7 @@ func (ac *Client) GetQueueRuntimeProperties(ctx context.Context, queueName strin props, err := newQueueRuntimeProperties(&atomResp.Content.QueueDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } return &GetQueueRuntimePropertiesResponse{ @@ -349,7 +349,7 @@ func (ac *Client) createOrUpdateQueueImpl(ctx context.Context, queueName string, newProps, err := newQueueProperties(&atomResp.Content.QueueDescription) if err != nil { - return nil, nil, atom.NewResponseError(err, resp) + return nil, nil, err } return newProps, resp, nil @@ -395,7 +395,7 @@ func (p *QueuePager) getNextPage(ctx context.Context) (*ListQueuesResponse, erro props, err := newQueueProperties(&env.Content.QueueDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } all = append(all, &QueueItem{ diff --git a/sdk/messaging/azservicebus/admin/admin_client_subscription.go b/sdk/messaging/azservicebus/admin/admin_client_subscription.go index 6cdd394eb747..d9fc400ec898 100644 --- a/sdk/messaging/azservicebus/admin/admin_client_subscription.go +++ b/sdk/messaging/azservicebus/admin/admin_client_subscription.go @@ -149,7 +149,7 @@ func (ac *Client) GetSubscription(ctx context.Context, topicName string, subscri props, err := newSubscriptionProperties(&atomResp.Content.SubscriptionDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } return &GetSubscriptionResponse{ @@ -186,7 +186,7 @@ func (ac *Client) GetSubscriptionRuntimeProperties(ctx context.Context, topicNam props, err := newSubscriptionRuntimeProperties(&atomResp.Content.SubscriptionDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } return &GetSubscriptionRuntimePropertiesResponse{ @@ -257,7 +257,7 @@ func (p *SubscriptionPager) getNext(ctx context.Context) (*ListSubscriptionsResp props, err := newSubscriptionProperties(&env.Content.SubscriptionDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } all = append(all, &SubscriptionPropertiesItem{ @@ -346,7 +346,7 @@ func (p *SubscriptionRuntimePropertiesPager) getNextPage(ctx context.Context) (* props, err := newSubscriptionRuntimeProperties(&entry.Content.SubscriptionDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } all = append(all, &SubscriptionRuntimePropertiesItem{ @@ -452,7 +452,7 @@ func (ac *Client) createOrUpdateSubscriptionImpl(ctx context.Context, topicName newProps, err := newSubscriptionProperties(&atomResp.Content.SubscriptionDescription) if err != nil { - return nil, nil, atom.NewResponseError(err, resp) + return nil, nil, err } return newProps, resp, nil diff --git a/sdk/messaging/azservicebus/admin/admin_client_test.go b/sdk/messaging/azservicebus/admin/admin_client_test.go index 8ff788c909fe..ab5c74954d33 100644 --- a/sdk/messaging/azservicebus/admin/admin_client_test.go +++ b/sdk/messaging/azservicebus/admin/admin_client_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/atom" @@ -204,7 +205,12 @@ func TestAdminClient_UpdateQueue(t *testing.T) { updatedProps, err = adminClient.UpdateQueue(context.Background(), "non-existent-queue", createdProps.QueueProperties, nil) // a little awkward, we'll make these programatically inspectable as we add in better error handling. - require.Contains(t, err.Error(), "error code: 404") + require.Contains(t, err.Error(), "404 Not Found") + + var asResponseErr *azcore.ResponseError + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 404, asResponseErr.StatusCode) + require.Nil(t, updatedProps) } @@ -475,7 +481,12 @@ func TestAdminClient_UpdateTopic(t *testing.T) { updateResp, err = adminClient.UpdateTopic(context.Background(), "non-existent-topic", addResp.TopicProperties, nil) // a little awkward, we'll make these programatically inspectable as we add in better error handling. - require.Contains(t, err.Error(), "error code: 404") + require.Contains(t, err.Error(), "404 Not Found") + + var asResponseErr *azcore.ResponseError + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 404, asResponseErr.StatusCode) + require.Nil(t, updateResp) } @@ -738,8 +749,12 @@ func TestAdminClient_UpdateSubscription(t *testing.T) { require.Nil(t, updateResp) updateResp, err = adminClient.UpdateSubscription(context.Background(), topicName, "non-existent-subscription", addResp.CreateSubscriptionResult.SubscriptionProperties, nil) - // a little awkward, we'll make these programatically inspectable as we add in better error handling. - require.Contains(t, err.Error(), "error code: 404") + require.Contains(t, err.Error(), "404 Not Found") + + var asResponseErr *azcore.ResponseError + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 404, asResponseErr.StatusCode) + require.Nil(t, updateResp) } @@ -754,21 +769,33 @@ func TestAdminClient_LackPermissions_Queue(t *testing.T) { require.True(t, notFound) require.NotNil(t, resp) + var re *azcore.ResponseError + _, err = testData.Client.GetQueue(ctx, testData.QueueName, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Manage,EntityRead claims") + require.Contains(t, err.Error(), "Manage,EntityRead claims required for this operation") + require.ErrorAs(t, err, &re) + require.EqualValues(t, 401, re.StatusCode) pager := testData.Client.ListQueues(nil) require.False(t, pager.NextPage(context.Background())) - require.Contains(t, pager.Err().Error(), "error code: 401, Details: Manage,EntityRead claims required for this operation") + require.Contains(t, pager.Err().Error(), "Manage,EntityRead claims required for this operation") + require.ErrorAs(t, err, &re) + require.EqualValues(t, 401, re.StatusCode) _, err = testData.Client.CreateQueue(ctx, "canneverbecreated", nil, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityWrite") + require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityWrite") + require.ErrorAs(t, err, &re) + require.EqualValues(t, 401, re.StatusCode) _, err = testData.Client.UpdateQueue(ctx, "canneverbecreated", QueueProperties{}, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityWrite") + require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityWrite") + require.ErrorAs(t, err, &re) + require.EqualValues(t, 401, re.StatusCode) _, err = testData.Client.DeleteQueue(ctx, testData.QueueName, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityDelete.") + require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityDelete.") + require.ErrorAs(t, err, &re) + require.EqualValues(t, 401, re.StatusCode) } func TestAdminClient_LackPermissions_Topic(t *testing.T) { @@ -782,21 +809,33 @@ func TestAdminClient_LackPermissions_Topic(t *testing.T) { require.True(t, notFound) require.NotNil(t, resp) + var asResponseErr *azcore.ResponseError + _, err = testData.Client.GetTopic(ctx, testData.TopicName, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Manage,EntityRead claims") + require.Contains(t, err.Error(), ">Manage,EntityRead claims required for this operation") + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 401, asResponseErr.StatusCode) pager := testData.Client.ListTopics(nil) require.False(t, pager.NextPage(context.Background())) - require.Contains(t, pager.Err().Error(), "error code: 401, Details: Manage,EntityRead claims required for this operation") + require.Contains(t, pager.Err().Error(), ">Manage,EntityRead claims required for this operation") + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 401, asResponseErr.StatusCode) _, err = testData.Client.CreateTopic(ctx, "canneverbecreated", nil, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action") + require.Contains(t, err.Error(), "Authorization failed for specified action") + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 401, asResponseErr.StatusCode) _, err = testData.Client.UpdateTopic(ctx, "canneverbecreated", TopicProperties{}, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action") + require.Contains(t, err.Error(), "Authorization failed for specified action") + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 401, asResponseErr.StatusCode) _, err = testData.Client.DeleteTopic(ctx, testData.TopicName, nil) - require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityDelete.") + require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityDelete.") + require.ErrorAs(t, err, &asResponseErr) + require.EqualValues(t, 401, asResponseErr.StatusCode) } func TestAdminClient_LackPermissions_Subscription(t *testing.T) { diff --git a/sdk/messaging/azservicebus/admin/admin_client_topic.go b/sdk/messaging/azservicebus/admin/admin_client_topic.go index ae11bb9358be..12cca7a25f19 100644 --- a/sdk/messaging/azservicebus/admin/admin_client_topic.go +++ b/sdk/messaging/azservicebus/admin/admin_client_topic.go @@ -170,7 +170,7 @@ func (ac *Client) GetTopicRuntimeProperties(ctx context.Context, topicName strin props, err := newTopicRuntimeProperties(&atomResp.Content.TopicDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } return &GetTopicRuntimePropertiesResponse{ @@ -239,7 +239,7 @@ func (p *TopicsPager) getNextPage(ctx context.Context) (*ListTopicsResponse, err props, err := newTopicProperties(&env.Content.TopicDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } all = append(all, &TopicItem{ @@ -326,7 +326,7 @@ func (p *TopicRuntimePropertiesPager) getNextPage(ctx context.Context) (*ListTop props, err := newTopicRuntimeProperties(&entry.Content.TopicDescription) if err != nil { - return nil, atom.NewResponseError(err, resp) + return nil, err } all = append(all, &TopicRuntimePropertiesItem{ @@ -436,7 +436,7 @@ func (ac *Client) createOrUpdateTopicImpl(ctx context.Context, topicName string, topicProps, err := newTopicProperties(&atomResp.Content.TopicDescription) if err != nil { - return nil, nil, atom.NewResponseError(err, resp) + return nil, nil, err } return topicProps, resp, nil diff --git a/sdk/messaging/azservicebus/client.go b/sdk/messaging/azservicebus/client.go index 5451c95688ec..f7f6a667c027 100644 --- a/sdk/messaging/azservicebus/client.go +++ b/sdk/messaging/azservicebus/client.go @@ -14,23 +14,30 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/devigned/tab" ) // Client provides methods to create Sender and Receiver // instances to send and receive messages from Service Bus. type Client struct { + // NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens + // is just to make it the first value in the struct + // See: + // Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG + // PR: https://github.com/Azure/azure-sdk-for-go/pull/16847 linkCounter uint64 - links map[uint64]internal.Closeable - creds clientCreds - namespace interface { + + linksMu *sync.Mutex + links map[uint64]internal.Closeable + creds clientCreds + namespace interface { // used internally by `Client` internal.NamespaceWithNewAMQPLinks // for child clients internal.NamespaceForAMQPLinks - internal.NamespaceForMgmtClient } - linksMu *sync.Mutex + retryOptions RetryOptions } // ClientOptions contains options for the `NewClient` and `NewClientFromConnectionString` @@ -45,8 +52,16 @@ type ClientOptions struct { // NewWebSocketConn is a function that can create a net.Conn for use with websockets. // For an example, see ExampleNewClient_usingWebsockets() function in example_client_test.go. NewWebSocketConn func(ctx context.Context, args NewWebSocketConnArgs) (net.Conn, error) + + // RetryOptions controls how often operations are retried from this client and any + // Receivers and Senders created from this client. + RetryOptions RetryOptions } +// RetryOptions controls how often operations are retried from this client and any +// Receivers and Senders created from this client. +type RetryOptions = utils.RetryOptions + // NewWebSocketConnArgs are passed to your web socket creation function (ClientOptions.NewWebSocketConn) type NewWebSocketConnArgs = internal.NewWebSocketConnArgs @@ -107,7 +122,7 @@ func newClientImpl(creds clientCreds, options *ClientOptions) (*Client, error) { if client.creds.connectionString != "" { nsOptions = append(nsOptions, internal.NamespaceWithConnectionString(client.creds.connectionString)) } else if client.creds.credential != nil { - option := internal.NamespacesWithTokenCredential( + option := internal.NamespaceWithTokenCredential( client.creds.fullyQualifiedNamespace, client.creds.credential) @@ -126,6 +141,8 @@ func newClientImpl(creds clientCreds, options *ClientOptions) (*Client, error) { if options.ApplicationID != "" { nsOptions = append(nsOptions, internal.NamespaceWithUserAgent(options.ApplicationID)) } + + nsOptions = append(nsOptions, internal.NamespaceWithRetryOptions((utils.RetryOptions)(options.RetryOptions))) } client.namespace, err = internal.NewNamespace(nsOptions...) @@ -135,7 +152,11 @@ func newClientImpl(creds clientCreds, options *ClientOptions) (*Client, error) { // NewReceiver creates a Receiver for a queue. A receiver allows you to receive messages. func (client *Client) NewReceiverForQueue(queueName string, options *ReceiverOptions) (*Receiver, error) { id, cleanupOnClose := client.getCleanupForCloseable() - receiver, err := newReceiver(client.namespace, &entity{Queue: queueName}, cleanupOnClose, options, nil) + receiver, err := newReceiver(newReceiverArgs{ + cleanupOnClose: cleanupOnClose, + ns: client.namespace, + entity: entity{Queue: queueName}, + }, options) if err != nil { return nil, err @@ -148,7 +169,11 @@ func (client *Client) NewReceiverForQueue(queueName string, options *ReceiverOpt // NewReceiver creates a Receiver for a subscription. A receiver allows you to receive messages. func (client *Client) NewReceiverForSubscription(topicName string, subscriptionName string, options *ReceiverOptions) (*Receiver, error) { id, cleanupOnClose := client.getCleanupForCloseable() - receiver, err := newReceiver(client.namespace, &entity{Topic: topicName, Subscription: subscriptionName}, cleanupOnClose, options, nil) + receiver, err := newReceiver(newReceiverArgs{ + cleanupOnClose: cleanupOnClose, + ns: client.namespace, + entity: entity{Topic: topicName, Subscription: subscriptionName}, + }, options) if err != nil { return nil, err @@ -165,7 +190,11 @@ type NewSenderOptions struct { // NewSender creates a Sender, which allows you to send messages or schedule messages. func (client *Client) NewSender(queueOrTopic string, options *NewSenderOptions) (*Sender, error) { id, cleanupOnClose := client.getCleanupForCloseable() - sender, err := newSender(client.namespace, queueOrTopic, cleanupOnClose) + sender, err := newSender(newSenderArgs{ + ns: client.namespace, + queueOrTopic: queueOrTopic, + cleanupOnClose: cleanupOnClose, + }, client.retryOptions) if err != nil { return nil, err @@ -183,7 +212,7 @@ func (client *Client) AcceptSessionForQueue(ctx context.Context, queueName strin ctx, &sessionID, client.namespace, - &entity{Queue: queueName}, + entity{Queue: queueName}, cleanupOnClose, toReceiverOptions(options)) @@ -207,7 +236,7 @@ func (client *Client) AcceptSessionForSubscription(ctx context.Context, topicNam ctx, &sessionID, client.namespace, - &entity{Topic: topicName, Subscription: subscriptionName}, + entity{Topic: topicName, Subscription: subscriptionName}, cleanupOnClose, toReceiverOptions(options)) @@ -231,7 +260,7 @@ func (client *Client) AcceptNextSessionForQueue(ctx context.Context, queueName s ctx, nil, client.namespace, - &entity{Queue: queueName}, + entity{Queue: queueName}, cleanupOnClose, toReceiverOptions(options)) @@ -255,7 +284,7 @@ func (client *Client) AcceptNextSessionForSubscription(ctx context.Context, topi ctx, nil, client.namespace, - &entity{Topic: topicName, Subscription: subscriptionName}, + entity{Topic: topicName, Subscription: subscriptionName}, cleanupOnClose, toReceiverOptions(options)) diff --git a/sdk/messaging/azservicebus/client_test.go b/sdk/messaging/azservicebus/client_test.go index fa187148f12d..9fc38f533e81 100644 --- a/sdk/messaging/azservicebus/client_test.go +++ b/sdk/messaging/azservicebus/client_test.go @@ -130,7 +130,7 @@ func TestNewClientUnitTests(t *testing.T) { // (really all part of the same functionality) ns := &internal.Namespace{} - require.NoError(t, internal.NamespacesWithTokenCredential("mysb.windows.servicebus.net", + require.NoError(t, internal.NamespaceWithTokenCredential("mysb.windows.servicebus.net", fakeTokenCredential)(ns)) require.EqualValues(t, ns.FQDN, "mysb.windows.servicebus.net") @@ -180,25 +180,5 @@ func TestNewClientUnitTests(t *testing.T) { require.NoError(t, client.Close(context.Background())) require.Empty(t, client.links) require.EqualValues(t, 1, ns.AMQPLinks.Closed) - - client, ns = setupClient() - _, err = newProcessorForQueue(client, "hello", nil) - - require.NoError(t, err) - require.EqualValues(t, 1, len(client.links)) - require.NotNil(t, client.links[1]) - require.NoError(t, client.Close(context.Background())) - require.Empty(t, client.links) - require.EqualValues(t, 1, ns.AMQPLinks.Closed) - - client, ns = setupClient() - _, err = newProcessorForSubscription(client, "hello", "world", nil) - - require.NoError(t, err) - require.EqualValues(t, 1, len(client.links)) - require.NotNil(t, client.links[1]) - require.NoError(t, client.Close(context.Background())) - require.Empty(t, client.links) - require.EqualValues(t, 1, ns.AMQPLinks.Closed) }) } diff --git a/sdk/messaging/azservicebus/errors.go b/sdk/messaging/azservicebus/errors.go deleted file mode 100644 index f11f111b2a11..000000000000 --- a/sdk/messaging/azservicebus/errors.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azservicebus - -import "fmt" - -// implements `internal/errorinfo/NonRetriable` -type errClosed struct { - link string -} - -func (ec errClosed) NonRetriable() {} -func (ec errClosed) Error() string { - return fmt.Sprintf("%s is closed and can no longer be used", ec.link) -} diff --git a/sdk/messaging/azservicebus/errors_test.go b/sdk/messaging/azservicebus/errors_test.go deleted file mode 100644 index 9132d5660442..000000000000 --- a/sdk/messaging/azservicebus/errors_test.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azservicebus - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" - "github.com/stretchr/testify/require" -) - -func TestErrClosed(t *testing.T) { - var err error = errClosed{link: "hello"} - - _, ok := err.(errorinfo.NonRetriable) - require.True(t, ok, "ErrClosed is a errorinfo.NonRetriable") - require.EqualValues(t, "hello is closed and can no longer be used", err.Error()) -} diff --git a/sdk/messaging/azservicebus/go.mod b/sdk/messaging/azservicebus/go.mod index a3ba068fdd30..2675f82d6c51 100644 --- a/sdk/messaging/azservicebus/go.mod +++ b/sdk/messaging/azservicebus/go.mod @@ -9,10 +9,17 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/messaging/internal v0.0.0-20211208010914-2b10e91d237e github.com/Azure/go-amqp v0.17.0 github.com/devigned/tab v0.1.1 + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect +) + +require ( + // used in tests only github.com/joho/godotenv v1.3.0 - github.com/jpillora/backoff v1.0.0 + + // used in stress tests github.com/microsoft/ApplicationInsights-Go v0.4.4 github.com/stretchr/testify v1.7.0 - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + + // used in examples only nhooyr.io/websocket v1.8.6 ) diff --git a/sdk/messaging/azservicebus/go.sum b/sdk/messaging/azservicebus/go.sum index 8f235251df4f..ddffa1b097c6 100644 --- a/sdk/messaging/azservicebus/go.sum +++ b/sdk/messaging/azservicebus/go.sum @@ -69,8 +69,6 @@ github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= -github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= -github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= diff --git a/sdk/messaging/azservicebus/internal/amqpInterfaces.go b/sdk/messaging/azservicebus/internal/amqpInterfaces.go index eb17ea50420c..cfcc41301858 100644 --- a/sdk/messaging/azservicebus/internal/amqpInterfaces.go +++ b/sdk/messaging/azservicebus/internal/amqpInterfaces.go @@ -5,9 +5,7 @@ package internal import ( "context" - "time" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc" "github.com/Azure/go-amqp" ) @@ -61,7 +59,7 @@ type AMQPSenderCloser interface { // RPCLink is implemented by *rpc.Link type RPCLink interface { Close(ctx context.Context) error - RetryableRPC(ctx context.Context, times int, delay time.Duration, msg *amqp.Message) (*rpc.Response, error) + RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error) } // Closeable is implemented by pretty much any AMQP link/client diff --git a/sdk/messaging/azservicebus/internal/amqpLinks.go b/sdk/messaging/azservicebus/internal/amqpLinks.go index 48a3ade96995..2f2575715156 100644 --- a/sdk/messaging/azservicebus/internal/amqpLinks.go +++ b/sdk/messaging/azservicebus/internal/amqpLinks.go @@ -10,19 +10,21 @@ import ( "strings" "sync" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/devigned/tab" ) -type errClosedPermanently struct{} - -func (e errClosedPermanently) Error() string { return "Link has been closed permanently" } -func (e errClosedPermanently) NonRetriable() {} - -func ShouldRecover(ctx context.Context, err error) bool { - return shouldRecreateConnection(ctx, err) || shouldRecreateLink(err) +type LinksWithID struct { + Sender AMQPSender + Receiver AMQPReceiver + RPC RPCLink + ID LinkID } +type RetryWithLinksFn func(ctx context.Context, lwid *LinksWithID, args *utils.RetryFnArgs) error + type AMQPLinks interface { EntityPath() string ManagementPath() string @@ -31,11 +33,14 @@ type AMQPLinks interface { // Get will initialize a session and call its link.linkCreator function. // If this link has been closed via Close() it will return an non retriable error. - Get(ctx context.Context) (AMQPSender, AMQPReceiver, MgmtClient, uint64, error) + Get(ctx context.Context) (*LinksWithID, error) + + // Retry will run your callback, recovering links when necessary. + Retry(ctx context.Context, name string, fn RetryWithLinksFn, o utils.RetryOptions) error // RecoverIfNeeded will check if an error requires recovery, and will recover // the link or, possibly, the connection. - RecoverIfNeeded(ctx context.Context, linksRevision uint64, err error) error + RecoverIfNeeded(ctx context.Context, linkID LinkID, err error) error // Close will close the the link. // If permanent is true the link will not be auto-recreated if Get/Recover @@ -46,40 +51,39 @@ type AMQPLinks interface { ClosedPermanently() bool } -// amqpLinks manages the set of AMQP links (and detritus) typically needed to work +// AMQPLinksImpl manages the set of AMQP links (and detritus) typically needed to work // within Service Bus: // - An *goamqp.Sender or *goamqp.Receiver AMQP link (could also be 'both' if needed) // - A `$management` link // - an *goamqp.Session // // State management can be done through Recover (close and reopen), Close (close permanently, return failures) -// and Get() (retrieve latest version of all amqpLinks, or create if needed). -type amqpLinks struct { +// and Get() (retrieve latest version of all AMQPLinksImpl, or create if needed). +type AMQPLinksImpl struct { + // NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens + // is just to make it the first value in the struct + // See: + // Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG + // PR: https://github.com/Azure/azure-sdk-for-go/pull/16847 + id LinkID + entityPath string managementPath string audience string createLink CreateLinkFunc - baseRetrier Retrier mu sync.RWMutex - // mgmt lets you interact with the $management link for your entity. - mgmt MgmtClient + // RPCLink lets you interact with the $management link for your entity. + RPCLink RPCLink // the AMQP session for either the 'sender' or 'receiver' link session AMQPSessionCloser // these are populated by your `createLinkFunc` when you construct // the amqpLinks - sender AMQPSenderCloser - receiver AMQPReceiverCloser - - // last connection revision seen by this links instance. - clientRevision uint64 - - // the current 'revision' of our set of links. - // starts at 1, increments each time you call Recover(). - revision uint64 + Sender AMQPSenderCloser + Receiver AMQPReceiverCloser // whether this links set has been closed permanently (via Close) // Recover() does not affect this value. @@ -97,15 +101,13 @@ type CreateLinkFunc func(ctx context.Context, session AMQPSession) (AMQPSenderCl // NewAMQPLinks creates a session, starts the claim refresher and creates an associated // management link for a specific entity path. -func newAMQPLinks(ns NamespaceForAMQPLinks, entityPath string, baseRetrier Retrier, createLink CreateLinkFunc) AMQPLinks { - l := &amqpLinks{ +func NewAMQPLinks(ns NamespaceForAMQPLinks, entityPath string, createLink CreateLinkFunc) AMQPLinks { + l := &AMQPLinksImpl{ entityPath: entityPath, managementPath: fmt.Sprintf("%s/$management", entityPath), audience: ns.GetEntityAudience(entityPath), createLink: createLink, - baseRetrier: baseRetrier, closedPermanently: false, - revision: 1, ns: ns, } @@ -113,53 +115,42 @@ func newAMQPLinks(ns NamespaceForAMQPLinks, entityPath string, baseRetrier Retri } // ManagementPath is the management path for the associated entity. -func (links *amqpLinks) ManagementPath() string { +func (links *AMQPLinksImpl) ManagementPath() string { return links.managementPath } // recoverLink will recycle all associated links (mgmt, receiver, sender and session) // and recreate them using the link.linkCreator function. -func (links *amqpLinks) recoverLink(ctx context.Context, theirLinkRevision *uint64) error { +func (links *AMQPLinksImpl) recoverLink(ctx context.Context, theirLinkRevision LinkID) error { ctx, span := tab.StartSpan(ctx, tracing.SpanRecoverLink) defer span.End() links.mu.RLock() closedPermanently := links.closedPermanently - ourLinkRevision := links.revision + ourLinkRevision := links.id links.mu.RUnlock() if closedPermanently { span.AddAttributes(tab.StringAttribute("outcome", "was_closed_permanently")) - return errClosedPermanently{} + return ErrNonRetriable{Message: "Link has been closed permanently"} } - if theirLinkRevision != nil && ourLinkRevision > *theirLinkRevision { + // cheap check before we do the lock + if ourLinkRevision.Link != theirLinkRevision.Link { // we've already recovered past their failure. - span.AddAttributes( - tab.StringAttribute("outcome", "already_recovered"), - tab.StringAttribute("lock", "readlock"), - tab.StringAttribute("revisions", fmt.Sprintf("ours(%d), theirs(%d)", ourLinkRevision, *theirLinkRevision)), - ) return nil } links.mu.Lock() defer links.mu.Unlock() - if theirLinkRevision != nil && ourLinkRevision > *theirLinkRevision { + // check once more, just in case someone else modified it before we took + // the lock. + if links.id.Link != theirLinkRevision.Link { // we've already recovered past their failure. - span.AddAttributes( - tab.StringAttribute("outcome", "already_recovered"), - tab.StringAttribute("lock", "writelock"), - tab.StringAttribute("revisions", fmt.Sprintf("ours(%d), theirs(%d)", ourLinkRevision, *theirLinkRevision)), - ) return nil } - if err := links.closeWithoutLocking(ctx, false); err != nil { - span.Logger().Error(err) - } - err := links.initWithoutLocking(ctx) if err != nil { @@ -167,149 +158,161 @@ func (links *amqpLinks) recoverLink(ctx context.Context, theirLinkRevision *uint return err } - links.revision++ - span.AddAttributes( tab.StringAttribute("outcome", "recovered"), - tab.StringAttribute("revision_new", fmt.Sprintf("%d", links.revision)), + tab.StringAttribute("revision_new", fmt.Sprintf("%d", links.id)), ) return nil } // Recover will recover the links or the connection, depending -// on the severity of the error. This function uses the `baseRetrier` -// defined in the links struct. -func (links *amqpLinks) RecoverIfNeeded(ctx context.Context, linksRevision uint64, origErr error) error { +// on the severity of the error. +func (links *AMQPLinksImpl) RecoverIfNeeded(ctx context.Context, theirID LinkID, origErr error) error { ctx, span := tab.StartSpan(ctx, tracing.SpanRecover) defer span.End() - var err error = origErr - - retrier := links.baseRetrier.Copy() - - for retrier.Try(ctx) { - span.AddAttributes(tab.StringAttribute("recover_attempt", fmt.Sprintf("%d", retrier.CurrentTry()))) - - err = links.recoverImpl(ctx, retrier.CurrentTry(), linksRevision, err) - - if err == nil { - return nil - } - } - - return err -} - -func (links *amqpLinks) recoverImpl(ctx context.Context, try int, linksRevision uint64, origErr error) error { - _, span := tab.StartSpan(ctx, tracing.SpanRecoverLink) - defer span.End() - if origErr == nil || IsCancelError(origErr) { return nil } + log.Writef(EventConn, "Recovering link for error %s", origErr.Error()) + select { case <-ctx.Done(): return ctx.Err() default: } - span.AddAttributes(tab.Int64Attribute("attempt", int64(try))) - - if shouldRecreateLink(origErr) { - span.AddAttributes( - tab.StringAttribute("recovery_kind", "link"), - tab.StringAttribute("error", origErr.Error()), - tab.StringAttribute("error_type", fmt.Sprintf("%T", origErr))) + sbe := GetSBErrInfo(origErr) - if err := links.recoverLink(ctx, &linksRevision); err != nil { - span.AddAttributes(tab.StringAttribute("recoveryFailure", err.Error())) + if sbe.RecoveryKind == RecoveryKindLink { + if err := links.recoverLink(ctx, theirID); err != nil { + log.Writef(EventConn, "failed to recreate link: %s", err.Error()) return err } + log.Writef(EventConn, "Recovered links") return nil - } else if shouldRecreateConnection(ctx, origErr) { - span.AddAttributes( - tab.StringAttribute("recovery_kind", "connection"), - tab.StringAttribute("error", origErr.Error()), - tab.StringAttribute("error_type", fmt.Sprintf("%T", origErr))) - - if err := links.recoverConnection(ctx); err != nil { - span.Logger().Error(fmt.Errorf("failed to recreate connection: %w", err)) - return err - } - - // unconditionally recover the link if the connection died. - if err := links.recoverLink(ctx, nil); err != nil { - span.Logger().Error(fmt.Errorf("failed to recover links after connection restarted: %w", err)) + } else if sbe.RecoveryKind == RecoveryKindConn { + if err := links.recoverConnection(ctx, theirID); err != nil { + log.Writef(EventConn, "failed to recreate connection: %s", err.Error()) return err } + log.Writef(EventConn, "Recovered connection and links") return nil } - span.AddAttributes( - tab.StringAttribute("recovery", "none"), - tab.StringAttribute("error", origErr.Error()), - tab.StringAttribute("errorType", fmt.Sprintf("%T", origErr))) - + log.Writef(EventConn, "Recovered, no action needed") return nil } -func (links *amqpLinks) recoverConnection(ctx context.Context) error { +func (links *AMQPLinksImpl) recoverConnection(ctx context.Context, theirID LinkID) error { tab.For(ctx).Info("Connection is dead, recovering") - links.mu.RLock() - clientRevision := links.clientRevision - links.mu.RUnlock() + links.mu.Lock() + defer links.mu.Unlock() - err := links.ns.Recover(ctx, clientRevision) + created, err := links.ns.Recover(ctx, uint64(theirID.Conn)) if err != nil { - tab.For(ctx).Error(fmt.Errorf("Recover connection failure: %w", err)) + log.Writef(EventConn, "Recover connection failure: %s", err) return err } + // We'll recreate the link if: + // - `created` is true, meaning we recreated the AMQP connection (ie, all old links are invalid) + // - the link they received an error on is our current link, so it needs to be recreated. + // (if it wasn't the same then we've already recovered and created a new link, + // so no recovery would be needed) + if created || theirID.Link == links.id.Link { + log.Writef(EventConn, "recreating link: c: %v, current:%v, old:%v", created, links.id, theirID) + if err := links.initWithoutLocking(ctx); err != nil { + return err + } + } + return nil } +// LinkID is ID that represent our current link and the client used to create it. +// These are used when trying to determine what parts need to be recreated when +// an error occurs, to prevent recovering a connection/link repeatedly. +// See amqpLinks.RecoverIfNeeded() for usage. +type LinkID struct { + // Conn is the ID of the connection we used to create our links. + Conn uint64 + + // Link is the ID of our current link. + Link uint64 +} + // Get will initialize a session and call its link.linkCreator function. // If this link has been closed via Close() it will return an non retriable error. -func (l *amqpLinks) Get(ctx context.Context) (AMQPSender, AMQPReceiver, MgmtClient, uint64, error) { +func (l *AMQPLinksImpl) Get(ctx context.Context) (*LinksWithID, error) { l.mu.RLock() - sender, receiver, mgmt, revision, closedPermanently := l.sender, l.receiver, l.mgmt, l.revision, l.closedPermanently + sender, receiver, mgmtLink, revision, closedPermanently := l.Sender, l.Receiver, l.RPCLink, l.id, l.closedPermanently l.mu.RUnlock() if closedPermanently { - return nil, nil, nil, 0, errClosedPermanently{} + return nil, ErrNonRetriable{} } if sender != nil || receiver != nil { - return sender, receiver, mgmt, revision, nil + return &LinksWithID{ + Sender: sender, + Receiver: receiver, + RPC: mgmtLink, + ID: revision, + }, nil } l.mu.Lock() defer l.mu.Unlock() if err := l.initWithoutLocking(ctx); err != nil { - return nil, nil, nil, 0, err + return nil, err } - return l.sender, l.receiver, l.mgmt, l.revision, nil + return &LinksWithID{ + Sender: l.Sender, + Receiver: l.Receiver, + RPC: l.RPCLink, + ID: l.id, + }, nil +} + +func (l *AMQPLinksImpl) Retry(ctx context.Context, name string, fn RetryWithLinksFn, o utils.RetryOptions) error { + var lastID LinkID + + return utils.Retry(ctx, name, func(ctx context.Context, args *utils.RetryFnArgs) error { + if err := l.RecoverIfNeeded(ctx, lastID, args.LastErr); err != nil { + return err + } + + linksWithVersion, err := l.Get(ctx) + + if err != nil { + return err + } + + lastID = linksWithVersion.ID + return fn(ctx, linksWithVersion, args) + }, IsFatalSBError, o) } // EntityPath is the full entity path for the queue/topic/subscription. -func (l *amqpLinks) EntityPath() string { +func (l *AMQPLinksImpl) EntityPath() string { return l.entityPath } // EntityPath is the audience for the queue/topic/subscription. -func (l *amqpLinks) Audience() string { +func (l *AMQPLinksImpl) Audience() string { return l.audience } // ClosedPermanently is true if AMQPLinks.Close(ctx, true) has been called. -func (l *amqpLinks) ClosedPermanently() bool { +func (l *AMQPLinksImpl) ClosedPermanently() bool { l.mu.RLock() defer l.mu.RUnlock() return l.closedPermanently @@ -317,20 +320,23 @@ func (l *amqpLinks) ClosedPermanently() bool { // Close will close the the link permanently. // Any further calls to Get()/Recover() to return ErrLinksClosed. -func (l *amqpLinks) Close(ctx context.Context, permanent bool) error { +func (l *AMQPLinksImpl) Close(ctx context.Context, permanent bool) error { l.mu.Lock() defer l.mu.Unlock() return l.closeWithoutLocking(ctx, permanent) } // initWithoutLocking will create a new link, unconditionally. -func (l *amqpLinks) initWithoutLocking(ctx context.Context) error { +func (l *AMQPLinksImpl) initWithoutLocking(ctx context.Context) error { + // shut down any links we have + _ = l.closeWithoutLocking(ctx, false) + var err error l.cancelAuthRefreshLink, err = l.ns.NegotiateClaim(ctx, l.entityPath) if err != nil { if err := l.closeWithoutLocking(ctx, false); err != nil { - tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after negotiateClaim: %s", err.Error())) + log.Writef(EventConn, "Failure during link cleanup after negotiateClaim: %s", err.Error()) } return err } @@ -339,44 +345,49 @@ func (l *amqpLinks) initWithoutLocking(ctx context.Context) error { if err != nil { if err := l.closeWithoutLocking(ctx, false); err != nil { - tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after negotiate claim for mgmt link: %s", err.Error())) + log.Writef(EventConn, "Failure during link cleanup after negotiate claim for mgmt link: %s", err.Error()) } return err } - l.session, l.clientRevision, err = l.ns.NewAMQPSession(ctx) + sess, cr, err := l.ns.NewAMQPSession(ctx) if err != nil { if err := l.closeWithoutLocking(ctx, false); err != nil { - tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after creating AMQP session: %s", err.Error())) + log.Writef(EventConn, "Failure during link cleanup after creating AMQP session: %s", err.Error()) } return err } - l.sender, l.receiver, err = l.createLink(ctx, l.session) + l.session = sess + l.id.Conn = cr + + l.Sender, l.Receiver, err = l.createLink(ctx, l.session) if err != nil { if err := l.closeWithoutLocking(ctx, false); err != nil { - tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after creating link: %s", err.Error())) + log.Writef(EventConn, "Failure during link cleanup after creating link: %s", err.Error()) } return err } - l.mgmt, err = l.ns.NewMgmtClient(ctx, l) + rpcLink, err := l.ns.NewRPCLink(ctx, l.ManagementPath()) if err != nil { if err := l.closeWithoutLocking(ctx, false); err != nil { - tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after creating mgmt client: %s", err.Error())) + log.Writef("Failure during link cleanup after creating mgmt client: %s", err.Error()) } return err } + l.RPCLink = rpcLink + l.id.Link++ return nil } // close closes the link. // NOTE: No locking is done in this function, call `Close` if you require locking. -func (l *amqpLinks) closeWithoutLocking(ctx context.Context, permanent bool) error { +func (l *AMQPLinksImpl) closeWithoutLocking(ctx context.Context, permanent bool) error { if l.closedPermanently { return nil } @@ -397,18 +408,18 @@ func (l *amqpLinks) closeWithoutLocking(ctx context.Context, permanent bool) err l.cancelAuthRefreshMgmtLink() } - if l.sender != nil { - if err := l.sender.Close(ctx); err != nil { + if l.Sender != nil { + if err := l.Sender.Close(ctx); err != nil { messages = append(messages, fmt.Sprintf("amqp sender close error: %s", err.Error())) } - l.sender = nil + l.Sender = nil } - if l.receiver != nil { - if err := l.receiver.Close(ctx); err != nil { + if l.Receiver != nil { + if err := l.Receiver.Close(ctx); err != nil { messages = append(messages, fmt.Sprintf("amqp receiver close error: %s", err.Error())) } - l.receiver = nil + l.Receiver = nil } if l.session != nil { @@ -418,11 +429,11 @@ func (l *amqpLinks) closeWithoutLocking(ctx context.Context, permanent bool) err l.session = nil } - if l.mgmt != nil { - if err := l.mgmt.Close(ctx); err != nil { + if l.RPCLink != nil { + if err := l.RPCLink.Close(ctx); err != nil { messages = append(messages, fmt.Sprintf("$management link close error: %s", err.Error())) } - l.mgmt = nil + l.RPCLink = nil } if len(messages) > 0 { diff --git a/sdk/messaging/azservicebus/internal/amqpLinks_test.go b/sdk/messaging/azservicebus/internal/amqpLinks_test.go index 8a2b11fbf4e4..886632c93d39 100644 --- a/sdk/messaging/azservicebus/internal/amqpLinks_test.go +++ b/sdk/messaging/azservicebus/internal/amqpLinks_test.go @@ -5,78 +5,20 @@ package internal import ( "context" - "errors" + "log" + "sync" "testing" + "time" + azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" "github.com/stretchr/testify/require" ) -func TestAMQPLinks(t *testing.T) { - fakeSender := &FakeAMQPSender{} - fakeSession := &FakeAMQPSession{} - fakeMgmtClient := &fakeMgmtClient{} - - createLinkFunc, createLinkCallCount := setupCreateLinkResponses(t, []createLinkResponse{ - {sender: fakeSender}, - }) - - links := newAMQPLinks(&FakeNS{ - Session: fakeSession, - MgmtClient: fakeMgmtClient, - }, "entityPath", &fakeRetrier{}, createLinkFunc) - - require.EqualValues(t, "entityPath", links.EntityPath()) - require.EqualValues(t, "audience: entityPath", links.Audience()) - - // successful Get() where a Sender was initialized - sender, receiver, mgmt, linkRevision, err := links.Get(context.Background()) - require.NotNil(t, sender) - require.NotNil(t, mgmt) // you always get a free mgmt link - require.Nil(t, receiver) - require.Nil(t, err) - require.EqualValues(t, 1, linkRevision) - require.EqualValues(t, 1, *createLinkCallCount) - - // further calls should just be cached instances - sender2, receiver2, mgmt2, linkRevision2, err2 := links.Get(context.Background()) - require.EqualValues(t, sender, sender2) - require.EqualValues(t, mgmt, mgmt2) - require.Nil(t, receiver2) - require.Nil(t, err2) - require.EqualValues(t, 1, linkRevision2, "No recover calls, so link revision remains the same") - require.EqualValues(t, 1, *createLinkCallCount, "No create call needed since an instance was cached") - - // closing multiple times is fine. - asAMQPLinks, ok := links.(*amqpLinks) - require.True(t, ok) - - require.NoError(t, links.Close(context.Background(), false)) - require.False(t, asAMQPLinks.closedPermanently) - - require.NoError(t, links.Close(context.Background(), true)) - require.True(t, asAMQPLinks.closedPermanently) - - require.NoError(t, links.Close(context.Background(), true)) - require.True(t, asAMQPLinks.closedPermanently) - - require.NoError(t, links.Close(context.Background(), false)) - require.True(t, asAMQPLinks.closedPermanently) - - // and the individual links are closed as well - require.EqualValues(t, 1, fakeSender.Closed) - require.EqualValues(t, 1, fakeSession.closed) - require.EqualValues(t, 1, fakeMgmtClient.closed) - - // and calls to Get() will indicate the amqpLinks has been closed permanently - sender, receiver, mgmt, linkRevision, err = links.Get(context.Background()) - require.Nil(t, sender) - require.Nil(t, receiver) - require.Nil(t, mgmt) - require.EqualValues(t, 0, linkRevision) - - _, ok = err.(NonRetriable) - require.True(t, ok) +var retryOptionsOnlyOnce = utils.RetryOptions{ + MaxRetries: 0, } type fakeNetError struct { @@ -88,112 +30,368 @@ func (pe fakeNetError) Timeout() bool { return pe.timeout } func (pe fakeNetError) Temporary() bool { return pe.temp } func (pe fakeNetError) Error() string { return "Fake but very permanent error" } -func TestAMQPLinksRecovery(t *testing.T) { - sess := &FakeAMQPSession{} - ns := &FakeNS{ - Session: sess, - } - sender := &FakeAMQPSender{} +func assertFailedLinks(t *testing.T, lwid *LinksWithID, expectedErr error) { + err := lwid.Sender.Send(context.TODO(), &amqp.Message{ + Data: [][]byte{ + {0}, + }, + }) + require.ErrorIs(t, err, expectedErr) + + _, err = PeekMessages(context.TODO(), lwid.RPC, 0, 1) + require.ErrorIs(t, err, expectedErr) + + msg, err := lwid.Receiver.Receive(context.TODO()) + require.ErrorIs(t, err, expectedErr) + require.Nil(t, msg) + +} + +func assertLinks(t *testing.T, lwid *LinksWithID) { + err := lwid.Sender.Send(context.TODO(), &amqp.Message{ + Data: [][]byte{ + {0}, + }, + }) + require.NoError(t, err) + + _, err = PeekMessages(context.TODO(), lwid.RPC, 0, 1) + require.NoError(t, err) + + require.NoError(t, lwid.Receiver.IssueCredit(1)) + msg, err := lwid.Receiver.Receive(context.TODO()) + require.NoError(t, err) + require.NotNil(t, msg) +} + +func TestAMQPLinksBasic(t *testing.T) { + entityPath, cleanup := test.CreateExpiringQueue(t, nil) + defer cleanup() + + cs := test.GetConnectionString(t) + ns, err := NewNamespace(NamespaceWithConnectionString(cs)) + require.NoError(t, err) + + links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + return newLinksForAMQPLinksTest(entityPath, session) + }) + + lwr, err := links.Get(context.Background()) + require.NoError(t, err) + + assertLinks(t, lwr) + + require.EqualValues(t, entityPath, links.EntityPath()) +} + +func TestAMQPLinksLive(t *testing.T) { + // we're not going to use this client for tehse tests. + entityPath, cleanup := test.CreateExpiringQueue(t, nil) + defer cleanup() + + cs := test.GetConnectionString(t) + ns, err := NewNamespace(NamespaceWithConnectionString(cs)) + require.NoError(t, err) + + defer func() { _ = ns.Close(context.Background()) }() + + createLinksCalled := 0 + + links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + createLinksCalled++ + return newLinksForAMQPLinksTest(entityPath, session) + }) + + require.EqualValues(t, 0, createLinksCalled) + require.NoError(t, links.RecoverIfNeeded(context.Background(), LinkID{}, amqp.ErrConnClosed)) + require.EqualValues(t, 1, createLinksCalled) + + lwr, err := links.Get(context.Background()) + require.NoError(t, err) + + amqpClient, clientRev, err := ns.GetAMQPClientImpl(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 1, clientRev) + require.NoError(t, amqpClient.Close()) + + // all the links are dead because the connection is dead. + assertFailedLinks(t, lwr, amqp.ErrConnClosed) + + // now we'll recover, which should recreate everything + require.NoError(t, links.RecoverIfNeeded(context.Background(), lwr.ID, amqp.ErrConnClosed)) + require.EqualValues(t, 2, createLinksCalled) + + lwr, err = links.Get(context.Background()) + require.NoError(t, err) + + // should work now, connection should be reopened + assertLinks(t, lwr) + + // cheat a bit and close the links out from under us (but leave them in place) + actualLinks := links.(*AMQPLinksImpl) + _ = actualLinks.Sender.Close(context.Background()) + _ = actualLinks.Receiver.Close(context.Background()) + _ = actualLinks.RPCLink.Close(context.Background()) + + assertFailedLinks(t, lwr, amqp.ErrLinkClosed) + + lwr, err = links.Get(context.Background()) + require.NoError(t, err) + + require.NoError(t, links.RecoverIfNeeded(context.Background(), lwr.ID, amqp.ErrLinkClosed)) + require.EqualValues(t, 3, createLinksCalled) + + lwr, err = links.Get(context.Background()) + require.NoError(t, err) + + assertLinks(t, lwr) +} + +func TestAMQPLinksLiveRecoverLink(t *testing.T) { + // we're not going to use this client for tehse tests. + entityPath, cleanup := test.CreateExpiringQueue(t, nil) + defer cleanup() + + cs := test.GetConnectionString(t) + ns, err := NewNamespace(NamespaceWithConnectionString(cs)) + require.NoError(t, err) + + defer func() { _ = ns.Close(context.Background()) }() + + createLinksCalled := 0 + + links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + createLinksCalled++ + return newLinksForAMQPLinksTest(entityPath, session) + }) + + require.EqualValues(t, 0, createLinksCalled) + require.NoError(t, links.RecoverIfNeeded(context.Background(), LinkID{}, amqp.ErrConnClosed)) + require.EqualValues(t, 1, createLinksCalled) + + lwr, err := links.Get(context.Background()) + require.NoError(t, err) + + require.NoError(t, links.RecoverIfNeeded(context.Background(), lwr.ID, amqp.ErrLinkClosed)) + require.EqualValues(t, 2, createLinksCalled) +} + +func TestAMQPLinksLiveRace(t *testing.T) { + entityPath, cleanup := test.CreateExpiringQueue(t, nil) + defer cleanup() - createLinkCalled := 0 + cs := test.GetConnectionString(t) + ns, err := NewNamespace(NamespaceWithConnectionString(cs)) + require.NoError(t, err) + + defer func() { _ = ns.Close(context.Background()) }() + + createLinksCalled := 0 - tmpLinks := newAMQPLinks(ns, "entity path", NewBackoffRetrier(BackoffRetrierParams{}), func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { - createLinkCalled++ - return sender, nil, nil + links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + createLinksCalled++ + return newLinksForAMQPLinksTest(entityPath, session) }) - links, _ := tmpLinks.(*amqpLinks) - - links.clientRevision = 2001 - links.sender = sender - - ctx := context.TODO() - - require.Nil(t, links.RecoverIfNeeded(ctx, 0, nil)) - require.EqualValues(t, 0, sess.closed) - require.EqualValues(t, 0, ns.recovered) - require.EqualValues(t, 0, createLinkCalled, "new links aren't needed") - require.False(t, links.closedPermanently, "link should still be usable") - require.Empty(t, ns.clientRevisions, "no connection recoveries happened") - - require.Nil(t, links.RecoverIfNeeded(ctx, 0, errors.New("Passes through"))) - require.EqualValues(t, 0, sess.closed) - require.EqualValues(t, 0, ns.recovered) - require.EqualValues(t, 0, createLinkCalled, "new links aren't needed") - require.False(t, links.closedPermanently, "link should still be usable") - require.Empty(t, ns.clientRevisions, "no connection recoveries happened") - - // now let's initiate a recovery at the connection level - require.NoError(t, links.RecoverIfNeeded(ctx, 0, fakeNetError{}), fakeNetError{}.Error()) - require.EqualValues(t, 1, ns.recovered, "client gets recovered") - require.EqualValues(t, 1, sender.Closed, "link is closed") - require.EqualValues(t, 1, createLinkCalled, "link is created") - require.False(t, links.closedPermanently, "link should still be usable") - require.EqualValues(t, []uint64{2001}, ns.clientRevisions, "links handed us the client revision it got last") - - // validate that our linkRevision got updated and that we're returning it. - // (note that link revisions start at 1, so we're not at 2, even though - // only one recover has happened) - _, _, _, linkRevision, err := links.Get(ctx) - require.NoError(t, err) - require.EqualValues(t, uint64(2), linkRevision) - - ns.recovered = 0 - sender.Closed = 0 - createLinkCalled = 0 - - // let's do just a link level one - require.NoError(t, links.RecoverIfNeeded(ctx, links.revision+1, &amqp.DetachError{}), &amqp.DetachError{}) - require.EqualValues(t, 0, ns.recovered) - require.EqualValues(t, 1, sender.Closed) - require.EqualValues(t, 1, createLinkCalled) - - _, _, _, linkRevision, err = links.Get(ctx) - require.NoError(t, err) - require.EqualValues(t, uint64(3), linkRevision) - - // cancellation - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - ns.recovered = 0 - sender.Closed = 0 - createLinkCalled = 0 - - // cancellation overrides any other logic. - require.Error(t, links.RecoverIfNeeded(ctx, links.revision+1, &amqp.DetachError{}), &amqp.DetachError{}) - require.EqualValues(t, 0, ns.recovered) - require.EqualValues(t, 0, sender.Closed) - require.EqualValues(t, 0, createLinkCalled) + wg := sync.WaitGroup{} + + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := links.RecoverIfNeeded(context.Background(), LinkID{}, amqp.ErrConnClosed) + require.NoError(t, err) + }() + } + + wg.Wait() + + // TODO: also check that the connection hasn't recycled multiple times. + require.EqualValues(t, 1, createLinksCalled) } -func TestAMQPLinks_Closed(t *testing.T) { - createLinks := func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { - return nil, nil, nil +func TestAMQPLinksLiveRaceLink(t *testing.T) { + entityPath, cleanup := test.CreateExpiringQueue(t, nil) + defer cleanup() + + cs := test.GetConnectionString(t) + ns, err := NewNamespace(NamespaceWithConnectionString(cs)) + require.NoError(t, err) + + defer func() { _ = ns.Close(context.Background()) }() + + createLinksCalled := 0 + + enableLogging() + + links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + createLinksCalled++ + return newLinksForAMQPLinksTest(entityPath, session) + }) + + wg := sync.WaitGroup{} + + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := links.RecoverIfNeeded(context.Background(), LinkID{}, &amqp.DetachError{}) + require.NoError(t, err) + }() } - links := newAMQPLinks(&FakeNS{}, "hello", &backoffRetrier{}, createLinks) - links.Close(context.Background(), true) + wg.Wait() + + // TODO: also check that the connection hasn't recycled multiple times. + require.EqualValues(t, 1, createLinksCalled) +} + +func TestAMQPLinksRetry(t *testing.T) { + entityPath, cleanup := test.CreateExpiringQueue(t, nil) + defer cleanup() + + cs := test.GetConnectionString(t) + ns, err := NewNamespace(NamespaceWithConnectionString(cs)) + require.NoError(t, err) + + defer func() { _ = ns.Close(context.Background()) }() + + createLinksCalled := 0 + + links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + createLinksCalled++ + return newLinksForAMQPLinksTest(entityPath, session) + }) + + err = links.Retry(context.Background(), "retryOp", func(ctx context.Context, lwid *LinksWithID, args *utils.RetryFnArgs) error { + // force recoveries + return &amqp.DetachError{} + }, utils.RetryOptions{ + MaxRetries: 2, + // note: omitting MaxRetries just to give a sanity check that + // we do setDefaults() before we run. + RetryDelay: time.Millisecond, + MaxRetryDelay: time.Millisecond, + }) + + var detachErr *amqp.DetachError + require.ErrorAs(t, err, &detachErr) + require.EqualValues(t, 3, createLinksCalled) +} + +func TestAMQPLinksMultipleWithSameConnection(t *testing.T) { + entityPath, cleanup := test.CreateExpiringQueue(t, nil) + defer cleanup() + + cs := test.GetConnectionString(t) + ns, err := NewNamespace(NamespaceWithConnectionString(cs)) + require.NoError(t, err) + + defer func() { _ = ns.Close(context.Background()) }() - _, _, _, _, err := links.Get(context.Background()) + createLinksCalled := 0 - require.True(t, IsNonRetriable(err)) + links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + createLinksCalled++ + return newLinksForAMQPLinksTest(entityPath, session) + }) + + createLinksCalled2 := 0 + + links2 := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + createLinksCalled2++ + return newLinksForAMQPLinksTest(entityPath, session) + }) + + wg := sync.WaitGroup{} + + lwr, err := links.Get(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 1, createLinksCalled) + require.EqualValues(t, 1, lwr.ID.Link) + + lwr2, err := links2.Get(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 1, createLinksCalled2) + require.EqualValues(t, 1, lwr2.ID.Link) + + wg.Add(1) + + go func() { + defer wg.Done() + err = links.RecoverIfNeeded(context.Background(), lwr.ID, &amqp.DetachError{}) + require.NoError(t, err) + }() + + wg.Add(1) + + go func() { + defer wg.Done() + + err = links2.RecoverIfNeeded(context.Background(), lwr2.ID, &amqp.DetachError{}) + require.NoError(t, err) + }() + + wg.Wait() + + // TODO: also check that the connection hasn't recycled multiple times. + require.EqualValues(t, 2, createLinksCalled) + require.EqualValues(t, 2, createLinksCalled2) + + _, clientRev, err := ns.GetAMQPClientImpl(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 1, clientRev) + + recovered, err := ns.Recover(context.Background(), clientRev) + require.NoError(t, err) + require.True(t, recovered) + + _, clientRev, err = ns.GetAMQPClientImpl(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 2, clientRev) + + // now attempt a recover but with an older revision (won't do anything since we've + // already recovered past that older rev. They should just call Get()) + recovered, err = ns.Recover(context.Background(), clientRev-1) + require.NoError(t, err) + require.False(t, recovered) + + _, clientRev, err = ns.GetAMQPClientImpl(context.Background()) + require.NoError(t, err) + require.EqualValues(t, 2, clientRev) } -func setupCreateLinkResponses(t *testing.T, responses []createLinkResponse) (CreateLinkFunc, *int) { - callCount := 0 - testCreateLinkFunc := func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { - callCount++ +func newLinksForAMQPLinksTest(entityPath string, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + receiveMode := amqp.ModeSecond + + opts := []amqp.LinkOption{ + amqp.LinkSourceAddress(entityPath), + amqp.LinkReceiverSettle(receiveMode), + amqp.LinkWithManualCredits(), + amqp.LinkCredit(1000), + } + + receiver, err := session.NewReceiver(opts...) - if len(responses) == 0 { - require.Fail(t, "createLinkFunc called too many times") - } + if err != nil { + return nil, nil, err + } - r := responses[0] - responses = responses[1:] + sender, err := session.NewSender( + amqp.LinkSenderSettle(amqp.ModeMixed), + amqp.LinkReceiverSettle(amqp.ModeFirst), + amqp.LinkTargetAddress(entityPath)) - return r.sender, r.receiver, r.err + if err != nil { + _ = receiver.Close(context.Background()) + return nil, nil, err } - return testCreateLinkFunc, &callCount + return sender, receiver, nil +} + +func enableLogging() { + azlog.SetListener(func(e azlog.Event, s string) { + log.Printf("%s %s", e, s) + }) } diff --git a/sdk/messaging/azservicebus/internal/amqp_test_utils.go b/sdk/messaging/azservicebus/internal/amqp_test_utils.go index 6feee7d1848c..014f26eb5309 100644 --- a/sdk/messaging/azservicebus/internal/amqp_test_utils.go +++ b/sdk/messaging/azservicebus/internal/amqp_test_utils.go @@ -7,15 +7,14 @@ import ( "context" "fmt" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" ) type FakeNS struct { claimNegotiated int recovered uint64 clientRevisions []uint64 - MgmtClient MgmtClient - RPCLink *rpc.Link + RPCLink RPCLink Session AMQPSessionCloser AMQPLinks *FakeAMQPLinks } @@ -30,21 +29,16 @@ type FakeAMQPSession struct { closed int } -type fakeMgmtClient struct { - MgmtClient - closed int -} - type FakeAMQPLinks struct { AMQPLinks Closed int // values to be returned for each `Get` call - Revision uint64 + Revision LinkID Receiver AMQPReceiver Sender AMQPSender - Mgmt MgmtClient + RPC RPCLink Err error permanently bool @@ -66,8 +60,23 @@ func (r *FakeAMQPReceiver) Close(ctx context.Context) error { return nil } -func (l *FakeAMQPLinks) Get(ctx context.Context) (AMQPSender, AMQPReceiver, MgmtClient, uint64, error) { - return l.Sender, l.Receiver, l.Mgmt, l.Revision, l.Err +func (l *FakeAMQPLinks) Get(ctx context.Context) (*LinksWithID, error) { + return &LinksWithID{ + Sender: l.Sender, + Receiver: l.Receiver, + RPC: l.RPC, + ID: l.Revision, + }, l.Err +} + +func (l *FakeAMQPLinks) Retry(ctx context.Context, name string, fn RetryWithLinksFn, o utils.RetryOptions) error { + lwr, err := l.Get(ctx) + + if err != nil { + return err + } + + return fn(ctx, lwr, &utils.RetryFnArgs{}) } func (l *FakeAMQPLinks) Close(ctx context.Context, permanently bool) error { @@ -93,11 +102,6 @@ func (s *FakeAMQPSession) Close(ctx context.Context) error { return nil } -func (m *fakeMgmtClient) Close(ctx context.Context) error { - m.closed++ - return nil -} - func (ns *FakeNS) NegotiateClaim(ctx context.Context, entityPath string) (func() <-chan struct{}, error) { ch := make(chan struct{}) close(ch) @@ -117,26 +121,20 @@ func (ns *FakeNS) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64 return ns.Session, ns.recovered + 100, nil } -func (ns *FakeNS) NewMgmtClient(ctx context.Context, links AMQPLinks) (MgmtClient, error) { - return ns.MgmtClient, nil -} - -func (ns *FakeNS) NewRPCLink(ctx context.Context, managementPath string) (*rpc.Link, error) { +func (ns *FakeNS) NewRPCLink(ctx context.Context, managementPath string) (RPCLink, error) { return ns.RPCLink, nil } -func (ns *FakeNS) Recover(ctx context.Context, clientRevision uint64) error { +func (ns *FakeNS) Recover(ctx context.Context, clientRevision uint64) (bool, error) { ns.clientRevisions = append(ns.clientRevisions, clientRevision) ns.recovered++ + return true, nil +} + +func (ns *FakeNS) CloseIfNeeded(ctx context.Context, clientRevision uint64) error { return nil } func (ns *FakeNS) NewAMQPLinks(entityPath string, createLinkFunc CreateLinkFunc) AMQPLinks { return ns.AMQPLinks } - -type createLinkResponse struct { - sender AMQPSenderCloser - receiver AMQPReceiverCloser - err error -} diff --git a/sdk/messaging/azservicebus/internal/atom/entity_manager.go b/sdk/messaging/azservicebus/internal/atom/entity_manager.go index 179b5b60d678..ddbbae249fc2 100644 --- a/sdk/messaging/azservicebus/internal/atom/entity_manager.go +++ b/sdk/messaging/azservicebus/internal/atom/entity_manager.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httputil" "strings" @@ -19,6 +20,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/sbauth" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/auth" "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/conn" "github.com/devigned/tab" @@ -44,6 +46,7 @@ type ( Host string mwStack []MiddlewareFunc version string + retryOptions utils.RetryOptions } // BaseEntityDescription provides common fields which are part of Queues, Topics and Subscriptions @@ -52,7 +55,8 @@ type ( ServiceBusSchema *string `xml:"xmlns,attr,omitempty"` } - managementError struct { + // example: 401Manage,EntityRead claims required for this operation. + ManagementError struct { XMLName xml.Name `xml:"Error"` Code int `xml:"Code"` Detail string `xml:"Detail"` @@ -145,7 +149,7 @@ const ( Unknown EntityStatus = "Unknown" ) -func (m *managementError) String() string { +func (m *ManagementError) String() string { return fmt.Sprintf("Code: %d, Details: %s", m.Code, m.Detail) } @@ -179,7 +183,7 @@ func NewEntityManagerWithConnectionString(connectionString string, version strin } // NewEntityManager creates an entity manager using a TokenCredential. -func NewEntityManager(ns string, tokenCredential azcore.TokenCredential, version string) (EntityManager, error) { +func NewEntityManager(ns string, tokenCredential azcore.TokenCredential, version string, retryOptions utils.RetryOptions) (EntityManager, error) { return &entityManager{ Host: fmt.Sprintf("https://%s/", ns), version: version, @@ -190,6 +194,7 @@ func NewEntityManager(ns string, tokenCredential azcore.TokenCredential, version addAuthorization(sbauth.NewTokenProvider(tokenCredential)), applyTracing(version), }, + retryOptions: retryOptions, }, nil } @@ -238,61 +243,90 @@ func (em *entityManager) Delete(ctx context.Context, entityPath string, mw ...Mi } func (em *entityManager) execute(ctx context.Context, method string, entityPath string, body io.Reader, mw ...MiddlewareFunc) (*http.Response, error) { - ctx, span := em.startSpanFromContext(ctx, "sb.ATOM.Execute") - defer span.End() + var finalResp *http.Response - req, err := http.NewRequest(method, em.Host+strings.TrimPrefix(entityPath, "/"), body) - if err != nil { - tab.For(ctx).Error(err) - return nil, err - } + err := utils.Retry(ctx, fmt.Sprintf("%s %s", method, entityPath), func(ctx context.Context, args *utils.RetryFnArgs) error { + ctx, span := em.startSpanFromContext(ctx, "sb.ATOM.Execute") + defer span.End() - final := func(_ RestHandler) RestHandler { - return func(reqCtx context.Context, request *http.Request) (*http.Response, error) { - client := &http.Client{ - Timeout: 60 * time.Second, - } - request = request.WithContext(reqCtx) - return client.Do(request) + req, err := http.NewRequest(method, em.Host+strings.TrimPrefix(entityPath, "/"), body) + if err != nil { + tab.For(ctx).Error(err) + return err } - } - mwStack := []MiddlewareFunc{final} - sl := len(em.mwStack) - 1 - for i := sl; i >= 0; i-- { - mwStack = append(mwStack, em.mwStack[i]) - } + final := func(_ RestHandler) RestHandler { + return func(reqCtx context.Context, request *http.Request) (*http.Response, error) { + client := &http.Client{ + Timeout: 60 * time.Second, + } + request = request.WithContext(reqCtx) + return client.Do(request) + } + } - for i := len(mw) - 1; i >= 0; i-- { - mwStack = append(mwStack, mw[i]) - } + mwStack := []MiddlewareFunc{final} + sl := len(em.mwStack) - 1 + for i := sl; i >= 0; i-- { + mwStack = append(mwStack, em.mwStack[i]) + } - var h RestHandler - for _, mw := range mwStack { - h = mw(h) - } + for i := len(mw) - 1; i >= 0; i-- { + mwStack = append(mwStack, mw[i]) + } - resp, err := h(ctx, req) + var h RestHandler + for _, mw := range mwStack { + h = mw(h) + } - if err == nil { - if resp.StatusCode >= http.StatusBadRequest { - bytes, err := ioutil.ReadAll(resp.Body) + resp, err := h(ctx, req) - if err == nil { - err = FormatManagementError(bytes, err) + if err == nil { + if resp.StatusCode >= http.StatusBadRequest { + return NewResponseError(resp) } - return nil, NewResponseError(err, resp) + finalResp = resp + return nil + } + + if resp != nil { + return NewResponseError(resp) } - return resp, nil + return err + }, isFatalHTTPError, em.retryOptions) + + if err != nil { + return nil, err + } + + return finalResp, nil +} + +func isFatalHTTPError(err error) bool { + var netErr net.Error + + if errors.As(err, &netErr) { + return false } - if resp != nil { - return nil, NewResponseError(err, resp) + var respErr *azcore.ResponseError + + // TODO: this is very much temporary. We need to move this over to the azcore HTTP stack. + if errors.As(err, &respErr) { + if respErr.StatusCode == http.StatusRequestTimeout || // 408 + respErr.StatusCode == http.StatusTooManyRequests || // 429 + respErr.StatusCode == http.StatusInternalServerError || // 500 + respErr.StatusCode == http.StatusBadGateway || // 502 + respErr.StatusCode == http.StatusServiceUnavailable || // 503 + respErr.StatusCode == http.StatusGatewayTimeout { // 504 ) + return false + } } - return nil, err + return true } // Use adds middleware to the middleware mwStack @@ -306,7 +340,7 @@ func (em *entityManager) TokenProvider() auth.TokenProvider { } func FormatManagementError(body []byte, origErr error) error { - var mgmtError managementError + var mgmtError ManagementError unmarshalErr := xml.Unmarshal(body, &mgmtError) if unmarshalErr != nil { return origErr diff --git a/sdk/messaging/azservicebus/internal/atom/errors.go b/sdk/messaging/azservicebus/internal/atom/errors.go index e8182a433746..0c50018dc7fb 100644 --- a/sdk/messaging/azservicebus/internal/atom/errors.go +++ b/sdk/messaging/azservicebus/internal/atom/errors.go @@ -6,10 +6,15 @@ package atom import ( "fmt" "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) -func NewResponseError(inner error, resp *http.Response) error { - return ResponseError{inner, resp} +func NewResponseError(resp *http.Response) error { + return &azcore.ResponseError{ + StatusCode: resp.StatusCode, + RawResponse: resp, + } } // ResponseError conforms to the older azcore.HTTPResponse diff --git a/sdk/messaging/azservicebus/internal/atom/manager_common_test.go b/sdk/messaging/azservicebus/internal/atom/manager_common_test.go index 9945fff0b8ec..12b7b47a0562 100644 --- a/sdk/messaging/azservicebus/internal/atom/manager_common_test.go +++ b/sdk/messaging/azservicebus/internal/atom/manager_common_test.go @@ -4,40 +4,66 @@ package atom import ( + "bytes" "context" - "errors" "io" "net/http" + "net/url" "strings" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/stretchr/testify/require" ) +func newFakeResponse(statusCode int, status string, contents string) *http.Response { + var body io.ReadCloser = http.NoBody + + if contents != "" { + body = &FakeReader{ + Reader: bytes.NewBufferString(contents), + } + } + + return &http.Response{ + Request: &http.Request{ + URL: &url.URL{}, + }, + StatusCode: statusCode, + Status: status, + Body: body, + } +} + func TestResponseError(t *testing.T) { - require.EqualValues(t, "this is now the error message: 409", NewResponseError(nil, &http.Response{ - StatusCode: http.StatusConflict, - Status: "this is now the error message", - }).Error()) - - require.EqualValues(t, "inner errors message takes precedence", NewResponseError(errors.New("inner errors message takes precedence"), &http.Response{ - StatusCode: http.StatusConflict, - Status: "going to be ignored", - }).Error()) + resp := newFakeResponse(http.StatusConflict, "statusString", "") + require.Contains(t, NewResponseError(resp).Error(), "statusString") + + resp = newFakeResponse(http.StatusConflict, "statusString", "contents") + require.Contains(t, NewResponseError(resp).Error(), "statusString") + + resp = newFakeResponse(http.StatusBadGateway, "statusString", "401Manage,EntityRead claims required for this operation.") + err := NewResponseError(resp) + + re, ok := err.(*azcore.ResponseError) + require.True(t, ok) + + require.Contains(t, re.Error(), "statusString") + require.EqualValues(t, http.StatusBadGateway, re.StatusCode) } type FakeReader struct { io.Reader - closed bool + closed bool + closeErr error } func (f *FakeReader) Close() error { f.closed = true - return nil + return f.closeErr } func TestCloseRes(t *testing.T) { - reader := strings.NewReader("hello") body := &FakeReader{Reader: reader} diff --git a/sdk/messaging/azservicebus/internal/errors.go b/sdk/messaging/azservicebus/internal/errors.go index d833bf0609f4..7edefc24bb56 100644 --- a/sdk/messaging/azservicebus/internal/errors.go +++ b/sdk/messaging/azservicebus/internal/errors.go @@ -11,193 +11,17 @@ import ( "net" "reflect" "strings" - "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc" "github.com/Azure/go-amqp" - "github.com/devigned/tab" ) -type NonRetriable interface { - error - NonRetriable() -} - -// IsNonRetriable indicates an error is fatal. Typically, this means -// the connection or link has been closed. -func IsNonRetriable(err error) bool { - var nonRetriable NonRetriable - return errors.As(err, &nonRetriable) -} - type ErrNonRetriable struct { Message string } func (e ErrNonRetriable) Error() string { return e.Message } -// Error Conditions -const ( - // Service Bus Errors - errorServerBusy amqp.ErrorCondition = "com.microsoft:server-busy" - errorTimeout amqp.ErrorCondition = "com.microsoft:timeout" - errorOperationCancelled amqp.ErrorCondition = "com.microsoft:operation-cancelled" - errorContainerClose amqp.ErrorCondition = "com.microsoft:container-close" -) - -const ( - amqpRetryDefaultTimes int = 3 - amqpRetryDefaultDelay time.Duration = time.Second -) - -type ( - // ErrMissingField indicates that an expected property was missing from an AMQP message. This should only be - // encountered when there is an error with this library, or the server has altered its behavior unexpectedly. - ErrMissingField string - - // ErrMalformedMessage indicates that a message was expected in the form of []byte was not a []byte. This is likely - // a bug and should be reported. - ErrMalformedMessage string - - // ErrIncorrectType indicates that type assertion failed. This should only be encountered when there is an error - // with this library, or the server has altered its behavior unexpectedly. - ErrIncorrectType struct { - Key string - ExpectedType reflect.Type - ActualValue interface{} - } - - // ErrAMQP indicates that the server communicated an AMQP error with a particular - ErrAMQP rpc.Response - - // ErrNoMessages is returned when an operation returned no messages. It is not indicative that there will not be - // more messages in the future. - ErrNoMessages struct{} - - // ErrNotFound is returned when an entity is not found (404) - ErrNotFound struct { - EntityPath string - } - - // ErrConnectionClosed indicates that the connection has been closed. - ErrConnectionClosed string -) - -func (e ErrMissingField) Error() string { - return fmt.Sprintf("missing value %q", string(e)) -} - -func (e ErrMalformedMessage) Error() string { - return "message was expected in the form of []byte was not a []byte" -} - -// NewErrIncorrectType lets you skip using the `reflect` package. Just provide a variable of the desired type as -// 'expected'. -func NewErrIncorrectType(key string, expected, actual interface{}) ErrIncorrectType { - return ErrIncorrectType{ - Key: key, - ExpectedType: reflect.TypeOf(expected), - ActualValue: actual, - } -} - -func (e ErrIncorrectType) Error() string { - return fmt.Sprintf( - "value at %q was expected to be of type %q but was actually of type %q", - e.Key, - e.ExpectedType, - reflect.TypeOf(e.ActualValue)) -} - -func (e ErrAMQP) Error() string { - return fmt.Sprintf("server says (%d) %s", e.Code, e.Description) -} - -func (e ErrNoMessages) Error() string { - return "no messages available" -} - -func (e ErrNotFound) Error() string { - return fmt.Sprintf("entity at %s not found", e.EntityPath) -} - -// IsErrNotFound returns true if the error argument is an ErrNotFound type -func IsErrNotFound(err error) bool { - _, ok := err.(ErrNotFound) - return ok -} - -func (e ErrConnectionClosed) Error() string { - return fmt.Sprintf("the connection has been closed: %s", string(e)) -} - -// Leveraging @serbrech's fine work from go-shuttle: -// https://github.com/Azure/go-shuttle/blob/ea882947109ade9b34d4d69642fdf7aec4570fee/common/errorhandling/recovery.go - -var retryableAMQPConditions = map[string]bool{ - string(amqp.ErrorInternalError): true, - string(errorServerBusy): true, // "com.microsoft:server-busy" - string(errorTimeout): true, // "com.microsoft:timeout" - string(errorOperationCancelled): true, // "com.microsoft:operation-cancelled" - "client.sender:not-enough-link-credit": true, - string(amqp.ErrorUnauthorizedAccess): true, - string(amqp.ErrorDetachForced): true, - string(amqp.ErrorConnectionForced): true, - string(amqp.ErrorTransferLimitExceeded): true, - "amqp: connection closed": true, - "unexpected frame": true, - string(amqp.ErrorNotFound): true, -} - -func isRetryableAMQPError(ctxForLogging context.Context, err error) bool { - var amqpErr *amqp.Error - var isAMQPError = errors.As(err, &amqpErr) - - if isAMQPError { - _, ok := retryableAMQPConditions[string(amqpErr.Condition)] - return ok - } - - // TODO: there is a bug somewhere that seems to be errorString'ing errors. Need to track that down. - // In the meantime, try string matching instead - for condition := range retryableAMQPConditions { - if strings.Contains(err.Error(), condition) { - tab.For(ctxForLogging).Error(fmt.Errorf("error needed to be matched by a string matcher, rather than by type: %w", err)) - return true - } - } - - return false -} - -func shouldRecreateLink(err error) bool { - if err == nil { - return false - } - - var detachError *amqp.DetachError - - return errors.As(err, &detachError) || - // TODO: proper error types needs to happen - strings.Contains(err.Error(), "detach frame link detached") -} - -func shouldRecreateConnection(ctxForLogging context.Context, err error) bool { - if err == nil { - return false - } - - shouldRecreate := isPermanentNetError(err) || - isRetryableAMQPError(ctxForLogging, err) || - errors.Is(err, io.EOF) || - // these are distinct from a detach and probably indicate something - // wrong with the connection itself, rather than just the link - errors.Is(err, amqp.ErrSessionClosed) || - errors.Is(err, amqp.ErrLinkClosed) - - return shouldRecreate -} - type recoveryKind string const RecoveryKindNone recoveryKind = "" @@ -205,48 +29,40 @@ const RecoveryKindFatal recoveryKind = "fatal" const RecoveryKindLink recoveryKind = "link" const RecoveryKindConn recoveryKind = "connection" -type ServiceBusError struct { +type SBErrInfo struct { inner error RecoveryKind recoveryKind } -func (sbe *ServiceBusError) String() string { +func (sbe *SBErrInfo) String() string { return sbe.inner.Error() } -func (sbe *ServiceBusError) AsError() error { +func (sbe *SBErrInfo) AsError() error { return sbe.inner } -// ToSBE wraps the passed in 'err' with a proper error with one of either: +func IsFatalSBError(err error) bool { + return GetSBErrInfo(err).RecoveryKind == RecoveryKindFatal +} + +// GetSBErrInfo wraps the passed in 'err' with a proper error with one of either: // - `fatalServiceBusError` if no recovery is possible. // - `serviceBusError` if the error is recoverable. The `recoveryKind` field contains the // type of recovery needed. -func ToSBE(loggingCtx context.Context, err error) *ServiceBusError { +func GetSBErrInfo(err error) *SBErrInfo { if err == nil { return nil } - sbe := &ServiceBusError{ + sbe := &SBErrInfo{ inner: err, - RecoveryKind: GetRecoveryKind(loggingCtx, err), + RecoveryKind: GetRecoveryKind(err), } return sbe } -func isPermanentNetError(err error) bool { - var netErr net.Error - - if errors.As(err, &netErr) { - temp := netErr.Temporary() - timeout := netErr.Timeout() - return !temp && !timeout - } - - return false -} - func IsCancelError(err error) bool { if err == nil { return false @@ -291,7 +107,7 @@ var amqpConditionsToRecoveryKind = map[amqp.ErrorCondition]recoveryKind{ amqp.ErrorCondition("com.microsoft:message-lock-lost"): RecoveryKindFatal, } -func GetRecoveryKind(ctxForLogging context.Context, err error) recoveryKind { +func GetRecoveryKind(err error) recoveryKind { if IsCancelError(err) { return RecoveryKindFatal } @@ -345,7 +161,106 @@ func GetRecoveryKind(ctxForLogging context.Context, err error) recoveryKind { } } + var me mgmtError + + if errors.As(err, &me) { + code := me.RPCCode() + + // this can happen when we're recovering the link - the client gets closed and the old link is still being + // used by this instance of the client. It needs to recover and attempt it again. + if code == 401 || + // we lost the session lock, attempt link recovery + code == 410 { + return RecoveryKindLink + } + + // simple timeouts + if me.Resp.Code == 408 || me.Resp.Code == 503 { + return RecoveryKindNone + } + } + // this is some error type we've never seen. - tab.For(ctxForLogging).Fatal(fmt.Sprintf("No recovery possible with error: %#v", err)) return RecoveryKindFatal } + +type ( + // ErrMissingField indicates that an expected property was missing from an AMQP message. This should only be + // encountered when there is an error with this library, or the server has altered its behavior unexpectedly. + ErrMissingField string + + // ErrMalformedMessage indicates that a message was expected in the form of []byte was not a []byte. This is likely + // a bug and should be reported. + ErrMalformedMessage string + + // ErrIncorrectType indicates that type assertion failed. This should only be encountered when there is an error + // with this library, or the server has altered its behavior unexpectedly. + ErrIncorrectType struct { + Key string + ExpectedType reflect.Type + ActualValue interface{} + } + + // ErrAMQP indicates that the server communicated an AMQP error with a particular + ErrAMQP rpc.Response + + // ErrNoMessages is returned when an operation returned no messages. It is not indicative that there will not be + // more messages in the future. + ErrNoMessages struct{} + + // ErrNotFound is returned when an entity is not found (404) + ErrNotFound struct { + EntityPath string + } + + // ErrConnectionClosed indicates that the connection has been closed. + ErrConnectionClosed string +) + +func (e ErrMissingField) Error() string { + return fmt.Sprintf("missing value %q", string(e)) +} + +func (e ErrMalformedMessage) Error() string { + return "message was expected in the form of []byte was not a []byte" +} + +// NewErrIncorrectType lets you skip using the `reflect` package. Just provide a variable of the desired type as +// 'expected'. +func NewErrIncorrectType(key string, expected, actual interface{}) ErrIncorrectType { + return ErrIncorrectType{ + Key: key, + ExpectedType: reflect.TypeOf(expected), + ActualValue: actual, + } +} + +func (e ErrIncorrectType) Error() string { + return fmt.Sprintf( + "value at %q was expected to be of type %q but was actually of type %q", + e.Key, + e.ExpectedType, + reflect.TypeOf(e.ActualValue)) +} + +func (e ErrAMQP) Error() string { + return fmt.Sprintf("server says (%d) %s", e.Code, e.Description) +} + +func (e ErrNoMessages) Error() string { + return "no messages available" +} + +func (e ErrNotFound) Error() string { + return fmt.Sprintf("entity at %s not found", e.EntityPath) +} + +// IsErrNotFound returns true if the error argument is an ErrNotFound type +func IsErrNotFound(err error) bool { + _, ok := err.(ErrNotFound) + return ok +} + +func (e ErrConnectionClosed) Error() string { + return fmt.Sprintf("the connection has been closed: %s", string(e)) +} diff --git a/sdk/messaging/azservicebus/internal/errors_test.go b/sdk/messaging/azservicebus/internal/errors_test.go index 471e7d96ad6d..743f2894f2af 100644 --- a/sdk/messaging/azservicebus/internal/errors_test.go +++ b/sdk/messaging/azservicebus/internal/errors_test.go @@ -77,96 +77,92 @@ func TestErrNotFound_Error(t *testing.T) { assert.False(t, IsErrNotFound(otherErr)) } -func Test_isPermanentNetError(t *testing.T) { - require.False(t, isPermanentNetError(&fakeNetError{ - temp: true, - })) - - require.False(t, isPermanentNetError(&fakeNetError{ - timeout: true, - })) - - require.False(t, isPermanentNetError(errors.New("not a net error"))) - - require.True(t, isPermanentNetError(&fakeNetError{})) +func Test_recoveryKind(t *testing.T) { + t.Run("link", func(t *testing.T) { + linkErrorCodes := []string{ + string(amqp.ErrorDetachForced), + } + + for _, code := range linkErrorCodes { + t.Run(code, func(t *testing.T) { + sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)}) + require.EqualValues(t, RecoveryKindLink, sbe.RecoveryKind, fmt.Sprintf("requires link recovery: %s", code)) + }) + } + + t.Run("sentintel errors", func(t *testing.T) { + sbe := GetSBErrInfo(amqp.ErrLinkClosed) + require.EqualValues(t, RecoveryKindLink, sbe.RecoveryKind) + + sbe = GetSBErrInfo(amqp.ErrSessionClosed) + require.EqualValues(t, RecoveryKindLink, sbe.RecoveryKind) + }) + }) + + t.Run("connection", func(t *testing.T) { + codes := []string{ + string(amqp.ErrorConnectionForced), + } + + for _, code := range codes { + t.Run(code, func(t *testing.T) { + sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)}) + require.EqualValues(t, RecoveryKindConn, sbe.RecoveryKind, fmt.Sprintf("requires connection recovery: %s", code)) + }) + } + + t.Run("sentinel errors", func(t *testing.T) { + sbe := GetSBErrInfo(amqp.ErrConnClosed) + require.EqualValues(t, RecoveryKindConn, sbe.RecoveryKind) + }) + }) + + t.Run("nonretriable", func(t *testing.T) { + codes := []string{ + string(amqp.ErrorTransferLimitExceeded), + string(amqp.ErrorInternalError), + string(amqp.ErrorUnauthorizedAccess), + string(amqp.ErrorNotFound), + string(amqp.ErrorMessageSizeExceeded), + } + + for _, code := range codes { + t.Run(code, func(t *testing.T) { + sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)}) + require.EqualValues(t, RecoveryKindFatal, sbe.RecoveryKind, fmt.Sprintf("cannot be recovered: %s", code)) + }) + } + }) + + t.Run("none", func(t *testing.T) { + codes := []string{ + string("com.microsoft:operation-cancelled"), + string("com.microsoft:server-busy"), + string("com.microsoft:timeout"), + } + + for _, code := range codes { + t.Run(code, func(t *testing.T) { + sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)}) + require.EqualValues(t, RecoveryKindNone, sbe.RecoveryKind, fmt.Sprintf("no recovery needed: %s", code)) + }) + } + }) } -func Test_isRetryableAMQPError(t *testing.T) { - ctx := context.Background() - - retryableCodes := []string{ - string(amqp.ErrorInternalError), - string(errorServerBusy), - string(errorTimeout), - string(errorOperationCancelled), - "client.sender:not-enough-link-credit", - string(amqp.ErrorUnauthorizedAccess), - string(amqp.ErrorDetachForced), - string(amqp.ErrorConnectionForced), - string(amqp.ErrorTransferLimitExceeded), - "amqp: connection closed", - "unexpected frame", - string(amqp.ErrorNotFound), +func Test_IsNonRetriable(t *testing.T) { + errs := []error{ + context.Canceled, + context.DeadlineExceeded, + ErrNonRetriable{Message: "any message"}, + fmt.Errorf("wrapped: %w", context.Canceled), + fmt.Errorf("wrapped: %w", context.DeadlineExceeded), + fmt.Errorf("wrapped: %w", ErrNonRetriable{Message: "any message"}), } - for _, code := range retryableCodes { - require.True(t, isRetryableAMQPError(ctx, &amqp.Error{ - Condition: amqp.ErrorCondition(code), - })) - - // it works equally well if the error is just in the String(). - // Need to narrow this down some more to see where the errors - // might not be getting converted properly. - require.True(t, isRetryableAMQPError(ctx, errors.New(code))) + for _, err := range errs { + require.EqualValues(t, RecoveryKindFatal, GetSBErrInfo(err).RecoveryKind) } - - require.False(t, isRetryableAMQPError(ctx, errors.New("some non-amqp related error"))) -} - -func Test_shouldRecreateLink(t *testing.T) { - require.False(t, shouldRecreateLink(nil)) - - require.True(t, shouldRecreateLink(&amqp.DetachError{})) - - // going to treat these as "connection troubles" and throw them into the - // connection recovery scenario instead. - require.False(t, shouldRecreateLink(amqp.ErrLinkClosed)) - require.False(t, shouldRecreateLink(amqp.ErrSessionClosed)) -} - -func Test_shouldRecreateConnection(t *testing.T) { - ctx := context.Background() - - require.False(t, shouldRecreateConnection(ctx, nil)) - require.True(t, shouldRecreateConnection(ctx, &fakeNetError{})) - require.True(t, shouldRecreateConnection(ctx, fmt.Errorf("%w", &fakeNetError{}))) - - require.False(t, shouldRecreateLink(amqp.ErrLinkClosed)) - require.False(t, shouldRecreateLink(fmt.Errorf("wrapped: %w", amqp.ErrLinkClosed))) - - require.False(t, shouldRecreateLink(amqp.ErrSessionClosed)) - require.False(t, shouldRecreateLink(fmt.Errorf("wrapped: %w", amqp.ErrSessionClosed))) -} - -// TODO: while testing it appeared there were some errors that were getting string-ized -// We want to eliminate these. 'stress.go' reproduces most of these as you disconnect -// and reconnect. -func Test_stringErrorsToEliminate(t *testing.T) { - require.True(t, shouldRecreateLink(errors.New("detach frame link detached"))) - require.True(t, isRetryableAMQPError(context.Background(), errors.New("amqp: connection closed"))) - require.True(t, IsCancelError(errors.New("context canceled"))) -} - -func Test_IsCancelError(t *testing.T) { - require.False(t, IsCancelError(nil)) - require.False(t, IsCancelError(errors.New("not a cancel error"))) - - require.True(t, IsCancelError(errors.New("context canceled"))) - - require.True(t, IsCancelError(context.Canceled)) - require.True(t, IsCancelError(context.DeadlineExceeded)) - require.True(t, IsCancelError(fmt.Errorf("wrapped: %w", context.Canceled))) - require.True(t, IsCancelError(fmt.Errorf("wrapped: %w", context.DeadlineExceeded))) } func Test_ServiceBusError_NoRecoveryNeeded(t *testing.T) { @@ -178,10 +174,13 @@ func Test_ServiceBusError_NoRecoveryNeeded(t *testing.T) { fakeNetError{temp: true}, fakeNetError{timeout: true}, fakeNetError{temp: false, timeout: false}, + // simple timeouts from the mgmt link + mgmtError{Resp: &RPCResponse{Code: 408}}, + mgmtError{Resp: &RPCResponse{Code: 503}}, } for i, err := range tempErrors { - rk := ToSBE(context.Background(), err).RecoveryKind + rk := GetSBErrInfo(err).RecoveryKind require.EqualValues(t, RecoveryKindNone, rk, fmt.Sprintf("[%d] %v", i, err)) } } @@ -194,7 +193,7 @@ func Test_ServiceBusError_ConnectionRecoveryNeeded(t *testing.T) { } for i, err := range connErrors { - rk := ToSBE(context.Background(), err).RecoveryKind + rk := GetSBErrInfo(err).RecoveryKind require.EqualValues(t, RecoveryKindConn, rk, fmt.Sprintf("[%d] %v", i, err)) } } @@ -205,10 +204,15 @@ func Test_ServiceBusError_LinkRecoveryNeeded(t *testing.T) { amqp.ErrLinkClosed, &amqp.DetachError{}, &amqp.Error{Condition: amqp.ErrorDetachForced}, + // we lost the session lock, attempt link recovery + mgmtError{Resp: &RPCResponse{Code: 410}}, + // this can happen when we're recovering the link - the client gets closed and the old link is still being + // used by this instance of the client. It needs to recover and attempt it again. + mgmtError{Resp: &RPCResponse{Code: 401}}, } for i, err := range linkErrors { - rk := ToSBE(context.Background(), err).RecoveryKind + rk := GetSBErrInfo(err).RecoveryKind require.EqualValues(t, RecoveryKindLink, rk, fmt.Sprintf("[%d] %v", i, err)) } } @@ -226,11 +230,14 @@ func Test_ServiceBusError_Fatal(t *testing.T) { } for i, cond := range fatalConditions { - rk := ToSBE(context.Background(), &amqp.Error{Condition: cond}).RecoveryKind + rk := GetSBErrInfo(&amqp.Error{Condition: cond}).RecoveryKind require.EqualValues(t, RecoveryKindFatal, rk, fmt.Sprintf("[%d] %s", i, cond)) } // unknown errors are also considered fatal - rk := ToSBE(context.Background(), errors.New("Some unknown error")).RecoveryKind + rk := GetSBErrInfo(errors.New("Some unknown error")).RecoveryKind + require.EqualValues(t, RecoveryKindFatal, rk, "some unknown error") + + rk = GetSBErrInfo(mgmtError{Resp: &RPCResponse{Code: 500}}).RecoveryKind require.EqualValues(t, RecoveryKindFatal, rk, "some unknown error") } diff --git a/sdk/messaging/azservicebus/internal/log.go b/sdk/messaging/azservicebus/internal/log.go index c500f1f9c7c5..cd04f8c3a996 100644 --- a/sdk/messaging/azservicebus/internal/log.go +++ b/sdk/messaging/azservicebus/internal/log.go @@ -3,7 +3,21 @@ package internal +import "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" + const ( + // Link/connection creation + EventConn = "azsb.Conn" + + // authentication/claims negotiation + EventAuth = "azsb.Auth" + + // receiver operations + EventReceiver = "azsb.Receiver" + + // mgmt link + EventMgmtLink = "azsb.Mgmt" + // internal operations - EventRetry = "azsb.Retry" + EventRetry = utils.EventRetry ) diff --git a/sdk/messaging/azservicebus/internal/mgmt.go b/sdk/messaging/azservicebus/internal/mgmt.go index b266fc95c856..1d9534fc7bc2 100644 --- a/sdk/messaging/azservicebus/internal/mgmt.go +++ b/sdk/messaging/azservicebus/internal/mgmt.go @@ -7,13 +7,10 @@ import ( "context" "errors" "fmt" - "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" - common "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc" "github.com/Azure/go-amqp" "github.com/devigned/tab" ) @@ -34,171 +31,39 @@ const ( DeferredDisposition DispositionStatus = "defered" ) -type ( - mgmtClient struct { - ns NamespaceForMgmtClient - links AMQPLinks - - clientMu sync.Mutex - rpcLink RPCLink - - sessionID *string - isSessionFilterSet bool - } -) - -type MgmtClient interface { - Close(ctx context.Context) error - SendDisposition(ctx context.Context, lockToken *amqp.UUID, state Disposition, propertiesToModify map[string]interface{}) error - ReceiveDeferred(ctx context.Context, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error) - PeekMessages(ctx context.Context, fromSequenceNumber int64, messageCount int32) ([]*amqp.Message, error) - - ScheduleMessages(ctx context.Context, enqueueTime time.Time, messages ...*amqp.Message) ([]int64, error) - CancelScheduled(ctx context.Context, seq ...int64) error - - RenewLocks(ctx context.Context, linkName string, lockTokens []amqp.UUID) ([]time.Time, error) - RenewSessionLock(ctx context.Context, sessionID string) (time.Time, error) - - GetSessionState(ctx context.Context, sessionID string) ([]byte, error) - SetSessionState(ctx context.Context, sessionID string, state []byte) error +type mgmtError struct { + Resp *RPCResponse + Message string } -func newMgmtClient(ctx context.Context, links AMQPLinks, ns NamespaceForMgmtClient) (MgmtClient, error) { - r := &mgmtClient{ - ns: ns, - links: links, - } - - return r, nil +func (me mgmtError) Error() string { + return me.Message } -// Recover will attempt to close the current session and link, then rebuild them -func (mc *mgmtClient) recover(ctx context.Context) error { - mc.clientMu.Lock() - defer mc.clientMu.Unlock() - - ctx, span := mc.startSpanFromContext(ctx, string(tracing.SpanNameRecover)) - defer span.End() - - if mc.rpcLink != nil { - if err := mc.rpcLink.Close(ctx); err != nil { - tab.For(ctx).Debug(fmt.Sprintf("Error while closing old link in recovery: %s", err.Error())) - } - mc.rpcLink = nil - } - - if _, err := mc.getLinkWithoutLock(ctx); err != nil { - return err - } - - return nil +func (me mgmtError) RPCCode() int { + return me.Resp.Code } -// getLinkWithoutLock returns the currently cached link (or creates a new one) -func (mc *mgmtClient) getLinkWithoutLock(ctx context.Context) (RPCLink, error) { - if mc.rpcLink != nil { - return mc.rpcLink, nil - } - - var err error - mc.rpcLink, err = mc.ns.NewRPCLink(ctx, mc.links.ManagementPath()) +// creates a new link and sends the RPC request, recovering and retrying on certain AMQP errors +func doRPC(ctx context.Context, name string, rpcLink RPCLink, msg *amqp.Message) (*RPCResponse, error) { + res, err := rpcLink.RPC(ctx, msg) if err != nil { return nil, err } - return mc.rpcLink, nil -} - -// Close will close the AMQP connection -func (mc *mgmtClient) Close(ctx context.Context) error { - mc.clientMu.Lock() - defer mc.clientMu.Unlock() - - if mc.rpcLink == nil { - return nil + if res.Code >= 200 && res.Code < 300 { + tab.For(ctx).Debug(fmt.Sprintf("rpc: success, status code %d and description: %s", res.Code, res.Description)) + return res, nil } - err := mc.rpcLink.Close(ctx) - mc.rpcLink = nil - return err -} - -// creates a new link and sends the RPC request, recovering and retrying on certain AMQP errors -func (mc *mgmtClient) doRPCWithRetry(ctx context.Context, msg *amqp.Message, times int, delay time.Duration, opts ...rpc.LinkOption) (*rpc.Response, error) { - // track the number of times we attempt to perform the RPC call. - // this is to avoid a potential infinite loop if the returned error - // is always transient and Recover() doesn't fail. - sendCount := 0 - - for { - mc.clientMu.Lock() - rpcLink, err := mc.getLinkWithoutLock(ctx) - mc.clientMu.Unlock() - - var rsp *rpc.Response - - if err == nil { - rsp, err = rpcLink.RetryableRPC(ctx, times, delay, msg) - - if err == nil { - return rsp, err - } - } - - if sendCount >= amqpRetryDefaultTimes || !isAMQPTransientError(ctx, err) { - return nil, err - } - sendCount++ - // if we get here, recover and try again - tab.For(ctx).Debug("recovering RPC connection") - - _, retryErr := common.Retry(amqpRetryDefaultTimes, amqpRetryDefaultDelay, func() (interface{}, error) { - ctx, sp := mc.startProducerSpanFromContext(ctx, string(tracing.SpanTryRecover)) - defer sp.End() - - if err := mc.recover(ctx); err == nil { - tab.For(ctx).Debug("recovered RPC connection") - return nil, nil - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - return nil, common.Retryable(err.Error()) - } - }) - - if retryErr != nil { - tab.For(ctx).Debug("RPC recovering retried, but error was unrecoverable") - return nil, retryErr - } + return nil, mgmtError{ + Message: fmt.Sprintf("rpc: failed, status code %d and description: %s", res.Code, res.Description), + Resp: res, } } -// returns true if the AMQP error is considered transient -func isAMQPTransientError(ctx context.Context, err error) bool { - // always retry on a detach error - var amqpDetach *amqp.DetachError - if errors.As(err, &amqpDetach) { - return true - } - // for an AMQP error, only retry depending on the condition - var amqpErr *amqp.Error - if errors.As(err, &amqpErr) { - switch amqpErr.Condition { - case errorServerBusy, errorTimeout, errorOperationCancelled, errorContainerClose: - return true - default: - tab.For(ctx).Debug(fmt.Sprintf("isAMQPTransientError: condition %s is not transient", amqpErr.Condition)) - return false - } - } - tab.For(ctx).Debug(fmt.Sprintf("isAMQPTransientError: %T is not transient", err)) - return false -} - -func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error) { +func ReceiveDeferred(ctx context.Context, rpcLink RPCLink, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error) { ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanReceiveDeferred, Version) defer span.End() @@ -214,12 +79,6 @@ func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, seq "receiver-settle-mode": uint32(backwardsMode), // pick up messages with peek lock } - var opts []rpc.LinkOption - if mc.isSessionFilterSet { - opts = append(opts, rpc.LinkWithSessionFilter(mc.sessionID)) - values["session-id"] = mc.sessionID - } - msg := &amqp.Message{ ApplicationProperties: map[string]interface{}{ "operation": "com.microsoft:receive-by-sequence-number", @@ -227,9 +86,8 @@ func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, seq Value: values, } - rsp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second, opts...) + rsp, err := doRPC(ctx, "receiveDeferred", rpcLink, msg) if err != nil { - tab.For(ctx).Error(err) return nil, err } @@ -286,7 +144,7 @@ func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, seq return transformedMessages, nil } -func (mc *mgmtClient) PeekMessages(ctx context.Context, fromSequenceNumber int64, messageCount int32) ([]*amqp.Message, error) { +func PeekMessages(ctx context.Context, rpcLink RPCLink, fromSequenceNumber int64, messageCount int32) ([]*amqp.Message, error) { ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanPeekFromSequenceNumber, Version) defer span.End() @@ -306,7 +164,7 @@ func (mc *mgmtClient) PeekMessages(ctx context.Context, fromSequenceNumber int64 msg.ApplicationProperties["server-timeout"] = uint(time.Until(deadline) / time.Millisecond) } - rsp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second) + rsp, err := doRPC(ctx, "peek", rpcLink, msg) if err != nil { tab.For(ctx).Error(err) return nil, err @@ -397,7 +255,7 @@ func (mc *mgmtClient) PeekMessages(ctx context.Context, fromSequenceNumber int64 // RenewLocks renews the locks in a single 'com.microsoft:renew-lock' operation. // NOTE: this function assumes all the messages received on the same link. -func (mc *mgmtClient) RenewLocks(ctx context.Context, linkName string, lockTokens []amqp.UUID) ([]time.Time, error) { +func RenewLocks(ctx context.Context, rpcLink RPCLink, linkName string, lockTokens []amqp.UUID) ([]time.Time, error) { ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanRenewLock, Version) defer span.End() @@ -414,7 +272,7 @@ func (mc *mgmtClient) RenewLocks(ctx context.Context, linkName string, lockToken renewRequestMsg.ApplicationProperties["associated-link-name"] = linkName } - response, err := mc.doRPCWithRetry(ctx, renewRequestMsg, 3, 1*time.Second) + response, err := doRPC(ctx, "renewlocks", rpcLink, renewRequestMsg) if err != nil { tab.For(ctx).Error(err) @@ -451,7 +309,7 @@ func (mc *mgmtClient) RenewLocks(ctx context.Context, linkName string, lockToken } // RenewSessionLocks renews a session lock. -func (mc *mgmtClient) RenewSessionLock(ctx context.Context, sessionID string) (time.Time, error) { +func RenewSessionLock(ctx context.Context, rpcLink RPCLink, sessionID string) (time.Time, error) { body := map[string]interface{}{ "session-id": sessionID, } @@ -463,7 +321,7 @@ func (mc *mgmtClient) RenewSessionLock(ctx context.Context, sessionID string) (t }, } - resp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second) + resp, err := doRPC(ctx, "renewsessionlock", rpcLink, msg) if err != nil { return time.Time{}, err @@ -485,7 +343,7 @@ func (mc *mgmtClient) RenewSessionLock(ctx context.Context, sessionID string) (t } // GetSessionState retrieves state associated with the session. -func (mc *mgmtClient) GetSessionState(ctx context.Context, sessionID string) ([]byte, error) { +func GetSessionState(ctx context.Context, rpcLink RPCLink, sessionID string) ([]byte, error) { amqpMsg := &amqp.Message{ Value: map[string]interface{}{ "session-id": sessionID, @@ -495,7 +353,7 @@ func (mc *mgmtClient) GetSessionState(ctx context.Context, sessionID string) ([] }, } - resp, err := mc.doRPCWithRetry(ctx, amqpMsg, 5, 5*time.Second) + resp, err := doRPC(ctx, "getsessionstate", rpcLink, amqpMsg) if err != nil { return nil, err @@ -528,7 +386,7 @@ func (mc *mgmtClient) GetSessionState(ctx context.Context, sessionID string) ([] } // SetSessionState sets the state associated with the session. -func (mc *mgmtClient) SetSessionState(ctx context.Context, sessionID string, state []byte) error { +func SetSessionState(ctx context.Context, rpcLink RPCLink, sessionID string, state []byte) error { uuid, err := uuid.New() if err != nil { @@ -546,7 +404,7 @@ func (mc *mgmtClient) SetSessionState(ctx context.Context, sessionID string, sta }, } - resp, err := mc.doRPCWithRetry(ctx, amqpMsg, 5, 5*time.Second) + resp, err := doRPC(ctx, "setsessionstate", rpcLink, amqpMsg) if err != nil { return err @@ -562,7 +420,7 @@ func (mc *mgmtClient) SetSessionState(ctx context.Context, sessionID string, sta // SendDisposition allows you settle a message using the management link, rather than via your // *amqp.Receiver. Use this if the receiver has been closed/lost or if the message isn't associated // with a link (ex: deferred messages). -func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID, state Disposition, propertiesToModify map[string]interface{}) error { +func SendDisposition(ctx context.Context, rpcLink RPCLink, lockToken *amqp.UUID, state Disposition, propertiesToModify map[string]interface{}) error { ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanSendDisposition, Version) defer span.End() @@ -572,7 +430,6 @@ func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID, return err } - var opts []rpc.LinkOption value := map[string]interface{}{ "disposition-status": string(state.Status), "lock-tokens": []amqp.UUID{*lockToken}, @@ -598,7 +455,7 @@ func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID, } // no error, then it was successful - _, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second, opts...) + _, err := doRPC(ctx, "senddisposition", rpcLink, msg) if err != nil { tab.For(ctx).Error(err) return err @@ -609,7 +466,7 @@ func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID, // ScheduleMessages will send a batch of messages to a Queue, schedule them to be enqueued, and return the sequence numbers // that can be used to cancel each message. -func (mc *mgmtClient) ScheduleMessages(ctx context.Context, enqueueTime time.Time, messages ...*amqp.Message) ([]int64, error) { +func ScheduleMessages(ctx context.Context, rpcLink RPCLink, enqueueTime time.Time, messages []*amqp.Message) ([]int64, error) { ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanScheduleMessage, Version) defer span.End() @@ -677,7 +534,7 @@ func (mc *mgmtClient) ScheduleMessages(ctx context.Context, enqueueTime time.Tim msg.ApplicationProperties["com.microsoft:server-timeout"] = uint(time.Until(deadline) / time.Millisecond) } - resp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second) + resp, err := doRPC(ctx, "schedule", rpcLink, msg) if err != nil { tab.For(ctx).Error(err) return nil, err @@ -704,9 +561,9 @@ func (mc *mgmtClient) ScheduleMessages(ctx context.Context, enqueueTime time.Tim return nil, NewErrIncorrectType("value", map[string]interface{}{}, resp.Message.Value) } -// CancelScheduled allows for removal of messages that have been handed to the Service Bus broker for later delivery, +// CancelScheduledMessages allows for removal of messages that have been handed to the Service Bus broker for later delivery, // but have not yet ben enqueued. -func (mc *mgmtClient) CancelScheduled(ctx context.Context, seq ...int64) error { +func CancelScheduledMessages(ctx context.Context, rpcLink RPCLink, seq []int64) error { ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanCancelScheduledMessage, Version) defer span.End() @@ -723,7 +580,7 @@ func (mc *mgmtClient) CancelScheduled(ctx context.Context, seq ...int64) error { msg.ApplicationProperties["com.microsoft:server-timeout"] = uint(time.Until(deadline) / time.Millisecond) } - resp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second) + resp, err := doRPC(ctx, "cancelscheduled", rpcLink, msg) if err != nil { tab.For(ctx).Error(err) return err @@ -735,19 +592,3 @@ func (mc *mgmtClient) CancelScheduled(ctx context.Context, seq ...int64) error { return nil } - -func (mc *mgmtClient) startSpanFromContext(ctx context.Context, operationName string) (context.Context, tab.Spanner) { - ctx, span := tracing.StartConsumerSpanFromContext(ctx, operationName, Version) - span.AddAttributes(tab.StringAttribute("message_bus.destination", mc.links.ManagementPath())) - return ctx, span -} - -func (mc *mgmtClient) startProducerSpanFromContext(ctx context.Context, operationName string) (context.Context, tab.Spanner) { - ctx, span := tab.StartSpan(ctx, operationName) - tracing.ApplyComponentInfo(span, Version) - span.AddAttributes( - tab.StringAttribute("span.kind", "producer"), - tab.StringAttribute("message_bus.destination", mc.links.ManagementPath()), - ) - return ctx, span -} diff --git a/sdk/messaging/azservicebus/internal/namespace.go b/sdk/messaging/azservicebus/internal/namespace.go index 806283596ee1..214d2d49963b 100644 --- a/sdk/messaging/azservicebus/internal/namespace.go +++ b/sdk/messaging/azservicebus/internal/namespace.go @@ -13,12 +13,13 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/sbauth" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/auth" "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/cbs" "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/conn" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc" "github.com/Azure/go-amqp" "github.com/devigned/tab" ) @@ -31,6 +32,13 @@ type ( // Namespace is an abstraction over an amqp.Client, allowing us to hold onto a single // instance of a connection per ServiceBusClient. Namespace struct { + // NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens + // is just to make it the first value in the struct + // See: + // Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG + // PR: https://github.com/Azure/azure-sdk-for-go/pull/16847 + connID uint64 + FQDN string TokenProvider *sbauth.TokenProvider tlsConfig *tls.Config @@ -38,12 +46,10 @@ type ( newWebSocketConn func(ctx context.Context, args NewWebSocketConnArgs) (net.Conn, error) - baseRetrier Retrier - - clientMu sync.Mutex - clientRevision uint64 - client *amqp.Client + retryOptions utils.RetryOptions + clientMu sync.RWMutex + client *amqp.Client negotiateClaimMu sync.Mutex } @@ -60,14 +66,9 @@ type NamespaceWithNewAMQPLinks interface { type NamespaceForAMQPLinks interface { NegotiateClaim(ctx context.Context, entityPath string) (func() <-chan struct{}, error) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64, error) - NewMgmtClient(ctx context.Context, links AMQPLinks) (MgmtClient, error) + NewRPCLink(ctx context.Context, managementPath string) (RPCLink, error) GetEntityAudience(entityPath string) string - Recover(ctx context.Context, clientRevision uint64) error -} - -// NamespaceForAMQPLinks is the Namespace surface needed for the *MgmtClient. -type NamespaceForMgmtClient interface { - NewRPCLink(ctx context.Context, managementPath string) (*rpc.Link, error) + Recover(ctx context.Context, clientRevision uint64) (bool, error) } // NamespaceWithConnectionString configures a namespace with the information provided in a Service Bus connection string @@ -125,9 +126,9 @@ func NamespaceWithWebSocket(newWebSocketConn func(ctx context.Context, args NewW } } -// NamespacesWithTokenCredential sets the token provider on the namespace +// NamespaceWithTokenCredential sets the token provider on the namespace // fullyQualifiedNamespace is the Service Bus namespace name (ex: myservicebus.servicebus.windows.net) -func NamespacesWithTokenCredential(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential) NamespaceOption { +func NamespaceWithTokenCredential(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential) NamespaceOption { return func(ns *Namespace) error { ns.TokenProvider = sbauth.NewTokenProvider(tokenCredential) ns.FQDN = fullyQualifiedNamespace @@ -135,22 +136,16 @@ func NamespacesWithTokenCredential(fullyQualifiedNamespace string, tokenCredenti } } +func NamespaceWithRetryOptions(retryOptions utils.RetryOptions) NamespaceOption { + return func(ns *Namespace) error { + ns.retryOptions = retryOptions + return nil + } +} + // NewNamespace creates a new namespace configured through NamespaceOption(s) func NewNamespace(opts ...NamespaceOption) (*Namespace, error) { - ns := &Namespace{ - baseRetrier: NewBackoffRetrier(struct { - MaxRetries int - Factor float64 - Jitter bool - Min time.Duration - Max time.Duration - }{ - Factor: 2, - Min: time.Second, - Max: time.Minute, - MaxRetries: 10, - }), - } + ns := &Namespace{} for _, opt := range opts { err := opt(ns) @@ -199,8 +194,9 @@ func (ns *Namespace) newClient(ctx context.Context) (*amqp.Client, error) { } // NewAMQPSession creates a new AMQP session with the internally cached *amqp.Client. +// Returns a closeable AMQP session and the current client revision. func (ns *Namespace) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64, error) { - client, clientRevision, err := ns.getAMQPClientImpl(ctx) + client, clientRevision, err := ns.GetAMQPClientImpl(ctx) if err != nil { return nil, 0, err @@ -215,26 +211,21 @@ func (ns *Namespace) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uin return session, clientRevision, err } -// NewMgmtClient creates a new management client with the internally cached *amqp.Client. -func (ns *Namespace) NewMgmtClient(ctx context.Context, l AMQPLinks) (MgmtClient, error) { - return newMgmtClient(ctx, l, ns) -} - // NewRPCLink creates a new amqp-common *rpc.Link with the internally cached *amqp.Client. -func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (*rpc.Link, error) { - client, _, err := ns.getAMQPClientImpl(ctx) +func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (RPCLink, error) { + client, _, err := ns.GetAMQPClientImpl(ctx) if err != nil { return nil, err } - return rpc.NewLink(client, managementPath) + return NewRPCLink(client, managementPath) } // NewAMQPLinks creates an AMQPLinks struct, which groups together the commonly needed links for // working with Service Bus. func (ns *Namespace) NewAMQPLinks(entityPath string, createLinkFunc CreateLinkFunc) AMQPLinks { - return newAMQPLinks(ns, entityPath, ns.baseRetrier, createLinkFunc) + return NewAMQPLinks(ns, entityPath, createLinkFunc) } // Close closes the current cached client. @@ -249,9 +240,11 @@ func (ns *Namespace) Close(ctx context.Context) error { return nil } -// Recover destroys the currently held client and recreates it. -// clientRevision being nil will recover without a revision check. -func (ns *Namespace) Recover(ctx context.Context, clientRevision uint64) error { +// Recover destroys the currently held AMQP connection and recreates it, if needed. +// If a new is actually created (rather than just cached) then the returned bool +// will be true. Any links that were created from the original connection will need to +// be recreated. +func (ns *Namespace) Recover(ctx context.Context, theirConnID uint64) (bool, error) { ns.clientMu.Lock() defer ns.clientMu.Unlock() @@ -259,13 +252,13 @@ func (ns *Namespace) Recover(ctx context.Context, clientRevision uint64) error { defer span.End() span.AddAttributes( - tab.Int64Attribute("revision", int64(ns.clientRevision)), - tab.Int64Attribute("requested", int64(clientRevision))) + tab.Int64Attribute("connID", int64(ns.connID)), + tab.Int64Attribute("theirConnID", int64(theirConnID))) - if ns.clientRevision > clientRevision { - span.Logger().Info(fmt.Sprintf("Skipping recovery, already recovered: %d vs %d", ns.clientRevision, clientRevision)) + if ns.connID != theirConnID { + log.Writef(EventConn, "Skipping connection recovery, already recovered: %d vs %d", ns.connID, theirConnID) // we've already recovered since the client last tried. - return nil + return false, nil } if ns.client != nil { @@ -273,23 +266,20 @@ func (ns *Namespace) Recover(ctx context.Context, clientRevision uint64) error { ns.client = nil // the error on close isn't critical - go func() { - span.Logger().Info(fmt.Sprintf("Closing old client (client:%d,passed in:%d)", ns.clientRevision, clientRevision)) - err := oldClient.Close() - tab.For(ctx).Error(err) - }() + _ = oldClient.Close() } var err error - span.Logger().Info(fmt.Sprintf("Creating a new client (client:%d,passed in:%d)", ns.clientRevision, clientRevision)) + log.Writef(EventConn, "Creating a new client (rev:%d)", ns.connID) ns.client, err = ns.newClient(ctx) - if err == nil { - span.AddAttributes(tab.Int64Attribute("newcr", int64(ns.clientRevision))) - ns.clientRevision++ + if err != nil { + return false, err } - return err + ns.connID++ + log.Writef(EventConn, "New client created, (rev: %d)", ns.connID) + return true, nil } // negotiateClaim performs initial authentication and starts periodic refresh of credentials. @@ -298,7 +288,7 @@ func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (fun return ns.startNegotiateClaimRenewer(ctx, entityPath, cbs.NegotiateClaim, - ns.getAMQPClientImpl, + ns.GetAMQPClientImpl, nextClaimRefreshDuration) } @@ -309,61 +299,52 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context, nextClaimRefreshDurationFn func(expirationTime time.Time, currentTime time.Time) time.Duration) (func() <-chan struct{}, error) { audience := ns.GetEntityAudience(entityPath) - refreshClaim := func() (time.Time, error) { - retrier := ns.baseRetrier.Copy() + refreshClaim := func(ctx context.Context) (time.Time, error) { + log.Writef(EventAuth, "(%s) refreshing claim", entityPath) + ctx, span := ns.startSpanFromContext(ctx, tracing.SpanNegotiateClaim) + defer span.End() - var lastErr error - var expiration time.Time + amqpClient, clientRevision, err := nsGetAMQPClientImpl(ctx) - for retrier.Try(ctx) { - expiration, lastErr = func() (time.Time, error) { - ctx, span := ns.startSpanFromContext(ctx, tracing.SpanNegotiateClaim) - defer span.End() + if err != nil { + return time.Time{}, err + } - amqpClient, clientRevision, err := nsGetAMQPClientImpl(ctx) + token, expiration, err := ns.TokenProvider.GetTokenAsTokenProvider(audience) - if err != nil { - span.Logger().Error(err) - return time.Time{}, err - } + if err != nil { + log.Writef(EventAuth, "(%s) negotiate claim, failed getting token: %s", entityPath, err.Error()) + return time.Time{}, err + } - token, expiration, err := ns.TokenProvider.GetTokenAsTokenProvider(audience) + log.Writef(EventAuth, "(%s) negotiate claim, token expires on %s", entityPath, expiration.Format(time.RFC3339)) - if err != nil { - span.Logger().Error(err) - return time.Time{}, err - } + // You're not allowed to have multiple $cbs links open in a single connection. + // The current cbs.NegotiateClaim implementation automatically creates and shuts + // down it's own link so we have to guard against that here. + ns.negotiateClaimMu.Lock() + err = cbsNegotiateClaim(ctx, audience, amqpClient, token) + ns.negotiateClaimMu.Unlock() - // You're not allowed to have multiple $cbs links open in a single connection. - // The current cbs.NegotiateClaim implementation automatically creates and shuts - // down it's own link so we have to guard against that here. - ns.negotiateClaimMu.Lock() - err = cbsNegotiateClaim(ctx, audience, amqpClient, token) - ns.negotiateClaimMu.Unlock() - - if err != nil { - if shouldRecreateConnection(ctx, err) { - if err := ns.Recover(ctx, clientRevision); err != nil { - span.Logger().Error(fmt.Errorf("connection recovery failed: %w", err)) - } - } + sbe := GetSBErrInfo(err) - span.Logger().Error(err) - return time.Time{}, err + if sbe != nil { + // Note we only handle connection recovery here since (currently) + // the negotiateClaim code creates it's own link each time. + if sbe.RecoveryKind == RecoveryKindConn { + if _, err := ns.Recover(ctx, clientRevision); err != nil { + log.Writef(EventAuth, "(%s) negotiate claim, failed in connection recovery: %s", entityPath, err) } - - return expiration, nil - }() - - if lastErr == nil { - break } + + log.Writef(EventAuth, "(%s) negotiate claim, failed: %s", entityPath, err.Error()) + return time.Time{}, err } - return expiration, lastErr + return expiration, nil } - expiresOn, err := refreshClaim() + expiresOn, err := refreshClaim(ctx) if err != nil { return nil, err @@ -373,15 +354,38 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context, refreshCtx, cancel := context.WithCancel(context.Background()) go func() { + TokenRefreshLoop: for { + nextClaimAt := nextClaimRefreshDurationFn(expiresOn, time.Now()) + + log.Writef(EventAuth, "(%s) next refresh in %s", entityPath, nextClaimAt) + select { case <-refreshCtx.Done(): return - case <-time.After(nextClaimRefreshDurationFn(expiresOn, time.Now())): - tmpExpiresOn, err := refreshClaim() // logging will report the error for now + case <-time.After(nextClaimAt): + for { + err := utils.Retry(refreshCtx, "claimrefresh", func(ctx context.Context, args *utils.RetryFnArgs) error { + tmpExpiresOn, err := refreshClaim(ctx) - if err == nil { - expiresOn = tmpExpiresOn + if err != nil { + return err + } + + expiresOn = tmpExpiresOn + return nil + }, IsFatalSBError, ns.retryOptions) + + if err == nil { + break + } + + // if we fail our retries _and_ we've exceeded the window where our token would have + // been good we can just stop. + if time.Since(expiresOn) <= 0 { + log.Writef(EventAuth, "[%s] token has expired, stopping refresh loop", entityPath) + break TokenRefreshLoop + } } } } @@ -395,26 +399,22 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context, return cancelRefresh, nil } -func (ns *Namespace) getAMQPClientImpl(ctx context.Context) (*amqp.Client, uint64, error) { +func (ns *Namespace) GetAMQPClientImpl(ctx context.Context) (*amqp.Client, uint64, error) { ns.clientMu.Lock() defer ns.clientMu.Unlock() if ns.client != nil { - return ns.client, ns.clientRevision, nil + return ns.client, ns.connID, nil } var err error - retrier := ns.baseRetrier.Copy() - - for retrier.Try(ctx) { - ns.client, err = ns.newClient(ctx) + ns.client, err = ns.newClient(ctx) - if err == nil { - break - } + if err == nil { + ns.connID++ } - return ns.client, ns.clientRevision, err + return ns.client, ns.connID, err } func (ns *Namespace) getWSSHostURI() string { diff --git a/sdk/messaging/azservicebus/internal/namespace_test.go b/sdk/messaging/azservicebus/internal/namespace_test.go index dcdcf8449b78..b7bb69bd738b 100644 --- a/sdk/messaging/azservicebus/internal/namespace_test.go +++ b/sdk/messaging/azservicebus/internal/namespace_test.go @@ -17,39 +17,6 @@ import ( "github.com/stretchr/testify/require" ) -// implements `Retrier` interface. -type fakeRetrier struct { - tryCalled int - copyCalled int -} - -func (r *fakeRetrier) Copy() Retrier { - r.copyCalled++ - - // NOTE: purposefully not making a copy so I can keep track of the - // try/copy counts. - return r -} - -func (r *fakeRetrier) Exhausted() bool { - return false -} - -func (r *fakeRetrier) Try(ctx context.Context) bool { - select { - case <-ctx.Done(): - return false - default: - } - - r.tryCalled++ - return true -} - -func (r *fakeRetrier) CurrentTry() int { - return r.tryCalled -} - type fakeTokenCredential struct { azcore.TokenCredential expires time.Time @@ -62,12 +29,10 @@ func (ftc *fakeTokenCredential) GetToken(ctx context.Context, options policy.Tok } func TestNamespaceNegotiateClaim(t *testing.T) { - retrier := &fakeRetrier{} - expires := time.Now().Add(24 * time.Hour) ns := &Namespace{ - baseRetrier: retrier, + retryOptions: retryOptionsOnlyOnce, TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: expires}), } @@ -106,17 +71,13 @@ func TestNamespaceNegotiateClaim(t *testing.T) { require.EqualValues(t, getAMQPClientCalled, 1) require.EqualValues(t, 1, cbsNegotiateClaimCalled) - require.EqualValues(t, 1, retrier.copyCalled) - require.EqualValues(t, 1, retrier.tryCalled) } func TestNamespaceNegotiateClaimRenewal(t *testing.T) { - retrier := &fakeRetrier{} - expires := time.Now().Add(24 * time.Hour) ns := &Namespace{ - baseRetrier: retrier, + retryOptions: retryOptionsOnlyOnce, TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: expires}), } @@ -163,8 +124,6 @@ func TestNamespaceNegotiateClaimRenewal(t *testing.T) { require.GreaterOrEqual(t, getAMQPClientCalled, 2+1) // that last +1 is when we blocked to prevent us renewing too much for our test! - require.EqualValues(t, 3, retrier.copyCalled) - require.EqualValues(t, 3, retrier.tryCalled) require.EqualValues(t, 2, nextRefreshDurationChecks) require.EqualValues(t, 2, cbsNegotiateClaimCalled) @@ -175,7 +134,6 @@ func TestNamespaceNegotiateClaimRenewal(t *testing.T) { func TestNamespaceNegotiateClaimFailsToGetClient(t *testing.T) { ns := &Namespace{ - baseRetrier: noRetryRetrier.Copy(), TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: time.Now()}), } @@ -197,7 +155,6 @@ func TestNamespaceNegotiateClaimFailsToGetClient(t *testing.T) { func TestNamespaceNegotiateClaimFails(t *testing.T) { ns := &Namespace{ - baseRetrier: noRetryRetrier.Copy(), TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: time.Now()}), } @@ -232,13 +189,3 @@ func TestNamespaceNextClaimRefreshDuration(t *testing.T) { require.EqualValues(t, 3*time.Minute, nextClaimRefreshDuration(now.Add(3*time.Minute+clockDrift), now)) } - -var noRetryRetrier = NewBackoffRetrier(struct { - MaxRetries int - Factor float64 - Jitter bool - Min time.Duration - Max time.Duration -}{ - MaxRetries: 0, -}) diff --git a/sdk/messaging/azservicebus/internal/retrier.go b/sdk/messaging/azservicebus/internal/retrier.go deleted file mode 100644 index 700abfa3ef61..000000000000 --- a/sdk/messaging/azservicebus/internal/retrier.go +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package internal - -import ( - "context" - "math" - "math/rand" - "time" - - "github.com/Azure/azure-sdk-for-go/sdk/internal/log" - "github.com/jpillora/backoff" -) - -// A retrier that allows you to do a basic for loop and get backoff -// and retry limits. See `Try` for more details on how to use it. -type Retrier interface { - // Copies the retrier. Retriers are stateful and must be copied - // before starting a set of retries. - Copy() Retrier - - // Exhausted is true if the retries were exhausted. - Exhausted() bool - - // CurrentTry is the current try (0 for the first run before retries) - CurrentTry() int - - // Try marks an attempt to call (first call to Try() does not sleep). - // Will return false if the `ctx` is cancelled or if we exhaust our retries. - // - // rp := RetryPolicy{Backoff:defaultBackoffPolicy, MaxRetries:5} - // - // for rp.Try(ctx) { - // - // } - // - // if rp.Cancelled() || rp.Exhausted() { - // // no more retries needed - // } - // - Try(ctx context.Context) bool -} - -// Encapsulates a backoff policy, which allows you to configure the amount of -// time in between retries as well as the maximum retries allowed (via MaxRetries) -// NOTE: this should be copied by the caller as it is stateful. -type backoffRetrier struct { - backoff backoff.Backoff - MaxRetries int - - tries int -} - -// BackoffRetrierParams are parameters for NewBackoffRetrier. -type BackoffRetrierParams struct { - // MaxRetries is the maximum number of tries (after the first attempt) - // that are allowed. - MaxRetries int - // Factor is the multiplying factor for each increment step - Factor float64 - // Jitter eases contention by randomizing backoff steps - Jitter bool - // Min and Max are the minimum and maximum values of the counter - Min, Max time.Duration -} - -// NewBackoffRetrier creates a retrier that allows for configurable -// min/max times, jitter and maximum retries. -func NewBackoffRetrier(params BackoffRetrierParams) Retrier { - return &backoffRetrier{ - backoff: backoff.Backoff{ - Factor: params.Factor, - Jitter: params.Jitter, - Min: params.Min, - Max: params.Max, - }, - MaxRetries: params.MaxRetries, - } -} - -// Copies the backoff retrier since it's stateful. -func (rp *backoffRetrier) Copy() Retrier { - copy := *rp - return © -} - -// Exhausted is true if all the retries have been used. -func (rp *backoffRetrier) Exhausted() bool { - return rp.tries > rp.MaxRetries -} - -// CurrentTry is the current try number (0 for the first run before retries) -func (rp *backoffRetrier) CurrentTry() int { - return rp.tries -} - -// Try marks an attempt to call (first call to Try() does not sleep). -// Will return false if the `ctx` is cancelled or if we exhaust our retries. -// -// rp := RetryPolicy{Backoff:defaultBackoffPolicy, MaxRetries:5} -// -// for rp.Try(ctx) { -// -// } -// -// if rp.Cancelled() || rp.Exhausted() { -// // no more retries needed -// } -// -func (rp *backoffRetrier) Try(ctx context.Context) bool { - defer func() { rp.tries++ }() - - select { - case <-ctx.Done(): - return false - default: - } - - if rp.tries == 0 { - // first 'try' is always free - return true - } - - if rp.Exhausted() { - return false - } - - select { - case <-time.After(rp.backoff.Duration()): - return true - case <-ctx.Done(): - return false - } -} - -type RetryFnArgs struct { - I int32 - // LastErr is the returned error from the previous loop. - // If you have potentially expensive - LastErr error - - resetAttempts bool -} - -// ResetAttempts causes the current retry attempt number to be reset -// in time for the next recovery (should we fail). -// NOTE: Use of this should be pretty rare, it's really only needed when you have -// a situation like Receiver.ReceiveMessages() that can recovery but intentionally -// does not return. -func (rf *RetryFnArgs) ResetAttempts() { - rf.resetAttempts = true -} - -// Retry runs a standard retry loop. It executes your passed in fn as the body of the loop. -// 'isFatal' can be nil, and defaults to just checking that ServiceBusError(err).recoveryKind != recoveryKindNonRetriable. -// It returns if it exceeds the number of configured retry options or if 'isFatal' returns true. -func Retry(ctx context.Context, name string, fn func(ctx context.Context, args *RetryFnArgs) error, isFatalFn func(err error) bool, o RetryOptions) error { - var ro RetryOptions = o - setDefaults(&ro) - - var err error - - for i := int32(0); i <= ro.MaxRetries; i++ { - if i > 0 { - sleep := calcDelay(ro, i) - log.Writef(EventRetry, "(%s) Attempt %d sleeping for %s", name, i, sleep) - time.Sleep(sleep) - } - - args := RetryFnArgs{ - I: i, - LastErr: err, - } - err = fn(ctx, &args) - - if args.resetAttempts { - log.Writef(EventRetry, "(%s) Resetting attempts", name) - i = int32(0) - } - - if err != nil { - if isFatalFn != nil { - if isFatalFn(err) { - log.Writef(EventRetry, "(%s) Attempt %d returned non-retryable error: %s", name, i, err.Error()) - return err - } else { - log.Writef(EventRetry, "(%s) Attempt %d returned retryable error: %s", name, i, err.Error()) - } - } else { - recoveryKind := ToSBE(ctx, err).RecoveryKind - if recoveryKind == RecoveryKindFatal { - log.Writef(EventRetry, "(%s) Attempt %d returned non-retryable error: %s", name, i, err.Error()) - return err - } else { - log.Writef(EventRetry, "(%s) Attempt %d returned retryable error with recovery kind %s: %s", name, i, recoveryKind, err.Error()) - } - } - - continue - } - - return nil - } - - return err -} - -// RetryOptions represent the options for retries. -type RetryOptions struct { - // MaxRetries specifies the maximum number of attempts a failed operation will be retried - // before producing an error. - // The default value is three. A value less than zero means one try and no retries. - MaxRetries int32 - - // RetryDelay specifies the initial amount of delay to use before retrying an operation. - // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. - // The default value is four seconds. A value less than zero means no delay between retries. - RetryDelay time.Duration - - // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. - // Typically the value is greater than or equal to the value specified in RetryDelay. - // The default Value is 120 seconds. A value less than zero means there is no cap. - MaxRetryDelay time.Duration -} - -func setDefaults(o *RetryOptions) { - if o.MaxRetries == 0 { - o.MaxRetries = 3 - } else if o.MaxRetries < 0 { - o.MaxRetries = 0 - } - if o.MaxRetryDelay == 0 { - o.MaxRetryDelay = 120 * time.Second - } else if o.MaxRetryDelay < 0 { - // not really an unlimited cap, but sufficiently large enough to be considered as such - o.MaxRetryDelay = math.MaxInt64 - } - if o.RetryDelay == 0 { - o.RetryDelay = 4 * time.Second - } else if o.RetryDelay < 0 { - o.RetryDelay = 0 - } -} - -// (adapted from from azcore/policy_retry) -func calcDelay(o RetryOptions, try int32) time.Duration { - if try == 0 { - return 0 - } - - pow := func(number int64, exponent int32) int64 { // pow is nested helper function - var result int64 = 1 - for n := int32(0); n < exponent; n++ { - result *= number - } - return result - } - - delay := time.Duration(pow(2, try)-1) * o.RetryDelay - - // Introduce some jitter: [0.0, 1.0) / 2 = [0.0, 0.5) + 0.8 = [0.8, 1.3) - delay = time.Duration(delay.Seconds() * (rand.Float64()/2 + 0.8) * float64(time.Second)) // NOTE: We want math/rand; not crypto/rand - if delay > o.MaxRetryDelay { - delay = o.MaxRetryDelay - } - return delay -} diff --git a/sdk/messaging/azservicebus/internal/retrier_test.go b/sdk/messaging/azservicebus/internal/retrier_test.go deleted file mode 100644 index 4cfdc52818c7..000000000000 --- a/sdk/messaging/azservicebus/internal/retrier_test.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package internal - -import ( - "context" - "errors" - "math" - "testing" - "time" - - "github.com/Azure/go-amqp" - "github.com/stretchr/testify/require" -) - -func TestRetrier(t *testing.T) { - t.Run("Basic", func(t *testing.T) { - retrier := NewBackoffRetrier(struct { - MaxRetries int - Factor float64 - Jitter bool - Min time.Duration - Max time.Duration - }{ - MaxRetries: 5, - Factor: 0, - }) - - require := require.New(t) - - // first iteration is always free (ie, that's not the - // retry part) - require.True(retrier.Try(context.Background())) - - // now we're doing retries - require.True(retrier.Try(context.Background())) - require.True(retrier.Try(context.Background())) - require.True(retrier.Try(context.Background())) - require.True(retrier.Try(context.Background())) - require.True(retrier.Try(context.Background())) - - // and it's the 6th retry that fails since we've exhausted - // the retries we're allotted. - require.False(retrier.Try(context.Background())) - require.True(retrier.Exhausted()) - }) - - t.Run("Cancellation", func(t *testing.T) { - retrier := NewBackoffRetrier(struct { - MaxRetries int - Factor float64 - Jitter bool - Min time.Duration - Max time.Duration - }{ - MaxRetries: 5, - Factor: 0, - }) - - // first iteration is always free (ie, that's not the - // retry part) - cancelledContext, cancel := context.WithCancel(context.Background()) - cancel() - require.False(t, retrier.Try(cancelledContext)) - }) -} - -var fastRetryOptions = RetryOptions{ - // note: omitting MaxRetries just to give a sanity check that - // we do setDefaults() before we run. - RetryDelay: time.Millisecond, - MaxRetryDelay: time.Millisecond, -} - -func TestRetryBasic(t *testing.T) { - called := 0 - - err := Retry(context.Background(), "retrytest", func(ctx context.Context, args *RetryFnArgs) error { - require.NotNil(t, args) - require.NotNil(t, ctx) - - called++ - - return &amqp.DetachError{} - }, nil, fastRetryOptions) - - var de *amqp.DetachError - require.ErrorAs(t, err, &de) - require.EqualValues(t, 4, called) -} - -func TestRetryWithFatalError(t *testing.T) { - called := 0 - - err := Retry(context.Background(), "retrytest", func(ctx context.Context, args *RetryFnArgs) error { - require.NotNil(t, args) - require.NotNil(t, ctx) - - called++ - - return &amqp.Error{ - // this is just a basic non-recoverable situation - typically happens if the - // lock period expires. - Condition: amqp.ErrorCondition("com.microsoft:message-lock-lost"), - } - }, nil, fastRetryOptions) - - // fatal error so we only called the function once - require.EqualValues(t, 1, called) - - var testErr *amqp.Error - - require.ErrorAs(t, err, &testErr) - require.EqualValues(t, "com.microsoft:message-lock-lost", testErr.Condition) -} - -func TestRetryCustomIsFatal(t *testing.T) { - called := 0 - var totallyHarmlessErrorAsFatal = errors.New("I'm supposed to be harmless but the custom error handler is going to make me fatal") - var isFatalErr error - - err := Retry(context.Background(), "retrytest", func(ctx context.Context, args *RetryFnArgs) error { - require.NotNil(t, args) - require.NotNil(t, ctx) - - called++ - - return totallyHarmlessErrorAsFatal - }, func(err error) bool { - require.Nil(t, isFatalErr, "should only get called once") - isFatalErr = err - return true - }, fastRetryOptions) - - // fatal error so we only called the function once - require.EqualValues(t, 1, called) - - require.ErrorIs(t, err, totallyHarmlessErrorAsFatal) - require.ErrorIs(t, isFatalErr, totallyHarmlessErrorAsFatal) -} - -func TestRetryDefaults(t *testing.T) { - ro := RetryOptions{} - setDefaults(&ro) - - require.EqualValues(t, 3, ro.MaxRetries) - require.EqualValues(t, 4*time.Second, ro.RetryDelay) - require.EqualValues(t, 2*time.Minute, ro.MaxRetryDelay) - - // this is an interesting default. Anything < 0 basically - // causes the max delay to be "infinite" - ro.MaxRetryDelay = -1 - // whereas this just normalizes to '0' - ro.RetryDelay = -1 - ro.MaxRetries = -1 - setDefaults(&ro) - require.EqualValues(t, time.Duration(math.MaxInt64), ro.MaxRetryDelay) - require.EqualValues(t, 0, ro.MaxRetries) - require.EqualValues(t, time.Duration(0), ro.RetryDelay) -} - -func TestCalcDelay(t *testing.T) { - // calcDelay introduces some jitter, automatically. - ro := RetryOptions{} - setDefaults(&ro) - d := calcDelay(ro, 0) - require.EqualValues(t, 0, d) - - // by default the first calc is 2^attempt - d = calcDelay(ro, 1) - require.LessOrEqual(t, d, 6*time.Second) - require.GreaterOrEqual(t, d, time.Second) -} diff --git a/sdk/messaging/azservicebus/internal/rpc.go b/sdk/messaging/azservicebus/internal/rpc.go new file mode 100644 index 000000000000..5875dc6b5c48 --- /dev/null +++ b/sdk/messaging/azservicebus/internal/rpc.go @@ -0,0 +1,445 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package internal + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/devigned/tab" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" + "github.com/Azure/go-amqp" +) + +const ( + replyPostfix = "-reply-to-" + statusCodeKey = "status-code" + descriptionKey = "status-description" + defaultReceiverCredits = 1000 +) + +type ( + // rpcLink is the bidirectional communication structure used for CBS negotiation + rpcLink struct { + session *amqp.Session + + receiver amqpReceiver // *amqp.Receiver + sender amqpSender // *amqp.Sender + + clientAddress string + sessionID *string + id string + + responseMu sync.Mutex + startResponseRouterOnce *sync.Once + responseMap map[string]chan rpcResponse + broadcastErr error // the error that caused the responseMap to be nil'd + + // for unit tests + uuidNewV4 func() (uuid.UUID, error) + messageAccept func(ctx context.Context, message *amqp.Message) error + } + + // RPCResponse is the simplified response structure from an RPC like call + RPCResponse struct { + Code int + Description string + Message *amqp.Message + } + + // RPCLinkOption provides a way to customize the construction of a Link + RPCLinkOption func(link *rpcLink) error + + rpcResponse struct { + message *amqp.Message + err error + } + + // Actually: *amqp.Receiver + amqpReceiver interface { + Receive(ctx context.Context) (*amqp.Message, error) + Close(ctx context.Context) error + } + + amqpSender interface { + Send(ctx context.Context, msg *amqp.Message) error + Close(ctx context.Context) error + } +) + +// NewLink will build a new request response link +func NewRPCLink(conn *amqp.Client, address string, opts ...RPCLinkOption) (*rpcLink, error) { + authSession, err := conn.NewSession() + if err != nil { + return nil, err + } + + return newRPCLinkWithSession(authSession, address, opts...) +} + +// NewLinkWithSession will build a new request response link, but will reuse an existing AMQP session +func newRPCLinkWithSession(session *amqp.Session, address string, opts ...RPCLinkOption) (*rpcLink, error) { + linkID, err := uuid.New() + if err != nil { + return nil, err + } + + id := linkID.String() + link := &rpcLink{ + session: session, + clientAddress: strings.Replace("$", "", address, -1) + replyPostfix + id, + id: id, + + uuidNewV4: uuid.New, + responseMap: map[string]chan rpcResponse{}, + startResponseRouterOnce: &sync.Once{}, + } + + for _, opt := range opts { + if err := opt(link); err != nil { + return nil, err + } + } + + sender, err := session.NewSender( + amqp.LinkTargetAddress(address), + ) + if err != nil { + return nil, err + } + + receiverOpts := []amqp.LinkOption{ + amqp.LinkSourceAddress(address), + amqp.LinkTargetAddress(link.clientAddress), + amqp.LinkCredit(defaultReceiverCredits), + } + + if link.sessionID != nil { + const name = "com.microsoft:session-filter" + const code = uint64(0x00000137000000C) + if link.sessionID == nil { + receiverOpts = append(receiverOpts, amqp.LinkSourceFilter(name, code, nil)) + } else { + receiverOpts = append(receiverOpts, amqp.LinkSourceFilter(name, code, link.sessionID)) + } + } + + receiver, err := session.NewReceiver(receiverOpts...) + if err != nil { + // make sure we close the sender + clsCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = sender.Close(clsCtx) + return nil, err + } + + link.sender = sender + link.receiver = receiver + link.messageAccept = receiver.AcceptMessage + + return link, nil +} + +// startResponseRouter is responsible for taking any messages received on the 'response' +// link and forwarding it to the proper channel. The channel is being select'd by the +// original `RPC` call. +func (l *rpcLink) startResponseRouter() { + for { + res, err := l.receiver.Receive(context.Background()) + + // You'll see this when the link is shutting down (either + // service-initiated via 'detach' or a user-initiated shutdown) + if isClosedError(err) { + l.broadcastError(err) + break + } + + // I don't believe this should happen. The JS version of this same code + // ignores errors as well since responses should always be correlated + // to actual send requests. So this is just here for completeness. + if res == nil { + continue + } + + autogenMessageId, ok := res.Properties.CorrelationID.(string) + + if !ok { + // TODO: it'd be good to track these in some way. We don't have a good way to + // forward this on at this point. + continue + } + + ch := l.deleteChannelFromMap(autogenMessageId) + + if ch != nil { + ch <- rpcResponse{message: res, err: err} + } + } +} + +// RPC sends a request and waits on a response for that request +func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error) { + l.startResponseRouterOnce.Do(func() { + go l.startResponseRouter() + }) + + copiedMessage, messageID, err := addMessageID(msg, l.uuidNewV4) + + if err != nil { + return nil, err + } + + // use the copiedMessage from this point + msg = copiedMessage + + const altStatusCodeKey, altDescriptionKey = "statusCode", "statusDescription" + + ctx, span := tab.StartSpan(ctx, "rpc.RPC") + tracing.ApplyComponentInfo(span, Version) + defer span.End() + + msg.Properties.ReplyTo = &l.clientAddress + + if msg.ApplicationProperties == nil { + msg.ApplicationProperties = make(map[string]interface{}) + } + + if _, ok := msg.ApplicationProperties["server-timeout"]; !ok { + if deadline, ok := ctx.Deadline(); ok { + msg.ApplicationProperties["server-timeout"] = uint(time.Until(deadline) / time.Millisecond) + } + } + + responseCh := l.addChannelToMap(messageID) + + if responseCh == nil { + return nil, l.broadcastErr + } + + err = l.sender.Send(ctx, msg) + + if err != nil { + l.deleteChannelFromMap(messageID) + tab.For(ctx).Error(err) + return nil, err + } + + var res *amqp.Message + + select { + case <-ctx.Done(): + l.deleteChannelFromMap(messageID) + res, err = nil, ctx.Err() + case resp := <-responseCh: + // this will get triggered by the loop in 'startReceiverRouter' when it receives + // a message with our autoGenMessageID set in the correlation_id property. + res, err = resp.message, resp.err + } + + if err != nil { + tab.For(ctx).Error(err) + return nil, err + } + + var statusCode int + statusCodeCandidates := []string{statusCodeKey, altStatusCodeKey} + for i := range statusCodeCandidates { + if rawStatusCode, ok := res.ApplicationProperties[statusCodeCandidates[i]]; ok { + if cast, ok := rawStatusCode.(int32); ok { + statusCode = int(cast) + break + } + + err := errors.New("status code was not of expected type int32") + tab.For(ctx).Error(err) + return nil, err + } + } + if statusCode == 0 { + err := errors.New("status codes was not found on rpc message") + tab.For(ctx).Error(err) + return nil, err + } + + var description string + descriptionCandidates := []string{descriptionKey, altDescriptionKey} + for i := range descriptionCandidates { + if rawDescription, ok := res.ApplicationProperties[descriptionCandidates[i]]; ok { + if description, ok = rawDescription.(string); ok || rawDescription == nil { + break + } else { + return nil, errors.New("status description was not of expected type string") + } + } + } + + span.AddAttributes(tab.StringAttribute("http.status_code", fmt.Sprintf("%d", statusCode))) + + response := &RPCResponse{ + Code: int(statusCode), + Description: description, + Message: res, + } + + if err := l.messageAccept(ctx, res); err != nil { + tab.For(ctx).Error(err) + return response, err + } + + return response, err +} + +// Close the link receiver, sender and session +func (l *rpcLink) Close(ctx context.Context) error { + ctx, span := startRPCSpan(ctx, "rpc.Close") + defer span.End() + + if err := l.closeReceiver(ctx); err != nil { + _ = l.closeSender(ctx) + _ = l.closeSession(ctx) + return err + } + + if err := l.closeSender(ctx); err != nil { + _ = l.closeSession(ctx) + return err + } + + return l.closeSession(ctx) +} + +func (l *rpcLink) closeReceiver(ctx context.Context) error { + ctx, span := startRPCSpan(ctx, "rpc.closeReceiver") + defer span.End() + + if l.receiver != nil { + return l.receiver.Close(ctx) + } + return nil +} + +func (l *rpcLink) closeSender(ctx context.Context) error { + ctx, span := startRPCSpan(ctx, "rpc.closeSender") + defer span.End() + + if l.sender != nil { + return l.sender.Close(ctx) + } + return nil +} + +func (l *rpcLink) closeSession(ctx context.Context) error { + ctx, span := startRPCSpan(ctx, "rpc.closeSession") + defer span.End() + + if l.session != nil { + return l.session.Close(ctx) + } + return nil +} + +// addChannelToMap adds a channel which will be used by the response router to +// notify when there is a response to the request. +// If l.responseMap is nil (for instance, via broadcastError) this function will +// return nil. +func (l *rpcLink) addChannelToMap(messageID string) chan rpcResponse { + l.responseMu.Lock() + defer l.responseMu.Unlock() + + if l.responseMap == nil { + return nil + } + + responseCh := make(chan rpcResponse, 1) + l.responseMap[messageID] = responseCh + + return responseCh +} + +// deleteChannelFromMap removes the message from our internal map and returns +// a channel that the corresponding RPC() call is waiting on. +// If l.responseMap is nil (for instance, via broadcastError) this function will +// return nil. +func (l *rpcLink) deleteChannelFromMap(messageID string) chan rpcResponse { + l.responseMu.Lock() + defer l.responseMu.Unlock() + + if l.responseMap == nil { + return nil + } + + ch := l.responseMap[messageID] + delete(l.responseMap, messageID) + + return ch +} + +// broadcastError notifies the anyone waiting for a response that the link/session/connection +// has closed. +func (l *rpcLink) broadcastError(err error) { + l.responseMu.Lock() + defer l.responseMu.Unlock() + + for _, ch := range l.responseMap { + ch <- rpcResponse{err: err} + } + + l.broadcastErr = err + l.responseMap = nil +} + +// addMessageID generates a unique UUID for the message. When the service +// responds it will fill out the correlation ID property of the response +// with this ID, allowing us to link the request and response together. +// +// NOTE: this function copies 'message', adding in a 'Properties' object +// if it does not already exist. +func addMessageID(message *amqp.Message, uuidNewV4 func() (uuid.UUID, error)) (*amqp.Message, string, error) { + uuid, err := uuidNewV4() + + if err != nil { + return nil, "", err + } + + autoGenMessageID := uuid.String() + + // we need to modify the message so we'll make a copy + copiedMessage := *message + + if message.Properties == nil { + copiedMessage.Properties = &amqp.MessageProperties{ + MessageID: autoGenMessageID, + } + } else { + // properties already exist, make a copy and then update + // the message ID + copiedProperties := *message.Properties + copiedProperties.MessageID = autoGenMessageID + + copiedMessage.Properties = &copiedProperties + } + + return &copiedMessage, autoGenMessageID, nil +} + +func isClosedError(err error) bool { + var detachError *amqp.DetachError + + return errors.Is(err, amqp.ErrLinkClosed) || + errors.As(err, &detachError) || + errors.Is(err, amqp.ErrConnClosed) || + errors.Is(err, amqp.ErrSessionClosed) +} + +func startRPCSpan(ctx context.Context, operation string) (context.Context, tab.Spanner) { + ctx, span := tab.StartSpan(ctx, operation) + tracing.ApplyComponentInfo(span, Version) + return ctx, span +} diff --git a/sdk/messaging/azservicebus/internal/sbauth/token_provider.go b/sdk/messaging/azservicebus/internal/sbauth/token_provider.go index 8aa15a6f1890..89fbf7b229cc 100644 --- a/sdk/messaging/azservicebus/internal/sbauth/token_provider.go +++ b/sdk/messaging/azservicebus/internal/sbauth/token_provider.go @@ -58,7 +58,6 @@ func (tp *TokenProvider) GetToken(uri string) (*auth.Token, error) { // GetToken returns a token (that is compatible as an auth.TokenProvider) and // the calculated time when you should renew your token. func (tp *TokenProvider) GetTokenAsTokenProvider(uri string) (*singleUseTokenProvider, time.Time, error) { - token, renewAt, err := tp.getTokenImpl(uri) if err != nil { diff --git a/sdk/messaging/azservicebus/internal/stress/Chart.lock b/sdk/messaging/azservicebus/internal/stress/Chart.lock new file mode 100644 index 000000000000..96e9785515bb --- /dev/null +++ b/sdk/messaging/azservicebus/internal/stress/Chart.lock @@ -0,0 +1,6 @@ +dependencies: +- name: stress-test-addons + repository: https://stresstestcharts.blob.core.windows.net/helm/ + version: 0.1.12 +digest: sha256:a2afbc89375d1518cd44194ca861e8a7a1ca1c16cfaea3409cdb48bfc910d878 +generated: "2021-11-16T19:17:10.5545801-05:00" diff --git a/sdk/messaging/azservicebus/internal/stress/Chart.yaml b/sdk/messaging/azservicebus/internal/stress/Chart.yaml new file mode 100644 index 000000000000..044770e86ef2 --- /dev/null +++ b/sdk/messaging/azservicebus/internal/stress/Chart.yaml @@ -0,0 +1,14 @@ +apiVersion: v2 +name: go-sb-stress +description: Stress tests for Go +version: 0.1.1 +appVersion: v0.1 +annotations: + stressTest: 'true' # enable auto-discovery of this test via `find-all-stress-packages.ps1` + namespace: 'go' + dockerbuilddir: '../..' + dockerfile: './Dockerfile' +dependencies: +- name: stress-test-addons + version: 0.1.12 + repository: https://stresstestcharts.blob.core.windows.net/helm/ \ No newline at end of file diff --git a/sdk/messaging/azservicebus/internal/stress/Dockerfile b/sdk/messaging/azservicebus/internal/stress/Dockerfile index e553930c2ea3..3159517f3471 100644 --- a/sdk/messaging/azservicebus/internal/stress/Dockerfile +++ b/sdk/messaging/azservicebus/internal/stress/Dockerfile @@ -1,4 +1,14 @@ -FROM alpine:latest -ADD ./stress / -WORKDIR / -ENTRYPOINT ["/stress"] +FROM golang as build +# you'll need to run this build from the root of the repo +ENV GOOS=linux +ENV GOARCH=amd64 +ENV CGO_ENABLED=0 +ADD . /src +WORKDIR /src/internal/stress +RUN go build -o stress . + +FROM alpine +RUN mkdir -p /app +COPY --from=build /src/internal/stress/stress /app/stress +WORKDIR /app +ENTRYPOINT ["./stress"] \ No newline at end of file diff --git a/sdk/messaging/azservicebus/internal/stress/shared/stats.go b/sdk/messaging/azservicebus/internal/stress/shared/stats.go index 72b8250960d7..858aff921a0f 100644 --- a/sdk/messaging/azservicebus/internal/stress/shared/stats.go +++ b/sdk/messaging/azservicebus/internal/stress/shared/stats.go @@ -80,28 +80,30 @@ func newStatsPrinter(ctx context.Context, prefix string, interval time.Duration, } sp.mu.RLock() + sp.PrintStats() + sp.mu.RUnlock() + } + }(ctx) - log.Printf("Stats:") - - for _, stats := range sp.all { - log.Printf(" %s", stats.String()) + return sp +} - if stats.Sent > 0 { - sp.tc.TrackMetric(fmt.Sprintf("%s.TotalSent", stats.name), float64(stats.Sent)) - } +func (sp *statsPrinter) PrintStats() { + log.Printf("Stats:") - if stats.Received > 0 { - sp.tc.TrackMetric(fmt.Sprintf("%s.TotalReceived", stats.name), float64(stats.Received)) - } + for _, stats := range sp.all { + log.Printf(" %s", stats.String()) - sp.tc.TrackMetric(fmt.Sprintf("%s.TotalErrors", stats.name), float64(stats.Errors)) - } + if stats.Sent > 0 { + sp.tc.TrackMetric(fmt.Sprintf("%s.TotalSent", stats.name), float64(stats.Sent)) + } - sp.mu.RUnlock() + if stats.Received > 0 { + sp.tc.TrackMetric(fmt.Sprintf("%s.TotalReceived", stats.name), float64(stats.Received)) } - }(ctx) - return sp + sp.tc.TrackMetric(fmt.Sprintf("%s.TotalErrors", stats.name), float64(stats.Errors)) + } } // NewStat creates a new stat with `name` and adds it to the list of statistics that will diff --git a/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go b/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go index 4116f10e4385..3422787c171b 100644 --- a/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go +++ b/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go @@ -12,6 +12,7 @@ import ( "sync/atomic" "time" + azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" "github.com/microsoft/ApplicationInsights-Go/appinsights" ) @@ -32,6 +33,8 @@ type StressContext struct { // ConnectionString represents the value of the environment variable SERVICEBUS_CONNECTION_STRING. ConnectionString string + logMessages chan string + cancel context.CancelFunc } @@ -61,12 +64,33 @@ func MustCreateStressContext(testName string) *StressContext { ctx, cancel := NewCtrlCContext() + azlog.SetEvents("azsb.Conn", "azsb.Auth", "azsb.Retry", "azsb.Mgmt") + + logMessages := make(chan string, 10000) + + go func() { + PrintLoop: + for { + select { + case <-ctx.Done(): + break PrintLoop + case msg := <-logMessages: + fmt.Println(msg) + } + } + }() + + azlog.SetListener(func(e azlog.Event, msg string) { + logMessages <- fmt.Sprintf("%s %10s %s", time.Now().Format(time.RFC3339), e, msg) + }) + return &StressContext{ TestRunID: testRunID, Nano: testRunID, // the same for now ConnectionString: cs, TelemetryClient: telemetryClient, statsPrinter: newStatsPrinter(ctx, testName, 5*time.Second, telemetryClient), + logMessages: logMessages, Context: ctx, cancel: cancel, } @@ -90,12 +114,29 @@ func (sc *StressContext) Start(entityName string, attributes map[string]string) func (sc *StressContext) End() { log.Printf("Stopping and flushing telemetry") + sc.cancel() + sc.TrackEvent("End") + sc.Channel().Flush() <-sc.Channel().Close() time.Sleep(5 * time.Second) + // dump out any remaining log messages +PrintLoop: + for { + select { + case msg := <-sc.logMessages: + fmt.Println(msg) + default: + break PrintLoop + } + } + + // dump out the last stats. + sc.PrintStats() + log.Printf("Done") } @@ -108,6 +149,14 @@ func (tracker *StressContext) PanicOnError(message string, err error) { } } +func (tracker *StressContext) Assert(condition bool, message string) { + tracker.LogIfFailed(message, nil, nil) + + if !condition { + panic(message) + } +} + func (sc *StressContext) LogIfFailed(message string, err error, stats *Stats) { if err != nil { log.Printf("Error: %s: %#v, %T", message, err, err) diff --git a/sdk/messaging/azservicebus/internal/stress/shared/utils.go b/sdk/messaging/azservicebus/internal/stress/shared/utils.go index b99f3dda9618..ceeff921d608 100644 --- a/sdk/messaging/azservicebus/internal/stress/shared/utils.go +++ b/sdk/messaging/azservicebus/internal/stress/shared/utils.go @@ -50,18 +50,24 @@ func MustGenerateMessages(sc *StressContext, sender *azservicebus.Sender, messag } // MustCreateAutoDeletingQueue creates a queue that will auto-delete 10 minutes after activity has ceased. -func MustCreateAutoDeletingQueue(sc *StressContext, queueName string) { +func MustCreateAutoDeletingQueue(sc *StressContext, queueName string, qp *admin.QueueProperties) { adminClient, err := admin.NewClientFromConnectionString(sc.ConnectionString, nil) sc.PanicOnError("failed to create adminClient", err) autoDeleteOnIdle := 10 * time.Minute - _, err = adminClient.CreateQueue(context.Background(), queueName, &admin.QueueProperties{ - AutoDeleteOnIdle: &autoDeleteOnIdle, + var newQP admin.QueueProperties - // mostly useful for tracking backwards in case something goes wrong. - UserMetadata: &sc.TestRunID, - }, nil) + if qp != nil { + newQP = *qp + } + + newQP.AutoDeleteOnIdle = &autoDeleteOnIdle + + // mostly useful for tracking backwards in case something goes wrong. + newQP.UserMetadata = &sc.TestRunID + + _, err = adminClient.CreateQueue(context.Background(), queueName, &newQP, nil) sc.PanicOnError("failed to create queue", err) } diff --git a/sdk/messaging/azservicebus/internal/stress/stress-test-resources.bicep b/sdk/messaging/azservicebus/internal/stress/stress-test-resources.bicep new file mode 100644 index 000000000000..7245dfec35df --- /dev/null +++ b/sdk/messaging/azservicebus/internal/stress/stress-test-resources.bicep @@ -0,0 +1,107 @@ +@description('The base resource name.') +param baseName string = resourceGroup().name + +@description('The client OID to grant access to test resources.') +param testApplicationOid string + +var apiVersion = '2017-04-01' +var location = resourceGroup().location +var authorizationRuleName_var = '${baseName}/RootManageSharedAccessKey' +var authorizationRuleNameNoManage_var = '${baseName}/NoManage' +var serviceBusDataOwnerRoleId = '/subscriptions/${subscription().subscriptionId}/providers/Microsoft.Authorization/roleDefinitions/090c5cfd-751d-490a-894a-3ce6f1109419' + +resource servicebus 'Microsoft.ServiceBus/namespaces@2018-01-01-preview' = { + name: baseName + location: location + sku: { + name: 'Standard' + tier: 'Standard' + } + properties: { + zoneRedundant: false + } +} + +resource authorizationRuleName 'Microsoft.ServiceBus/namespaces/AuthorizationRules@2015-08-01' = { + name: authorizationRuleName_var + location: location + properties: { + rights: [ + 'Listen' + 'Manage' + 'Send' + ] + } + dependsOn: [ + servicebus + ] +} + +resource authorizationRuleNameNoManage 'Microsoft.ServiceBus/namespaces/AuthorizationRules@2015-08-01' = { + name: authorizationRuleNameNoManage_var + location: location + properties: { + rights: [ + 'Listen' + 'Send' + ] + } + dependsOn: [ + servicebus + ] +} + + + +resource dataOwnerRoleId 'Microsoft.Authorization/roleAssignments@2018-01-01-preview' = { + name: guid('dataOwnerRoleId${baseName}') + properties: { + roleDefinitionId: serviceBusDataOwnerRoleId + principalId: testApplicationOid + } + dependsOn: [ + servicebus + ] +} + +resource testQueue 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = { + parent: servicebus + name: 'testQueue' + properties: { + lockDuration: 'PT5M' + maxSizeInMegabytes: 1024 + requiresDuplicateDetection: false + requiresSession: false + defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S' + deadLetteringOnMessageExpiration: false + duplicateDetectionHistoryTimeWindow: 'PT10M' + maxDeliveryCount: 10 + autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S' + enablePartitioning: false + enableExpress: false + } +} + +resource testQueueWithSessions 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = { + parent: servicebus + name: 'testQueueWithSessions' + properties: { + lockDuration: 'PT5M' + maxSizeInMegabytes: 1024 + requiresDuplicateDetection: false + requiresSession: true + defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S' + deadLetteringOnMessageExpiration: false + duplicateDetectionHistoryTimeWindow: 'PT10M' + maxDeliveryCount: 10 + autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S' + enablePartitioning: false + enableExpress: false + } +} + +output SERVICEBUS_CONNECTION_STRING string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'RootManageSharedAccessKey'), apiVersion).primaryConnectionString +output SERVICEBUS_CONNECTION_STRING_NO_MANAGE string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'NoManage'), apiVersion).primaryConnectionString +output SERVICEBUS_ENDPOINT string = replace(servicebus.properties.serviceBusEndpoint, ':443/', '') +output QUEUE_NAME string = 'testQueue' +output QUEUE_NAME_WITH_SESSIONS string = 'testQueueWithSessions' diff --git a/sdk/messaging/azservicebus/internal/stress/templates/deploy-job.yaml b/sdk/messaging/azservicebus/internal/stress/templates/deploy-job.yaml new file mode 100644 index 000000000000..37cf472c5c8c --- /dev/null +++ b/sdk/messaging/azservicebus/internal/stress/templates/deploy-job.yaml @@ -0,0 +1,17 @@ +{{- include "stress-test-addons.deploy-job-template.from-pod" (list . "stress.deploy-example") -}} +{{- define "stress.deploy-example" -}} +metadata: + labels: + testName: "go-servicebus" +spec: + containers: + - name: main + # az acr list -g rg-stress-cluster-test --subscription "Azure SDK Developer Playground" --query "[0].loginServer" + image: {{ .Values.image }} + command: ['/app/stress'] + args: + - "tests" + # (this is injected automatically. The full list of scenarios is in `../values.yaml`) + - {{ .Scenario }} + {{- include "stress-test-addons.container-env" . | nindent 6 }} +{{- end -}} diff --git a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go index 75ad7dd8184c..ad1c532b77d8 100644 --- a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go +++ b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go @@ -18,7 +18,7 @@ func ConstantDetachment(remainingArgs []string) { queueName := fmt.Sprintf("detach-tester-%s", sc.Nano) - shared.MustCreateAutoDeletingQueue(sc, queueName) + shared.MustCreateAutoDeletingQueue(sc, queueName, nil) client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil) sc.PanicOnError("failed to create client", err) diff --git a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go index 4395edda3a5f..371335f43fc9 100644 --- a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go +++ b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go @@ -78,7 +78,7 @@ func ConstantDetachmentSender(remainingArgs []string) { func createDetachResources(sc *shared.StressContext, name string) (string, *shared.Stats, *azservicebus.Sender) { queueName := fmt.Sprintf("detach_%s-%s", name, sc.Nano) - shared.MustCreateAutoDeletingQueue(sc, queueName) + shared.MustCreateAutoDeletingQueue(sc, queueName, nil) client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil) sc.PanicOnError("failed to create client", err) diff --git a/sdk/messaging/azservicebus/internal/stress/tests/finite_peeks.go b/sdk/messaging/azservicebus/internal/stress/tests/finite_peeks.go new file mode 100644 index 000000000000..66a1b365b2c5 --- /dev/null +++ b/sdk/messaging/azservicebus/internal/stress/tests/finite_peeks.go @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package tests + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/stress/shared" +) + +func FinitePeeks(remainingArgs []string) { + sc := shared.MustCreateStressContext("FinitePeeks") + defer sc.Done() + + queueName := fmt.Sprintf("finite-peeks-%s", sc.Nano) + shared.MustCreateAutoDeletingQueue(sc, queueName, nil) + + client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil) + sc.PanicOnError("failed to create client", err) + + receiverStats := sc.NewStat("peeks") + + sender, err := client.NewSender(queueName, nil) + sc.PanicOnError("failed to create sender", err) + + err = sender.SendMessage(sc.Context, &azservicebus.Message{ + Body: []byte("peekable message"), + }) + sc.PanicOnError("failed to send message", err) + + _ = sender.Close(sc.Context) + + receiver, err := client.NewReceiverForQueue(queueName, nil) + sc.PanicOnError("failed to create receiver", err) + + // receiving here just guarantees the message has arrived and is available (sometimes + // there's a slight delay) + receiveCtx, cancel := context.WithTimeout(sc.Context, time.Minute) + defer cancel() + + tmp, err := receiver.ReceiveMessages(receiveCtx, 1, nil) + sc.PanicOnError("failed to receive messages", err) + sc.Assert(len(tmp) == 1, "message was never available") + + // return the message back from whence it came. + sc.PanicOnError("failed to abandon message", + receiver.AbandonMessage(sc.Context, tmp[0], nil)) + + for i := 0; i < 10000; i++ { + log.Printf("Sleeping for 1 second before iteration %d", i) + time.Sleep(time.Second) + + seqNum := int64(0) + + messages, err := receiver.PeekMessages(sc.Context, 1, &azservicebus.PeekMessagesOptions{ + FromSequenceNumber: &seqNum, + }) + sc.PanicOnError("failed to peek messages", err) + sc.Assert(len(messages) == 1, "no messages returned in peek") + + receiverStats.AddReceived(int32(1)) + } +} diff --git a/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go b/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go index 6d21be448476..257fc0833395 100644 --- a/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go +++ b/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go @@ -5,6 +5,7 @@ package tests import ( "context" + "errors" "fmt" "log" "strings" @@ -12,6 +13,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/stress/shared" ) @@ -19,23 +21,29 @@ func FiniteSendAndReceiveTest(remainingArgs []string) { sc := shared.MustCreateStressContext("FiniteSendAndReceiveTest") sc.TrackEvent("Start") - defer sc.TrackEvent("End") + defer sc.End() queueName := strings.ToLower(fmt.Sprintf("queue-%X", time.Now().UnixNano())) log.Printf("Creating queue") - shared.MustCreateAutoDeletingQueue(sc, queueName) + + lockDuration := 5 * time.Minute + + shared.MustCreateAutoDeletingQueue(sc, queueName, &admin.QueueProperties{ + LockDuration: &lockDuration, + }) client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil) sc.PanicOnError("failed to create client", err) sender, err := client.NewSender(queueName, nil) sc.PanicOnError("failed to create sender", err) - const messageLimit = 50000 - shared.MustGenerateMessages(sc, sender, messageLimit, 100, sc.NewStat("sender")) + const messageLimit = 500 log.Printf("Sending %d messages (all messages will be sent before receiving begins)", messageLimit) + shared.MustGenerateMessages(sc, sender, messageLimit, 100, sc.NewStat("sender")) + log.Printf("Starting receiving...") receiver, err := client.NewReceiverForQueue(queueName, nil) @@ -45,7 +53,7 @@ func FiniteSendAndReceiveTest(remainingArgs []string) { receiverStats := sc.NewStat("receiver") - for receiverStats.Received == messageLimit { + for receiverStats.Received < messageLimit { log.Printf("[start] Receiving messages...") messages, err := receiver.ReceiveMessages(context.Background(), 100, nil) log.Printf("[done] Receiving messages... %v, %v", len(messages), err) @@ -53,23 +61,36 @@ func FiniteSendAndReceiveTest(remainingArgs []string) { wg := sync.WaitGroup{} + log.Printf("About to complete %d messages", len(messages)) + time.Sleep(10 * time.Second) + for _, msg := range messages { wg.Add(1) + go func(msg *azservicebus.ReceivedMessage) { completions <- struct{}{} defer wg.Done() defer func() { <-completions }() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() err := receiver.CompleteMessage(ctx, msg) + + var rpcCodeErr interface{ RPCCode() int } + + if errors.As(err, &rpcCodeErr) { + if rpcCodeErr.RPCCode() == 410 { + receiverStats.AddError("lock lost", err) + return + } + } + sc.PanicOnError("failed to complete message", err) + receiverStats.AddReceived(1) }(msg) } wg.Wait() - - receiverStats.AddReceived(int32(len(messages))) } } diff --git a/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go b/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go index 25c1af5d1664..86b06d42195c 100644 --- a/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go +++ b/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go @@ -17,7 +17,7 @@ func LongRunningRenewLockTest(remainingArgs []string) { sc := shared.MustCreateStressContext("LongRunningRenewLockTest") queueName := fmt.Sprintf("renew-lock-test-%s", sc.Nano) - shared.MustCreateAutoDeletingQueue(sc, queueName) + shared.MustCreateAutoDeletingQueue(sc, queueName, nil) client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil) sc.PanicOnError("failed to create admin.Client", err) diff --git a/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go b/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go index 5af1fd96639b..1279dfe8dcf3 100644 --- a/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go +++ b/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go @@ -17,7 +17,7 @@ func RapidOpenCloseTest(remainingArgs []string) { sc := shared.MustCreateStressContext("RapidOpenCloseTest") queueName := fmt.Sprintf("rapid_open_close-%X", time.Now().UnixNano()) - shared.MustCreateAutoDeletingQueue(sc, queueName) + shared.MustCreateAutoDeletingQueue(sc, queueName, nil) for i := 0; i < 100; i++ { log.Printf("[%d] Open/Close", i) diff --git a/sdk/messaging/azservicebus/internal/stress/tests/tests.go b/sdk/messaging/azservicebus/internal/stress/tests/tests.go index 2c50db3a56b8..d4ae280f55c0 100644 --- a/sdk/messaging/azservicebus/internal/stress/tests/tests.go +++ b/sdk/messaging/azservicebus/internal/stress/tests/tests.go @@ -11,9 +11,6 @@ import ( "strings" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/stress/shared" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" - "github.com/devigned/tab" ) // Simple query to view some of the stats reported by these stress tests. @@ -25,22 +22,23 @@ import ( func Run(remainingArgs []string) { // turn on some simple stderr diagnostics - tracer := utils.NewSimpleTracer(map[string]bool{ - tracing.SpanRecover: true, - tracing.SpanNegotiateClaim: true, - tracing.SpanRecoverClient: true, - tracing.SpanRecoverLink: true, - }, nil) + // tracer := utils.NewSimpleTracer(map[string]bool{ + // // tracing.SpanRecover: true, + // //tracing.SpanNegotiateClaim: true, + // // tracing.SpanRecoverClient: true, + // // tracing.SpanRecoverLink: true, + // }, nil) - tab.Register(tracer) + // tab.Register(tracer) allTests := map[string]func(args []string){ - "infiniteSendAndReceive": InfiniteSendAndReceiveRun, - "finiteSendAndReceive": FiniteSendAndReceiveTest, - "rapidOpenClose": RapidOpenCloseTest, - "longRunningRenewLock": LongRunningRenewLockTest, "constantDetach": ConstantDetachment, "constantDetachmentSender": ConstantDetachmentSender, + "finitePeeks": FinitePeeks, + "finiteSendAndReceive": FiniteSendAndReceiveTest, + "infiniteSendAndReceive": InfiniteSendAndReceiveRun, + "longRunningRenewLock": LongRunningRenewLockTest, + "rapidOpenClose": RapidOpenCloseTest, } if len(remainingArgs) == 0 { diff --git a/sdk/messaging/azservicebus/internal/stress/values.yaml b/sdk/messaging/azservicebus/internal/stress/values.yaml new file mode 100644 index 000000000000..5fbec5a3ab77 --- /dev/null +++ b/sdk/messaging/azservicebus/internal/stress/values.yaml @@ -0,0 +1,11 @@ +# Optional list of scenarios. If specified multiple stress test jobs will be generated, +# one for each scenario in the list. The pod spec can then be configured to pass the +# scenario name down to the test command, e.g. `command: ["node", "{{ .Scenario }}.js"]` +scenarios: +- "constantDetach" +- "constantDetachmentSender" +- "finitePeeks" +- "finiteSendAndReceive" +- "infiniteSendAndReceive" +- "longRunningRenewLock" +- "rapidOpenClose" diff --git a/sdk/messaging/azservicebus/internal/test/test_helpers.go b/sdk/messaging/azservicebus/internal/test/test_helpers.go index 6ddc3f37eeae..40d0fdceaf8d 100644 --- a/sdk/messaging/azservicebus/internal/test/test_helpers.go +++ b/sdk/messaging/azservicebus/internal/test/test_helpers.go @@ -4,10 +4,16 @@ package test import ( + "context" "math/rand" + "net/http" "os" "testing" "time" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/atom" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" + "github.com/stretchr/testify/require" ) var ( @@ -46,3 +52,30 @@ func GetConnectionStringWithoutManagePerms(t *testing.T) string { return cs } + +func CreateExpiringQueue(t *testing.T, qd *atom.QueueDescription) (string, func()) { + cs := GetConnectionString(t) + em, err := atom.NewEntityManagerWithConnectionString(cs, "") + require.NoError(t, err) + + queueName := RandomString("queue", 5) + + if qd == nil { + qd = &atom.QueueDescription{} + } + + deleteAfter := 5 * time.Minute + qd.AutoDeleteOnIdle = utils.DurationToStringPtr(&deleteAfter) + + env, _ := atom.WrapWithQueueEnvelope(qd, em.TokenProvider()) + + var qe *atom.QueueEnvelope + resp, err := em.Put(context.Background(), queueName, env, &qe) + require.NoError(t, err) + require.EqualValues(t, http.StatusCreated, resp.StatusCode) + + return queueName, func() { + _, err := em.Delete(context.Background(), queueName) + require.NoError(t, err) + } +} diff --git a/sdk/messaging/azservicebus/internal/utils/retrier.go b/sdk/messaging/azservicebus/internal/utils/retrier.go new file mode 100644 index 000000000000..e880c870b4b1 --- /dev/null +++ b/sdk/messaging/azservicebus/internal/utils/retrier.go @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package utils + +import ( + "context" + "math" + "math/rand" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// EventRetry is the name for retry events +const EventRetry = "azsb.Retry" + +type RetryFnArgs struct { + I int32 + // LastErr is the returned error from the previous loop. + // If you have potentially expensive + LastErr error + + resetAttempts bool +} + +// ResetAttempts causes the current retry attempt number to be reset +// in time for the next recovery (should we fail). +// NOTE: Use of this should be pretty rare, it's really only needed when you have +// a situation like Receiver.ReceiveMessages() that can recovery but intentionally +// does not return. +func (rf *RetryFnArgs) ResetAttempts() { + rf.resetAttempts = true +} + +// Retry runs a standard retry loop. It executes your passed in fn as the body of the loop. +// It returns if it exceeds the number of configured retry options or if 'isFatal' returns true. +func Retry(ctx context.Context, name string, fn func(ctx context.Context, args *RetryFnArgs) error, isFatalFn func(err error) bool, o RetryOptions) error { + if isFatalFn == nil { + panic("isFatalFn is nil, errors would panic") + } + + var ro RetryOptions = o + setDefaults(&ro) + + var err error + + for i := int32(0); i <= ro.MaxRetries; i++ { + if i > 0 { + sleep := calcDelay(ro, i) + log.Writef(EventRetry, "(%s) Attempt %d sleeping for %s", name, i, sleep) + time.Sleep(sleep) + } + + args := RetryFnArgs{ + I: i, + LastErr: err, + } + err = fn(ctx, &args) + + if args.resetAttempts { + log.Writef(EventRetry, "(%s) Resetting attempts", name) + i = int32(0) + } + + if err != nil { + if isFatalFn(err) { + log.Writef(EventRetry, "(%s) Attempt %d returned non-retryable error: %s", name, i, err.Error()) + return err + } else { + log.Writef(EventRetry, "(%s) Attempt %d returned retryable error: %s", name, i, err.Error()) + } + + continue + } + + return nil + } + + return err +} + +// RetryOptions represent the options for retries. +type RetryOptions struct { + // MaxRetries specifies the maximum number of attempts a failed operation will be retried + // before producing an error. + // The default value is three. A value less than zero means one try and no retries. + MaxRetries int32 + + // RetryDelay specifies the initial amount of delay to use before retrying an operation. + // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. + // The default value is four seconds. A value less than zero means no delay between retries. + RetryDelay time.Duration + + // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. + // Typically the value is greater than or equal to the value specified in RetryDelay. + // The default Value is 120 seconds. A value less than zero means there is no cap. + MaxRetryDelay time.Duration +} + +func setDefaults(o *RetryOptions) { + if o.MaxRetries == 0 { + o.MaxRetries = 3 + } else if o.MaxRetries < 0 { + o.MaxRetries = 0 + } + if o.MaxRetryDelay == 0 { + o.MaxRetryDelay = 120 * time.Second + } else if o.MaxRetryDelay < 0 { + // not really an unlimited cap, but sufficiently large enough to be considered as such + o.MaxRetryDelay = math.MaxInt64 + } + if o.RetryDelay == 0 { + o.RetryDelay = 4 * time.Second + } else if o.RetryDelay < 0 { + o.RetryDelay = 0 + } +} + +// (adapted from from azcore/policy_retry) +func calcDelay(o RetryOptions, try int32) time.Duration { + if try == 0 { + return 0 + } + + pow := func(number int64, exponent int32) int64 { // pow is nested helper function + var result int64 = 1 + for n := int32(0); n < exponent; n++ { + result *= number + } + return result + } + + delay := time.Duration(pow(2, try)-1) * o.RetryDelay + + // Introduce some jitter: [0.0, 1.0) / 2 = [0.0, 0.5) + 0.8 = [0.8, 1.3) + delay = time.Duration(delay.Seconds() * (rand.Float64()/2 + 0.8) * float64(time.Second)) // NOTE: We want math/rand; not crypto/rand + if delay > o.MaxRetryDelay { + delay = o.MaxRetryDelay + } + return delay +} diff --git a/sdk/messaging/azservicebus/internal/utils/retrier_test.go b/sdk/messaging/azservicebus/internal/utils/retrier_test.go new file mode 100644 index 000000000000..fc5826f58860 --- /dev/null +++ b/sdk/messaging/azservicebus/internal/utils/retrier_test.go @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package utils + +import ( + "context" + "errors" + "fmt" + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrier(t *testing.T) { + t.Run("Succeeds", func(t *testing.T) { + ctx := context.Background() + + called := 0 + + err := Retry(ctx, "Retrier", func(ctx context.Context, args *RetryFnArgs) error { + called++ + return nil + }, func(err error) bool { + panic("won't get called") + }, RetryOptions{}) + + require.Nil(t, err) + require.EqualValues(t, 1, called) + }) + + t.Run("FailsThenSucceeds", func(t *testing.T) { + ctx := context.Background() + + called := 0 + isFatalCalled := 0 + + isFatalFn := func(err error) bool { + require.NotNil(t, err) + // we'll just keep saying the errors aren't fatal. + isFatalCalled++ + return false + } + + err := Retry(ctx, "FailsThenSucceeds", func(ctx context.Context, args *RetryFnArgs) error { + called++ + + if args.I == 3 { + // we're on the last iteration, succeed + return nil + } + + return fmt.Errorf("Error, iteration %d", args.I) + }, isFatalFn, fastRetryOptions) + + require.EqualValues(t, 4, called) + require.EqualValues(t, 3, isFatalCalled) + + // if an attempt succeeds then there's no error (despite previous failed tries) + require.NoError(t, err) + }) + + t.Run("FatalFailure", func(t *testing.T) { + ctx := context.Background() + called := 0 + + isFatalFn := func(err error) bool { + require.EqualValues(t, "isFatalFn says this is a fatal error", err.Error()) + return true + } + + err := Retry(ctx, "FatalFailure", func(ctx context.Context, args *RetryFnArgs) error { + called++ + return errors.New("isFatalFn says this is a fatal error") + }, isFatalFn, RetryOptions{}) + + require.EqualValues(t, "isFatalFn says this is a fatal error", err.Error()) + require.EqualValues(t, 1, called) + }) + + t.Run("Cancellation", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + isFatalFn := func(err error) bool { + return errors.Is(err, context.Canceled) + } + + // it's up to + err := Retry(ctx, "Cancellation", func(ctx context.Context, args *RetryFnArgs) error { + // NOTE: it's up to the underlying function to handle cancellation. `Retry` doesn't + // do anything but propagate it. + select { + case <-ctx.Done(): + default: + require.Fail(t, "Context should have been cancelled") + } + + return context.Canceled + }, isFatalFn, RetryOptions{}) + + require.ErrorIs(t, context.Canceled, err) + }) + + t.Run("ResetAttempts", func(t *testing.T) { + isFatalFn := func(err error) bool { + return errors.Is(err, context.Canceled) + } + + customRetryOptions := fastRetryOptions + customRetryOptions.MaxRetries = 1 + + var actualAttempts []int32 + + err := Retry(context.Background(), "ResetAttempts", func(ctx context.Context, args *RetryFnArgs) error { + actualAttempts = append(actualAttempts, args.I) + + if len(actualAttempts) == 3 { + args.ResetAttempts() + } + + return errors.New("whatever") + }, isFatalFn, RetryOptions{ + MaxRetries: 2, + RetryDelay: time.Millisecond, + MaxRetryDelay: time.Millisecond, + }) + + expectedAttempts := []int32{ + 0, 1, 2, // we resetted attempts here. + 1, 2, // and we start at the first retry attempt again. + } + + require.EqualValues(t, "whatever", err.Error()) + require.EqualValues(t, expectedAttempts, actualAttempts) + }) + + t.Run("DisableRetries", func(t *testing.T) { + isFatalFn := func(err error) bool { + return errors.Is(err, context.Canceled) + } + + customRetryOptions := fastRetryOptions + customRetryOptions.MaxRetries = -1 + + called := 0 + + err := Retry(context.Background(), "ResetAttempts", func(ctx context.Context, args *RetryFnArgs) error { + called++ + return errors.New("whatever") + }, isFatalFn, customRetryOptions) + + require.EqualValues(t, 1, called) + require.EqualValues(t, "whatever", err.Error()) + }) +} + +func Test_calcDelay(t *testing.T) { + t.Run("can't exceed max retry delay", func(t *testing.T) { + duration := calcDelay(RetryOptions{ + RetryDelay: time.Hour, + MaxRetryDelay: time.Minute, + }, 1) + + require.EqualValues(t, time.Minute, duration) + }) + + t.Run("increases with jitter", func(t *testing.T) { + duration := calcDelay(RetryOptions{ + RetryDelay: time.Minute, + MaxRetryDelay: time.Hour, + }, 1) + + require.GreaterOrEqual(t, duration, time.Duration((2-1)*time.Minute.Seconds()*0.8*float64(time.Second))) + require.LessOrEqual(t, duration, time.Duration((2-1)*time.Minute.Seconds()*1.3*float64(time.Second))) + + duration = calcDelay(RetryOptions{ + RetryDelay: time.Minute, + MaxRetryDelay: time.Hour, + }, 2) + + require.GreaterOrEqual(t, duration, time.Duration((2*2-1)*time.Minute.Seconds()*0.8*float64(time.Second))) + require.LessOrEqual(t, duration, time.Duration((2*2-1)*time.Minute.Seconds()*1.3*float64(time.Second))) + + duration = calcDelay(RetryOptions{ + RetryDelay: time.Minute, + MaxRetryDelay: time.Hour, + }, 3) + + require.GreaterOrEqual(t, duration, time.Duration((2*2*2-1)*time.Minute.Seconds()*0.8*float64(time.Second))) + require.LessOrEqual(t, duration, time.Duration((2*2*2-1)*time.Minute.Seconds()*1.3*float64(time.Second))) + }) +} + +var fastRetryOptions = RetryOptions{ + // note: omitting MaxRetries just to give a sanity check that + // we do setDefaults() before we run. + RetryDelay: time.Millisecond, + MaxRetryDelay: time.Millisecond, +} + +func TestRetryDefaults(t *testing.T) { + ro := RetryOptions{} + setDefaults(&ro) + + require.EqualValues(t, 3, ro.MaxRetries) + require.EqualValues(t, 4*time.Second, ro.RetryDelay) + require.EqualValues(t, 2*time.Minute, ro.MaxRetryDelay) + + // this is an interesting default. Anything < 0 basically + // causes the max delay to be "infinite" + ro.MaxRetryDelay = -1 + // whereas this just normalizes to '0' + ro.RetryDelay = -1 + ro.MaxRetries = -1 + setDefaults(&ro) + require.EqualValues(t, time.Duration(math.MaxInt64), ro.MaxRetryDelay) + require.EqualValues(t, 0, ro.MaxRetries) + require.EqualValues(t, time.Duration(0), ro.RetryDelay) +} + +func TestCalcDelay(t *testing.T) { + // calcDelay introduces some jitter, automatically. + ro := RetryOptions{} + setDefaults(&ro) + d := calcDelay(ro, 0) + require.EqualValues(t, 0, d) + + // by default the first calc is 2^attempt + d = calcDelay(ro, 1) + require.LessOrEqual(t, d, 6*time.Second) + require.GreaterOrEqual(t, d, time.Second) +} diff --git a/sdk/messaging/azservicebus/liveTestHelpers_test.go b/sdk/messaging/azservicebus/liveTestHelpers_test.go index 850fb5afb504..353183ef141f 100644 --- a/sdk/messaging/azservicebus/liveTestHelpers_test.go +++ b/sdk/messaging/azservicebus/liveTestHelpers_test.go @@ -45,6 +45,9 @@ func createQueue(t *testing.T, connectionString string, queueProperties *admin.Q queueProperties = &admin.QueueProperties{} } + autoDeleteOnIdle := 5 * time.Minute + queueProperties.AutoDeleteOnIdle = &autoDeleteOnIdle + _, err = adminClient.CreateQueue(context.Background(), queueName, queueProperties, nil) require.NoError(t, err) diff --git a/sdk/messaging/azservicebus/message.go b/sdk/messaging/azservicebus/message.go index 44fa4963528f..f71dbc9f14f7 100644 --- a/sdk/messaging/azservicebus/message.go +++ b/sdk/messaging/azservicebus/message.go @@ -4,14 +4,14 @@ package azservicebus import ( - "context" "errors" "fmt" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" "github.com/Azure/go-amqp" - "github.com/devigned/tab" ) // ReceivedMessage is a received message from a Client.NewReceiver(). @@ -193,7 +193,7 @@ func (m *Message) toAMQPMessage() *amqp.Message { // newReceivedMessage creates a received message from an AMQP message. // NOTE: this converter assumes that the Body of this message will be the first // serialized byte array in the Data section of the messsage. -func newReceivedMessage(ctxForLogging context.Context, amqpMsg *amqp.Message) *ReceivedMessage { +func newReceivedMessage(amqpMsg *amqp.Message) *ReceivedMessage { msg := &ReceivedMessage{ rawAMQPMessage: amqpMsg, } @@ -302,7 +302,7 @@ func newReceivedMessage(ctxForLogging context.Context, amqpMsg *amqp.Message) *R if err == nil { msg.LockToken = *(*amqp.UUID)(lockToken) } else { - tab.For(ctxForLogging).Info(fmt.Sprintf("msg.DeliveryTag could not be converted into a UUID: %s", err.Error())) + log.Writef(internal.EventReceiver, "msg.DeliveryTag could not be converted into a UUID: %s", err.Error()) } } diff --git a/sdk/messaging/azservicebus/messageSettler.go b/sdk/messaging/azservicebus/messageSettler.go index 18b967bf5e86..cc54c49c1c2b 100644 --- a/sdk/messaging/azservicebus/messageSettler.go +++ b/sdk/messaging/azservicebus/messageSettler.go @@ -5,14 +5,12 @@ package azservicebus import ( "context" - "errors" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" ) -var errReceiveAndDeleteReceiver = errors.New("messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable") - type settler interface { CompleteMessage(ctx context.Context, message *ReceivedMessage) error AbandonMessage(ctx context.Context, message *ReceivedMessage, options *AbandonMessageOptions) error @@ -21,15 +19,17 @@ type settler interface { } type messageSettler struct { - links internal.AMQPLinks + links internal.AMQPLinks + retryOptions utils.RetryOptions + + // used only for tests onlyDoBackupSettlement bool - baseRetrier internal.Retrier } -func newMessageSettler(links internal.AMQPLinks, baseRetrier internal.Retrier) settler { +func newMessageSettler(links internal.AMQPLinks, retryOptions utils.RetryOptions) settler { return &messageSettler{ - links: links, - baseRetrier: baseRetrier, + links: links, + retryOptions: retryOptions, } } @@ -39,44 +39,27 @@ func (s *messageSettler) useManagementLink(m *ReceivedMessage, receiver internal m.rawAMQPMessage.LinkName() != receiver.LinkName() } -func (s *messageSettler) settleWithRetries(ctx context.Context, message *ReceivedMessage, settleFn func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error) error { +func (s *messageSettler) settleWithRetries(ctx context.Context, message *ReceivedMessage, settleFn func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error) error { if s == nil { - return errReceiveAndDeleteReceiver + return internal.ErrNonRetriable{Message: "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable"} } - retrier := s.baseRetrier.Copy() - var lastErr error - - for retrier.Try(ctx) { - var receiver internal.AMQPReceiver - var mgmt internal.MgmtClient - var linkRevision uint64 - - _, receiver, mgmt, linkRevision, lastErr = s.links.Get(ctx) - - if lastErr != nil { - _ = s.links.RecoverIfNeeded(ctx, linkRevision, lastErr) - continue + err := s.links.Retry(ctx, "settle", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error { + if err := settleFn(lwid.Receiver, lwid.RPC); err != nil { + return err } - lastErr := settleFn(receiver, mgmt) - - if lastErr != nil { - _ = s.links.RecoverIfNeeded(ctx, linkRevision, lastErr) - continue - } - - break - } + return nil + }, utils.RetryOptions{}) - return lastErr + return err } // CompleteMessage completes a message, deleting it from the queue or subscription. func (s *messageSettler) CompleteMessage(ctx context.Context, message *ReceivedMessage) error { - return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error { + return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error { if s.useManagementLink(message, receiver) { - return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), internal.Disposition{Status: internal.CompletedDisposition}, nil) + return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), internal.Disposition{Status: internal.CompletedDisposition}, nil) } else { return receiver.AcceptMessage(ctx, message.rawAMQPMessage) } @@ -92,7 +75,7 @@ type AbandonMessageOptions struct { // This will increment its delivery count, and potentially cause it to be dead lettered // depending on your queue or subscription's configuration. func (s *messageSettler) AbandonMessage(ctx context.Context, message *ReceivedMessage, options *AbandonMessageOptions) error { - return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error { + return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error { if s.useManagementLink(message, receiver) { d := internal.Disposition{ Status: internal.AbandonedDisposition, @@ -104,7 +87,7 @@ func (s *messageSettler) AbandonMessage(ctx context.Context, message *ReceivedMe propertiesToModify = options.PropertiesToModify } - return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), d, propertiesToModify) + return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), d, propertiesToModify) } var annotations amqp.Annotations @@ -125,7 +108,7 @@ type DeferMessageOptions struct { // DeferMessage will cause a message to be deferred. Deferred messages // can be received using `Receiver.ReceiveDeferredMessages`. func (s *messageSettler) DeferMessage(ctx context.Context, message *ReceivedMessage, options *DeferMessageOptions) error { - return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error { + return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error { if s.useManagementLink(message, receiver) { d := internal.Disposition{ Status: internal.DeferredDisposition, @@ -137,7 +120,7 @@ func (s *messageSettler) DeferMessage(ctx context.Context, message *ReceivedMess propertiesToModify = options.PropertiesToModify } - return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), d, propertiesToModify) + return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), d, propertiesToModify) } var annotations amqp.Annotations @@ -167,7 +150,7 @@ type DeadLetterOptions struct { // queue or subscription. To receive these messages create a receiver with `Client.NewReceiver()` // using the `SubQueue` option. func (s *messageSettler) DeadLetterMessage(ctx context.Context, message *ReceivedMessage, options *DeadLetterOptions) error { - return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error { + return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error { reason := "" description := "" @@ -194,7 +177,7 @@ func (s *messageSettler) DeadLetterMessage(ctx context.Context, message *Receive propertiesToModify = options.PropertiesToModify } - return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), d, propertiesToModify) + return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), d, propertiesToModify) } info := map[string]interface{}{ diff --git a/sdk/messaging/azservicebus/messageSettler_test.go b/sdk/messaging/azservicebus/messageSettler_test.go index a88a577dfece..91b09139c19d 100644 --- a/sdk/messaging/azservicebus/messageSettler_test.go +++ b/sdk/messaging/azservicebus/messageSettler_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" "github.com/stretchr/testify/require" ) @@ -109,10 +110,12 @@ func TestMessageSettlementUsingReceiverWithReceiveAndDelete(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, messages) - require.EqualError(t, receiver.AbandonMessage(ctx, messages[0], nil), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable") - require.EqualError(t, receiver.CompleteMessage(ctx, messages[0]), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable") + require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.AbandonMessage(ctx, messages[0], nil)).RecoveryKind) + require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.CompleteMessage(ctx, messages[0])).RecoveryKind) + require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.DeferMessage(ctx, messages[0], nil)).RecoveryKind) + require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.DeadLetterMessage(ctx, messages[0], nil)).RecoveryKind) + require.EqualError(t, receiver.DeadLetterMessage(ctx, messages[0], nil), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable") - require.EqualError(t, receiver.DeferMessage(ctx, messages[0], nil), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable") } func TestDeferredMessages(t *testing.T) { diff --git a/sdk/messaging/azservicebus/message_test.go b/sdk/messaging/azservicebus/message_test.go index 6c65c311cc35..823fae6764c2 100644 --- a/sdk/messaging/azservicebus/message_test.go +++ b/sdk/messaging/azservicebus/message_test.go @@ -4,7 +4,6 @@ package azservicebus import ( - "context" "testing" "time" @@ -52,7 +51,7 @@ func TestMessageUnitTest(t *testing.T) { func TestAMQPMessageToReceivedMessage(t *testing.T) { t.Run("empty_message", func(t *testing.T) { // nothing should blow up. - rm := newReceivedMessage(context.Background(), &amqp.Message{}) + rm := newReceivedMessage(&amqp.Message{}) require.NotNil(t, rm) }) @@ -72,7 +71,7 @@ func TestAMQPMessageToReceivedMessage(t *testing.T) { }, } - receivedMessage := newReceivedMessage(context.Background(), amqpMessage) + receivedMessage := newReceivedMessage(amqpMessage) require.EqualValues(t, lockedUntil, *receivedMessage.LockedUntil) require.EqualValues(t, int64(101), *receivedMessage.SequenceNumber) @@ -132,7 +131,7 @@ func TestAMQPMessageToMessage(t *testing.T) { Data: [][]byte{[]byte("foo")}, } - msg := newReceivedMessage(context.Background(), amqpMsg) + msg := newReceivedMessage(amqpMsg) require.EqualValues(t, msg.MessageID, amqpMsg.Properties.MessageID, "messageID") require.EqualValues(t, msg.SessionID, amqpMsg.Properties.GroupID, "groupID") diff --git a/sdk/messaging/azservicebus/processor.go b/sdk/messaging/azservicebus/processor.go deleted file mode 100644 index a3ec0683afa2..000000000000 --- a/sdk/messaging/azservicebus/processor.go +++ /dev/null @@ -1,408 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azservicebus - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" - "github.com/Azure/go-amqp" - "github.com/devigned/tab" -) - -// NOTE: this type is experimental - -// processorOptions contains options for the `Client.NewProcessorForQueue` or -// `Client.NewProcessorForSubscription` functions. -type processorOptions struct { - // ReceiveMode controls when a message is deleted from Service Bus. - // - // `azservicebus.PeekLock` is the default. The message is locked, preventing multiple - // receivers from processing the message at once. You control the lock state of the message - // using one of the message settlement functions like processor.CompleteMessage(), which removes - // it from Service Bus, or processor.AbandonMessage(), which makes it available again. - // - // `azservicebus.ReceiveAndDelete` causes Service Bus to remove the message as soon - // as it's received. - // - // More information about receive modes: - // https://docs.microsoft.com/azure/service-bus-messaging/message-transfers-locks-settlement#settling-receive-operations - ReceiveMode ReceiveMode - - // SubQueue should be set to connect to the sub queue (ex: dead letter queue) - // of the queue or subscription. - SubQueue SubQueue - - // DisableAutoComplete controls whether messages must be settled explicitly via the - // settlement methods (ie, Complete, Abandon) or if the processor will automatically - // settle messages. - // - // If true, no automatic settlement is done. - // If false, the return value of your `handleMessage` function will control if the - // message is abandoned (non-nil error return) or completed (nil error return). - // - // This option is false, by default. - DisableAutoComplete bool - - // MaxConcurrentCalls controls the maximum number of message processing - // goroutines that are active at any time. - // Default is 1. - MaxConcurrentCalls int -} - -// processor is a push-based receiver for Service Bus. -type processor struct { - receiveMode ReceiveMode - autoComplete bool - maxConcurrentCalls int - - settler settler - amqpLinks internal.AMQPLinks - - mu *sync.Mutex - - userMessageHandler func(message *ReceivedMessage) error - userErrorHandler func(err error) - - receiversCtx context.Context - cancelReceivers func() - - wg sync.WaitGroup - - baseRetrier internal.Retrier - cleanupOnClose func() -} - -func applyProcessorOptions(p *processor, entity *entity, options *processorOptions) error { - if options == nil { - p.maxConcurrentCalls = 1 - p.receiveMode = ReceiveModePeekLock - p.autoComplete = true - - return nil - } - - p.autoComplete = !options.DisableAutoComplete - - if err := checkReceiverMode(options.ReceiveMode); err != nil { - return err - } - - p.receiveMode = options.ReceiveMode - - if err := entity.SetSubQueue(options.SubQueue); err != nil { - return err - } - - if options.MaxConcurrentCalls > 0 { - p.maxConcurrentCalls = options.MaxConcurrentCalls - } - - return nil -} - -func newProcessor(ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupOnClose func(), options *processorOptions) (*processor, error) { - processor := &processor{ - // TODO: make this configurable - baseRetrier: internal.NewBackoffRetrier(internal.BackoffRetrierParams{ - Factor: 1.5, - Min: time.Second, - Max: time.Minute, - MaxRetries: 10, - }), - cleanupOnClose: cleanupOnClose, - mu: &sync.Mutex{}, - } - - if err := applyProcessorOptions(processor, entity, options); err != nil { - return nil, err - } - - entityPath, err := entity.String() - - if err != nil { - return nil, err - } - - processor.amqpLinks = ns.NewAMQPLinks(entityPath, func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { - linkOptions := createLinkOptions(processor.receiveMode, entityPath) - _, receiver, err := createReceiverLink(ctx, session, linkOptions) - - if err != nil { - return nil, nil, err - } - - if err := receiver.IssueCredit(uint32(processor.maxConcurrentCalls)); err != nil { - _ = receiver.Close(ctx) - return nil, nil, err - } - - return nil, receiver, nil - }) - - processor.settler = newMessageSettler(processor.amqpLinks, processor.baseRetrier) - processor.receiversCtx, processor.cancelReceivers = context.WithCancel(context.Background()) - - return processor, nil -} - -// Start will start receiving messages from the queue or subscription. -// -// if err := processor.Start(context.TODO(), messageHandler, errorHandler); err != nil { -// log.Fatalf("Processor failed to start: %s", err.Error()) -// } -// -// Any errors that occur (such as network disconnects, failures in handleMessage) will be -// sent to your handleError function. The processor will retry and restart as needed - -// no user intervention is required. -func (p *processor) Start(ctx context.Context, handleMessage func(message *ReceivedMessage) error, handleError func(err error)) error { - ctx, span := tab.StartSpan(ctx, tracing.SpanProcessorLoop) - defer span.End() - - err := func() error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.userMessageHandler != nil { - return errors.New("processor already started") - } - - p.userMessageHandler = handleMessage - p.userErrorHandler = handleError - - p.receiversCtx, p.cancelReceivers = context.WithCancel(ctx) - - return nil - }() - - if err != nil { - return err - } - - for { - retrier := p.baseRetrier.Copy() - - for retrier.Try(p.receiversCtx) { - if err := p.subscribe(); err != nil { - if internal.IsCancelError(err) { - break - } - } - } - - select { - case <-p.receiversCtx.Done(): - // check, did they cancel or did we cancel? - select { - case <-ctx.Done(): - return ctx.Err() - default: - return nil - } - default: - } - } - -} - -// Close will wait for any pending callbacks to complete. -// NOTE: Close() cannot be called synchronously in a message -// or error handler. You must run it asynchronously using -// `go processor.Close(ctx)` or similar. -func (p *processor) Close(ctx context.Context) error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.amqpLinks.ClosedPermanently() { - return nil - } - - ctx, span := tab.StartSpan(ctx, tracing.SpanProcessorClose) - defer span.End() - - defer func() { - if err := p.amqpLinks.Close(ctx, true); err != nil { - span.Logger().Debug(fmt.Sprintf("Error closing amqpLinks on processor.Close(): %s", err.Error())) - } - }() - - p.cleanupOnClose() - - _, receiver, _, _, err := p.amqpLinks.Get(ctx) - - if err != nil { - span.Logger().Error(err) - return err - } - - if err := receiver.DrainCredit(ctx); err != nil { - span.Logger().Error(err) - // fall through for now and just let whatever is going on finish - // otherwise they might not be able to actually close. - } - - p.cancelReceivers() - return utils.WaitForGroupOrContext(ctx, &p.wg) -} - -// CompleteMessage completes a message, deleting it from the queue or subscription. -func (p *processor) CompleteMessage(ctx context.Context, message *ReceivedMessage) error { - return p.settler.CompleteMessage(ctx, message) -} - -// AbandonMessage will cause a message to be returned to the queue or subscription. -// This will increment its delivery count, and potentially cause it to be dead lettered -// depending on your queue or subscription's configuration. -func (p *processor) AbandonMessage(ctx context.Context, message *ReceivedMessage, options *AbandonMessageOptions) error { - return p.settler.AbandonMessage(ctx, message, options) -} - -// DeferMessage will cause a message to be deferred. Deferred messages -// can be received using `Receiver.ReceiveDeferredMessages`. -func (p *processor) DeferMessage(ctx context.Context, message *ReceivedMessage, options *DeferMessageOptions) error { - return p.settler.DeferMessage(ctx, message, options) -} - -// DeadLetterMessage settles a message by moving it to the dead letter queue for a -// queue or subscription. To receive these messages create a processor with `Client.NewProcessorForQueue()` -// or `Client.NewProcessorForSubscription()` using the `ProcessorOptions.SubQueue` option. -func (p *processor) DeadLetterMessage(ctx context.Context, message *ReceivedMessage, options *DeadLetterOptions) error { - return p.settler.DeadLetterMessage(ctx, message, options) -} - -// subscribe continually receives messages from Service Bus, stopping -// if a fatal link/connection error occurs. -func (p *processor) subscribe() error { - p.wg.Add(1) - defer p.wg.Done() - - for { - _, receiver, _, linkRevision, err := p.amqpLinks.Get(p.receiversCtx) - - if err != nil { - if internal.IsCancelError(err) { - return err - } - - if err := p.amqpLinks.RecoverIfNeeded(p.receiversCtx, linkRevision, err); err != nil { - p.userErrorHandler(err) - return err - } - } - - amqpMessage, err := receiver.Receive(p.receiversCtx) - - if err != nil { - if internal.IsCancelError(err) { - return err - } - - if err := p.amqpLinks.RecoverIfNeeded(p.receiversCtx, linkRevision, err); err != nil { - p.userErrorHandler(err) - } - - return nil - } - - if amqpMessage == nil { - // amqpMessage shouldn't be nil here, but somehow it is. - // need to track this down in the AMQP library. - continue - } - - p.wg.Add(1) - - go func() { - defer p.wg.Done() - - // purposefully avoiding using `ctx`. We always let processing complete - // for message threads to avoid potential message loss. - _ = p.processMessage(context.Background(), receiver, amqpMessage) - }() - } -} - -func (p *processor) processMessage(ctx context.Context, receiver internal.AMQPReceiver, amqpMessage *amqp.Message) error { - ctx, span := tab.StartSpan(ctx, tracing.SpanProcessorMessage) - defer span.End() - - receivedMessage := newReceivedMessage(ctx, amqpMessage) - messageHandlerErr := p.userMessageHandler(receivedMessage) - - if messageHandlerErr != nil { - p.userErrorHandler(messageHandlerErr) - } - - if p.autoComplete { - var settleErr error - - if messageHandlerErr != nil { - settleErr = p.settler.AbandonMessage(ctx, receivedMessage, nil) - } else { - settleErr = p.settler.CompleteMessage(ctx, receivedMessage) - } - - if settleErr != nil { - p.userErrorHandler(fmt.Errorf("failed to settle message with ID '%s': %w", receivedMessage.MessageID, settleErr)) - return settleErr - } - } - - select { - case <-p.receiversCtx.Done(): - return nil - default: - } - - if err := receiver.IssueCredit(1); err != nil { - if !internal.IsDrainingError(err) { - p.userErrorHandler(err) - return fmt.Errorf("failed issuing additional credit, processor will be restarted: %w", err) - } - } - - return nil -} - -func checkReceiverMode(receiveMode ReceiveMode) error { - if receiveMode == ReceiveModePeekLock || receiveMode == ReceiveModeReceiveAndDelete { - return nil - } else { - return fmt.Errorf("invalid receive mode %d, must be either azservicebus.PeekLock or azservicebus.ReceiveAndDelete", receiveMode) - } -} - -// newProcessorForQueue creates a Processor for a queue. -func newProcessorForQueue(client *Client, queue string, options *processorOptions) (*processor, error) { - id, cleanupOnClose := client.getCleanupForCloseable() - - processor, err := newProcessor(client.namespace, &entity{Queue: queue}, cleanupOnClose, options) - - if err != nil { - return nil, err - } - - client.addCloseable(id, processor) - return processor, nil -} - -// newProcessorForQueue creates a Processor for a subscription. -func newProcessorForSubscription(client *Client, topic string, subscription string, options *processorOptions) (*processor, error) { - id, cleanupOnClose := client.getCleanupForCloseable() - - processor, err := newProcessor(client.namespace, &entity{Topic: topic, Subscription: subscription}, cleanupOnClose, options) - - if err != nil { - return nil, err - } - - client.addCloseable(id, processor) - return processor, nil -} diff --git a/sdk/messaging/azservicebus/processor_test.go b/sdk/messaging/azservicebus/processor_test.go deleted file mode 100644 index 73e85cb4101c..000000000000 --- a/sdk/messaging/azservicebus/processor_test.go +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azservicebus - -import ( - "context" - "errors" - "fmt" - "sort" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestProcessorReceiveWithDefaults(t *testing.T) { - serviceBusClient, cleanup, queueName := setupLiveTest(t, nil) - defer cleanup() - - go func() { - sender, err := serviceBusClient.NewSender(queueName, nil) - require.NoError(t, err) - - defer sender.Close(context.Background()) - - // it's perfectly fine to have the processor started before the messages - // have been sent. - for i := 0; i < 5; i++ { - err = sender.SendMessage(context.Background(), &Message{ - Body: []byte(fmt.Sprintf("hello world %d", i)), - }) - - time.Sleep(time.Second) - } - require.NoError(t, err) - }() - - processor, err := newProcessorForQueue(serviceBusClient, queueName, nil) - require.NoError(t, err) - - defer processor.Close(context.Background()) // multiple close is fine - - var messages []string - mu := sync.Mutex{} - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - - err = processor.Start(ctx, func(m *ReceivedMessage) error { - mu.Lock() - defer mu.Unlock() - - body, err := m.Body() - require.NoError(t, err) - - messages = append(messages, string(body)) - - if len(messages) == 5 { - cancel() - } - - return nil - }, func(err error) { - if errors.Is(err, context.Canceled) { - return - } - - require.NoError(t, err) - }) - - require.ErrorIs(t, err, context.Canceled, "cancelling the context stops the processor") - - sort.Strings(messages) - - require.EqualValues(t, []string{ - "hello world 0", - "hello world 1", - "hello world 2", - "hello world 3", - "hello world 4", - }, messages) - - require.NoError(t, processor.Close(context.Background())) -} - -func TestProcessorReceiveWith100MessagesWithMaxConcurrency(t *testing.T) { - serviceBusClient, cleanup, queueName := setupLiveTest(t, nil) - defer cleanup() - - const numMessages = 100 - var expectedBodies []string - - go func() { - sender, err := serviceBusClient.NewSender(queueName, nil) - require.NoError(t, err) - - defer sender.Close(context.Background()) - - batch, err := sender.NewMessageBatch(context.Background(), nil) - require.NoError(t, err) - - // it's perfectly fine to have the processor started before the messages - // have been sent. - for i := 0; i < numMessages; i++ { - expectedBodies = append(expectedBodies, fmt.Sprintf("hello world %03d", i)) - err := batch.AddMessage(&Message{ - Body: []byte(expectedBodies[len(expectedBodies)-1]), - }) - require.NoError(t, err) - } - - require.NoError(t, sender.SendMessageBatch(context.Background(), batch)) - }() - - processor, err := newProcessorForQueue( - serviceBusClient, - queueName, - &processorOptions{ - MaxConcurrentCalls: 20, - }) - - require.NoError(t, err) - - defer func() { - require.NoError(t, processor.Close(context.Background())) // multiple close is fine - }() - - var messages []string - mu := sync.Mutex{} - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2) - defer cancel() - - err = processor.Start(ctx, func(m *ReceivedMessage) error { - mu.Lock() - defer mu.Unlock() - - body, err := m.Body() - require.NoError(t, err) - messages = append(messages, string(body)) - - if len(messages) == 100 { - go processor.Close(context.Background()) - } - - return nil - }, func(err error) { - if errors.Is(err, context.Canceled) { - return - } - - require.NoError(t, err) - }) - - require.NoError(t, err) - - sort.Strings(messages) - require.EqualValues(t, expectedBodies, messages) - - require.NoError(t, processor.Close(ctx)) -} - -func TestProcessorUnitTests(t *testing.T) { - p := &processor{} - e := &entity{} - - require.NoError(t, applyProcessorOptions(p, e, nil)) - require.True(t, p.autoComplete) - require.EqualValues(t, 1, p.maxConcurrentCalls) - require.EqualValues(t, ReceiveModePeekLock, p.receiveMode) - - p = &processor{} - e = &entity{ - Queue: "queue", - } - - require.NoError(t, applyProcessorOptions(p, e, &processorOptions{ - ReceiveMode: ReceiveModeReceiveAndDelete, - SubQueue: SubQueueDeadLetter, - DisableAutoComplete: true, - MaxConcurrentCalls: 101, - })) - - require.False(t, p.autoComplete) - require.EqualValues(t, 101, p.maxConcurrentCalls) - require.EqualValues(t, ReceiveModeReceiveAndDelete, p.receiveMode) - fullEntityPath, err := e.String() - require.NoError(t, err) - require.EqualValues(t, "queue/$DeadLetterQueue", fullEntityPath) -} - -// func TestProcessorUnitTests(t *testing.T) { -// t.Run("Processor", func(t *testing.T) { -// t.Run("StartAndClose", func(t *testing.T) { - -// }) - -// t.Run("CloseWaitsForActiveSubscribersToExit", func(t *testing.T) { -// }) - -// t.Run("CloseWithoutStart", func(t *testing.T) { - -// }) - -// t.Run("DoubleClose", func(t *testing.T) { - -// }) -// }) - -// t.Run("subscribe", func(t *testing.T) { -// t.Run("cancelled by user does not retry", func(t *testing.T) { -// ctx, cancel := context.WithCancel(context.Background()) -// cancel() // pre-cancel this context - -// receiver := internal.NewFakeLegacyReceiver() -// var cancelledError error - -// retry := subscribe(ctx, receiver, true, func(message *ReceivedMessage) error { -// return nil -// }, func(err error) { -// cancelledError = err -// }) - -// require.EqualError(t, cancelledError, context.Canceled.Error()) -// require.False(t, retry, "User cancelling the context will not be retried") -// require.False(t, receiver.CloseCalled) // subscribe() is not responsible for the lifetime of the receiver -// }) - -// t.Run("error in the listener is retryable", func(t *testing.T) { -// receiver := internal.NewFakeLegacyReceiver() - -// receiver.ListenImpl = func(ctx context.Context, handler internal.Handler) internal.ListenerHandle { -// ch := make(chan struct{}) -// close(ch) -// return &internal.FakeListenerHandle{ -// DoneChan: ch, -// ErrValue: errors.New("Some AMQP related error"), -// } -// } - -// var errorFromListener error - -// retry := subscribe(context.Background(), receiver, true, func(message *ReceivedMessage) error { -// return nil -// }, func(err error) { -// errorFromListener = err -// }) - -// require.EqualError(t, errorFromListener, "Some AMQP related error") -// require.True(t, retry, "AMQP errors will cause us to retry") -// require.False(t, receiver.CloseCalled) // subscribe() is not responsible for the lifetime of the receiver -// }) - -// }) - -// t.Run("handleSingleMessage", func(t *testing.T) { -// fakeMessage := &internal.Message{ -// ID: "fakeID", -// LockToken: &uuid.UUID{}, -// SystemProperties: &internal.SystemProperties{ -// SequenceNumber: to.Int64Ptr(1), -// }, -// } - -// setup := func() *internal.FakeInternalReceiver { -// receiver := internal.NewFakeLegacyReceiver() -// return receiver -// } - -// t.Run("AutoCompleteCompleteMessage", func(t *testing.T) { -// receiver := setup() - -// handleSingleMessage(func(message *ReceivedMessage) error { -// // successful return -// return nil -// }, func(err error) { -// require.NoError(t, err) -// }, true, receiver, fakeMessage) - -// require.True(t, receiver.CompleteCalled) -// require.False(t, receiver.AbandonCalled) -// }) - -// t.Run("AutoCompleteAbandonMessage", func(t *testing.T) { -// receiver := setup() - -// handleSingleMessage(func(message *ReceivedMessage) error { -// // error returned will abandon -// return errors.New("Purposefully reported error") -// }, func(err error) { -// require.EqualErrorf(t, err, "Purposefully reported error", "Error from the handler gets forwarded") -// }, true, receiver, fakeMessage) - -// require.True(t, receiver.AbandonCalled) -// require.False(t, receiver.CompleteCalled) -// }) - -// t.Run("AutoCompleteAlreadySettledDoNotSettleTwice)", func(t *testing.T) { -// receiver := setup() - -// handleSingleMessage(func(message *ReceivedMessage) error { -// // error returned will abandon -// return errors.New("Purposefully reported error") -// }, func(err error) { -// require.EqualErrorf(t, err, "Purposefully reported error", "Error from the handler gets forwarded") -// }, true, receiver, fakeMessage) - -// // TODO: neither should be called - the message was already settled. -// require.True(t, receiver.AbandonCalled) -// require.False(t, receiver.CompleteCalled) -// }) - -// t.Run("autoComplete (off)", func(t *testing.T) { -// receiver := setup() - -// handleSingleMessage(func(message *ReceivedMessage) error { -// // successful return -// return nil -// }, func(err error) { -// require.NoError(t, err) -// }, false, receiver, fakeMessage) - -// require.False(t, receiver.CompleteCalled) -// require.False(t, receiver.AbandonCalled) -// }) - -// t.Run("SettlementErrorsAreForwarded(complete)", func(t *testing.T) { -// receiver := setup() - -// receiver.CompleteMessageImpl = func(ctx context.Context, msg *internal.Message) error { -// return errors.New("Complete failed") -// } - -// var settleError error - -// handleSingleMessage(func(message *ReceivedMessage) error { -// return nil -// }, func(err error) { -// settleError = err -// }, true, receiver, fakeMessage) - -// require.EqualError(t, settleError, "Complete failed") -// }) - -// t.Run("SettlementErrorsAreForwarded(abandon)", func(t *testing.T) { -// receiver := setup() - -// receiver.AbandonMessageImpl = func(ctx context.Context, msg *internal.Message) error { -// return errors.New("Abandon failed") -// } - -// var settleErrors []string - -// handleSingleMessage(func(message *ReceivedMessage) error { -// return errors.New("Error that caused the abandon") -// }, func(err error) { -// settleErrors = append(settleErrors, err.Error()) -// }, true, receiver, fakeMessage) - -// require.EqualValues(t, settleErrors, []string{ -// "Error that caused the abandon", -// "Abandon failed", -// }) -// }) -// }) -// } diff --git a/sdk/messaging/azservicebus/receiver.go b/sdk/messaging/azservicebus/receiver.go index 55d92bf87d81..a0519406501e 100644 --- a/sdk/messaging/azservicebus/receiver.go +++ b/sdk/messaging/azservicebus/receiver.go @@ -10,7 +10,9 @@ import ( "sync" "time" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" "github.com/devigned/tab" ) @@ -42,9 +44,10 @@ const ( // Receiver receives messages using pull based functions (ReceiveMessages). type Receiver struct { receiveMode ReceiveMode + entityPath string settler settler - baseRetrier internal.Retrier + retryOptions utils.RetryOptions cleanupOnClose func() lastPeekedSequenceNumber int64 @@ -74,65 +77,68 @@ type ReceiverOptions struct { // SubQueue should be set to connect to the sub queue (ex: dead letter queue) // of the queue or subscription. SubQueue SubQueue + + retryOptions utils.RetryOptions } const defaultLinkRxBuffer = 2048 func applyReceiverOptions(receiver *Receiver, entity *entity, options *ReceiverOptions) error { + if options == nil { receiver.receiveMode = ReceiveModePeekLock - return nil - } + } else { + if err := checkReceiverMode(options.ReceiveMode); err != nil { + return err + } - if err := checkReceiverMode(options.ReceiveMode); err != nil { - return err + receiver.receiveMode = options.ReceiveMode + + if err := entity.SetSubQueue(options.SubQueue); err != nil { + return err + } + + receiver.retryOptions = options.retryOptions } - receiver.receiveMode = options.ReceiveMode + entityPath, err := entity.String() - if err := entity.SetSubQueue(options.SubQueue); err != nil { + if err != nil { return err } + receiver.entityPath = entityPath return nil } -func newReceiver(ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupOnClose func(), options *ReceiverOptions, newLinksFn func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error)) (*Receiver, error) { +type newReceiverArgs struct { + ns internal.NamespaceWithNewAMQPLinks + entity entity + cleanupOnClose func() + newLinkFn func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) +} + +func newReceiver(args newReceiverArgs, options *ReceiverOptions) (*Receiver, error) { receiver := &Receiver{ lastPeekedSequenceNumber: 0, - // TODO: make this configurable - baseRetrier: internal.NewBackoffRetrier(internal.BackoffRetrierParams{ - Factor: 1.5, - Jitter: true, - Min: time.Second, - Max: time.Minute, - MaxRetries: 10, - }), - cleanupOnClose: cleanupOnClose, + cleanupOnClose: args.cleanupOnClose, } - if err := applyReceiverOptions(receiver, entity, options); err != nil { + if err := applyReceiverOptions(receiver, &args.entity, options); err != nil { return nil, err } - entityPath, err := entity.String() + newLinkFn := receiver.newReceiverLink - if err != nil { - return nil, err - } - - if newLinksFn == nil { - newLinksFn = func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { - linkOptions := createLinkOptions(receiver.receiveMode, entityPath) - return createReceiverLink(ctx, session, linkOptions) - } + if args.newLinkFn != nil { + newLinkFn = args.newLinkFn } - receiver.amqpLinks = ns.NewAMQPLinks(entityPath, newLinksFn) + receiver.amqpLinks = args.ns.NewAMQPLinks(receiver.entityPath, newLinkFn) // 'nil' settler handles returning an error message for receiveAndDelete links. if receiver.receiveMode == ReceiveModePeekLock { - receiver.settler = newMessageSettler(receiver.amqpLinks, receiver.baseRetrier) + receiver.settler = newMessageSettler(receiver.amqpLinks, receiver.retryOptions) } else { receiver.settler = (*messageSettler)(nil) } @@ -140,6 +146,12 @@ func newReceiver(ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupO return receiver, nil } +func (r *Receiver) newReceiverLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { + linkOptions := createLinkOptions(r.receiveMode, r.entityPath) + link, err := createReceiverLink(ctx, session, linkOptions) + return nil, link, err +} + // ReceiveMessagesOptions are options for the ReceiveMessages function. type ReceiveMessagesOptions struct { // For future expansion @@ -174,25 +186,27 @@ func (r *Receiver) ReceiveMessages(ctx context.Context, maxMessages int, options // ReceiveDeferredMessages receives messages that were deferred using `Receiver.DeferMessage`. func (r *Receiver) ReceiveDeferredMessages(ctx context.Context, sequenceNumbers []int64) ([]*ReceivedMessage, error) { - _, _, mgmt, _, err := r.amqpLinks.Get(ctx) + var receivedMessages []*ReceivedMessage - if err != nil { - return nil, err - } + err := r.amqpLinks.Retry(ctx, "receiveDeferredMessage", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error { + amqpMessages, err := internal.ReceiveDeferred(ctx, lwid.RPC, r.receiveMode, sequenceNumbers) - amqpMessages, err := mgmt.ReceiveDeferred(ctx, r.receiveMode, sequenceNumbers) + if err != nil { + return err + } - if err != nil { - return nil, err - } + for _, amqpMsg := range amqpMessages { + receivedMsg := newReceivedMessage(amqpMsg) + receivedMsg.deferred = true - var receivedMessages []*ReceivedMessage + receivedMessages = append(receivedMessages, receivedMsg) + } - for _, amqpMsg := range amqpMessages { - receivedMsg := newReceivedMessage(ctx, amqpMsg) - receivedMsg.deferred = true + return nil + }, utils.RetryOptions(r.retryOptions)) - receivedMessages = append(receivedMessages, receivedMsg) + if err != nil { + return nil, err } return receivedMessages, nil @@ -210,35 +224,39 @@ type PeekMessagesOptions struct { // like CompleteMessage, AbandonMessage, DeferMessage or DeadLetterMessage // will not work with them. func (r *Receiver) PeekMessages(ctx context.Context, maxMessageCount int, options *PeekMessagesOptions) ([]*ReceivedMessage, error) { - _, _, mgmt, _, err := r.amqpLinks.Get(ctx) + var receivedMessages []*ReceivedMessage - if err != nil { - return nil, err - } + err := r.amqpLinks.Retry(ctx, "peekMessages", func(ctx context.Context, links *internal.LinksWithID, args *utils.RetryFnArgs) error { + var sequenceNumber = r.lastPeekedSequenceNumber + 1 + updateInternalSequenceNumber := true - var sequenceNumber = r.lastPeekedSequenceNumber + 1 - updateInternalSequenceNumber := true + if options != nil && options.FromSequenceNumber != nil { + sequenceNumber = *options.FromSequenceNumber + updateInternalSequenceNumber = false + } - if options != nil && options.FromSequenceNumber != nil { - sequenceNumber = *options.FromSequenceNumber - updateInternalSequenceNumber = false - } + messages, err := internal.PeekMessages(ctx, links.RPC, sequenceNumber, int32(maxMessageCount)) - messages, err := mgmt.PeekMessages(ctx, sequenceNumber, int32(maxMessageCount)) + if err != nil { + return err + } - if err != nil { - return nil, err - } + receivedMessages = make([]*ReceivedMessage, len(messages)) - receivedMessages := make([]*ReceivedMessage, len(messages)) + for i := 0; i < len(messages); i++ { + receivedMessages[i] = newReceivedMessage(messages[i]) + } - for i := 0; i < len(messages); i++ { - receivedMessages[i] = newReceivedMessage(ctx, messages[i]) - } + if len(receivedMessages) > 0 && updateInternalSequenceNumber { + // only update this if they're doing the implicit iteration as part of the receiver. + r.lastPeekedSequenceNumber = *receivedMessages[len(receivedMessages)-1].SequenceNumber + } + + return nil + }, r.retryOptions) - if len(receivedMessages) > 0 && updateInternalSequenceNumber { - // only update this if they're doing the implicit iteration as part of the receiver. - r.lastPeekedSequenceNumber = *receivedMessages[len(receivedMessages)-1].SequenceNumber + if err != nil { + return nil, err } return receivedMessages, nil @@ -246,22 +264,19 @@ func (r *Receiver) PeekMessages(ctx context.Context, maxMessageCount int, option // RenewLock renews the lock on a message, updating the `LockedUntil` field on `msg`. func (r *Receiver) RenewMessageLock(ctx context.Context, msg *ReceivedMessage) error { - _, _, mgmt, _, err := r.amqpLinks.Get(ctx) + return r.amqpLinks.Retry(ctx, "renewMessageLock", func(ctx context.Context, linksWithVersion *internal.LinksWithID, args *utils.RetryFnArgs) error { + newExpirationTime, err := internal.RenewLocks(ctx, linksWithVersion.RPC, msg.rawAMQPMessage.LinkName(), []amqp.UUID{ + (amqp.UUID)(msg.LockToken), + }) - if err != nil { - return err - } - - newExpirationTime, err := mgmt.RenewLocks(ctx, msg.rawAMQPMessage.LinkName(), []amqp.UUID{ - (amqp.UUID)(msg.LockToken), - }) + if err != nil { + return err + } - if err != nil { - return err - } + msg.LockedUntil = &newExpirationTime[0] + return nil + }, r.retryOptions) - msg.LockedUntil = &newExpirationTime[0] - return nil } // Close permanently closes the receiver. @@ -333,20 +348,10 @@ func (r *Receiver) receiveMessagesImpl(ctx context.Context, maxMessages int, opt // user isn't actually waiting for anymore. So we make sure that #3 runs if the // link is still valid. // Phase 3. - _, receiver, _, linksRevision, err := r.amqpLinks.Get(ctx) - if err != nil { - if err := r.amqpLinks.RecoverIfNeeded(ctx, linksRevision, err); err != nil { - return nil, err - } - - return nil, err - } - - if err := receiver.IssueCredit(uint32(maxMessages)); err != nil { - _ = r.amqpLinks.RecoverIfNeeded(ctx, linksRevision, err) - return nil, err - } + var all []*ReceivedMessage + var cancel context.CancelFunc + needsDrain := true defaultTimeAfterFirstMessage := 20 * time.Millisecond @@ -354,97 +359,103 @@ func (r *Receiver) receiveMessagesImpl(ctx context.Context, maxMessages int, opt defaultTimeAfterFirstMessage = time.Second } - messages, err := r.getMessages(ctx, receiver, maxMessages, defaultTimeAfterFirstMessage) + var linksWithID *internal.LinksWithID + + err := r.amqpLinks.Retry(ctx, "receiveMessages", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error { + if args.LastErr != nil { + // we recovered succesfully (amqpLinks does it for us), we can reset our retry attempts. + // This fixes a potential problem where something like this happens: + // a. amqplink.Receive() returns an error + // b. we attempt recovery a few times and recover just in the nick of time at attempt #3. + // d. amqplink.Receive() blocks (no messages in queue, for instance) + // + // e. amqpLink.Receive() returns an error + // + // We don't want to the # of attempts to already be '3' when we get to + // step 5, so we reset here so the next set of retry attempts starts fresh. + args.ResetAttempts() + } - if err != nil { - return nil, err - } + linksWithID = lwid + credits := maxMessages - len(all) - if len(messages) == maxMessages { - // no drain needed, all messages arrived. - return messages, nil - } + if err := lwid.Receiver.IssueCredit(uint32(credits)); err != nil { + return err + } - return r.drainLink(ctx, receiver, messages) -} + got := 0 -// drainLink initiates a drainLink on the link. Service Bus will send whatever messages it might have still had and -// set our link credit to 0. -// ctxForLoggingOnly is literally only used for when we need to extract context for logging. This function will always attempt -// to complete, ignoring cancellation, otherwise we can leave the link with messages that haven't been returned to the user. -func (r *Receiver) drainLink(ctxForLoggingOnly context.Context, receiver internal.AMQPReceiver, messages []*ReceivedMessage) ([]*ReceivedMessage, error) { - // start the drain asynchronously. Note that we ignore the user's context at this point - // since draining makes sure we don't get messages when nobody is receiving. - if err := receiver.DrainCredit(context.Background()); err != nil { - tab.For(ctxForLoggingOnly).Debug(fmt.Sprintf("Draining of credit failed. link will be closed and will re-open on next receive: %s", err.Error())) - - // if the drain fails we just close the link so it'll re-open at the next receive. - if err := r.amqpLinks.Close(context.Background(), false); err != nil { - tab.For(ctxForLoggingOnly).Debug(fmt.Sprintf("Failed to close links on ReceiveMessages cleanup. Not fatal: %s", err.Error())) - } - } + for { + amqpMessage, err := lwid.Receiver.Receive(ctx) - // Draining data from the receiver's prefetched queue. This won't wait for new messages to - // arrive, so it'll only receive messages that arrived prior to the drain. - for { - am, err := receiver.Prefetched(context.Background()) + if err != nil { + return err + } - if am == nil || internal.IsCancelError(err) { - break - } + all = append(all, newReceivedMessage(amqpMessage)) + got++ - if err != nil { - // something fatal happened, we will just - _ = r.amqpLinks.Close(context.TODO(), false) + if got == credits { + // no excess credits on link + needsDrain = false + break + } - if len(messages) > 0 { - return messages, nil - } else { - return nil, err + if cancel == nil { + // replace the context that we're using for everything with a new one that will cancel + // after a period of time. + ctx, cancel = context.WithTimeout(ctx, defaultTimeAfterFirstMessage) + defer cancel() } } - messages = append(messages, newReceivedMessage(ctxForLoggingOnly, am)) + return nil + }, utils.RetryOptions(r.retryOptions)) + + ret := func(err error) ([]*ReceivedMessage, error) { + if len(all) > 0 { + // we don't return the error here because we did retrieve _some_ messages and you can still + // use them. + return all, nil + } else { + return nil, err + } } - return messages, nil -} - -// getMessages receives messages until a link failure, timeout or the user -// cancels their context. -func (r *Receiver) getMessages(ctx context.Context, receiver internal.AMQPReceiver, maxMessages int, maxWaitTimeAfterFirstMessage time.Duration) ([]*ReceivedMessage, error) { - var messages []*ReceivedMessage - - for { - var amqpMessage *amqp.Message - amqpMessage, err := receiver.Receive(ctx) + if err != nil && !internal.IsCancelError(err) { + return ret(err) + } - if err != nil { - if internal.IsCancelError(err) { - return messages, nil + if needsDrain { + // start the drain asynchronously. Note that we ignore the user's context at this point + // since draining makes sure we don't get messages when nobody is receiving. + if err := linksWithID.Receiver.DrainCredit(context.Background()); err != nil { + if err := r.amqpLinks.RecoverIfNeeded(context.Background(), linksWithID.ID, err); err != nil { + log.Writef(internal.EventReceiver, "Failed to recover links after a failed drain: %s", err.Error()) + return ret(err) } - // we'll close (instead of recovering) since we are holding onto messages - // and want to get them back to the user ASAP. (recovery will just happen - // on the next call to receive) - if err := r.amqpLinks.Close(context.Background(), false); err != nil { - tab.For(ctx).Debug(fmt.Sprintf("Failed to close links on ReceiveMessages cleanup. Not fatal: %s", err.Error())) - } - return nil, err + return ret(err) } - messages = append(messages, newReceivedMessage(ctx, amqpMessage)) + // Draining data from the receiver's prefetched queue. This won't wait for new messages to + // arrive, so it'll only receive messages that arrived prior to the drain. + for { + am, err := linksWithID.Receiver.Prefetched(context.Background()) - if len(messages) == maxMessages { - return messages, nil - } + if am == nil || internal.IsCancelError(err) { + return all, nil + } - if len(messages) == 1 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, maxWaitTimeAfterFirstMessage) - defer cancel() + if err != nil { + return ret(err) + } + + all = append(all, newReceivedMessage(am)) } } + + return all, nil } type entity struct { @@ -485,15 +496,15 @@ func (e *entity) SetSubQueue(subQueue SubQueue) error { return fmt.Errorf("unknown SubQueue %d", subQueue) } -func createReceiverLink(ctx context.Context, session internal.AMQPSession, linkOptions []amqp.LinkOption) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { +func createReceiverLink(ctx context.Context, session internal.AMQPSession, linkOptions []amqp.LinkOption) (internal.AMQPReceiverCloser, error) { amqpReceiver, err := session.NewReceiver(linkOptions...) if err != nil { tab.For(ctx).Error(err) - return nil, nil, err + return nil, err } - return nil, amqpReceiver, nil + return amqpReceiver, nil } func createLinkOptions(mode ReceiveMode, entityPath string) []amqp.LinkOption { @@ -516,3 +527,11 @@ func createLinkOptions(mode ReceiveMode, entityPath string) []amqp.LinkOption { return opts } + +func checkReceiverMode(receiveMode ReceiveMode) error { + if receiveMode == ReceiveModePeekLock || receiveMode == ReceiveModeReceiveAndDelete { + return nil + } + + return fmt.Errorf("invalid receive mode %d, must be either azservicebus.PeekLock or azservicebus.ReceiveAndDelete", receiveMode) +} diff --git a/sdk/messaging/azservicebus/receiver_test.go b/sdk/messaging/azservicebus/receiver_test.go index c0363442f460..6cbd067462a2 100644 --- a/sdk/messaging/azservicebus/receiver_test.go +++ b/sdk/messaging/azservicebus/receiver_test.go @@ -11,7 +11,10 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" "github.com/stretchr/testify/require" ) @@ -343,6 +346,62 @@ func TestReceiverPeek(t *testing.T) { require.Empty(t, noMessagesExpected) } +func TestReceiverDetach(t *testing.T) { + // NOTE: uncomment this to see some of the background reconnects + // azlog.SetListener(func(e azlog.Event, s string) { + // log.Printf("%s %s", e, s) + // }) + + serviceBusClient, cleanup, queueName := setupLiveTest(t, nil) + defer cleanup() + + adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil) + require.NoError(t, err) + + receiver, err := serviceBusClient.NewReceiverForQueue(queueName, nil) + require.NoError(t, err) + + // make sure the receiver link and connection are live. + _, err = receiver.PeekMessages(context.Background(), 1, nil) + require.NoError(t, err) + + sender, err := serviceBusClient.NewSender(queueName, nil) + require.NoError(t, err) + + err = sender.SendMessage(context.Background(), &Message{ + Body: []byte("hello world"), + }) + require.NoError(t, err) + require.NoError(t, sender.Close(context.Background())) + + // force a detach to happen + _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil) + require.NoError(t, err) + + messages, err := receiver.ReceiveMessages(context.Background(), 1, nil) + require.NoError(t, err) + require.EqualValues(t, []string{"hello world"}, getSortedBodies(messages)) + + // force a detach to happen + _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil) + require.NoError(t, err) + + peekedMessages, err := receiver.PeekMessages(context.Background(), 1, nil) + require.NoError(t, err) + require.EqualValues(t, []string{"hello world"}, getSortedBodies(peekedMessages)) + + // force a detach to happen + _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil) + require.NoError(t, err) + + require.NoError(t, receiver.CompleteMessage(context.Background(), messages[0])) + + // and last, check that the queue is properly empty + peekedMessages, err = receiver.PeekMessages(context.Background(), 1, nil) + require.NoError(t, err) + require.Empty(t, peekedMessages) +} + func TestReceiver_RenewMessageLock(t *testing.T) { client, cleanup, queueName := setupLiveTest(t, nil) defer cleanup() @@ -401,19 +460,23 @@ func TestReceiverOptions(t *testing.T) { require.NoError(t, applyReceiverOptions(receiver, e, &ReceiverOptions{ ReceiveMode: ReceiveModeReceiveAndDelete, SubQueue: SubQueueTransfer, + retryOptions: utils.RetryOptions{ + MaxRetries: 101, + }, })) require.EqualValues(t, ReceiveModeReceiveAndDelete, receiver.receiveMode) path, err = e.String() require.NoError(t, err) require.EqualValues(t, "topic/Subscriptions/subscription/$Transfer/$DeadLetterQueue", path) + require.EqualValues(t, 101, receiver.retryOptions.MaxRetries) } -type badMgmtClient struct { - internal.MgmtClient +type badRPCLink struct { + internal.RPCLink } -func (b badMgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error) { +func (br *badRPCLink) RPC(ctx context.Context, msg *amqp.Message) (*internal.RPCResponse, error) { return nil, errors.New("receive deferred messages failed") } @@ -430,7 +493,7 @@ func TestReceiverDeferUnitTests(t *testing.T) { r = &Receiver{ amqpLinks: &internal.FakeAMQPLinks{ - Mgmt: &badMgmtClient{}, + RPC: &badRPCLink{}, }, } diff --git a/sdk/messaging/azservicebus/sender.go b/sdk/messaging/azservicebus/sender.go index b8613e4c906c..e6457fc685d1 100644 --- a/sdk/messaging/azservicebus/sender.go +++ b/sdk/messaging/azservicebus/sender.go @@ -9,6 +9,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" "github.com/devigned/tab" ) @@ -19,7 +20,7 @@ type ( queueOrTopic string cleanupOnClose func() links internal.AMQPLinks - retryOptions internal.RetryOptions + retryOptions utils.RetryOptions } ) @@ -40,25 +41,10 @@ type MessageBatchOptions struct { // messages. Sending a batch of messages is more efficient than sending the // messages one at a time. func (s *Sender) NewMessageBatch(ctx context.Context, options *MessageBatchOptions) (*MessageBatch, error) { - var lastRevision uint64 var batch *MessageBatch - err := internal.Retry(ctx, "send", func(ctx context.Context, args *internal.RetryFnArgs) error { - if args.LastErr != nil { - if err := s.links.RecoverIfNeeded(ctx, lastRevision, args.LastErr); err != nil { - return err - } - } - - sender, _, _, lr, err := s.links.Get(ctx) - - if err != nil { - return err - } - - lastRevision = lr - - maxBytes := sender.MaxMessageSize() + err := s.links.Retry(ctx, "send", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error { + maxBytes := lwid.Sender.MaxMessageSize() if options != nil && options.MaxBytes != 0 { maxBytes = options.MaxBytes @@ -66,7 +52,7 @@ func (s *Sender) NewMessageBatch(ctx context.Context, options *MessageBatchOptio batch = newMessageBatch(maxBytes) return nil - }, nil, s.retryOptions) + }, s.retryOptions) if err != nil { return nil, err @@ -77,55 +63,23 @@ func (s *Sender) NewMessageBatch(ctx context.Context, options *MessageBatchOptio // SendMessage sends a Message to a queue or topic. func (s *Sender) SendMessage(ctx context.Context, message *Message) error { - ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendMessage) - defer span.End() - - var lastRevision uint64 - - return internal.Retry(ctx, "send", func(ctx context.Context, args *internal.RetryFnArgs) error { - if args.LastErr != nil { - if err := s.links.RecoverIfNeeded(ctx, lastRevision, args.LastErr); err != nil { - return err - } - } - - sender, _, _, lr, err := s.links.Get(ctx) - - if err != nil { - return err - } - - lastRevision = lr + return s.links.Retry(ctx, "SendMessage", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error { + ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendMessage) + defer span.End() - return sender.Send(ctx, message.toAMQPMessage()) - }, nil, s.retryOptions) + return lwid.Sender.Send(ctx, message.toAMQPMessage()) + }, utils.RetryOptions(s.retryOptions)) } // SendMessageBatch sends a MessageBatch to a queue or topic. // Message batches can be created using `Sender.NewMessageBatch`. func (s *Sender) SendMessageBatch(ctx context.Context, batch *MessageBatch) error { - ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendBatch) - defer span.End() - - var lastRevision uint64 - - return internal.Retry(ctx, "send", func(ctx context.Context, args *internal.RetryFnArgs) error { - if args.LastErr != nil { - if err := s.links.RecoverIfNeeded(ctx, lastRevision, args.LastErr); err != nil { - return err - } - } + return s.links.Retry(ctx, "SendMessageBatch", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error { + ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendBatch) + defer span.End() - sender, _, _, lr, err := s.links.Get(ctx) - - if err != nil { - return err - } - - lastRevision = lr - - return sender.Send(ctx, batch.toAMQPMessage()) - }, nil, s.retryOptions) + return lwid.Sender.Send(ctx, batch.toAMQPMessage()) + }, utils.RetryOptions(s.retryOptions)) } // ScheduleMessages schedules a slice of Messages to appear on Service Bus Queue/Subscription at a later time. @@ -144,14 +98,10 @@ func (s *Sender) ScheduleMessages(ctx context.Context, messages []*Message, sche // MessageBatch changes // CancelScheduledMessages cancels multiple messages that were scheduled. -func (s *Sender) CancelScheduledMessages(ctx context.Context, sequenceNumber []int64) error { - _, _, mgmt, _, err := s.links.Get(ctx) - - if err != nil { - return err - } - - return mgmt.CancelScheduled(ctx, sequenceNumber...) +func (s *Sender) CancelScheduledMessages(ctx context.Context, sequenceNumbers []int64) error { + return s.links.Retry(ctx, "cancelScheduledMessage", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error { + return internal.CancelScheduledMessages(ctx, lwv.RPC, sequenceNumbers) + }, s.retryOptions) } // Close permanently closes the Sender. @@ -161,13 +111,19 @@ func (s *Sender) Close(ctx context.Context) error { } func (s *Sender) scheduleAMQPMessages(ctx context.Context, messages []*amqp.Message, scheduledEnqueueTime time.Time) ([]int64, error) { - _, _, mgmt, _, err := s.links.Get(ctx) + var sequenceNumbers []int64 - if err != nil { - return nil, err - } + err := s.links.Retry(ctx, "scheduleMessages", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error { + sn, err := internal.ScheduleMessages(ctx, lwv.RPC, scheduledEnqueueTime, messages) + + if err != nil { + return err + } + sequenceNumbers = sn + return nil + }, s.retryOptions) - return mgmt.ScheduleMessages(ctx, scheduledEnqueueTime, messages...) + return sequenceNumbers, err } func (sender *Sender) createSenderLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { @@ -184,14 +140,20 @@ func (sender *Sender) createSenderLink(ctx context.Context, session internal.AMQ return amqpSender, nil, nil } -func newSender(ns internal.NamespaceWithNewAMQPLinks, queueOrTopic string, cleanupOnClose func()) (*Sender, error) { +type newSenderArgs struct { + ns internal.NamespaceWithNewAMQPLinks + queueOrTopic string + cleanupOnClose func() +} + +func newSender(args newSenderArgs, retryOptions RetryOptions) (*Sender, error) { sender := &Sender{ - queueOrTopic: queueOrTopic, - cleanupOnClose: cleanupOnClose, - retryOptions: internal.RetryOptions{}, + queueOrTopic: args.queueOrTopic, + cleanupOnClose: args.cleanupOnClose, + retryOptions: utils.RetryOptions(retryOptions), } - sender.links = ns.NewAMQPLinks(queueOrTopic, sender.createSenderLink) + sender.links = args.ns.NewAMQPLinks(args.queueOrTopic, sender.createSenderLink) return sender, nil } diff --git a/sdk/messaging/azservicebus/sender_test.go b/sdk/messaging/azservicebus/sender_test.go index 585ef801ee23..801cb52d3a90 100644 --- a/sdk/messaging/azservicebus/sender_test.go +++ b/sdk/messaging/azservicebus/sender_test.go @@ -5,12 +5,14 @@ package azservicebus import ( "context" + "fmt" "sort" "testing" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test" "github.com/stretchr/testify/require" ) @@ -276,6 +278,118 @@ func Test_Sender_ScheduleMessages(t *testing.T) { } } +func TestSender_SendMessagesDetach(t *testing.T) { + // NOTE: uncomment this to see some of the background reconnects + // azlog.SetListener(func(e azlog.Event, s string) { + // log.Printf("%s %s", e, s) + // }) + + sbc, cleanup, queueName := setupLiveTest(t, nil) + defer cleanup() + + adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil) + require.NoError(t, err) + + sender, err := sbc.NewSender(queueName, nil) + require.NoError(t, err) + + // make sure the sender link is open and active. + err = sender.SendMessage(context.Background(), &Message{ + Body: []byte("0"), + }) + require.NoError(t, err) + + // now force a detach to happen + _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil) + require.NoError(t, err) + + for i := 1; i < 5; i++ { + err = sender.SendMessage(context.Background(), &Message{ + Body: []byte(fmt.Sprintf("%d", i)), + }) + require.NoError(t, err) + } + + receiver, err := sbc.NewReceiverForQueue(queueName, &ReceiverOptions{ + ReceiveMode: ReceiveModeReceiveAndDelete, + }) + require.NoError(t, err) + + // get all the messages + var all []*ReceivedMessage + + for { + messages, err := receiver.ReceiveMessages(context.Background(), 5, nil) + require.NoError(t, err) + + all = append(messages, all...) + + if len(all) == 5 { + break + } + } + + require.EqualValues(t, []string{"0", "1", "2", "3", "4"}, getSortedBodies(all)) +} + +func TestSender_SendMessageBatchDetach(t *testing.T) { + // NOTE: uncomment this to see some of the background reconnects + // azlog.SetListener(func(e azlog.Event, s string) { + // log.Printf("%s %s", e, s) + // }) + + sbc, cleanup, queueName := setupLiveTest(t, nil) + defer cleanup() + + adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil) + require.NoError(t, err) + + sender, err := sbc.NewSender(queueName, nil) + require.NoError(t, err) + + // make sure the sender link is open and active. + err = sender.SendMessage(context.Background(), &Message{ + Body: []byte("0"), + }) + require.NoError(t, err) + + // now force a detach to happen + _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil) + require.NoError(t, err) + + for i := 1; i < 5; i++ { + batch, err := sender.NewMessageBatch(context.Background(), nil) + require.NoError(t, err) + require.NoError(t, batch.AddMessage(&Message{ + Body: []byte(fmt.Sprintf("%d", i)), + })) + + err = sender.SendMessageBatch(context.Background(), batch) + require.NoError(t, err) + } + + receiver, err := sbc.NewReceiverForQueue(queueName, &ReceiverOptions{ + ReceiveMode: ReceiveModeReceiveAndDelete, + }) + require.NoError(t, err) + + // get all the messages + var all []*ReceivedMessage + + for { + messages, err := receiver.ReceiveMessages(context.Background(), 5, nil) + require.NoError(t, err) + + all = append(messages, all...) + + if len(all) == 5 { + break + } + } + + require.EqualValues(t, []string{"0", "1", "2", "3", "4"}, getSortedBodies(all)) +} + func getSortedBodies(messages []*ReceivedMessage) []string { sort.Sort(receivedMessages(messages)) diff --git a/sdk/messaging/azservicebus/session_receiver.go b/sdk/messaging/azservicebus/session_receiver.go index 95517483734d..9a7af6c11ba2 100644 --- a/sdk/messaging/azservicebus/session_receiver.go +++ b/sdk/messaging/azservicebus/session_receiver.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" ) @@ -47,57 +48,64 @@ func toReceiverOptions(sropts *SessionReceiverOptions) *ReceiverOptions { } } -func newSessionReceiver(ctx context.Context, sessionID *string, ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupOnClose func(), options *ReceiverOptions) (*SessionReceiver, error) { - const sessionFilterName = "com.microsoft:session-filter" - const code = uint64(0x00000137000000C) - +func newSessionReceiver(ctx context.Context, sessionID *string, ns internal.NamespaceWithNewAMQPLinks, entity entity, cleanupOnClose func(), options *ReceiverOptions) (*SessionReceiver, error) { sessionReceiver := &SessionReceiver{ sessionID: sessionID, lockedUntil: time.Time{}, } - var err error + r, err := newReceiver(newReceiverArgs{ + ns: ns, + entity: entity, + cleanupOnClose: cleanupOnClose, + newLinkFn: sessionReceiver.newLink, + }, options) - sessionReceiver.inner, err = newReceiver(ns, entity, cleanupOnClose, options, func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { - linkOptions := createLinkOptions(sessionReceiver.inner.receiveMode, sessionReceiver.inner.amqpLinks.EntityPath()) + if err != nil { + return nil, err + } - if sessionID == nil { - linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, nil)) - } else { - linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, sessionID)) - } + sessionReceiver.inner = r - _, link, err := createReceiverLink(ctx, session, linkOptions) + // temp workaround until we expose the session expiration time from the receiver in go-amqp + if err := sessionReceiver.RenewSessionLock(ctx); err != nil { + _ = sessionReceiver.Close(context.Background()) + return nil, err + } - if err != nil { - return nil, nil, err - } + return sessionReceiver, nil +} - // check the session ID that came back - if we asked for a named session ID and didn't get it then - // we failed to get the lock. - // if we specified nil then we can _set_ our internally held session ID now that we know the value. - receivedSessionID := link.LinkSourceFilterValue(sessionFilterName) - asStr, ok := receivedSessionID.(string) +func (r *SessionReceiver) newLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { + const sessionFilterName = "com.microsoft:session-filter" + const code = uint64(0x00000137000000C) - if !ok || (sessionID != nil && asStr != *sessionID) { - return nil, nil, fmt.Errorf("invalid type/value for returned sessionID(type:%T, value:%v)", receivedSessionID, receivedSessionID) - } + linkOptions := createLinkOptions(r.inner.receiveMode, r.inner.amqpLinks.EntityPath()) - sessionReceiver.sessionID = &asStr - return nil, link, nil - }) + if r.sessionID == nil { + linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, nil)) + } else { + linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, r.sessionID)) + } + + link, err := createReceiverLink(ctx, session, linkOptions) if err != nil { - return nil, err + return nil, nil, err } - // temp workaround until we expose the session expiration time from the receiver in go-amqp - if err := sessionReceiver.RenewSessionLock(ctx); err != nil { - _ = sessionReceiver.Close(context.Background()) - return nil, err + // check the session ID that came back - if we asked for a named session ID and didn't get it then + // we failed to get the lock. + // if we specified nil then we can _set_ our internally held session ID now that we know the value. + receivedSessionID := link.LinkSourceFilterValue(sessionFilterName) + receivedSessionIDStr, ok := receivedSessionID.(string) + + if !ok || (r.sessionID != nil && receivedSessionIDStr != *r.sessionID) { + return nil, nil, fmt.Errorf("invalid type/value for returned sessionID(type:%T, value:%v)", receivedSessionID, receivedSessionID) } - return sessionReceiver, nil + r.sessionID = &receivedSessionIDStr + return nil, link, nil } // ReceiveMessages receives a fixed number of messages, up to numMessages. @@ -172,48 +180,47 @@ func (sr *SessionReceiver) LockedUntil() time.Time { // GetSessionState retrieves state associated with the session. func (sr *SessionReceiver) GetSessionState(ctx context.Context) ([]byte, error) { - _, _, mgmt, _, err := sr.inner.amqpLinks.Get(ctx) + var sessionState []byte - if err != nil { - return nil, err - } + err := sr.inner.amqpLinks.Retry(ctx, "GetSessionState", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error { + s, err := internal.GetSessionState(ctx, lwv.RPC, sr.SessionID()) - return mgmt.GetSessionState(ctx, sr.SessionID()) + if err != nil { + return err + } + + sessionState = s + return nil + }, sr.inner.retryOptions) + + return sessionState, err } // SetSessionState sets the state associated with the session. func (sr *SessionReceiver) SetSessionState(ctx context.Context, state []byte) error { - _, _, mgmt, _, err := sr.inner.amqpLinks.Get(ctx) - - if err != nil { - return err - } - - return mgmt.SetSessionState(ctx, sr.SessionID(), state) + return sr.inner.amqpLinks.Retry(ctx, "SetSessionState", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error { + return internal.SetSessionState(ctx, lwv.RPC, sr.SessionID(), state) + }, sr.inner.retryOptions) } // RenewSessionLock renews this session's lock. The new expiration time is available // using `LockedUntil`. func (sr *SessionReceiver) RenewSessionLock(ctx context.Context) error { - _, _, mgmt, _, err := sr.inner.amqpLinks.Get(ctx) - - if err != nil { - return err - } + return sr.inner.amqpLinks.Retry(ctx, "SetSessionState", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error { + newLockedUntil, err := internal.RenewSessionLock(ctx, lwv.RPC, *sr.sessionID) - newLockedUntil, err := mgmt.RenewSessionLock(ctx, *sr.sessionID) - - if err != nil { - return err - } + if err != nil { + return err + } - sr.lockedUntil = newLockedUntil - return nil + sr.lockedUntil = newLockedUntil + return nil + }, sr.inner.retryOptions) } // init ensures the link was created, guaranteeing that we get our expected session lock. func (sr *SessionReceiver) init(ctx context.Context) error { // initialize the links - _, _, _, _, err := sr.inner.amqpLinks.Get(ctx) + _, err := sr.inner.amqpLinks.Get(ctx) return err } diff --git a/sdk/messaging/azservicebus/session_receiver_test.go b/sdk/messaging/azservicebus/session_receiver_test.go index ffb0b3cb9592..e9baa824ae44 100644 --- a/sdk/messaging/azservicebus/session_receiver_test.go +++ b/sdk/messaging/azservicebus/session_receiver_test.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test" "github.com/Azure/go-amqp" "github.com/stretchr/testify/require" ) @@ -132,7 +133,9 @@ func TestSessionReceiver_acceptSessionButAlreadyLocked(t *testing.T) { // You can address a session by name which makes lock contention possible (unlike // messages where the lock token is not a predefined value) receiver, err = client.AcceptSessionForQueue(ctx, queueName, "session-1", nil) - require.EqualValues(t, internal.RecoveryKindFatal, internal.ToSBE(context.Background(), err).RecoveryKind) + + sbe := internal.GetSBErrInfo(err) + require.EqualValues(t, internal.RecoveryKindFatal, sbe.RecoveryKind) require.Nil(t, receiver) } @@ -254,19 +257,59 @@ func TestSessionReceiver_RenewSessionLock(t *testing.T) { require.NoError(t, err) require.NotNil(t, messages) - // surprisingly this works. Not sure what it accomplishes though. C# has a manual check for it. - // err = sessionReceiver.RenewMessageLock(context.Background(), messages[0]) - // require.NoError(t, err) - orig := sessionReceiver.LockedUntil() require.NoError(t, sessionReceiver.RenewSessionLock(context.Background())) require.Greater(t, sessionReceiver.LockedUntil().UnixNano(), orig.UnixNano()) +} + +func TestSessionReceiver_Detach(t *testing.T) { + serviceBusClient, cleanup, queueName := setupLiveTest(t, &admin.QueueProperties{ + RequiresSession: to.BoolPtr(true), + }) + defer cleanup() + + adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil) + require.NoError(t, err) + + receiver, err := serviceBusClient.AcceptSessionForQueue(context.Background(), queueName, "test-session", nil) + require.NoError(t, err) + + sender, err := serviceBusClient.NewSender(queueName, nil) + require.NoError(t, err) + + err = sender.SendMessage(context.Background(), &Message{ + Body: []byte("hello"), + SessionID: to.StringPtr("test-session"), + }) + require.NoError(t, err) + require.NoError(t, sender.Close(context.Background())) + + state, err := receiver.GetSessionState(context.Background()) + require.NoError(t, err) + require.Nil(t, state) + + // force a detach to happen + _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{ + RequiresSession: to.BoolPtr(true), + }, nil) + require.NoError(t, err) + + state, err = receiver.GetSessionState(context.Background()) + require.NoError(t, err) + require.Nil(t, state) - // bogus renewal - sessionReceiver.sessionID = to.StringPtr("bogus") + // force a detach to happen + _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{ + RequiresSession: to.BoolPtr(true), + }, nil) + require.NoError(t, err) + + messages, err := receiver.ReceiveMessages(context.Background(), 1, nil) + require.NoError(t, err) + require.NotEmpty(t, messages) - err = sessionReceiver.RenewSessionLock(context.Background()) - require.Contains(t, err.Error(), "status code 410 and description: The session lock has expired on the MessageSession") + require.NoError(t, receiver.CompleteMessage(context.Background(), messages[0])) + require.NoError(t, receiver.Close(context.Background())) } func Test_toReceiverOptions(t *testing.T) {