diff --git a/sdk/messaging/azservicebus/admin/admin_client.go b/sdk/messaging/azservicebus/admin/admin_client.go
index 38f1c51d3a7b..0cfe8ddbc270 100644
--- a/sdk/messaging/azservicebus/admin/admin_client.go
+++ b/sdk/messaging/azservicebus/admin/admin_client.go
@@ -12,6 +12,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/atom"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
)
// Client allows you to administer resources in a Service Bus Namespace.
@@ -21,8 +22,13 @@ type Client struct {
em atom.EntityManager
}
+// RetryOptions represent the options for retries.
+type RetryOptions = utils.RetryOptions
+
+// ClientOptions allows you to set optional configuration for `Client`.
type ClientOptions struct {
- // for future expansion
+ // RetryOptions controls how often operations are retried from this client.
+ RetryOptions *RetryOptions
}
// NewClientFromConnectionString creates a Client authenticating using a connection string.
@@ -38,7 +44,13 @@ func NewClientFromConnectionString(connectionString string, options *ClientOptio
// NewClient creates a Client authenticating using a TokenCredential.
func NewClient(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential, options *ClientOptions) (*Client, error) {
- em, err := atom.NewEntityManager(fullyQualifiedNamespace, tokenCredential, internal.Version)
+ var retryOptions utils.RetryOptions
+
+ if options != nil && options.RetryOptions != nil {
+ retryOptions = *options.RetryOptions
+ }
+
+ em, err := atom.NewEntityManager(fullyQualifiedNamespace, tokenCredential, internal.Version, retryOptions)
if err != nil {
return nil, err
diff --git a/sdk/messaging/azservicebus/admin/admin_client_queue.go b/sdk/messaging/azservicebus/admin/admin_client_queue.go
index 72f48b13d4a6..c79d4ef49c86 100644
--- a/sdk/messaging/azservicebus/admin/admin_client_queue.go
+++ b/sdk/messaging/azservicebus/admin/admin_client_queue.go
@@ -196,7 +196,7 @@ func (ac *Client) GetQueue(ctx context.Context, queueName string, options *GetQu
props, err := newQueueProperties(&atomResp.Content.QueueDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
return &GetQueueResponse{
@@ -234,7 +234,7 @@ func (ac *Client) GetQueueRuntimeProperties(ctx context.Context, queueName strin
props, err := newQueueRuntimeProperties(&atomResp.Content.QueueDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
return &GetQueueRuntimePropertiesResponse{
@@ -349,7 +349,7 @@ func (ac *Client) createOrUpdateQueueImpl(ctx context.Context, queueName string,
newProps, err := newQueueProperties(&atomResp.Content.QueueDescription)
if err != nil {
- return nil, nil, atom.NewResponseError(err, resp)
+ return nil, nil, err
}
return newProps, resp, nil
@@ -395,7 +395,7 @@ func (p *QueuePager) getNextPage(ctx context.Context) (*ListQueuesResponse, erro
props, err := newQueueProperties(&env.Content.QueueDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
all = append(all, &QueueItem{
diff --git a/sdk/messaging/azservicebus/admin/admin_client_subscription.go b/sdk/messaging/azservicebus/admin/admin_client_subscription.go
index 6cdd394eb747..d9fc400ec898 100644
--- a/sdk/messaging/azservicebus/admin/admin_client_subscription.go
+++ b/sdk/messaging/azservicebus/admin/admin_client_subscription.go
@@ -149,7 +149,7 @@ func (ac *Client) GetSubscription(ctx context.Context, topicName string, subscri
props, err := newSubscriptionProperties(&atomResp.Content.SubscriptionDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
return &GetSubscriptionResponse{
@@ -186,7 +186,7 @@ func (ac *Client) GetSubscriptionRuntimeProperties(ctx context.Context, topicNam
props, err := newSubscriptionRuntimeProperties(&atomResp.Content.SubscriptionDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
return &GetSubscriptionRuntimePropertiesResponse{
@@ -257,7 +257,7 @@ func (p *SubscriptionPager) getNext(ctx context.Context) (*ListSubscriptionsResp
props, err := newSubscriptionProperties(&env.Content.SubscriptionDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
all = append(all, &SubscriptionPropertiesItem{
@@ -346,7 +346,7 @@ func (p *SubscriptionRuntimePropertiesPager) getNextPage(ctx context.Context) (*
props, err := newSubscriptionRuntimeProperties(&entry.Content.SubscriptionDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
all = append(all, &SubscriptionRuntimePropertiesItem{
@@ -452,7 +452,7 @@ func (ac *Client) createOrUpdateSubscriptionImpl(ctx context.Context, topicName
newProps, err := newSubscriptionProperties(&atomResp.Content.SubscriptionDescription)
if err != nil {
- return nil, nil, atom.NewResponseError(err, resp)
+ return nil, nil, err
}
return newProps, resp, nil
diff --git a/sdk/messaging/azservicebus/admin/admin_client_test.go b/sdk/messaging/azservicebus/admin/admin_client_test.go
index 8ff788c909fe..ab5c74954d33 100644
--- a/sdk/messaging/azservicebus/admin/admin_client_test.go
+++ b/sdk/messaging/azservicebus/admin/admin_client_test.go
@@ -13,6 +13,7 @@ import (
"testing"
"time"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/atom"
@@ -204,7 +205,12 @@ func TestAdminClient_UpdateQueue(t *testing.T) {
updatedProps, err = adminClient.UpdateQueue(context.Background(), "non-existent-queue", createdProps.QueueProperties, nil)
// a little awkward, we'll make these programatically inspectable as we add in better error handling.
- require.Contains(t, err.Error(), "error code: 404")
+ require.Contains(t, err.Error(), "404 Not Found")
+
+ var asResponseErr *azcore.ResponseError
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 404, asResponseErr.StatusCode)
+
require.Nil(t, updatedProps)
}
@@ -475,7 +481,12 @@ func TestAdminClient_UpdateTopic(t *testing.T) {
updateResp, err = adminClient.UpdateTopic(context.Background(), "non-existent-topic", addResp.TopicProperties, nil)
// a little awkward, we'll make these programatically inspectable as we add in better error handling.
- require.Contains(t, err.Error(), "error code: 404")
+ require.Contains(t, err.Error(), "404 Not Found")
+
+ var asResponseErr *azcore.ResponseError
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 404, asResponseErr.StatusCode)
+
require.Nil(t, updateResp)
}
@@ -738,8 +749,12 @@ func TestAdminClient_UpdateSubscription(t *testing.T) {
require.Nil(t, updateResp)
updateResp, err = adminClient.UpdateSubscription(context.Background(), topicName, "non-existent-subscription", addResp.CreateSubscriptionResult.SubscriptionProperties, nil)
- // a little awkward, we'll make these programatically inspectable as we add in better error handling.
- require.Contains(t, err.Error(), "error code: 404")
+ require.Contains(t, err.Error(), "404 Not Found")
+
+ var asResponseErr *azcore.ResponseError
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 404, asResponseErr.StatusCode)
+
require.Nil(t, updateResp)
}
@@ -754,21 +769,33 @@ func TestAdminClient_LackPermissions_Queue(t *testing.T) {
require.True(t, notFound)
require.NotNil(t, resp)
+ var re *azcore.ResponseError
+
_, err = testData.Client.GetQueue(ctx, testData.QueueName, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Manage,EntityRead claims")
+ require.Contains(t, err.Error(), "Manage,EntityRead claims required for this operation")
+ require.ErrorAs(t, err, &re)
+ require.EqualValues(t, 401, re.StatusCode)
pager := testData.Client.ListQueues(nil)
require.False(t, pager.NextPage(context.Background()))
- require.Contains(t, pager.Err().Error(), "error code: 401, Details: Manage,EntityRead claims required for this operation")
+ require.Contains(t, pager.Err().Error(), "Manage,EntityRead claims required for this operation")
+ require.ErrorAs(t, err, &re)
+ require.EqualValues(t, 401, re.StatusCode)
_, err = testData.Client.CreateQueue(ctx, "canneverbecreated", nil, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityWrite")
+ require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityWrite")
+ require.ErrorAs(t, err, &re)
+ require.EqualValues(t, 401, re.StatusCode)
_, err = testData.Client.UpdateQueue(ctx, "canneverbecreated", QueueProperties{}, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityWrite")
+ require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityWrite")
+ require.ErrorAs(t, err, &re)
+ require.EqualValues(t, 401, re.StatusCode)
_, err = testData.Client.DeleteQueue(ctx, testData.QueueName, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityDelete.")
+ require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityDelete.")
+ require.ErrorAs(t, err, &re)
+ require.EqualValues(t, 401, re.StatusCode)
}
func TestAdminClient_LackPermissions_Topic(t *testing.T) {
@@ -782,21 +809,33 @@ func TestAdminClient_LackPermissions_Topic(t *testing.T) {
require.True(t, notFound)
require.NotNil(t, resp)
+ var asResponseErr *azcore.ResponseError
+
_, err = testData.Client.GetTopic(ctx, testData.TopicName, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Manage,EntityRead claims")
+ require.Contains(t, err.Error(), ">Manage,EntityRead claims required for this operation")
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 401, asResponseErr.StatusCode)
pager := testData.Client.ListTopics(nil)
require.False(t, pager.NextPage(context.Background()))
- require.Contains(t, pager.Err().Error(), "error code: 401, Details: Manage,EntityRead claims required for this operation")
+ require.Contains(t, pager.Err().Error(), ">Manage,EntityRead claims required for this operation")
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 401, asResponseErr.StatusCode)
_, err = testData.Client.CreateTopic(ctx, "canneverbecreated", nil, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action")
+ require.Contains(t, err.Error(), "Authorization failed for specified action")
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 401, asResponseErr.StatusCode)
_, err = testData.Client.UpdateTopic(ctx, "canneverbecreated", TopicProperties{}, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action")
+ require.Contains(t, err.Error(), "Authorization failed for specified action")
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 401, asResponseErr.StatusCode)
_, err = testData.Client.DeleteTopic(ctx, testData.TopicName, nil)
- require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action: Manage,EntityDelete.")
+ require.Contains(t, err.Error(), "Authorization failed for specified action: Manage,EntityDelete.")
+ require.ErrorAs(t, err, &asResponseErr)
+ require.EqualValues(t, 401, asResponseErr.StatusCode)
}
func TestAdminClient_LackPermissions_Subscription(t *testing.T) {
diff --git a/sdk/messaging/azservicebus/admin/admin_client_topic.go b/sdk/messaging/azservicebus/admin/admin_client_topic.go
index ae11bb9358be..12cca7a25f19 100644
--- a/sdk/messaging/azservicebus/admin/admin_client_topic.go
+++ b/sdk/messaging/azservicebus/admin/admin_client_topic.go
@@ -170,7 +170,7 @@ func (ac *Client) GetTopicRuntimeProperties(ctx context.Context, topicName strin
props, err := newTopicRuntimeProperties(&atomResp.Content.TopicDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
return &GetTopicRuntimePropertiesResponse{
@@ -239,7 +239,7 @@ func (p *TopicsPager) getNextPage(ctx context.Context) (*ListTopicsResponse, err
props, err := newTopicProperties(&env.Content.TopicDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
all = append(all, &TopicItem{
@@ -326,7 +326,7 @@ func (p *TopicRuntimePropertiesPager) getNextPage(ctx context.Context) (*ListTop
props, err := newTopicRuntimeProperties(&entry.Content.TopicDescription)
if err != nil {
- return nil, atom.NewResponseError(err, resp)
+ return nil, err
}
all = append(all, &TopicRuntimePropertiesItem{
@@ -436,7 +436,7 @@ func (ac *Client) createOrUpdateTopicImpl(ctx context.Context, topicName string,
topicProps, err := newTopicProperties(&atomResp.Content.TopicDescription)
if err != nil {
- return nil, nil, atom.NewResponseError(err, resp)
+ return nil, nil, err
}
return topicProps, resp, nil
diff --git a/sdk/messaging/azservicebus/client.go b/sdk/messaging/azservicebus/client.go
index 5451c95688ec..f7f6a667c027 100644
--- a/sdk/messaging/azservicebus/client.go
+++ b/sdk/messaging/azservicebus/client.go
@@ -14,23 +14,30 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/devigned/tab"
)
// Client provides methods to create Sender and Receiver
// instances to send and receive messages from Service Bus.
type Client struct {
+ // NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens
+ // is just to make it the first value in the struct
+ // See:
+ // Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG
+ // PR: https://github.com/Azure/azure-sdk-for-go/pull/16847
linkCounter uint64
- links map[uint64]internal.Closeable
- creds clientCreds
- namespace interface {
+
+ linksMu *sync.Mutex
+ links map[uint64]internal.Closeable
+ creds clientCreds
+ namespace interface {
// used internally by `Client`
internal.NamespaceWithNewAMQPLinks
// for child clients
internal.NamespaceForAMQPLinks
- internal.NamespaceForMgmtClient
}
- linksMu *sync.Mutex
+ retryOptions RetryOptions
}
// ClientOptions contains options for the `NewClient` and `NewClientFromConnectionString`
@@ -45,8 +52,16 @@ type ClientOptions struct {
// NewWebSocketConn is a function that can create a net.Conn for use with websockets.
// For an example, see ExampleNewClient_usingWebsockets() function in example_client_test.go.
NewWebSocketConn func(ctx context.Context, args NewWebSocketConnArgs) (net.Conn, error)
+
+ // RetryOptions controls how often operations are retried from this client and any
+ // Receivers and Senders created from this client.
+ RetryOptions RetryOptions
}
+// RetryOptions controls how often operations are retried from this client and any
+// Receivers and Senders created from this client.
+type RetryOptions = utils.RetryOptions
+
// NewWebSocketConnArgs are passed to your web socket creation function (ClientOptions.NewWebSocketConn)
type NewWebSocketConnArgs = internal.NewWebSocketConnArgs
@@ -107,7 +122,7 @@ func newClientImpl(creds clientCreds, options *ClientOptions) (*Client, error) {
if client.creds.connectionString != "" {
nsOptions = append(nsOptions, internal.NamespaceWithConnectionString(client.creds.connectionString))
} else if client.creds.credential != nil {
- option := internal.NamespacesWithTokenCredential(
+ option := internal.NamespaceWithTokenCredential(
client.creds.fullyQualifiedNamespace,
client.creds.credential)
@@ -126,6 +141,8 @@ func newClientImpl(creds clientCreds, options *ClientOptions) (*Client, error) {
if options.ApplicationID != "" {
nsOptions = append(nsOptions, internal.NamespaceWithUserAgent(options.ApplicationID))
}
+
+ nsOptions = append(nsOptions, internal.NamespaceWithRetryOptions((utils.RetryOptions)(options.RetryOptions)))
}
client.namespace, err = internal.NewNamespace(nsOptions...)
@@ -135,7 +152,11 @@ func newClientImpl(creds clientCreds, options *ClientOptions) (*Client, error) {
// NewReceiver creates a Receiver for a queue. A receiver allows you to receive messages.
func (client *Client) NewReceiverForQueue(queueName string, options *ReceiverOptions) (*Receiver, error) {
id, cleanupOnClose := client.getCleanupForCloseable()
- receiver, err := newReceiver(client.namespace, &entity{Queue: queueName}, cleanupOnClose, options, nil)
+ receiver, err := newReceiver(newReceiverArgs{
+ cleanupOnClose: cleanupOnClose,
+ ns: client.namespace,
+ entity: entity{Queue: queueName},
+ }, options)
if err != nil {
return nil, err
@@ -148,7 +169,11 @@ func (client *Client) NewReceiverForQueue(queueName string, options *ReceiverOpt
// NewReceiver creates a Receiver for a subscription. A receiver allows you to receive messages.
func (client *Client) NewReceiverForSubscription(topicName string, subscriptionName string, options *ReceiverOptions) (*Receiver, error) {
id, cleanupOnClose := client.getCleanupForCloseable()
- receiver, err := newReceiver(client.namespace, &entity{Topic: topicName, Subscription: subscriptionName}, cleanupOnClose, options, nil)
+ receiver, err := newReceiver(newReceiverArgs{
+ cleanupOnClose: cleanupOnClose,
+ ns: client.namespace,
+ entity: entity{Topic: topicName, Subscription: subscriptionName},
+ }, options)
if err != nil {
return nil, err
@@ -165,7 +190,11 @@ type NewSenderOptions struct {
// NewSender creates a Sender, which allows you to send messages or schedule messages.
func (client *Client) NewSender(queueOrTopic string, options *NewSenderOptions) (*Sender, error) {
id, cleanupOnClose := client.getCleanupForCloseable()
- sender, err := newSender(client.namespace, queueOrTopic, cleanupOnClose)
+ sender, err := newSender(newSenderArgs{
+ ns: client.namespace,
+ queueOrTopic: queueOrTopic,
+ cleanupOnClose: cleanupOnClose,
+ }, client.retryOptions)
if err != nil {
return nil, err
@@ -183,7 +212,7 @@ func (client *Client) AcceptSessionForQueue(ctx context.Context, queueName strin
ctx,
&sessionID,
client.namespace,
- &entity{Queue: queueName},
+ entity{Queue: queueName},
cleanupOnClose,
toReceiverOptions(options))
@@ -207,7 +236,7 @@ func (client *Client) AcceptSessionForSubscription(ctx context.Context, topicNam
ctx,
&sessionID,
client.namespace,
- &entity{Topic: topicName, Subscription: subscriptionName},
+ entity{Topic: topicName, Subscription: subscriptionName},
cleanupOnClose,
toReceiverOptions(options))
@@ -231,7 +260,7 @@ func (client *Client) AcceptNextSessionForQueue(ctx context.Context, queueName s
ctx,
nil,
client.namespace,
- &entity{Queue: queueName},
+ entity{Queue: queueName},
cleanupOnClose,
toReceiverOptions(options))
@@ -255,7 +284,7 @@ func (client *Client) AcceptNextSessionForSubscription(ctx context.Context, topi
ctx,
nil,
client.namespace,
- &entity{Topic: topicName, Subscription: subscriptionName},
+ entity{Topic: topicName, Subscription: subscriptionName},
cleanupOnClose,
toReceiverOptions(options))
diff --git a/sdk/messaging/azservicebus/client_test.go b/sdk/messaging/azservicebus/client_test.go
index fa187148f12d..9fc38f533e81 100644
--- a/sdk/messaging/azservicebus/client_test.go
+++ b/sdk/messaging/azservicebus/client_test.go
@@ -130,7 +130,7 @@ func TestNewClientUnitTests(t *testing.T) {
// (really all part of the same functionality)
ns := &internal.Namespace{}
- require.NoError(t, internal.NamespacesWithTokenCredential("mysb.windows.servicebus.net",
+ require.NoError(t, internal.NamespaceWithTokenCredential("mysb.windows.servicebus.net",
fakeTokenCredential)(ns))
require.EqualValues(t, ns.FQDN, "mysb.windows.servicebus.net")
@@ -180,25 +180,5 @@ func TestNewClientUnitTests(t *testing.T) {
require.NoError(t, client.Close(context.Background()))
require.Empty(t, client.links)
require.EqualValues(t, 1, ns.AMQPLinks.Closed)
-
- client, ns = setupClient()
- _, err = newProcessorForQueue(client, "hello", nil)
-
- require.NoError(t, err)
- require.EqualValues(t, 1, len(client.links))
- require.NotNil(t, client.links[1])
- require.NoError(t, client.Close(context.Background()))
- require.Empty(t, client.links)
- require.EqualValues(t, 1, ns.AMQPLinks.Closed)
-
- client, ns = setupClient()
- _, err = newProcessorForSubscription(client, "hello", "world", nil)
-
- require.NoError(t, err)
- require.EqualValues(t, 1, len(client.links))
- require.NotNil(t, client.links[1])
- require.NoError(t, client.Close(context.Background()))
- require.Empty(t, client.links)
- require.EqualValues(t, 1, ns.AMQPLinks.Closed)
})
}
diff --git a/sdk/messaging/azservicebus/errors.go b/sdk/messaging/azservicebus/errors.go
deleted file mode 100644
index f11f111b2a11..000000000000
--- a/sdk/messaging/azservicebus/errors.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package azservicebus
-
-import "fmt"
-
-// implements `internal/errorinfo/NonRetriable`
-type errClosed struct {
- link string
-}
-
-func (ec errClosed) NonRetriable() {}
-func (ec errClosed) Error() string {
- return fmt.Sprintf("%s is closed and can no longer be used", ec.link)
-}
diff --git a/sdk/messaging/azservicebus/errors_test.go b/sdk/messaging/azservicebus/errors_test.go
deleted file mode 100644
index 9132d5660442..000000000000
--- a/sdk/messaging/azservicebus/errors_test.go
+++ /dev/null
@@ -1,19 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package azservicebus
-
-import (
- "testing"
-
- "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
- "github.com/stretchr/testify/require"
-)
-
-func TestErrClosed(t *testing.T) {
- var err error = errClosed{link: "hello"}
-
- _, ok := err.(errorinfo.NonRetriable)
- require.True(t, ok, "ErrClosed is a errorinfo.NonRetriable")
- require.EqualValues(t, "hello is closed and can no longer be used", err.Error())
-}
diff --git a/sdk/messaging/azservicebus/go.mod b/sdk/messaging/azservicebus/go.mod
index a3ba068fdd30..2675f82d6c51 100644
--- a/sdk/messaging/azservicebus/go.mod
+++ b/sdk/messaging/azservicebus/go.mod
@@ -9,10 +9,17 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/messaging/internal v0.0.0-20211208010914-2b10e91d237e
github.com/Azure/go-amqp v0.17.0
github.com/devigned/tab v0.1.1
+ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+)
+
+require (
+ // used in tests only
github.com/joho/godotenv v1.3.0
- github.com/jpillora/backoff v1.0.0
+
+ // used in stress tests
github.com/microsoft/ApplicationInsights-Go v0.4.4
github.com/stretchr/testify v1.7.0
- golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+
+ // used in examples only
nhooyr.io/websocket v1.8.6
)
diff --git a/sdk/messaging/azservicebus/go.sum b/sdk/messaging/azservicebus/go.sum
index 8f235251df4f..ddffa1b097c6 100644
--- a/sdk/messaging/azservicebus/go.sum
+++ b/sdk/messaging/azservicebus/go.sum
@@ -69,8 +69,6 @@ github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
-github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
-github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8=
diff --git a/sdk/messaging/azservicebus/internal/amqpInterfaces.go b/sdk/messaging/azservicebus/internal/amqpInterfaces.go
index eb17ea50420c..cfcc41301858 100644
--- a/sdk/messaging/azservicebus/internal/amqpInterfaces.go
+++ b/sdk/messaging/azservicebus/internal/amqpInterfaces.go
@@ -5,9 +5,7 @@ package internal
import (
"context"
- "time"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc"
"github.com/Azure/go-amqp"
)
@@ -61,7 +59,7 @@ type AMQPSenderCloser interface {
// RPCLink is implemented by *rpc.Link
type RPCLink interface {
Close(ctx context.Context) error
- RetryableRPC(ctx context.Context, times int, delay time.Duration, msg *amqp.Message) (*rpc.Response, error)
+ RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error)
}
// Closeable is implemented by pretty much any AMQP link/client
diff --git a/sdk/messaging/azservicebus/internal/amqpLinks.go b/sdk/messaging/azservicebus/internal/amqpLinks.go
index 48a3ade96995..2f2575715156 100644
--- a/sdk/messaging/azservicebus/internal/amqpLinks.go
+++ b/sdk/messaging/azservicebus/internal/amqpLinks.go
@@ -10,19 +10,21 @@ import (
"strings"
"sync"
+ "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/devigned/tab"
)
-type errClosedPermanently struct{}
-
-func (e errClosedPermanently) Error() string { return "Link has been closed permanently" }
-func (e errClosedPermanently) NonRetriable() {}
-
-func ShouldRecover(ctx context.Context, err error) bool {
- return shouldRecreateConnection(ctx, err) || shouldRecreateLink(err)
+type LinksWithID struct {
+ Sender AMQPSender
+ Receiver AMQPReceiver
+ RPC RPCLink
+ ID LinkID
}
+type RetryWithLinksFn func(ctx context.Context, lwid *LinksWithID, args *utils.RetryFnArgs) error
+
type AMQPLinks interface {
EntityPath() string
ManagementPath() string
@@ -31,11 +33,14 @@ type AMQPLinks interface {
// Get will initialize a session and call its link.linkCreator function.
// If this link has been closed via Close() it will return an non retriable error.
- Get(ctx context.Context) (AMQPSender, AMQPReceiver, MgmtClient, uint64, error)
+ Get(ctx context.Context) (*LinksWithID, error)
+
+ // Retry will run your callback, recovering links when necessary.
+ Retry(ctx context.Context, name string, fn RetryWithLinksFn, o utils.RetryOptions) error
// RecoverIfNeeded will check if an error requires recovery, and will recover
// the link or, possibly, the connection.
- RecoverIfNeeded(ctx context.Context, linksRevision uint64, err error) error
+ RecoverIfNeeded(ctx context.Context, linkID LinkID, err error) error
// Close will close the the link.
// If permanent is true the link will not be auto-recreated if Get/Recover
@@ -46,40 +51,39 @@ type AMQPLinks interface {
ClosedPermanently() bool
}
-// amqpLinks manages the set of AMQP links (and detritus) typically needed to work
+// AMQPLinksImpl manages the set of AMQP links (and detritus) typically needed to work
// within Service Bus:
// - An *goamqp.Sender or *goamqp.Receiver AMQP link (could also be 'both' if needed)
// - A `$management` link
// - an *goamqp.Session
//
// State management can be done through Recover (close and reopen), Close (close permanently, return failures)
-// and Get() (retrieve latest version of all amqpLinks, or create if needed).
-type amqpLinks struct {
+// and Get() (retrieve latest version of all AMQPLinksImpl, or create if needed).
+type AMQPLinksImpl struct {
+ // NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens
+ // is just to make it the first value in the struct
+ // See:
+ // Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG
+ // PR: https://github.com/Azure/azure-sdk-for-go/pull/16847
+ id LinkID
+
entityPath string
managementPath string
audience string
createLink CreateLinkFunc
- baseRetrier Retrier
mu sync.RWMutex
- // mgmt lets you interact with the $management link for your entity.
- mgmt MgmtClient
+ // RPCLink lets you interact with the $management link for your entity.
+ RPCLink RPCLink
// the AMQP session for either the 'sender' or 'receiver' link
session AMQPSessionCloser
// these are populated by your `createLinkFunc` when you construct
// the amqpLinks
- sender AMQPSenderCloser
- receiver AMQPReceiverCloser
-
- // last connection revision seen by this links instance.
- clientRevision uint64
-
- // the current 'revision' of our set of links.
- // starts at 1, increments each time you call Recover().
- revision uint64
+ Sender AMQPSenderCloser
+ Receiver AMQPReceiverCloser
// whether this links set has been closed permanently (via Close)
// Recover() does not affect this value.
@@ -97,15 +101,13 @@ type CreateLinkFunc func(ctx context.Context, session AMQPSession) (AMQPSenderCl
// NewAMQPLinks creates a session, starts the claim refresher and creates an associated
// management link for a specific entity path.
-func newAMQPLinks(ns NamespaceForAMQPLinks, entityPath string, baseRetrier Retrier, createLink CreateLinkFunc) AMQPLinks {
- l := &amqpLinks{
+func NewAMQPLinks(ns NamespaceForAMQPLinks, entityPath string, createLink CreateLinkFunc) AMQPLinks {
+ l := &AMQPLinksImpl{
entityPath: entityPath,
managementPath: fmt.Sprintf("%s/$management", entityPath),
audience: ns.GetEntityAudience(entityPath),
createLink: createLink,
- baseRetrier: baseRetrier,
closedPermanently: false,
- revision: 1,
ns: ns,
}
@@ -113,53 +115,42 @@ func newAMQPLinks(ns NamespaceForAMQPLinks, entityPath string, baseRetrier Retri
}
// ManagementPath is the management path for the associated entity.
-func (links *amqpLinks) ManagementPath() string {
+func (links *AMQPLinksImpl) ManagementPath() string {
return links.managementPath
}
// recoverLink will recycle all associated links (mgmt, receiver, sender and session)
// and recreate them using the link.linkCreator function.
-func (links *amqpLinks) recoverLink(ctx context.Context, theirLinkRevision *uint64) error {
+func (links *AMQPLinksImpl) recoverLink(ctx context.Context, theirLinkRevision LinkID) error {
ctx, span := tab.StartSpan(ctx, tracing.SpanRecoverLink)
defer span.End()
links.mu.RLock()
closedPermanently := links.closedPermanently
- ourLinkRevision := links.revision
+ ourLinkRevision := links.id
links.mu.RUnlock()
if closedPermanently {
span.AddAttributes(tab.StringAttribute("outcome", "was_closed_permanently"))
- return errClosedPermanently{}
+ return ErrNonRetriable{Message: "Link has been closed permanently"}
}
- if theirLinkRevision != nil && ourLinkRevision > *theirLinkRevision {
+ // cheap check before we do the lock
+ if ourLinkRevision.Link != theirLinkRevision.Link {
// we've already recovered past their failure.
- span.AddAttributes(
- tab.StringAttribute("outcome", "already_recovered"),
- tab.StringAttribute("lock", "readlock"),
- tab.StringAttribute("revisions", fmt.Sprintf("ours(%d), theirs(%d)", ourLinkRevision, *theirLinkRevision)),
- )
return nil
}
links.mu.Lock()
defer links.mu.Unlock()
- if theirLinkRevision != nil && ourLinkRevision > *theirLinkRevision {
+ // check once more, just in case someone else modified it before we took
+ // the lock.
+ if links.id.Link != theirLinkRevision.Link {
// we've already recovered past their failure.
- span.AddAttributes(
- tab.StringAttribute("outcome", "already_recovered"),
- tab.StringAttribute("lock", "writelock"),
- tab.StringAttribute("revisions", fmt.Sprintf("ours(%d), theirs(%d)", ourLinkRevision, *theirLinkRevision)),
- )
return nil
}
- if err := links.closeWithoutLocking(ctx, false); err != nil {
- span.Logger().Error(err)
- }
-
err := links.initWithoutLocking(ctx)
if err != nil {
@@ -167,149 +158,161 @@ func (links *amqpLinks) recoverLink(ctx context.Context, theirLinkRevision *uint
return err
}
- links.revision++
-
span.AddAttributes(
tab.StringAttribute("outcome", "recovered"),
- tab.StringAttribute("revision_new", fmt.Sprintf("%d", links.revision)),
+ tab.StringAttribute("revision_new", fmt.Sprintf("%d", links.id)),
)
return nil
}
// Recover will recover the links or the connection, depending
-// on the severity of the error. This function uses the `baseRetrier`
-// defined in the links struct.
-func (links *amqpLinks) RecoverIfNeeded(ctx context.Context, linksRevision uint64, origErr error) error {
+// on the severity of the error.
+func (links *AMQPLinksImpl) RecoverIfNeeded(ctx context.Context, theirID LinkID, origErr error) error {
ctx, span := tab.StartSpan(ctx, tracing.SpanRecover)
defer span.End()
- var err error = origErr
-
- retrier := links.baseRetrier.Copy()
-
- for retrier.Try(ctx) {
- span.AddAttributes(tab.StringAttribute("recover_attempt", fmt.Sprintf("%d", retrier.CurrentTry())))
-
- err = links.recoverImpl(ctx, retrier.CurrentTry(), linksRevision, err)
-
- if err == nil {
- return nil
- }
- }
-
- return err
-}
-
-func (links *amqpLinks) recoverImpl(ctx context.Context, try int, linksRevision uint64, origErr error) error {
- _, span := tab.StartSpan(ctx, tracing.SpanRecoverLink)
- defer span.End()
-
if origErr == nil || IsCancelError(origErr) {
return nil
}
+ log.Writef(EventConn, "Recovering link for error %s", origErr.Error())
+
select {
case <-ctx.Done():
return ctx.Err()
default:
}
- span.AddAttributes(tab.Int64Attribute("attempt", int64(try)))
-
- if shouldRecreateLink(origErr) {
- span.AddAttributes(
- tab.StringAttribute("recovery_kind", "link"),
- tab.StringAttribute("error", origErr.Error()),
- tab.StringAttribute("error_type", fmt.Sprintf("%T", origErr)))
+ sbe := GetSBErrInfo(origErr)
- if err := links.recoverLink(ctx, &linksRevision); err != nil {
- span.AddAttributes(tab.StringAttribute("recoveryFailure", err.Error()))
+ if sbe.RecoveryKind == RecoveryKindLink {
+ if err := links.recoverLink(ctx, theirID); err != nil {
+ log.Writef(EventConn, "failed to recreate link: %s", err.Error())
return err
}
+ log.Writef(EventConn, "Recovered links")
return nil
- } else if shouldRecreateConnection(ctx, origErr) {
- span.AddAttributes(
- tab.StringAttribute("recovery_kind", "connection"),
- tab.StringAttribute("error", origErr.Error()),
- tab.StringAttribute("error_type", fmt.Sprintf("%T", origErr)))
-
- if err := links.recoverConnection(ctx); err != nil {
- span.Logger().Error(fmt.Errorf("failed to recreate connection: %w", err))
- return err
- }
-
- // unconditionally recover the link if the connection died.
- if err := links.recoverLink(ctx, nil); err != nil {
- span.Logger().Error(fmt.Errorf("failed to recover links after connection restarted: %w", err))
+ } else if sbe.RecoveryKind == RecoveryKindConn {
+ if err := links.recoverConnection(ctx, theirID); err != nil {
+ log.Writef(EventConn, "failed to recreate connection: %s", err.Error())
return err
}
+ log.Writef(EventConn, "Recovered connection and links")
return nil
}
- span.AddAttributes(
- tab.StringAttribute("recovery", "none"),
- tab.StringAttribute("error", origErr.Error()),
- tab.StringAttribute("errorType", fmt.Sprintf("%T", origErr)))
-
+ log.Writef(EventConn, "Recovered, no action needed")
return nil
}
-func (links *amqpLinks) recoverConnection(ctx context.Context) error {
+func (links *AMQPLinksImpl) recoverConnection(ctx context.Context, theirID LinkID) error {
tab.For(ctx).Info("Connection is dead, recovering")
- links.mu.RLock()
- clientRevision := links.clientRevision
- links.mu.RUnlock()
+ links.mu.Lock()
+ defer links.mu.Unlock()
- err := links.ns.Recover(ctx, clientRevision)
+ created, err := links.ns.Recover(ctx, uint64(theirID.Conn))
if err != nil {
- tab.For(ctx).Error(fmt.Errorf("Recover connection failure: %w", err))
+ log.Writef(EventConn, "Recover connection failure: %s", err)
return err
}
+ // We'll recreate the link if:
+ // - `created` is true, meaning we recreated the AMQP connection (ie, all old links are invalid)
+ // - the link they received an error on is our current link, so it needs to be recreated.
+ // (if it wasn't the same then we've already recovered and created a new link,
+ // so no recovery would be needed)
+ if created || theirID.Link == links.id.Link {
+ log.Writef(EventConn, "recreating link: c: %v, current:%v, old:%v", created, links.id, theirID)
+ if err := links.initWithoutLocking(ctx); err != nil {
+ return err
+ }
+ }
+
return nil
}
+// LinkID is ID that represent our current link and the client used to create it.
+// These are used when trying to determine what parts need to be recreated when
+// an error occurs, to prevent recovering a connection/link repeatedly.
+// See amqpLinks.RecoverIfNeeded() for usage.
+type LinkID struct {
+ // Conn is the ID of the connection we used to create our links.
+ Conn uint64
+
+ // Link is the ID of our current link.
+ Link uint64
+}
+
// Get will initialize a session and call its link.linkCreator function.
// If this link has been closed via Close() it will return an non retriable error.
-func (l *amqpLinks) Get(ctx context.Context) (AMQPSender, AMQPReceiver, MgmtClient, uint64, error) {
+func (l *AMQPLinksImpl) Get(ctx context.Context) (*LinksWithID, error) {
l.mu.RLock()
- sender, receiver, mgmt, revision, closedPermanently := l.sender, l.receiver, l.mgmt, l.revision, l.closedPermanently
+ sender, receiver, mgmtLink, revision, closedPermanently := l.Sender, l.Receiver, l.RPCLink, l.id, l.closedPermanently
l.mu.RUnlock()
if closedPermanently {
- return nil, nil, nil, 0, errClosedPermanently{}
+ return nil, ErrNonRetriable{}
}
if sender != nil || receiver != nil {
- return sender, receiver, mgmt, revision, nil
+ return &LinksWithID{
+ Sender: sender,
+ Receiver: receiver,
+ RPC: mgmtLink,
+ ID: revision,
+ }, nil
}
l.mu.Lock()
defer l.mu.Unlock()
if err := l.initWithoutLocking(ctx); err != nil {
- return nil, nil, nil, 0, err
+ return nil, err
}
- return l.sender, l.receiver, l.mgmt, l.revision, nil
+ return &LinksWithID{
+ Sender: l.Sender,
+ Receiver: l.Receiver,
+ RPC: l.RPCLink,
+ ID: l.id,
+ }, nil
+}
+
+func (l *AMQPLinksImpl) Retry(ctx context.Context, name string, fn RetryWithLinksFn, o utils.RetryOptions) error {
+ var lastID LinkID
+
+ return utils.Retry(ctx, name, func(ctx context.Context, args *utils.RetryFnArgs) error {
+ if err := l.RecoverIfNeeded(ctx, lastID, args.LastErr); err != nil {
+ return err
+ }
+
+ linksWithVersion, err := l.Get(ctx)
+
+ if err != nil {
+ return err
+ }
+
+ lastID = linksWithVersion.ID
+ return fn(ctx, linksWithVersion, args)
+ }, IsFatalSBError, o)
}
// EntityPath is the full entity path for the queue/topic/subscription.
-func (l *amqpLinks) EntityPath() string {
+func (l *AMQPLinksImpl) EntityPath() string {
return l.entityPath
}
// EntityPath is the audience for the queue/topic/subscription.
-func (l *amqpLinks) Audience() string {
+func (l *AMQPLinksImpl) Audience() string {
return l.audience
}
// ClosedPermanently is true if AMQPLinks.Close(ctx, true) has been called.
-func (l *amqpLinks) ClosedPermanently() bool {
+func (l *AMQPLinksImpl) ClosedPermanently() bool {
l.mu.RLock()
defer l.mu.RUnlock()
return l.closedPermanently
@@ -317,20 +320,23 @@ func (l *amqpLinks) ClosedPermanently() bool {
// Close will close the the link permanently.
// Any further calls to Get()/Recover() to return ErrLinksClosed.
-func (l *amqpLinks) Close(ctx context.Context, permanent bool) error {
+func (l *AMQPLinksImpl) Close(ctx context.Context, permanent bool) error {
l.mu.Lock()
defer l.mu.Unlock()
return l.closeWithoutLocking(ctx, permanent)
}
// initWithoutLocking will create a new link, unconditionally.
-func (l *amqpLinks) initWithoutLocking(ctx context.Context) error {
+func (l *AMQPLinksImpl) initWithoutLocking(ctx context.Context) error {
+ // shut down any links we have
+ _ = l.closeWithoutLocking(ctx, false)
+
var err error
l.cancelAuthRefreshLink, err = l.ns.NegotiateClaim(ctx, l.entityPath)
if err != nil {
if err := l.closeWithoutLocking(ctx, false); err != nil {
- tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after negotiateClaim: %s", err.Error()))
+ log.Writef(EventConn, "Failure during link cleanup after negotiateClaim: %s", err.Error())
}
return err
}
@@ -339,44 +345,49 @@ func (l *amqpLinks) initWithoutLocking(ctx context.Context) error {
if err != nil {
if err := l.closeWithoutLocking(ctx, false); err != nil {
- tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after negotiate claim for mgmt link: %s", err.Error()))
+ log.Writef(EventConn, "Failure during link cleanup after negotiate claim for mgmt link: %s", err.Error())
}
return err
}
- l.session, l.clientRevision, err = l.ns.NewAMQPSession(ctx)
+ sess, cr, err := l.ns.NewAMQPSession(ctx)
if err != nil {
if err := l.closeWithoutLocking(ctx, false); err != nil {
- tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after creating AMQP session: %s", err.Error()))
+ log.Writef(EventConn, "Failure during link cleanup after creating AMQP session: %s", err.Error())
}
return err
}
- l.sender, l.receiver, err = l.createLink(ctx, l.session)
+ l.session = sess
+ l.id.Conn = cr
+
+ l.Sender, l.Receiver, err = l.createLink(ctx, l.session)
if err != nil {
if err := l.closeWithoutLocking(ctx, false); err != nil {
- tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after creating link: %s", err.Error()))
+ log.Writef(EventConn, "Failure during link cleanup after creating link: %s", err.Error())
}
return err
}
- l.mgmt, err = l.ns.NewMgmtClient(ctx, l)
+ rpcLink, err := l.ns.NewRPCLink(ctx, l.ManagementPath())
if err != nil {
if err := l.closeWithoutLocking(ctx, false); err != nil {
- tab.For(ctx).Debug(fmt.Sprintf("Failure during link cleanup after creating mgmt client: %s", err.Error()))
+ log.Writef("Failure during link cleanup after creating mgmt client: %s", err.Error())
}
return err
}
+ l.RPCLink = rpcLink
+ l.id.Link++
return nil
}
// close closes the link.
// NOTE: No locking is done in this function, call `Close` if you require locking.
-func (l *amqpLinks) closeWithoutLocking(ctx context.Context, permanent bool) error {
+func (l *AMQPLinksImpl) closeWithoutLocking(ctx context.Context, permanent bool) error {
if l.closedPermanently {
return nil
}
@@ -397,18 +408,18 @@ func (l *amqpLinks) closeWithoutLocking(ctx context.Context, permanent bool) err
l.cancelAuthRefreshMgmtLink()
}
- if l.sender != nil {
- if err := l.sender.Close(ctx); err != nil {
+ if l.Sender != nil {
+ if err := l.Sender.Close(ctx); err != nil {
messages = append(messages, fmt.Sprintf("amqp sender close error: %s", err.Error()))
}
- l.sender = nil
+ l.Sender = nil
}
- if l.receiver != nil {
- if err := l.receiver.Close(ctx); err != nil {
+ if l.Receiver != nil {
+ if err := l.Receiver.Close(ctx); err != nil {
messages = append(messages, fmt.Sprintf("amqp receiver close error: %s", err.Error()))
}
- l.receiver = nil
+ l.Receiver = nil
}
if l.session != nil {
@@ -418,11 +429,11 @@ func (l *amqpLinks) closeWithoutLocking(ctx context.Context, permanent bool) err
l.session = nil
}
- if l.mgmt != nil {
- if err := l.mgmt.Close(ctx); err != nil {
+ if l.RPCLink != nil {
+ if err := l.RPCLink.Close(ctx); err != nil {
messages = append(messages, fmt.Sprintf("$management link close error: %s", err.Error()))
}
- l.mgmt = nil
+ l.RPCLink = nil
}
if len(messages) > 0 {
diff --git a/sdk/messaging/azservicebus/internal/amqpLinks_test.go b/sdk/messaging/azservicebus/internal/amqpLinks_test.go
index 8a2b11fbf4e4..886632c93d39 100644
--- a/sdk/messaging/azservicebus/internal/amqpLinks_test.go
+++ b/sdk/messaging/azservicebus/internal/amqpLinks_test.go
@@ -5,78 +5,20 @@ package internal
import (
"context"
- "errors"
+ "log"
+ "sync"
"testing"
+ "time"
+ azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/require"
)
-func TestAMQPLinks(t *testing.T) {
- fakeSender := &FakeAMQPSender{}
- fakeSession := &FakeAMQPSession{}
- fakeMgmtClient := &fakeMgmtClient{}
-
- createLinkFunc, createLinkCallCount := setupCreateLinkResponses(t, []createLinkResponse{
- {sender: fakeSender},
- })
-
- links := newAMQPLinks(&FakeNS{
- Session: fakeSession,
- MgmtClient: fakeMgmtClient,
- }, "entityPath", &fakeRetrier{}, createLinkFunc)
-
- require.EqualValues(t, "entityPath", links.EntityPath())
- require.EqualValues(t, "audience: entityPath", links.Audience())
-
- // successful Get() where a Sender was initialized
- sender, receiver, mgmt, linkRevision, err := links.Get(context.Background())
- require.NotNil(t, sender)
- require.NotNil(t, mgmt) // you always get a free mgmt link
- require.Nil(t, receiver)
- require.Nil(t, err)
- require.EqualValues(t, 1, linkRevision)
- require.EqualValues(t, 1, *createLinkCallCount)
-
- // further calls should just be cached instances
- sender2, receiver2, mgmt2, linkRevision2, err2 := links.Get(context.Background())
- require.EqualValues(t, sender, sender2)
- require.EqualValues(t, mgmt, mgmt2)
- require.Nil(t, receiver2)
- require.Nil(t, err2)
- require.EqualValues(t, 1, linkRevision2, "No recover calls, so link revision remains the same")
- require.EqualValues(t, 1, *createLinkCallCount, "No create call needed since an instance was cached")
-
- // closing multiple times is fine.
- asAMQPLinks, ok := links.(*amqpLinks)
- require.True(t, ok)
-
- require.NoError(t, links.Close(context.Background(), false))
- require.False(t, asAMQPLinks.closedPermanently)
-
- require.NoError(t, links.Close(context.Background(), true))
- require.True(t, asAMQPLinks.closedPermanently)
-
- require.NoError(t, links.Close(context.Background(), true))
- require.True(t, asAMQPLinks.closedPermanently)
-
- require.NoError(t, links.Close(context.Background(), false))
- require.True(t, asAMQPLinks.closedPermanently)
-
- // and the individual links are closed as well
- require.EqualValues(t, 1, fakeSender.Closed)
- require.EqualValues(t, 1, fakeSession.closed)
- require.EqualValues(t, 1, fakeMgmtClient.closed)
-
- // and calls to Get() will indicate the amqpLinks has been closed permanently
- sender, receiver, mgmt, linkRevision, err = links.Get(context.Background())
- require.Nil(t, sender)
- require.Nil(t, receiver)
- require.Nil(t, mgmt)
- require.EqualValues(t, 0, linkRevision)
-
- _, ok = err.(NonRetriable)
- require.True(t, ok)
+var retryOptionsOnlyOnce = utils.RetryOptions{
+ MaxRetries: 0,
}
type fakeNetError struct {
@@ -88,112 +30,368 @@ func (pe fakeNetError) Timeout() bool { return pe.timeout }
func (pe fakeNetError) Temporary() bool { return pe.temp }
func (pe fakeNetError) Error() string { return "Fake but very permanent error" }
-func TestAMQPLinksRecovery(t *testing.T) {
- sess := &FakeAMQPSession{}
- ns := &FakeNS{
- Session: sess,
- }
- sender := &FakeAMQPSender{}
+func assertFailedLinks(t *testing.T, lwid *LinksWithID, expectedErr error) {
+ err := lwid.Sender.Send(context.TODO(), &amqp.Message{
+ Data: [][]byte{
+ {0},
+ },
+ })
+ require.ErrorIs(t, err, expectedErr)
+
+ _, err = PeekMessages(context.TODO(), lwid.RPC, 0, 1)
+ require.ErrorIs(t, err, expectedErr)
+
+ msg, err := lwid.Receiver.Receive(context.TODO())
+ require.ErrorIs(t, err, expectedErr)
+ require.Nil(t, msg)
+
+}
+
+func assertLinks(t *testing.T, lwid *LinksWithID) {
+ err := lwid.Sender.Send(context.TODO(), &amqp.Message{
+ Data: [][]byte{
+ {0},
+ },
+ })
+ require.NoError(t, err)
+
+ _, err = PeekMessages(context.TODO(), lwid.RPC, 0, 1)
+ require.NoError(t, err)
+
+ require.NoError(t, lwid.Receiver.IssueCredit(1))
+ msg, err := lwid.Receiver.Receive(context.TODO())
+ require.NoError(t, err)
+ require.NotNil(t, msg)
+}
+
+func TestAMQPLinksBasic(t *testing.T) {
+ entityPath, cleanup := test.CreateExpiringQueue(t, nil)
+ defer cleanup()
+
+ cs := test.GetConnectionString(t)
+ ns, err := NewNamespace(NamespaceWithConnectionString(cs))
+ require.NoError(t, err)
+
+ links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ return newLinksForAMQPLinksTest(entityPath, session)
+ })
+
+ lwr, err := links.Get(context.Background())
+ require.NoError(t, err)
+
+ assertLinks(t, lwr)
+
+ require.EqualValues(t, entityPath, links.EntityPath())
+}
+
+func TestAMQPLinksLive(t *testing.T) {
+ // we're not going to use this client for tehse tests.
+ entityPath, cleanup := test.CreateExpiringQueue(t, nil)
+ defer cleanup()
+
+ cs := test.GetConnectionString(t)
+ ns, err := NewNamespace(NamespaceWithConnectionString(cs))
+ require.NoError(t, err)
+
+ defer func() { _ = ns.Close(context.Background()) }()
+
+ createLinksCalled := 0
+
+ links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ createLinksCalled++
+ return newLinksForAMQPLinksTest(entityPath, session)
+ })
+
+ require.EqualValues(t, 0, createLinksCalled)
+ require.NoError(t, links.RecoverIfNeeded(context.Background(), LinkID{}, amqp.ErrConnClosed))
+ require.EqualValues(t, 1, createLinksCalled)
+
+ lwr, err := links.Get(context.Background())
+ require.NoError(t, err)
+
+ amqpClient, clientRev, err := ns.GetAMQPClientImpl(context.Background())
+ require.NoError(t, err)
+ require.EqualValues(t, 1, clientRev)
+ require.NoError(t, amqpClient.Close())
+
+ // all the links are dead because the connection is dead.
+ assertFailedLinks(t, lwr, amqp.ErrConnClosed)
+
+ // now we'll recover, which should recreate everything
+ require.NoError(t, links.RecoverIfNeeded(context.Background(), lwr.ID, amqp.ErrConnClosed))
+ require.EqualValues(t, 2, createLinksCalled)
+
+ lwr, err = links.Get(context.Background())
+ require.NoError(t, err)
+
+ // should work now, connection should be reopened
+ assertLinks(t, lwr)
+
+ // cheat a bit and close the links out from under us (but leave them in place)
+ actualLinks := links.(*AMQPLinksImpl)
+ _ = actualLinks.Sender.Close(context.Background())
+ _ = actualLinks.Receiver.Close(context.Background())
+ _ = actualLinks.RPCLink.Close(context.Background())
+
+ assertFailedLinks(t, lwr, amqp.ErrLinkClosed)
+
+ lwr, err = links.Get(context.Background())
+ require.NoError(t, err)
+
+ require.NoError(t, links.RecoverIfNeeded(context.Background(), lwr.ID, amqp.ErrLinkClosed))
+ require.EqualValues(t, 3, createLinksCalled)
+
+ lwr, err = links.Get(context.Background())
+ require.NoError(t, err)
+
+ assertLinks(t, lwr)
+}
+
+func TestAMQPLinksLiveRecoverLink(t *testing.T) {
+ // we're not going to use this client for tehse tests.
+ entityPath, cleanup := test.CreateExpiringQueue(t, nil)
+ defer cleanup()
+
+ cs := test.GetConnectionString(t)
+ ns, err := NewNamespace(NamespaceWithConnectionString(cs))
+ require.NoError(t, err)
+
+ defer func() { _ = ns.Close(context.Background()) }()
+
+ createLinksCalled := 0
+
+ links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ createLinksCalled++
+ return newLinksForAMQPLinksTest(entityPath, session)
+ })
+
+ require.EqualValues(t, 0, createLinksCalled)
+ require.NoError(t, links.RecoverIfNeeded(context.Background(), LinkID{}, amqp.ErrConnClosed))
+ require.EqualValues(t, 1, createLinksCalled)
+
+ lwr, err := links.Get(context.Background())
+ require.NoError(t, err)
+
+ require.NoError(t, links.RecoverIfNeeded(context.Background(), lwr.ID, amqp.ErrLinkClosed))
+ require.EqualValues(t, 2, createLinksCalled)
+}
+
+func TestAMQPLinksLiveRace(t *testing.T) {
+ entityPath, cleanup := test.CreateExpiringQueue(t, nil)
+ defer cleanup()
- createLinkCalled := 0
+ cs := test.GetConnectionString(t)
+ ns, err := NewNamespace(NamespaceWithConnectionString(cs))
+ require.NoError(t, err)
+
+ defer func() { _ = ns.Close(context.Background()) }()
+
+ createLinksCalled := 0
- tmpLinks := newAMQPLinks(ns, "entity path", NewBackoffRetrier(BackoffRetrierParams{}), func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
- createLinkCalled++
- return sender, nil, nil
+ links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ createLinksCalled++
+ return newLinksForAMQPLinksTest(entityPath, session)
})
- links, _ := tmpLinks.(*amqpLinks)
-
- links.clientRevision = 2001
- links.sender = sender
-
- ctx := context.TODO()
-
- require.Nil(t, links.RecoverIfNeeded(ctx, 0, nil))
- require.EqualValues(t, 0, sess.closed)
- require.EqualValues(t, 0, ns.recovered)
- require.EqualValues(t, 0, createLinkCalled, "new links aren't needed")
- require.False(t, links.closedPermanently, "link should still be usable")
- require.Empty(t, ns.clientRevisions, "no connection recoveries happened")
-
- require.Nil(t, links.RecoverIfNeeded(ctx, 0, errors.New("Passes through")))
- require.EqualValues(t, 0, sess.closed)
- require.EqualValues(t, 0, ns.recovered)
- require.EqualValues(t, 0, createLinkCalled, "new links aren't needed")
- require.False(t, links.closedPermanently, "link should still be usable")
- require.Empty(t, ns.clientRevisions, "no connection recoveries happened")
-
- // now let's initiate a recovery at the connection level
- require.NoError(t, links.RecoverIfNeeded(ctx, 0, fakeNetError{}), fakeNetError{}.Error())
- require.EqualValues(t, 1, ns.recovered, "client gets recovered")
- require.EqualValues(t, 1, sender.Closed, "link is closed")
- require.EqualValues(t, 1, createLinkCalled, "link is created")
- require.False(t, links.closedPermanently, "link should still be usable")
- require.EqualValues(t, []uint64{2001}, ns.clientRevisions, "links handed us the client revision it got last")
-
- // validate that our linkRevision got updated and that we're returning it.
- // (note that link revisions start at 1, so we're not at 2, even though
- // only one recover has happened)
- _, _, _, linkRevision, err := links.Get(ctx)
- require.NoError(t, err)
- require.EqualValues(t, uint64(2), linkRevision)
-
- ns.recovered = 0
- sender.Closed = 0
- createLinkCalled = 0
-
- // let's do just a link level one
- require.NoError(t, links.RecoverIfNeeded(ctx, links.revision+1, &amqp.DetachError{}), &amqp.DetachError{})
- require.EqualValues(t, 0, ns.recovered)
- require.EqualValues(t, 1, sender.Closed)
- require.EqualValues(t, 1, createLinkCalled)
-
- _, _, _, linkRevision, err = links.Get(ctx)
- require.NoError(t, err)
- require.EqualValues(t, uint64(3), linkRevision)
-
- // cancellation
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
-
- ns.recovered = 0
- sender.Closed = 0
- createLinkCalled = 0
-
- // cancellation overrides any other logic.
- require.Error(t, links.RecoverIfNeeded(ctx, links.revision+1, &amqp.DetachError{}), &amqp.DetachError{})
- require.EqualValues(t, 0, ns.recovered)
- require.EqualValues(t, 0, sender.Closed)
- require.EqualValues(t, 0, createLinkCalled)
+ wg := sync.WaitGroup{}
+
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ err := links.RecoverIfNeeded(context.Background(), LinkID{}, amqp.ErrConnClosed)
+ require.NoError(t, err)
+ }()
+ }
+
+ wg.Wait()
+
+ // TODO: also check that the connection hasn't recycled multiple times.
+ require.EqualValues(t, 1, createLinksCalled)
}
-func TestAMQPLinks_Closed(t *testing.T) {
- createLinks := func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
- return nil, nil, nil
+func TestAMQPLinksLiveRaceLink(t *testing.T) {
+ entityPath, cleanup := test.CreateExpiringQueue(t, nil)
+ defer cleanup()
+
+ cs := test.GetConnectionString(t)
+ ns, err := NewNamespace(NamespaceWithConnectionString(cs))
+ require.NoError(t, err)
+
+ defer func() { _ = ns.Close(context.Background()) }()
+
+ createLinksCalled := 0
+
+ enableLogging()
+
+ links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ createLinksCalled++
+ return newLinksForAMQPLinksTest(entityPath, session)
+ })
+
+ wg := sync.WaitGroup{}
+
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ err := links.RecoverIfNeeded(context.Background(), LinkID{}, &amqp.DetachError{})
+ require.NoError(t, err)
+ }()
}
- links := newAMQPLinks(&FakeNS{}, "hello", &backoffRetrier{}, createLinks)
- links.Close(context.Background(), true)
+ wg.Wait()
+
+ // TODO: also check that the connection hasn't recycled multiple times.
+ require.EqualValues(t, 1, createLinksCalled)
+}
+
+func TestAMQPLinksRetry(t *testing.T) {
+ entityPath, cleanup := test.CreateExpiringQueue(t, nil)
+ defer cleanup()
+
+ cs := test.GetConnectionString(t)
+ ns, err := NewNamespace(NamespaceWithConnectionString(cs))
+ require.NoError(t, err)
+
+ defer func() { _ = ns.Close(context.Background()) }()
+
+ createLinksCalled := 0
+
+ links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ createLinksCalled++
+ return newLinksForAMQPLinksTest(entityPath, session)
+ })
+
+ err = links.Retry(context.Background(), "retryOp", func(ctx context.Context, lwid *LinksWithID, args *utils.RetryFnArgs) error {
+ // force recoveries
+ return &amqp.DetachError{}
+ }, utils.RetryOptions{
+ MaxRetries: 2,
+ // note: omitting MaxRetries just to give a sanity check that
+ // we do setDefaults() before we run.
+ RetryDelay: time.Millisecond,
+ MaxRetryDelay: time.Millisecond,
+ })
+
+ var detachErr *amqp.DetachError
+ require.ErrorAs(t, err, &detachErr)
+ require.EqualValues(t, 3, createLinksCalled)
+}
+
+func TestAMQPLinksMultipleWithSameConnection(t *testing.T) {
+ entityPath, cleanup := test.CreateExpiringQueue(t, nil)
+ defer cleanup()
+
+ cs := test.GetConnectionString(t)
+ ns, err := NewNamespace(NamespaceWithConnectionString(cs))
+ require.NoError(t, err)
+
+ defer func() { _ = ns.Close(context.Background()) }()
- _, _, _, _, err := links.Get(context.Background())
+ createLinksCalled := 0
- require.True(t, IsNonRetriable(err))
+ links := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ createLinksCalled++
+ return newLinksForAMQPLinksTest(entityPath, session)
+ })
+
+ createLinksCalled2 := 0
+
+ links2 := NewAMQPLinks(ns, entityPath, func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ createLinksCalled2++
+ return newLinksForAMQPLinksTest(entityPath, session)
+ })
+
+ wg := sync.WaitGroup{}
+
+ lwr, err := links.Get(context.Background())
+ require.NoError(t, err)
+ require.EqualValues(t, 1, createLinksCalled)
+ require.EqualValues(t, 1, lwr.ID.Link)
+
+ lwr2, err := links2.Get(context.Background())
+ require.NoError(t, err)
+ require.EqualValues(t, 1, createLinksCalled2)
+ require.EqualValues(t, 1, lwr2.ID.Link)
+
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+ err = links.RecoverIfNeeded(context.Background(), lwr.ID, &amqp.DetachError{})
+ require.NoError(t, err)
+ }()
+
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+
+ err = links2.RecoverIfNeeded(context.Background(), lwr2.ID, &amqp.DetachError{})
+ require.NoError(t, err)
+ }()
+
+ wg.Wait()
+
+ // TODO: also check that the connection hasn't recycled multiple times.
+ require.EqualValues(t, 2, createLinksCalled)
+ require.EqualValues(t, 2, createLinksCalled2)
+
+ _, clientRev, err := ns.GetAMQPClientImpl(context.Background())
+ require.NoError(t, err)
+ require.EqualValues(t, 1, clientRev)
+
+ recovered, err := ns.Recover(context.Background(), clientRev)
+ require.NoError(t, err)
+ require.True(t, recovered)
+
+ _, clientRev, err = ns.GetAMQPClientImpl(context.Background())
+ require.NoError(t, err)
+ require.EqualValues(t, 2, clientRev)
+
+ // now attempt a recover but with an older revision (won't do anything since we've
+ // already recovered past that older rev. They should just call Get())
+ recovered, err = ns.Recover(context.Background(), clientRev-1)
+ require.NoError(t, err)
+ require.False(t, recovered)
+
+ _, clientRev, err = ns.GetAMQPClientImpl(context.Background())
+ require.NoError(t, err)
+ require.EqualValues(t, 2, clientRev)
}
-func setupCreateLinkResponses(t *testing.T, responses []createLinkResponse) (CreateLinkFunc, *int) {
- callCount := 0
- testCreateLinkFunc := func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
- callCount++
+func newLinksForAMQPLinksTest(entityPath string, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) {
+ receiveMode := amqp.ModeSecond
+
+ opts := []amqp.LinkOption{
+ amqp.LinkSourceAddress(entityPath),
+ amqp.LinkReceiverSettle(receiveMode),
+ amqp.LinkWithManualCredits(),
+ amqp.LinkCredit(1000),
+ }
+
+ receiver, err := session.NewReceiver(opts...)
- if len(responses) == 0 {
- require.Fail(t, "createLinkFunc called too many times")
- }
+ if err != nil {
+ return nil, nil, err
+ }
- r := responses[0]
- responses = responses[1:]
+ sender, err := session.NewSender(
+ amqp.LinkSenderSettle(amqp.ModeMixed),
+ amqp.LinkReceiverSettle(amqp.ModeFirst),
+ amqp.LinkTargetAddress(entityPath))
- return r.sender, r.receiver, r.err
+ if err != nil {
+ _ = receiver.Close(context.Background())
+ return nil, nil, err
}
- return testCreateLinkFunc, &callCount
+ return sender, receiver, nil
+}
+
+func enableLogging() {
+ azlog.SetListener(func(e azlog.Event, s string) {
+ log.Printf("%s %s", e, s)
+ })
}
diff --git a/sdk/messaging/azservicebus/internal/amqp_test_utils.go b/sdk/messaging/azservicebus/internal/amqp_test_utils.go
index 6feee7d1848c..014f26eb5309 100644
--- a/sdk/messaging/azservicebus/internal/amqp_test_utils.go
+++ b/sdk/messaging/azservicebus/internal/amqp_test_utils.go
@@ -7,15 +7,14 @@ import (
"context"
"fmt"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
)
type FakeNS struct {
claimNegotiated int
recovered uint64
clientRevisions []uint64
- MgmtClient MgmtClient
- RPCLink *rpc.Link
+ RPCLink RPCLink
Session AMQPSessionCloser
AMQPLinks *FakeAMQPLinks
}
@@ -30,21 +29,16 @@ type FakeAMQPSession struct {
closed int
}
-type fakeMgmtClient struct {
- MgmtClient
- closed int
-}
-
type FakeAMQPLinks struct {
AMQPLinks
Closed int
// values to be returned for each `Get` call
- Revision uint64
+ Revision LinkID
Receiver AMQPReceiver
Sender AMQPSender
- Mgmt MgmtClient
+ RPC RPCLink
Err error
permanently bool
@@ -66,8 +60,23 @@ func (r *FakeAMQPReceiver) Close(ctx context.Context) error {
return nil
}
-func (l *FakeAMQPLinks) Get(ctx context.Context) (AMQPSender, AMQPReceiver, MgmtClient, uint64, error) {
- return l.Sender, l.Receiver, l.Mgmt, l.Revision, l.Err
+func (l *FakeAMQPLinks) Get(ctx context.Context) (*LinksWithID, error) {
+ return &LinksWithID{
+ Sender: l.Sender,
+ Receiver: l.Receiver,
+ RPC: l.RPC,
+ ID: l.Revision,
+ }, l.Err
+}
+
+func (l *FakeAMQPLinks) Retry(ctx context.Context, name string, fn RetryWithLinksFn, o utils.RetryOptions) error {
+ lwr, err := l.Get(ctx)
+
+ if err != nil {
+ return err
+ }
+
+ return fn(ctx, lwr, &utils.RetryFnArgs{})
}
func (l *FakeAMQPLinks) Close(ctx context.Context, permanently bool) error {
@@ -93,11 +102,6 @@ func (s *FakeAMQPSession) Close(ctx context.Context) error {
return nil
}
-func (m *fakeMgmtClient) Close(ctx context.Context) error {
- m.closed++
- return nil
-}
-
func (ns *FakeNS) NegotiateClaim(ctx context.Context, entityPath string) (func() <-chan struct{}, error) {
ch := make(chan struct{})
close(ch)
@@ -117,26 +121,20 @@ func (ns *FakeNS) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64
return ns.Session, ns.recovered + 100, nil
}
-func (ns *FakeNS) NewMgmtClient(ctx context.Context, links AMQPLinks) (MgmtClient, error) {
- return ns.MgmtClient, nil
-}
-
-func (ns *FakeNS) NewRPCLink(ctx context.Context, managementPath string) (*rpc.Link, error) {
+func (ns *FakeNS) NewRPCLink(ctx context.Context, managementPath string) (RPCLink, error) {
return ns.RPCLink, nil
}
-func (ns *FakeNS) Recover(ctx context.Context, clientRevision uint64) error {
+func (ns *FakeNS) Recover(ctx context.Context, clientRevision uint64) (bool, error) {
ns.clientRevisions = append(ns.clientRevisions, clientRevision)
ns.recovered++
+ return true, nil
+}
+
+func (ns *FakeNS) CloseIfNeeded(ctx context.Context, clientRevision uint64) error {
return nil
}
func (ns *FakeNS) NewAMQPLinks(entityPath string, createLinkFunc CreateLinkFunc) AMQPLinks {
return ns.AMQPLinks
}
-
-type createLinkResponse struct {
- sender AMQPSenderCloser
- receiver AMQPReceiverCloser
- err error
-}
diff --git a/sdk/messaging/azservicebus/internal/atom/entity_manager.go b/sdk/messaging/azservicebus/internal/atom/entity_manager.go
index 179b5b60d678..ddbbae249fc2 100644
--- a/sdk/messaging/azservicebus/internal/atom/entity_manager.go
+++ b/sdk/messaging/azservicebus/internal/atom/entity_manager.go
@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "net"
"net/http"
"net/http/httputil"
"strings"
@@ -19,6 +20,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/sbauth"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/auth"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/conn"
"github.com/devigned/tab"
@@ -44,6 +46,7 @@ type (
Host string
mwStack []MiddlewareFunc
version string
+ retryOptions utils.RetryOptions
}
// BaseEntityDescription provides common fields which are part of Queues, Topics and Subscriptions
@@ -52,7 +55,8 @@ type (
ServiceBusSchema *string `xml:"xmlns,attr,omitempty"`
}
- managementError struct {
+ // example: 401Manage,EntityRead claims required for this operation.
+ ManagementError struct {
XMLName xml.Name `xml:"Error"`
Code int `xml:"Code"`
Detail string `xml:"Detail"`
@@ -145,7 +149,7 @@ const (
Unknown EntityStatus = "Unknown"
)
-func (m *managementError) String() string {
+func (m *ManagementError) String() string {
return fmt.Sprintf("Code: %d, Details: %s", m.Code, m.Detail)
}
@@ -179,7 +183,7 @@ func NewEntityManagerWithConnectionString(connectionString string, version strin
}
// NewEntityManager creates an entity manager using a TokenCredential.
-func NewEntityManager(ns string, tokenCredential azcore.TokenCredential, version string) (EntityManager, error) {
+func NewEntityManager(ns string, tokenCredential azcore.TokenCredential, version string, retryOptions utils.RetryOptions) (EntityManager, error) {
return &entityManager{
Host: fmt.Sprintf("https://%s/", ns),
version: version,
@@ -190,6 +194,7 @@ func NewEntityManager(ns string, tokenCredential azcore.TokenCredential, version
addAuthorization(sbauth.NewTokenProvider(tokenCredential)),
applyTracing(version),
},
+ retryOptions: retryOptions,
}, nil
}
@@ -238,61 +243,90 @@ func (em *entityManager) Delete(ctx context.Context, entityPath string, mw ...Mi
}
func (em *entityManager) execute(ctx context.Context, method string, entityPath string, body io.Reader, mw ...MiddlewareFunc) (*http.Response, error) {
- ctx, span := em.startSpanFromContext(ctx, "sb.ATOM.Execute")
- defer span.End()
+ var finalResp *http.Response
- req, err := http.NewRequest(method, em.Host+strings.TrimPrefix(entityPath, "/"), body)
- if err != nil {
- tab.For(ctx).Error(err)
- return nil, err
- }
+ err := utils.Retry(ctx, fmt.Sprintf("%s %s", method, entityPath), func(ctx context.Context, args *utils.RetryFnArgs) error {
+ ctx, span := em.startSpanFromContext(ctx, "sb.ATOM.Execute")
+ defer span.End()
- final := func(_ RestHandler) RestHandler {
- return func(reqCtx context.Context, request *http.Request) (*http.Response, error) {
- client := &http.Client{
- Timeout: 60 * time.Second,
- }
- request = request.WithContext(reqCtx)
- return client.Do(request)
+ req, err := http.NewRequest(method, em.Host+strings.TrimPrefix(entityPath, "/"), body)
+ if err != nil {
+ tab.For(ctx).Error(err)
+ return err
}
- }
- mwStack := []MiddlewareFunc{final}
- sl := len(em.mwStack) - 1
- for i := sl; i >= 0; i-- {
- mwStack = append(mwStack, em.mwStack[i])
- }
+ final := func(_ RestHandler) RestHandler {
+ return func(reqCtx context.Context, request *http.Request) (*http.Response, error) {
+ client := &http.Client{
+ Timeout: 60 * time.Second,
+ }
+ request = request.WithContext(reqCtx)
+ return client.Do(request)
+ }
+ }
- for i := len(mw) - 1; i >= 0; i-- {
- mwStack = append(mwStack, mw[i])
- }
+ mwStack := []MiddlewareFunc{final}
+ sl := len(em.mwStack) - 1
+ for i := sl; i >= 0; i-- {
+ mwStack = append(mwStack, em.mwStack[i])
+ }
- var h RestHandler
- for _, mw := range mwStack {
- h = mw(h)
- }
+ for i := len(mw) - 1; i >= 0; i-- {
+ mwStack = append(mwStack, mw[i])
+ }
- resp, err := h(ctx, req)
+ var h RestHandler
+ for _, mw := range mwStack {
+ h = mw(h)
+ }
- if err == nil {
- if resp.StatusCode >= http.StatusBadRequest {
- bytes, err := ioutil.ReadAll(resp.Body)
+ resp, err := h(ctx, req)
- if err == nil {
- err = FormatManagementError(bytes, err)
+ if err == nil {
+ if resp.StatusCode >= http.StatusBadRequest {
+ return NewResponseError(resp)
}
- return nil, NewResponseError(err, resp)
+ finalResp = resp
+ return nil
+ }
+
+ if resp != nil {
+ return NewResponseError(resp)
}
- return resp, nil
+ return err
+ }, isFatalHTTPError, em.retryOptions)
+
+ if err != nil {
+ return nil, err
+ }
+
+ return finalResp, nil
+}
+
+func isFatalHTTPError(err error) bool {
+ var netErr net.Error
+
+ if errors.As(err, &netErr) {
+ return false
}
- if resp != nil {
- return nil, NewResponseError(err, resp)
+ var respErr *azcore.ResponseError
+
+ // TODO: this is very much temporary. We need to move this over to the azcore HTTP stack.
+ if errors.As(err, &respErr) {
+ if respErr.StatusCode == http.StatusRequestTimeout || // 408
+ respErr.StatusCode == http.StatusTooManyRequests || // 429
+ respErr.StatusCode == http.StatusInternalServerError || // 500
+ respErr.StatusCode == http.StatusBadGateway || // 502
+ respErr.StatusCode == http.StatusServiceUnavailable || // 503
+ respErr.StatusCode == http.StatusGatewayTimeout { // 504 )
+ return false
+ }
}
- return nil, err
+ return true
}
// Use adds middleware to the middleware mwStack
@@ -306,7 +340,7 @@ func (em *entityManager) TokenProvider() auth.TokenProvider {
}
func FormatManagementError(body []byte, origErr error) error {
- var mgmtError managementError
+ var mgmtError ManagementError
unmarshalErr := xml.Unmarshal(body, &mgmtError)
if unmarshalErr != nil {
return origErr
diff --git a/sdk/messaging/azservicebus/internal/atom/errors.go b/sdk/messaging/azservicebus/internal/atom/errors.go
index e8182a433746..0c50018dc7fb 100644
--- a/sdk/messaging/azservicebus/internal/atom/errors.go
+++ b/sdk/messaging/azservicebus/internal/atom/errors.go
@@ -6,10 +6,15 @@ package atom
import (
"fmt"
"net/http"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
)
-func NewResponseError(inner error, resp *http.Response) error {
- return ResponseError{inner, resp}
+func NewResponseError(resp *http.Response) error {
+ return &azcore.ResponseError{
+ StatusCode: resp.StatusCode,
+ RawResponse: resp,
+ }
}
// ResponseError conforms to the older azcore.HTTPResponse
diff --git a/sdk/messaging/azservicebus/internal/atom/manager_common_test.go b/sdk/messaging/azservicebus/internal/atom/manager_common_test.go
index 9945fff0b8ec..12b7b47a0562 100644
--- a/sdk/messaging/azservicebus/internal/atom/manager_common_test.go
+++ b/sdk/messaging/azservicebus/internal/atom/manager_common_test.go
@@ -4,40 +4,66 @@
package atom
import (
+ "bytes"
"context"
- "errors"
"io"
"net/http"
+ "net/url"
"strings"
"testing"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/stretchr/testify/require"
)
+func newFakeResponse(statusCode int, status string, contents string) *http.Response {
+ var body io.ReadCloser = http.NoBody
+
+ if contents != "" {
+ body = &FakeReader{
+ Reader: bytes.NewBufferString(contents),
+ }
+ }
+
+ return &http.Response{
+ Request: &http.Request{
+ URL: &url.URL{},
+ },
+ StatusCode: statusCode,
+ Status: status,
+ Body: body,
+ }
+}
+
func TestResponseError(t *testing.T) {
- require.EqualValues(t, "this is now the error message: 409", NewResponseError(nil, &http.Response{
- StatusCode: http.StatusConflict,
- Status: "this is now the error message",
- }).Error())
-
- require.EqualValues(t, "inner errors message takes precedence", NewResponseError(errors.New("inner errors message takes precedence"), &http.Response{
- StatusCode: http.StatusConflict,
- Status: "going to be ignored",
- }).Error())
+ resp := newFakeResponse(http.StatusConflict, "statusString", "")
+ require.Contains(t, NewResponseError(resp).Error(), "statusString")
+
+ resp = newFakeResponse(http.StatusConflict, "statusString", "contents")
+ require.Contains(t, NewResponseError(resp).Error(), "statusString")
+
+ resp = newFakeResponse(http.StatusBadGateway, "statusString", "401Manage,EntityRead claims required for this operation.")
+ err := NewResponseError(resp)
+
+ re, ok := err.(*azcore.ResponseError)
+ require.True(t, ok)
+
+ require.Contains(t, re.Error(), "statusString")
+ require.EqualValues(t, http.StatusBadGateway, re.StatusCode)
}
type FakeReader struct {
io.Reader
- closed bool
+ closed bool
+ closeErr error
}
func (f *FakeReader) Close() error {
f.closed = true
- return nil
+ return f.closeErr
}
func TestCloseRes(t *testing.T) {
-
reader := strings.NewReader("hello")
body := &FakeReader{Reader: reader}
diff --git a/sdk/messaging/azservicebus/internal/errors.go b/sdk/messaging/azservicebus/internal/errors.go
index d833bf0609f4..7edefc24bb56 100644
--- a/sdk/messaging/azservicebus/internal/errors.go
+++ b/sdk/messaging/azservicebus/internal/errors.go
@@ -11,193 +11,17 @@ import (
"net"
"reflect"
"strings"
- "time"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc"
"github.com/Azure/go-amqp"
- "github.com/devigned/tab"
)
-type NonRetriable interface {
- error
- NonRetriable()
-}
-
-// IsNonRetriable indicates an error is fatal. Typically, this means
-// the connection or link has been closed.
-func IsNonRetriable(err error) bool {
- var nonRetriable NonRetriable
- return errors.As(err, &nonRetriable)
-}
-
type ErrNonRetriable struct {
Message string
}
func (e ErrNonRetriable) Error() string { return e.Message }
-// Error Conditions
-const (
- // Service Bus Errors
- errorServerBusy amqp.ErrorCondition = "com.microsoft:server-busy"
- errorTimeout amqp.ErrorCondition = "com.microsoft:timeout"
- errorOperationCancelled amqp.ErrorCondition = "com.microsoft:operation-cancelled"
- errorContainerClose amqp.ErrorCondition = "com.microsoft:container-close"
-)
-
-const (
- amqpRetryDefaultTimes int = 3
- amqpRetryDefaultDelay time.Duration = time.Second
-)
-
-type (
- // ErrMissingField indicates that an expected property was missing from an AMQP message. This should only be
- // encountered when there is an error with this library, or the server has altered its behavior unexpectedly.
- ErrMissingField string
-
- // ErrMalformedMessage indicates that a message was expected in the form of []byte was not a []byte. This is likely
- // a bug and should be reported.
- ErrMalformedMessage string
-
- // ErrIncorrectType indicates that type assertion failed. This should only be encountered when there is an error
- // with this library, or the server has altered its behavior unexpectedly.
- ErrIncorrectType struct {
- Key string
- ExpectedType reflect.Type
- ActualValue interface{}
- }
-
- // ErrAMQP indicates that the server communicated an AMQP error with a particular
- ErrAMQP rpc.Response
-
- // ErrNoMessages is returned when an operation returned no messages. It is not indicative that there will not be
- // more messages in the future.
- ErrNoMessages struct{}
-
- // ErrNotFound is returned when an entity is not found (404)
- ErrNotFound struct {
- EntityPath string
- }
-
- // ErrConnectionClosed indicates that the connection has been closed.
- ErrConnectionClosed string
-)
-
-func (e ErrMissingField) Error() string {
- return fmt.Sprintf("missing value %q", string(e))
-}
-
-func (e ErrMalformedMessage) Error() string {
- return "message was expected in the form of []byte was not a []byte"
-}
-
-// NewErrIncorrectType lets you skip using the `reflect` package. Just provide a variable of the desired type as
-// 'expected'.
-func NewErrIncorrectType(key string, expected, actual interface{}) ErrIncorrectType {
- return ErrIncorrectType{
- Key: key,
- ExpectedType: reflect.TypeOf(expected),
- ActualValue: actual,
- }
-}
-
-func (e ErrIncorrectType) Error() string {
- return fmt.Sprintf(
- "value at %q was expected to be of type %q but was actually of type %q",
- e.Key,
- e.ExpectedType,
- reflect.TypeOf(e.ActualValue))
-}
-
-func (e ErrAMQP) Error() string {
- return fmt.Sprintf("server says (%d) %s", e.Code, e.Description)
-}
-
-func (e ErrNoMessages) Error() string {
- return "no messages available"
-}
-
-func (e ErrNotFound) Error() string {
- return fmt.Sprintf("entity at %s not found", e.EntityPath)
-}
-
-// IsErrNotFound returns true if the error argument is an ErrNotFound type
-func IsErrNotFound(err error) bool {
- _, ok := err.(ErrNotFound)
- return ok
-}
-
-func (e ErrConnectionClosed) Error() string {
- return fmt.Sprintf("the connection has been closed: %s", string(e))
-}
-
-// Leveraging @serbrech's fine work from go-shuttle:
-// https://github.com/Azure/go-shuttle/blob/ea882947109ade9b34d4d69642fdf7aec4570fee/common/errorhandling/recovery.go
-
-var retryableAMQPConditions = map[string]bool{
- string(amqp.ErrorInternalError): true,
- string(errorServerBusy): true, // "com.microsoft:server-busy"
- string(errorTimeout): true, // "com.microsoft:timeout"
- string(errorOperationCancelled): true, // "com.microsoft:operation-cancelled"
- "client.sender:not-enough-link-credit": true,
- string(amqp.ErrorUnauthorizedAccess): true,
- string(amqp.ErrorDetachForced): true,
- string(amqp.ErrorConnectionForced): true,
- string(amqp.ErrorTransferLimitExceeded): true,
- "amqp: connection closed": true,
- "unexpected frame": true,
- string(amqp.ErrorNotFound): true,
-}
-
-func isRetryableAMQPError(ctxForLogging context.Context, err error) bool {
- var amqpErr *amqp.Error
- var isAMQPError = errors.As(err, &amqpErr)
-
- if isAMQPError {
- _, ok := retryableAMQPConditions[string(amqpErr.Condition)]
- return ok
- }
-
- // TODO: there is a bug somewhere that seems to be errorString'ing errors. Need to track that down.
- // In the meantime, try string matching instead
- for condition := range retryableAMQPConditions {
- if strings.Contains(err.Error(), condition) {
- tab.For(ctxForLogging).Error(fmt.Errorf("error needed to be matched by a string matcher, rather than by type: %w", err))
- return true
- }
- }
-
- return false
-}
-
-func shouldRecreateLink(err error) bool {
- if err == nil {
- return false
- }
-
- var detachError *amqp.DetachError
-
- return errors.As(err, &detachError) ||
- // TODO: proper error types needs to happen
- strings.Contains(err.Error(), "detach frame link detached")
-}
-
-func shouldRecreateConnection(ctxForLogging context.Context, err error) bool {
- if err == nil {
- return false
- }
-
- shouldRecreate := isPermanentNetError(err) ||
- isRetryableAMQPError(ctxForLogging, err) ||
- errors.Is(err, io.EOF) ||
- // these are distinct from a detach and probably indicate something
- // wrong with the connection itself, rather than just the link
- errors.Is(err, amqp.ErrSessionClosed) ||
- errors.Is(err, amqp.ErrLinkClosed)
-
- return shouldRecreate
-}
-
type recoveryKind string
const RecoveryKindNone recoveryKind = ""
@@ -205,48 +29,40 @@ const RecoveryKindFatal recoveryKind = "fatal"
const RecoveryKindLink recoveryKind = "link"
const RecoveryKindConn recoveryKind = "connection"
-type ServiceBusError struct {
+type SBErrInfo struct {
inner error
RecoveryKind recoveryKind
}
-func (sbe *ServiceBusError) String() string {
+func (sbe *SBErrInfo) String() string {
return sbe.inner.Error()
}
-func (sbe *ServiceBusError) AsError() error {
+func (sbe *SBErrInfo) AsError() error {
return sbe.inner
}
-// ToSBE wraps the passed in 'err' with a proper error with one of either:
+func IsFatalSBError(err error) bool {
+ return GetSBErrInfo(err).RecoveryKind == RecoveryKindFatal
+}
+
+// GetSBErrInfo wraps the passed in 'err' with a proper error with one of either:
// - `fatalServiceBusError` if no recovery is possible.
// - `serviceBusError` if the error is recoverable. The `recoveryKind` field contains the
// type of recovery needed.
-func ToSBE(loggingCtx context.Context, err error) *ServiceBusError {
+func GetSBErrInfo(err error) *SBErrInfo {
if err == nil {
return nil
}
- sbe := &ServiceBusError{
+ sbe := &SBErrInfo{
inner: err,
- RecoveryKind: GetRecoveryKind(loggingCtx, err),
+ RecoveryKind: GetRecoveryKind(err),
}
return sbe
}
-func isPermanentNetError(err error) bool {
- var netErr net.Error
-
- if errors.As(err, &netErr) {
- temp := netErr.Temporary()
- timeout := netErr.Timeout()
- return !temp && !timeout
- }
-
- return false
-}
-
func IsCancelError(err error) bool {
if err == nil {
return false
@@ -291,7 +107,7 @@ var amqpConditionsToRecoveryKind = map[amqp.ErrorCondition]recoveryKind{
amqp.ErrorCondition("com.microsoft:message-lock-lost"): RecoveryKindFatal,
}
-func GetRecoveryKind(ctxForLogging context.Context, err error) recoveryKind {
+func GetRecoveryKind(err error) recoveryKind {
if IsCancelError(err) {
return RecoveryKindFatal
}
@@ -345,7 +161,106 @@ func GetRecoveryKind(ctxForLogging context.Context, err error) recoveryKind {
}
}
+ var me mgmtError
+
+ if errors.As(err, &me) {
+ code := me.RPCCode()
+
+ // this can happen when we're recovering the link - the client gets closed and the old link is still being
+ // used by this instance of the client. It needs to recover and attempt it again.
+ if code == 401 ||
+ // we lost the session lock, attempt link recovery
+ code == 410 {
+ return RecoveryKindLink
+ }
+
+ // simple timeouts
+ if me.Resp.Code == 408 || me.Resp.Code == 503 {
+ return RecoveryKindNone
+ }
+ }
+
// this is some error type we've never seen.
- tab.For(ctxForLogging).Fatal(fmt.Sprintf("No recovery possible with error: %#v", err))
return RecoveryKindFatal
}
+
+type (
+ // ErrMissingField indicates that an expected property was missing from an AMQP message. This should only be
+ // encountered when there is an error with this library, or the server has altered its behavior unexpectedly.
+ ErrMissingField string
+
+ // ErrMalformedMessage indicates that a message was expected in the form of []byte was not a []byte. This is likely
+ // a bug and should be reported.
+ ErrMalformedMessage string
+
+ // ErrIncorrectType indicates that type assertion failed. This should only be encountered when there is an error
+ // with this library, or the server has altered its behavior unexpectedly.
+ ErrIncorrectType struct {
+ Key string
+ ExpectedType reflect.Type
+ ActualValue interface{}
+ }
+
+ // ErrAMQP indicates that the server communicated an AMQP error with a particular
+ ErrAMQP rpc.Response
+
+ // ErrNoMessages is returned when an operation returned no messages. It is not indicative that there will not be
+ // more messages in the future.
+ ErrNoMessages struct{}
+
+ // ErrNotFound is returned when an entity is not found (404)
+ ErrNotFound struct {
+ EntityPath string
+ }
+
+ // ErrConnectionClosed indicates that the connection has been closed.
+ ErrConnectionClosed string
+)
+
+func (e ErrMissingField) Error() string {
+ return fmt.Sprintf("missing value %q", string(e))
+}
+
+func (e ErrMalformedMessage) Error() string {
+ return "message was expected in the form of []byte was not a []byte"
+}
+
+// NewErrIncorrectType lets you skip using the `reflect` package. Just provide a variable of the desired type as
+// 'expected'.
+func NewErrIncorrectType(key string, expected, actual interface{}) ErrIncorrectType {
+ return ErrIncorrectType{
+ Key: key,
+ ExpectedType: reflect.TypeOf(expected),
+ ActualValue: actual,
+ }
+}
+
+func (e ErrIncorrectType) Error() string {
+ return fmt.Sprintf(
+ "value at %q was expected to be of type %q but was actually of type %q",
+ e.Key,
+ e.ExpectedType,
+ reflect.TypeOf(e.ActualValue))
+}
+
+func (e ErrAMQP) Error() string {
+ return fmt.Sprintf("server says (%d) %s", e.Code, e.Description)
+}
+
+func (e ErrNoMessages) Error() string {
+ return "no messages available"
+}
+
+func (e ErrNotFound) Error() string {
+ return fmt.Sprintf("entity at %s not found", e.EntityPath)
+}
+
+// IsErrNotFound returns true if the error argument is an ErrNotFound type
+func IsErrNotFound(err error) bool {
+ _, ok := err.(ErrNotFound)
+ return ok
+}
+
+func (e ErrConnectionClosed) Error() string {
+ return fmt.Sprintf("the connection has been closed: %s", string(e))
+}
diff --git a/sdk/messaging/azservicebus/internal/errors_test.go b/sdk/messaging/azservicebus/internal/errors_test.go
index 471e7d96ad6d..743f2894f2af 100644
--- a/sdk/messaging/azservicebus/internal/errors_test.go
+++ b/sdk/messaging/azservicebus/internal/errors_test.go
@@ -77,96 +77,92 @@ func TestErrNotFound_Error(t *testing.T) {
assert.False(t, IsErrNotFound(otherErr))
}
-func Test_isPermanentNetError(t *testing.T) {
- require.False(t, isPermanentNetError(&fakeNetError{
- temp: true,
- }))
-
- require.False(t, isPermanentNetError(&fakeNetError{
- timeout: true,
- }))
-
- require.False(t, isPermanentNetError(errors.New("not a net error")))
-
- require.True(t, isPermanentNetError(&fakeNetError{}))
+func Test_recoveryKind(t *testing.T) {
+ t.Run("link", func(t *testing.T) {
+ linkErrorCodes := []string{
+ string(amqp.ErrorDetachForced),
+ }
+
+ for _, code := range linkErrorCodes {
+ t.Run(code, func(t *testing.T) {
+ sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)})
+ require.EqualValues(t, RecoveryKindLink, sbe.RecoveryKind, fmt.Sprintf("requires link recovery: %s", code))
+ })
+ }
+
+ t.Run("sentintel errors", func(t *testing.T) {
+ sbe := GetSBErrInfo(amqp.ErrLinkClosed)
+ require.EqualValues(t, RecoveryKindLink, sbe.RecoveryKind)
+
+ sbe = GetSBErrInfo(amqp.ErrSessionClosed)
+ require.EqualValues(t, RecoveryKindLink, sbe.RecoveryKind)
+ })
+ })
+
+ t.Run("connection", func(t *testing.T) {
+ codes := []string{
+ string(amqp.ErrorConnectionForced),
+ }
+
+ for _, code := range codes {
+ t.Run(code, func(t *testing.T) {
+ sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)})
+ require.EqualValues(t, RecoveryKindConn, sbe.RecoveryKind, fmt.Sprintf("requires connection recovery: %s", code))
+ })
+ }
+
+ t.Run("sentinel errors", func(t *testing.T) {
+ sbe := GetSBErrInfo(amqp.ErrConnClosed)
+ require.EqualValues(t, RecoveryKindConn, sbe.RecoveryKind)
+ })
+ })
+
+ t.Run("nonretriable", func(t *testing.T) {
+ codes := []string{
+ string(amqp.ErrorTransferLimitExceeded),
+ string(amqp.ErrorInternalError),
+ string(amqp.ErrorUnauthorizedAccess),
+ string(amqp.ErrorNotFound),
+ string(amqp.ErrorMessageSizeExceeded),
+ }
+
+ for _, code := range codes {
+ t.Run(code, func(t *testing.T) {
+ sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)})
+ require.EqualValues(t, RecoveryKindFatal, sbe.RecoveryKind, fmt.Sprintf("cannot be recovered: %s", code))
+ })
+ }
+ })
+
+ t.Run("none", func(t *testing.T) {
+ codes := []string{
+ string("com.microsoft:operation-cancelled"),
+ string("com.microsoft:server-busy"),
+ string("com.microsoft:timeout"),
+ }
+
+ for _, code := range codes {
+ t.Run(code, func(t *testing.T) {
+ sbe := GetSBErrInfo(&amqp.Error{Condition: amqp.ErrorCondition(code)})
+ require.EqualValues(t, RecoveryKindNone, sbe.RecoveryKind, fmt.Sprintf("no recovery needed: %s", code))
+ })
+ }
+ })
}
-func Test_isRetryableAMQPError(t *testing.T) {
- ctx := context.Background()
-
- retryableCodes := []string{
- string(amqp.ErrorInternalError),
- string(errorServerBusy),
- string(errorTimeout),
- string(errorOperationCancelled),
- "client.sender:not-enough-link-credit",
- string(amqp.ErrorUnauthorizedAccess),
- string(amqp.ErrorDetachForced),
- string(amqp.ErrorConnectionForced),
- string(amqp.ErrorTransferLimitExceeded),
- "amqp: connection closed",
- "unexpected frame",
- string(amqp.ErrorNotFound),
+func Test_IsNonRetriable(t *testing.T) {
+ errs := []error{
+ context.Canceled,
+ context.DeadlineExceeded,
+ ErrNonRetriable{Message: "any message"},
+ fmt.Errorf("wrapped: %w", context.Canceled),
+ fmt.Errorf("wrapped: %w", context.DeadlineExceeded),
+ fmt.Errorf("wrapped: %w", ErrNonRetriable{Message: "any message"}),
}
- for _, code := range retryableCodes {
- require.True(t, isRetryableAMQPError(ctx, &amqp.Error{
- Condition: amqp.ErrorCondition(code),
- }))
-
- // it works equally well if the error is just in the String().
- // Need to narrow this down some more to see where the errors
- // might not be getting converted properly.
- require.True(t, isRetryableAMQPError(ctx, errors.New(code)))
+ for _, err := range errs {
+ require.EqualValues(t, RecoveryKindFatal, GetSBErrInfo(err).RecoveryKind)
}
-
- require.False(t, isRetryableAMQPError(ctx, errors.New("some non-amqp related error")))
-}
-
-func Test_shouldRecreateLink(t *testing.T) {
- require.False(t, shouldRecreateLink(nil))
-
- require.True(t, shouldRecreateLink(&amqp.DetachError{}))
-
- // going to treat these as "connection troubles" and throw them into the
- // connection recovery scenario instead.
- require.False(t, shouldRecreateLink(amqp.ErrLinkClosed))
- require.False(t, shouldRecreateLink(amqp.ErrSessionClosed))
-}
-
-func Test_shouldRecreateConnection(t *testing.T) {
- ctx := context.Background()
-
- require.False(t, shouldRecreateConnection(ctx, nil))
- require.True(t, shouldRecreateConnection(ctx, &fakeNetError{}))
- require.True(t, shouldRecreateConnection(ctx, fmt.Errorf("%w", &fakeNetError{})))
-
- require.False(t, shouldRecreateLink(amqp.ErrLinkClosed))
- require.False(t, shouldRecreateLink(fmt.Errorf("wrapped: %w", amqp.ErrLinkClosed)))
-
- require.False(t, shouldRecreateLink(amqp.ErrSessionClosed))
- require.False(t, shouldRecreateLink(fmt.Errorf("wrapped: %w", amqp.ErrSessionClosed)))
-}
-
-// TODO: while testing it appeared there were some errors that were getting string-ized
-// We want to eliminate these. 'stress.go' reproduces most of these as you disconnect
-// and reconnect.
-func Test_stringErrorsToEliminate(t *testing.T) {
- require.True(t, shouldRecreateLink(errors.New("detach frame link detached")))
- require.True(t, isRetryableAMQPError(context.Background(), errors.New("amqp: connection closed")))
- require.True(t, IsCancelError(errors.New("context canceled")))
-}
-
-func Test_IsCancelError(t *testing.T) {
- require.False(t, IsCancelError(nil))
- require.False(t, IsCancelError(errors.New("not a cancel error")))
-
- require.True(t, IsCancelError(errors.New("context canceled")))
-
- require.True(t, IsCancelError(context.Canceled))
- require.True(t, IsCancelError(context.DeadlineExceeded))
- require.True(t, IsCancelError(fmt.Errorf("wrapped: %w", context.Canceled)))
- require.True(t, IsCancelError(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)))
}
func Test_ServiceBusError_NoRecoveryNeeded(t *testing.T) {
@@ -178,10 +174,13 @@ func Test_ServiceBusError_NoRecoveryNeeded(t *testing.T) {
fakeNetError{temp: true},
fakeNetError{timeout: true},
fakeNetError{temp: false, timeout: false},
+ // simple timeouts from the mgmt link
+ mgmtError{Resp: &RPCResponse{Code: 408}},
+ mgmtError{Resp: &RPCResponse{Code: 503}},
}
for i, err := range tempErrors {
- rk := ToSBE(context.Background(), err).RecoveryKind
+ rk := GetSBErrInfo(err).RecoveryKind
require.EqualValues(t, RecoveryKindNone, rk, fmt.Sprintf("[%d] %v", i, err))
}
}
@@ -194,7 +193,7 @@ func Test_ServiceBusError_ConnectionRecoveryNeeded(t *testing.T) {
}
for i, err := range connErrors {
- rk := ToSBE(context.Background(), err).RecoveryKind
+ rk := GetSBErrInfo(err).RecoveryKind
require.EqualValues(t, RecoveryKindConn, rk, fmt.Sprintf("[%d] %v", i, err))
}
}
@@ -205,10 +204,15 @@ func Test_ServiceBusError_LinkRecoveryNeeded(t *testing.T) {
amqp.ErrLinkClosed,
&amqp.DetachError{},
&amqp.Error{Condition: amqp.ErrorDetachForced},
+ // we lost the session lock, attempt link recovery
+ mgmtError{Resp: &RPCResponse{Code: 410}},
+ // this can happen when we're recovering the link - the client gets closed and the old link is still being
+ // used by this instance of the client. It needs to recover and attempt it again.
+ mgmtError{Resp: &RPCResponse{Code: 401}},
}
for i, err := range linkErrors {
- rk := ToSBE(context.Background(), err).RecoveryKind
+ rk := GetSBErrInfo(err).RecoveryKind
require.EqualValues(t, RecoveryKindLink, rk, fmt.Sprintf("[%d] %v", i, err))
}
}
@@ -226,11 +230,14 @@ func Test_ServiceBusError_Fatal(t *testing.T) {
}
for i, cond := range fatalConditions {
- rk := ToSBE(context.Background(), &amqp.Error{Condition: cond}).RecoveryKind
+ rk := GetSBErrInfo(&amqp.Error{Condition: cond}).RecoveryKind
require.EqualValues(t, RecoveryKindFatal, rk, fmt.Sprintf("[%d] %s", i, cond))
}
// unknown errors are also considered fatal
- rk := ToSBE(context.Background(), errors.New("Some unknown error")).RecoveryKind
+ rk := GetSBErrInfo(errors.New("Some unknown error")).RecoveryKind
+ require.EqualValues(t, RecoveryKindFatal, rk, "some unknown error")
+
+ rk = GetSBErrInfo(mgmtError{Resp: &RPCResponse{Code: 500}}).RecoveryKind
require.EqualValues(t, RecoveryKindFatal, rk, "some unknown error")
}
diff --git a/sdk/messaging/azservicebus/internal/log.go b/sdk/messaging/azservicebus/internal/log.go
index c500f1f9c7c5..cd04f8c3a996 100644
--- a/sdk/messaging/azservicebus/internal/log.go
+++ b/sdk/messaging/azservicebus/internal/log.go
@@ -3,7 +3,21 @@
package internal
+import "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
+
const (
+ // Link/connection creation
+ EventConn = "azsb.Conn"
+
+ // authentication/claims negotiation
+ EventAuth = "azsb.Auth"
+
+ // receiver operations
+ EventReceiver = "azsb.Receiver"
+
+ // mgmt link
+ EventMgmtLink = "azsb.Mgmt"
+
// internal operations
- EventRetry = "azsb.Retry"
+ EventRetry = utils.EventRetry
)
diff --git a/sdk/messaging/azservicebus/internal/mgmt.go b/sdk/messaging/azservicebus/internal/mgmt.go
index b266fc95c856..1d9534fc7bc2 100644
--- a/sdk/messaging/azservicebus/internal/mgmt.go
+++ b/sdk/messaging/azservicebus/internal/mgmt.go
@@ -7,13 +7,10 @@ import (
"context"
"errors"
"fmt"
- "sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
- common "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc"
"github.com/Azure/go-amqp"
"github.com/devigned/tab"
)
@@ -34,171 +31,39 @@ const (
DeferredDisposition DispositionStatus = "defered"
)
-type (
- mgmtClient struct {
- ns NamespaceForMgmtClient
- links AMQPLinks
-
- clientMu sync.Mutex
- rpcLink RPCLink
-
- sessionID *string
- isSessionFilterSet bool
- }
-)
-
-type MgmtClient interface {
- Close(ctx context.Context) error
- SendDisposition(ctx context.Context, lockToken *amqp.UUID, state Disposition, propertiesToModify map[string]interface{}) error
- ReceiveDeferred(ctx context.Context, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error)
- PeekMessages(ctx context.Context, fromSequenceNumber int64, messageCount int32) ([]*amqp.Message, error)
-
- ScheduleMessages(ctx context.Context, enqueueTime time.Time, messages ...*amqp.Message) ([]int64, error)
- CancelScheduled(ctx context.Context, seq ...int64) error
-
- RenewLocks(ctx context.Context, linkName string, lockTokens []amqp.UUID) ([]time.Time, error)
- RenewSessionLock(ctx context.Context, sessionID string) (time.Time, error)
-
- GetSessionState(ctx context.Context, sessionID string) ([]byte, error)
- SetSessionState(ctx context.Context, sessionID string, state []byte) error
+type mgmtError struct {
+ Resp *RPCResponse
+ Message string
}
-func newMgmtClient(ctx context.Context, links AMQPLinks, ns NamespaceForMgmtClient) (MgmtClient, error) {
- r := &mgmtClient{
- ns: ns,
- links: links,
- }
-
- return r, nil
+func (me mgmtError) Error() string {
+ return me.Message
}
-// Recover will attempt to close the current session and link, then rebuild them
-func (mc *mgmtClient) recover(ctx context.Context) error {
- mc.clientMu.Lock()
- defer mc.clientMu.Unlock()
-
- ctx, span := mc.startSpanFromContext(ctx, string(tracing.SpanNameRecover))
- defer span.End()
-
- if mc.rpcLink != nil {
- if err := mc.rpcLink.Close(ctx); err != nil {
- tab.For(ctx).Debug(fmt.Sprintf("Error while closing old link in recovery: %s", err.Error()))
- }
- mc.rpcLink = nil
- }
-
- if _, err := mc.getLinkWithoutLock(ctx); err != nil {
- return err
- }
-
- return nil
+func (me mgmtError) RPCCode() int {
+ return me.Resp.Code
}
-// getLinkWithoutLock returns the currently cached link (or creates a new one)
-func (mc *mgmtClient) getLinkWithoutLock(ctx context.Context) (RPCLink, error) {
- if mc.rpcLink != nil {
- return mc.rpcLink, nil
- }
-
- var err error
- mc.rpcLink, err = mc.ns.NewRPCLink(ctx, mc.links.ManagementPath())
+// creates a new link and sends the RPC request, recovering and retrying on certain AMQP errors
+func doRPC(ctx context.Context, name string, rpcLink RPCLink, msg *amqp.Message) (*RPCResponse, error) {
+ res, err := rpcLink.RPC(ctx, msg)
if err != nil {
return nil, err
}
- return mc.rpcLink, nil
-}
-
-// Close will close the AMQP connection
-func (mc *mgmtClient) Close(ctx context.Context) error {
- mc.clientMu.Lock()
- defer mc.clientMu.Unlock()
-
- if mc.rpcLink == nil {
- return nil
+ if res.Code >= 200 && res.Code < 300 {
+ tab.For(ctx).Debug(fmt.Sprintf("rpc: success, status code %d and description: %s", res.Code, res.Description))
+ return res, nil
}
- err := mc.rpcLink.Close(ctx)
- mc.rpcLink = nil
- return err
-}
-
-// creates a new link and sends the RPC request, recovering and retrying on certain AMQP errors
-func (mc *mgmtClient) doRPCWithRetry(ctx context.Context, msg *amqp.Message, times int, delay time.Duration, opts ...rpc.LinkOption) (*rpc.Response, error) {
- // track the number of times we attempt to perform the RPC call.
- // this is to avoid a potential infinite loop if the returned error
- // is always transient and Recover() doesn't fail.
- sendCount := 0
-
- for {
- mc.clientMu.Lock()
- rpcLink, err := mc.getLinkWithoutLock(ctx)
- mc.clientMu.Unlock()
-
- var rsp *rpc.Response
-
- if err == nil {
- rsp, err = rpcLink.RetryableRPC(ctx, times, delay, msg)
-
- if err == nil {
- return rsp, err
- }
- }
-
- if sendCount >= amqpRetryDefaultTimes || !isAMQPTransientError(ctx, err) {
- return nil, err
- }
- sendCount++
- // if we get here, recover and try again
- tab.For(ctx).Debug("recovering RPC connection")
-
- _, retryErr := common.Retry(amqpRetryDefaultTimes, amqpRetryDefaultDelay, func() (interface{}, error) {
- ctx, sp := mc.startProducerSpanFromContext(ctx, string(tracing.SpanTryRecover))
- defer sp.End()
-
- if err := mc.recover(ctx); err == nil {
- tab.For(ctx).Debug("recovered RPC connection")
- return nil, nil
- }
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- return nil, common.Retryable(err.Error())
- }
- })
-
- if retryErr != nil {
- tab.For(ctx).Debug("RPC recovering retried, but error was unrecoverable")
- return nil, retryErr
- }
+ return nil, mgmtError{
+ Message: fmt.Sprintf("rpc: failed, status code %d and description: %s", res.Code, res.Description),
+ Resp: res,
}
}
-// returns true if the AMQP error is considered transient
-func isAMQPTransientError(ctx context.Context, err error) bool {
- // always retry on a detach error
- var amqpDetach *amqp.DetachError
- if errors.As(err, &amqpDetach) {
- return true
- }
- // for an AMQP error, only retry depending on the condition
- var amqpErr *amqp.Error
- if errors.As(err, &amqpErr) {
- switch amqpErr.Condition {
- case errorServerBusy, errorTimeout, errorOperationCancelled, errorContainerClose:
- return true
- default:
- tab.For(ctx).Debug(fmt.Sprintf("isAMQPTransientError: condition %s is not transient", amqpErr.Condition))
- return false
- }
- }
- tab.For(ctx).Debug(fmt.Sprintf("isAMQPTransientError: %T is not transient", err))
- return false
-}
-
-func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error) {
+func ReceiveDeferred(ctx context.Context, rpcLink RPCLink, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error) {
ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanReceiveDeferred, Version)
defer span.End()
@@ -214,12 +79,6 @@ func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, seq
"receiver-settle-mode": uint32(backwardsMode), // pick up messages with peek lock
}
- var opts []rpc.LinkOption
- if mc.isSessionFilterSet {
- opts = append(opts, rpc.LinkWithSessionFilter(mc.sessionID))
- values["session-id"] = mc.sessionID
- }
-
msg := &amqp.Message{
ApplicationProperties: map[string]interface{}{
"operation": "com.microsoft:receive-by-sequence-number",
@@ -227,9 +86,8 @@ func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, seq
Value: values,
}
- rsp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second, opts...)
+ rsp, err := doRPC(ctx, "receiveDeferred", rpcLink, msg)
if err != nil {
- tab.For(ctx).Error(err)
return nil, err
}
@@ -286,7 +144,7 @@ func (mc *mgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, seq
return transformedMessages, nil
}
-func (mc *mgmtClient) PeekMessages(ctx context.Context, fromSequenceNumber int64, messageCount int32) ([]*amqp.Message, error) {
+func PeekMessages(ctx context.Context, rpcLink RPCLink, fromSequenceNumber int64, messageCount int32) ([]*amqp.Message, error) {
ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanPeekFromSequenceNumber, Version)
defer span.End()
@@ -306,7 +164,7 @@ func (mc *mgmtClient) PeekMessages(ctx context.Context, fromSequenceNumber int64
msg.ApplicationProperties["server-timeout"] = uint(time.Until(deadline) / time.Millisecond)
}
- rsp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second)
+ rsp, err := doRPC(ctx, "peek", rpcLink, msg)
if err != nil {
tab.For(ctx).Error(err)
return nil, err
@@ -397,7 +255,7 @@ func (mc *mgmtClient) PeekMessages(ctx context.Context, fromSequenceNumber int64
// RenewLocks renews the locks in a single 'com.microsoft:renew-lock' operation.
// NOTE: this function assumes all the messages received on the same link.
-func (mc *mgmtClient) RenewLocks(ctx context.Context, linkName string, lockTokens []amqp.UUID) ([]time.Time, error) {
+func RenewLocks(ctx context.Context, rpcLink RPCLink, linkName string, lockTokens []amqp.UUID) ([]time.Time, error) {
ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanRenewLock, Version)
defer span.End()
@@ -414,7 +272,7 @@ func (mc *mgmtClient) RenewLocks(ctx context.Context, linkName string, lockToken
renewRequestMsg.ApplicationProperties["associated-link-name"] = linkName
}
- response, err := mc.doRPCWithRetry(ctx, renewRequestMsg, 3, 1*time.Second)
+ response, err := doRPC(ctx, "renewlocks", rpcLink, renewRequestMsg)
if err != nil {
tab.For(ctx).Error(err)
@@ -451,7 +309,7 @@ func (mc *mgmtClient) RenewLocks(ctx context.Context, linkName string, lockToken
}
// RenewSessionLocks renews a session lock.
-func (mc *mgmtClient) RenewSessionLock(ctx context.Context, sessionID string) (time.Time, error) {
+func RenewSessionLock(ctx context.Context, rpcLink RPCLink, sessionID string) (time.Time, error) {
body := map[string]interface{}{
"session-id": sessionID,
}
@@ -463,7 +321,7 @@ func (mc *mgmtClient) RenewSessionLock(ctx context.Context, sessionID string) (t
},
}
- resp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second)
+ resp, err := doRPC(ctx, "renewsessionlock", rpcLink, msg)
if err != nil {
return time.Time{}, err
@@ -485,7 +343,7 @@ func (mc *mgmtClient) RenewSessionLock(ctx context.Context, sessionID string) (t
}
// GetSessionState retrieves state associated with the session.
-func (mc *mgmtClient) GetSessionState(ctx context.Context, sessionID string) ([]byte, error) {
+func GetSessionState(ctx context.Context, rpcLink RPCLink, sessionID string) ([]byte, error) {
amqpMsg := &amqp.Message{
Value: map[string]interface{}{
"session-id": sessionID,
@@ -495,7 +353,7 @@ func (mc *mgmtClient) GetSessionState(ctx context.Context, sessionID string) ([]
},
}
- resp, err := mc.doRPCWithRetry(ctx, amqpMsg, 5, 5*time.Second)
+ resp, err := doRPC(ctx, "getsessionstate", rpcLink, amqpMsg)
if err != nil {
return nil, err
@@ -528,7 +386,7 @@ func (mc *mgmtClient) GetSessionState(ctx context.Context, sessionID string) ([]
}
// SetSessionState sets the state associated with the session.
-func (mc *mgmtClient) SetSessionState(ctx context.Context, sessionID string, state []byte) error {
+func SetSessionState(ctx context.Context, rpcLink RPCLink, sessionID string, state []byte) error {
uuid, err := uuid.New()
if err != nil {
@@ -546,7 +404,7 @@ func (mc *mgmtClient) SetSessionState(ctx context.Context, sessionID string, sta
},
}
- resp, err := mc.doRPCWithRetry(ctx, amqpMsg, 5, 5*time.Second)
+ resp, err := doRPC(ctx, "setsessionstate", rpcLink, amqpMsg)
if err != nil {
return err
@@ -562,7 +420,7 @@ func (mc *mgmtClient) SetSessionState(ctx context.Context, sessionID string, sta
// SendDisposition allows you settle a message using the management link, rather than via your
// *amqp.Receiver. Use this if the receiver has been closed/lost or if the message isn't associated
// with a link (ex: deferred messages).
-func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID, state Disposition, propertiesToModify map[string]interface{}) error {
+func SendDisposition(ctx context.Context, rpcLink RPCLink, lockToken *amqp.UUID, state Disposition, propertiesToModify map[string]interface{}) error {
ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanSendDisposition, Version)
defer span.End()
@@ -572,7 +430,6 @@ func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID,
return err
}
- var opts []rpc.LinkOption
value := map[string]interface{}{
"disposition-status": string(state.Status),
"lock-tokens": []amqp.UUID{*lockToken},
@@ -598,7 +455,7 @@ func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID,
}
// no error, then it was successful
- _, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second, opts...)
+ _, err := doRPC(ctx, "senddisposition", rpcLink, msg)
if err != nil {
tab.For(ctx).Error(err)
return err
@@ -609,7 +466,7 @@ func (mc *mgmtClient) SendDisposition(ctx context.Context, lockToken *amqp.UUID,
// ScheduleMessages will send a batch of messages to a Queue, schedule them to be enqueued, and return the sequence numbers
// that can be used to cancel each message.
-func (mc *mgmtClient) ScheduleMessages(ctx context.Context, enqueueTime time.Time, messages ...*amqp.Message) ([]int64, error) {
+func ScheduleMessages(ctx context.Context, rpcLink RPCLink, enqueueTime time.Time, messages []*amqp.Message) ([]int64, error) {
ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanScheduleMessage, Version)
defer span.End()
@@ -677,7 +534,7 @@ func (mc *mgmtClient) ScheduleMessages(ctx context.Context, enqueueTime time.Tim
msg.ApplicationProperties["com.microsoft:server-timeout"] = uint(time.Until(deadline) / time.Millisecond)
}
- resp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second)
+ resp, err := doRPC(ctx, "schedule", rpcLink, msg)
if err != nil {
tab.For(ctx).Error(err)
return nil, err
@@ -704,9 +561,9 @@ func (mc *mgmtClient) ScheduleMessages(ctx context.Context, enqueueTime time.Tim
return nil, NewErrIncorrectType("value", map[string]interface{}{}, resp.Message.Value)
}
-// CancelScheduled allows for removal of messages that have been handed to the Service Bus broker for later delivery,
+// CancelScheduledMessages allows for removal of messages that have been handed to the Service Bus broker for later delivery,
// but have not yet ben enqueued.
-func (mc *mgmtClient) CancelScheduled(ctx context.Context, seq ...int64) error {
+func CancelScheduledMessages(ctx context.Context, rpcLink RPCLink, seq []int64) error {
ctx, span := tracing.StartConsumerSpanFromContext(ctx, tracing.SpanCancelScheduledMessage, Version)
defer span.End()
@@ -723,7 +580,7 @@ func (mc *mgmtClient) CancelScheduled(ctx context.Context, seq ...int64) error {
msg.ApplicationProperties["com.microsoft:server-timeout"] = uint(time.Until(deadline) / time.Millisecond)
}
- resp, err := mc.doRPCWithRetry(ctx, msg, 5, 5*time.Second)
+ resp, err := doRPC(ctx, "cancelscheduled", rpcLink, msg)
if err != nil {
tab.For(ctx).Error(err)
return err
@@ -735,19 +592,3 @@ func (mc *mgmtClient) CancelScheduled(ctx context.Context, seq ...int64) error {
return nil
}
-
-func (mc *mgmtClient) startSpanFromContext(ctx context.Context, operationName string) (context.Context, tab.Spanner) {
- ctx, span := tracing.StartConsumerSpanFromContext(ctx, operationName, Version)
- span.AddAttributes(tab.StringAttribute("message_bus.destination", mc.links.ManagementPath()))
- return ctx, span
-}
-
-func (mc *mgmtClient) startProducerSpanFromContext(ctx context.Context, operationName string) (context.Context, tab.Spanner) {
- ctx, span := tab.StartSpan(ctx, operationName)
- tracing.ApplyComponentInfo(span, Version)
- span.AddAttributes(
- tab.StringAttribute("span.kind", "producer"),
- tab.StringAttribute("message_bus.destination", mc.links.ManagementPath()),
- )
- return ctx, span
-}
diff --git a/sdk/messaging/azservicebus/internal/namespace.go b/sdk/messaging/azservicebus/internal/namespace.go
index 806283596ee1..214d2d49963b 100644
--- a/sdk/messaging/azservicebus/internal/namespace.go
+++ b/sdk/messaging/azservicebus/internal/namespace.go
@@ -13,12 +13,13 @@ import (
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
+ "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/sbauth"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/auth"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/cbs"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/conn"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/rpc"
"github.com/Azure/go-amqp"
"github.com/devigned/tab"
)
@@ -31,6 +32,13 @@ type (
// Namespace is an abstraction over an amqp.Client, allowing us to hold onto a single
// instance of a connection per ServiceBusClient.
Namespace struct {
+ // NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens
+ // is just to make it the first value in the struct
+ // See:
+ // Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG
+ // PR: https://github.com/Azure/azure-sdk-for-go/pull/16847
+ connID uint64
+
FQDN string
TokenProvider *sbauth.TokenProvider
tlsConfig *tls.Config
@@ -38,12 +46,10 @@ type (
newWebSocketConn func(ctx context.Context, args NewWebSocketConnArgs) (net.Conn, error)
- baseRetrier Retrier
-
- clientMu sync.Mutex
- clientRevision uint64
- client *amqp.Client
+ retryOptions utils.RetryOptions
+ clientMu sync.RWMutex
+ client *amqp.Client
negotiateClaimMu sync.Mutex
}
@@ -60,14 +66,9 @@ type NamespaceWithNewAMQPLinks interface {
type NamespaceForAMQPLinks interface {
NegotiateClaim(ctx context.Context, entityPath string) (func() <-chan struct{}, error)
NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64, error)
- NewMgmtClient(ctx context.Context, links AMQPLinks) (MgmtClient, error)
+ NewRPCLink(ctx context.Context, managementPath string) (RPCLink, error)
GetEntityAudience(entityPath string) string
- Recover(ctx context.Context, clientRevision uint64) error
-}
-
-// NamespaceForAMQPLinks is the Namespace surface needed for the *MgmtClient.
-type NamespaceForMgmtClient interface {
- NewRPCLink(ctx context.Context, managementPath string) (*rpc.Link, error)
+ Recover(ctx context.Context, clientRevision uint64) (bool, error)
}
// NamespaceWithConnectionString configures a namespace with the information provided in a Service Bus connection string
@@ -125,9 +126,9 @@ func NamespaceWithWebSocket(newWebSocketConn func(ctx context.Context, args NewW
}
}
-// NamespacesWithTokenCredential sets the token provider on the namespace
+// NamespaceWithTokenCredential sets the token provider on the namespace
// fullyQualifiedNamespace is the Service Bus namespace name (ex: myservicebus.servicebus.windows.net)
-func NamespacesWithTokenCredential(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential) NamespaceOption {
+func NamespaceWithTokenCredential(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential) NamespaceOption {
return func(ns *Namespace) error {
ns.TokenProvider = sbauth.NewTokenProvider(tokenCredential)
ns.FQDN = fullyQualifiedNamespace
@@ -135,22 +136,16 @@ func NamespacesWithTokenCredential(fullyQualifiedNamespace string, tokenCredenti
}
}
+func NamespaceWithRetryOptions(retryOptions utils.RetryOptions) NamespaceOption {
+ return func(ns *Namespace) error {
+ ns.retryOptions = retryOptions
+ return nil
+ }
+}
+
// NewNamespace creates a new namespace configured through NamespaceOption(s)
func NewNamespace(opts ...NamespaceOption) (*Namespace, error) {
- ns := &Namespace{
- baseRetrier: NewBackoffRetrier(struct {
- MaxRetries int
- Factor float64
- Jitter bool
- Min time.Duration
- Max time.Duration
- }{
- Factor: 2,
- Min: time.Second,
- Max: time.Minute,
- MaxRetries: 10,
- }),
- }
+ ns := &Namespace{}
for _, opt := range opts {
err := opt(ns)
@@ -199,8 +194,9 @@ func (ns *Namespace) newClient(ctx context.Context) (*amqp.Client, error) {
}
// NewAMQPSession creates a new AMQP session with the internally cached *amqp.Client.
+// Returns a closeable AMQP session and the current client revision.
func (ns *Namespace) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64, error) {
- client, clientRevision, err := ns.getAMQPClientImpl(ctx)
+ client, clientRevision, err := ns.GetAMQPClientImpl(ctx)
if err != nil {
return nil, 0, err
@@ -215,26 +211,21 @@ func (ns *Namespace) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uin
return session, clientRevision, err
}
-// NewMgmtClient creates a new management client with the internally cached *amqp.Client.
-func (ns *Namespace) NewMgmtClient(ctx context.Context, l AMQPLinks) (MgmtClient, error) {
- return newMgmtClient(ctx, l, ns)
-}
-
// NewRPCLink creates a new amqp-common *rpc.Link with the internally cached *amqp.Client.
-func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (*rpc.Link, error) {
- client, _, err := ns.getAMQPClientImpl(ctx)
+func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (RPCLink, error) {
+ client, _, err := ns.GetAMQPClientImpl(ctx)
if err != nil {
return nil, err
}
- return rpc.NewLink(client, managementPath)
+ return NewRPCLink(client, managementPath)
}
// NewAMQPLinks creates an AMQPLinks struct, which groups together the commonly needed links for
// working with Service Bus.
func (ns *Namespace) NewAMQPLinks(entityPath string, createLinkFunc CreateLinkFunc) AMQPLinks {
- return newAMQPLinks(ns, entityPath, ns.baseRetrier, createLinkFunc)
+ return NewAMQPLinks(ns, entityPath, createLinkFunc)
}
// Close closes the current cached client.
@@ -249,9 +240,11 @@ func (ns *Namespace) Close(ctx context.Context) error {
return nil
}
-// Recover destroys the currently held client and recreates it.
-// clientRevision being nil will recover without a revision check.
-func (ns *Namespace) Recover(ctx context.Context, clientRevision uint64) error {
+// Recover destroys the currently held AMQP connection and recreates it, if needed.
+// If a new is actually created (rather than just cached) then the returned bool
+// will be true. Any links that were created from the original connection will need to
+// be recreated.
+func (ns *Namespace) Recover(ctx context.Context, theirConnID uint64) (bool, error) {
ns.clientMu.Lock()
defer ns.clientMu.Unlock()
@@ -259,13 +252,13 @@ func (ns *Namespace) Recover(ctx context.Context, clientRevision uint64) error {
defer span.End()
span.AddAttributes(
- tab.Int64Attribute("revision", int64(ns.clientRevision)),
- tab.Int64Attribute("requested", int64(clientRevision)))
+ tab.Int64Attribute("connID", int64(ns.connID)),
+ tab.Int64Attribute("theirConnID", int64(theirConnID)))
- if ns.clientRevision > clientRevision {
- span.Logger().Info(fmt.Sprintf("Skipping recovery, already recovered: %d vs %d", ns.clientRevision, clientRevision))
+ if ns.connID != theirConnID {
+ log.Writef(EventConn, "Skipping connection recovery, already recovered: %d vs %d", ns.connID, theirConnID)
// we've already recovered since the client last tried.
- return nil
+ return false, nil
}
if ns.client != nil {
@@ -273,23 +266,20 @@ func (ns *Namespace) Recover(ctx context.Context, clientRevision uint64) error {
ns.client = nil
// the error on close isn't critical
- go func() {
- span.Logger().Info(fmt.Sprintf("Closing old client (client:%d,passed in:%d)", ns.clientRevision, clientRevision))
- err := oldClient.Close()
- tab.For(ctx).Error(err)
- }()
+ _ = oldClient.Close()
}
var err error
- span.Logger().Info(fmt.Sprintf("Creating a new client (client:%d,passed in:%d)", ns.clientRevision, clientRevision))
+ log.Writef(EventConn, "Creating a new client (rev:%d)", ns.connID)
ns.client, err = ns.newClient(ctx)
- if err == nil {
- span.AddAttributes(tab.Int64Attribute("newcr", int64(ns.clientRevision)))
- ns.clientRevision++
+ if err != nil {
+ return false, err
}
- return err
+ ns.connID++
+ log.Writef(EventConn, "New client created, (rev: %d)", ns.connID)
+ return true, nil
}
// negotiateClaim performs initial authentication and starts periodic refresh of credentials.
@@ -298,7 +288,7 @@ func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (fun
return ns.startNegotiateClaimRenewer(ctx,
entityPath,
cbs.NegotiateClaim,
- ns.getAMQPClientImpl,
+ ns.GetAMQPClientImpl,
nextClaimRefreshDuration)
}
@@ -309,61 +299,52 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context,
nextClaimRefreshDurationFn func(expirationTime time.Time, currentTime time.Time) time.Duration) (func() <-chan struct{}, error) {
audience := ns.GetEntityAudience(entityPath)
- refreshClaim := func() (time.Time, error) {
- retrier := ns.baseRetrier.Copy()
+ refreshClaim := func(ctx context.Context) (time.Time, error) {
+ log.Writef(EventAuth, "(%s) refreshing claim", entityPath)
+ ctx, span := ns.startSpanFromContext(ctx, tracing.SpanNegotiateClaim)
+ defer span.End()
- var lastErr error
- var expiration time.Time
+ amqpClient, clientRevision, err := nsGetAMQPClientImpl(ctx)
- for retrier.Try(ctx) {
- expiration, lastErr = func() (time.Time, error) {
- ctx, span := ns.startSpanFromContext(ctx, tracing.SpanNegotiateClaim)
- defer span.End()
+ if err != nil {
+ return time.Time{}, err
+ }
- amqpClient, clientRevision, err := nsGetAMQPClientImpl(ctx)
+ token, expiration, err := ns.TokenProvider.GetTokenAsTokenProvider(audience)
- if err != nil {
- span.Logger().Error(err)
- return time.Time{}, err
- }
+ if err != nil {
+ log.Writef(EventAuth, "(%s) negotiate claim, failed getting token: %s", entityPath, err.Error())
+ return time.Time{}, err
+ }
- token, expiration, err := ns.TokenProvider.GetTokenAsTokenProvider(audience)
+ log.Writef(EventAuth, "(%s) negotiate claim, token expires on %s", entityPath, expiration.Format(time.RFC3339))
- if err != nil {
- span.Logger().Error(err)
- return time.Time{}, err
- }
+ // You're not allowed to have multiple $cbs links open in a single connection.
+ // The current cbs.NegotiateClaim implementation automatically creates and shuts
+ // down it's own link so we have to guard against that here.
+ ns.negotiateClaimMu.Lock()
+ err = cbsNegotiateClaim(ctx, audience, amqpClient, token)
+ ns.negotiateClaimMu.Unlock()
- // You're not allowed to have multiple $cbs links open in a single connection.
- // The current cbs.NegotiateClaim implementation automatically creates and shuts
- // down it's own link so we have to guard against that here.
- ns.negotiateClaimMu.Lock()
- err = cbsNegotiateClaim(ctx, audience, amqpClient, token)
- ns.negotiateClaimMu.Unlock()
-
- if err != nil {
- if shouldRecreateConnection(ctx, err) {
- if err := ns.Recover(ctx, clientRevision); err != nil {
- span.Logger().Error(fmt.Errorf("connection recovery failed: %w", err))
- }
- }
+ sbe := GetSBErrInfo(err)
- span.Logger().Error(err)
- return time.Time{}, err
+ if sbe != nil {
+ // Note we only handle connection recovery here since (currently)
+ // the negotiateClaim code creates it's own link each time.
+ if sbe.RecoveryKind == RecoveryKindConn {
+ if _, err := ns.Recover(ctx, clientRevision); err != nil {
+ log.Writef(EventAuth, "(%s) negotiate claim, failed in connection recovery: %s", entityPath, err)
}
-
- return expiration, nil
- }()
-
- if lastErr == nil {
- break
}
+
+ log.Writef(EventAuth, "(%s) negotiate claim, failed: %s", entityPath, err.Error())
+ return time.Time{}, err
}
- return expiration, lastErr
+ return expiration, nil
}
- expiresOn, err := refreshClaim()
+ expiresOn, err := refreshClaim(ctx)
if err != nil {
return nil, err
@@ -373,15 +354,38 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context,
refreshCtx, cancel := context.WithCancel(context.Background())
go func() {
+ TokenRefreshLoop:
for {
+ nextClaimAt := nextClaimRefreshDurationFn(expiresOn, time.Now())
+
+ log.Writef(EventAuth, "(%s) next refresh in %s", entityPath, nextClaimAt)
+
select {
case <-refreshCtx.Done():
return
- case <-time.After(nextClaimRefreshDurationFn(expiresOn, time.Now())):
- tmpExpiresOn, err := refreshClaim() // logging will report the error for now
+ case <-time.After(nextClaimAt):
+ for {
+ err := utils.Retry(refreshCtx, "claimrefresh", func(ctx context.Context, args *utils.RetryFnArgs) error {
+ tmpExpiresOn, err := refreshClaim(ctx)
- if err == nil {
- expiresOn = tmpExpiresOn
+ if err != nil {
+ return err
+ }
+
+ expiresOn = tmpExpiresOn
+ return nil
+ }, IsFatalSBError, ns.retryOptions)
+
+ if err == nil {
+ break
+ }
+
+ // if we fail our retries _and_ we've exceeded the window where our token would have
+ // been good we can just stop.
+ if time.Since(expiresOn) <= 0 {
+ log.Writef(EventAuth, "[%s] token has expired, stopping refresh loop", entityPath)
+ break TokenRefreshLoop
+ }
}
}
}
@@ -395,26 +399,22 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context,
return cancelRefresh, nil
}
-func (ns *Namespace) getAMQPClientImpl(ctx context.Context) (*amqp.Client, uint64, error) {
+func (ns *Namespace) GetAMQPClientImpl(ctx context.Context) (*amqp.Client, uint64, error) {
ns.clientMu.Lock()
defer ns.clientMu.Unlock()
if ns.client != nil {
- return ns.client, ns.clientRevision, nil
+ return ns.client, ns.connID, nil
}
var err error
- retrier := ns.baseRetrier.Copy()
-
- for retrier.Try(ctx) {
- ns.client, err = ns.newClient(ctx)
+ ns.client, err = ns.newClient(ctx)
- if err == nil {
- break
- }
+ if err == nil {
+ ns.connID++
}
- return ns.client, ns.clientRevision, err
+ return ns.client, ns.connID, err
}
func (ns *Namespace) getWSSHostURI() string {
diff --git a/sdk/messaging/azservicebus/internal/namespace_test.go b/sdk/messaging/azservicebus/internal/namespace_test.go
index dcdcf8449b78..b7bb69bd738b 100644
--- a/sdk/messaging/azservicebus/internal/namespace_test.go
+++ b/sdk/messaging/azservicebus/internal/namespace_test.go
@@ -17,39 +17,6 @@ import (
"github.com/stretchr/testify/require"
)
-// implements `Retrier` interface.
-type fakeRetrier struct {
- tryCalled int
- copyCalled int
-}
-
-func (r *fakeRetrier) Copy() Retrier {
- r.copyCalled++
-
- // NOTE: purposefully not making a copy so I can keep track of the
- // try/copy counts.
- return r
-}
-
-func (r *fakeRetrier) Exhausted() bool {
- return false
-}
-
-func (r *fakeRetrier) Try(ctx context.Context) bool {
- select {
- case <-ctx.Done():
- return false
- default:
- }
-
- r.tryCalled++
- return true
-}
-
-func (r *fakeRetrier) CurrentTry() int {
- return r.tryCalled
-}
-
type fakeTokenCredential struct {
azcore.TokenCredential
expires time.Time
@@ -62,12 +29,10 @@ func (ftc *fakeTokenCredential) GetToken(ctx context.Context, options policy.Tok
}
func TestNamespaceNegotiateClaim(t *testing.T) {
- retrier := &fakeRetrier{}
-
expires := time.Now().Add(24 * time.Hour)
ns := &Namespace{
- baseRetrier: retrier,
+ retryOptions: retryOptionsOnlyOnce,
TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: expires}),
}
@@ -106,17 +71,13 @@ func TestNamespaceNegotiateClaim(t *testing.T) {
require.EqualValues(t, getAMQPClientCalled, 1)
require.EqualValues(t, 1, cbsNegotiateClaimCalled)
- require.EqualValues(t, 1, retrier.copyCalled)
- require.EqualValues(t, 1, retrier.tryCalled)
}
func TestNamespaceNegotiateClaimRenewal(t *testing.T) {
- retrier := &fakeRetrier{}
-
expires := time.Now().Add(24 * time.Hour)
ns := &Namespace{
- baseRetrier: retrier,
+ retryOptions: retryOptionsOnlyOnce,
TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: expires}),
}
@@ -163,8 +124,6 @@ func TestNamespaceNegotiateClaimRenewal(t *testing.T) {
require.GreaterOrEqual(t, getAMQPClientCalled, 2+1) // that last +1 is when we blocked to prevent us renewing too much for our test!
- require.EqualValues(t, 3, retrier.copyCalled)
- require.EqualValues(t, 3, retrier.tryCalled)
require.EqualValues(t, 2, nextRefreshDurationChecks)
require.EqualValues(t, 2, cbsNegotiateClaimCalled)
@@ -175,7 +134,6 @@ func TestNamespaceNegotiateClaimRenewal(t *testing.T) {
func TestNamespaceNegotiateClaimFailsToGetClient(t *testing.T) {
ns := &Namespace{
- baseRetrier: noRetryRetrier.Copy(),
TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: time.Now()}),
}
@@ -197,7 +155,6 @@ func TestNamespaceNegotiateClaimFailsToGetClient(t *testing.T) {
func TestNamespaceNegotiateClaimFails(t *testing.T) {
ns := &Namespace{
- baseRetrier: noRetryRetrier.Copy(),
TokenProvider: sbauth.NewTokenProvider(&fakeTokenCredential{expires: time.Now()}),
}
@@ -232,13 +189,3 @@ func TestNamespaceNextClaimRefreshDuration(t *testing.T) {
require.EqualValues(t, 3*time.Minute, nextClaimRefreshDuration(now.Add(3*time.Minute+clockDrift), now))
}
-
-var noRetryRetrier = NewBackoffRetrier(struct {
- MaxRetries int
- Factor float64
- Jitter bool
- Min time.Duration
- Max time.Duration
-}{
- MaxRetries: 0,
-})
diff --git a/sdk/messaging/azservicebus/internal/retrier.go b/sdk/messaging/azservicebus/internal/retrier.go
deleted file mode 100644
index 700abfa3ef61..000000000000
--- a/sdk/messaging/azservicebus/internal/retrier.go
+++ /dev/null
@@ -1,268 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package internal
-
-import (
- "context"
- "math"
- "math/rand"
- "time"
-
- "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
- "github.com/jpillora/backoff"
-)
-
-// A retrier that allows you to do a basic for loop and get backoff
-// and retry limits. See `Try` for more details on how to use it.
-type Retrier interface {
- // Copies the retrier. Retriers are stateful and must be copied
- // before starting a set of retries.
- Copy() Retrier
-
- // Exhausted is true if the retries were exhausted.
- Exhausted() bool
-
- // CurrentTry is the current try (0 for the first run before retries)
- CurrentTry() int
-
- // Try marks an attempt to call (first call to Try() does not sleep).
- // Will return false if the `ctx` is cancelled or if we exhaust our retries.
- //
- // rp := RetryPolicy{Backoff:defaultBackoffPolicy, MaxRetries:5}
- //
- // for rp.Try(ctx) {
- //
- // }
- //
- // if rp.Cancelled() || rp.Exhausted() {
- // // no more retries needed
- // }
- //
- Try(ctx context.Context) bool
-}
-
-// Encapsulates a backoff policy, which allows you to configure the amount of
-// time in between retries as well as the maximum retries allowed (via MaxRetries)
-// NOTE: this should be copied by the caller as it is stateful.
-type backoffRetrier struct {
- backoff backoff.Backoff
- MaxRetries int
-
- tries int
-}
-
-// BackoffRetrierParams are parameters for NewBackoffRetrier.
-type BackoffRetrierParams struct {
- // MaxRetries is the maximum number of tries (after the first attempt)
- // that are allowed.
- MaxRetries int
- // Factor is the multiplying factor for each increment step
- Factor float64
- // Jitter eases contention by randomizing backoff steps
- Jitter bool
- // Min and Max are the minimum and maximum values of the counter
- Min, Max time.Duration
-}
-
-// NewBackoffRetrier creates a retrier that allows for configurable
-// min/max times, jitter and maximum retries.
-func NewBackoffRetrier(params BackoffRetrierParams) Retrier {
- return &backoffRetrier{
- backoff: backoff.Backoff{
- Factor: params.Factor,
- Jitter: params.Jitter,
- Min: params.Min,
- Max: params.Max,
- },
- MaxRetries: params.MaxRetries,
- }
-}
-
-// Copies the backoff retrier since it's stateful.
-func (rp *backoffRetrier) Copy() Retrier {
- copy := *rp
- return ©
-}
-
-// Exhausted is true if all the retries have been used.
-func (rp *backoffRetrier) Exhausted() bool {
- return rp.tries > rp.MaxRetries
-}
-
-// CurrentTry is the current try number (0 for the first run before retries)
-func (rp *backoffRetrier) CurrentTry() int {
- return rp.tries
-}
-
-// Try marks an attempt to call (first call to Try() does not sleep).
-// Will return false if the `ctx` is cancelled or if we exhaust our retries.
-//
-// rp := RetryPolicy{Backoff:defaultBackoffPolicy, MaxRetries:5}
-//
-// for rp.Try(ctx) {
-//
-// }
-//
-// if rp.Cancelled() || rp.Exhausted() {
-// // no more retries needed
-// }
-//
-func (rp *backoffRetrier) Try(ctx context.Context) bool {
- defer func() { rp.tries++ }()
-
- select {
- case <-ctx.Done():
- return false
- default:
- }
-
- if rp.tries == 0 {
- // first 'try' is always free
- return true
- }
-
- if rp.Exhausted() {
- return false
- }
-
- select {
- case <-time.After(rp.backoff.Duration()):
- return true
- case <-ctx.Done():
- return false
- }
-}
-
-type RetryFnArgs struct {
- I int32
- // LastErr is the returned error from the previous loop.
- // If you have potentially expensive
- LastErr error
-
- resetAttempts bool
-}
-
-// ResetAttempts causes the current retry attempt number to be reset
-// in time for the next recovery (should we fail).
-// NOTE: Use of this should be pretty rare, it's really only needed when you have
-// a situation like Receiver.ReceiveMessages() that can recovery but intentionally
-// does not return.
-func (rf *RetryFnArgs) ResetAttempts() {
- rf.resetAttempts = true
-}
-
-// Retry runs a standard retry loop. It executes your passed in fn as the body of the loop.
-// 'isFatal' can be nil, and defaults to just checking that ServiceBusError(err).recoveryKind != recoveryKindNonRetriable.
-// It returns if it exceeds the number of configured retry options or if 'isFatal' returns true.
-func Retry(ctx context.Context, name string, fn func(ctx context.Context, args *RetryFnArgs) error, isFatalFn func(err error) bool, o RetryOptions) error {
- var ro RetryOptions = o
- setDefaults(&ro)
-
- var err error
-
- for i := int32(0); i <= ro.MaxRetries; i++ {
- if i > 0 {
- sleep := calcDelay(ro, i)
- log.Writef(EventRetry, "(%s) Attempt %d sleeping for %s", name, i, sleep)
- time.Sleep(sleep)
- }
-
- args := RetryFnArgs{
- I: i,
- LastErr: err,
- }
- err = fn(ctx, &args)
-
- if args.resetAttempts {
- log.Writef(EventRetry, "(%s) Resetting attempts", name)
- i = int32(0)
- }
-
- if err != nil {
- if isFatalFn != nil {
- if isFatalFn(err) {
- log.Writef(EventRetry, "(%s) Attempt %d returned non-retryable error: %s", name, i, err.Error())
- return err
- } else {
- log.Writef(EventRetry, "(%s) Attempt %d returned retryable error: %s", name, i, err.Error())
- }
- } else {
- recoveryKind := ToSBE(ctx, err).RecoveryKind
- if recoveryKind == RecoveryKindFatal {
- log.Writef(EventRetry, "(%s) Attempt %d returned non-retryable error: %s", name, i, err.Error())
- return err
- } else {
- log.Writef(EventRetry, "(%s) Attempt %d returned retryable error with recovery kind %s: %s", name, i, recoveryKind, err.Error())
- }
- }
-
- continue
- }
-
- return nil
- }
-
- return err
-}
-
-// RetryOptions represent the options for retries.
-type RetryOptions struct {
- // MaxRetries specifies the maximum number of attempts a failed operation will be retried
- // before producing an error.
- // The default value is three. A value less than zero means one try and no retries.
- MaxRetries int32
-
- // RetryDelay specifies the initial amount of delay to use before retrying an operation.
- // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay.
- // The default value is four seconds. A value less than zero means no delay between retries.
- RetryDelay time.Duration
-
- // MaxRetryDelay specifies the maximum delay allowed before retrying an operation.
- // Typically the value is greater than or equal to the value specified in RetryDelay.
- // The default Value is 120 seconds. A value less than zero means there is no cap.
- MaxRetryDelay time.Duration
-}
-
-func setDefaults(o *RetryOptions) {
- if o.MaxRetries == 0 {
- o.MaxRetries = 3
- } else if o.MaxRetries < 0 {
- o.MaxRetries = 0
- }
- if o.MaxRetryDelay == 0 {
- o.MaxRetryDelay = 120 * time.Second
- } else if o.MaxRetryDelay < 0 {
- // not really an unlimited cap, but sufficiently large enough to be considered as such
- o.MaxRetryDelay = math.MaxInt64
- }
- if o.RetryDelay == 0 {
- o.RetryDelay = 4 * time.Second
- } else if o.RetryDelay < 0 {
- o.RetryDelay = 0
- }
-}
-
-// (adapted from from azcore/policy_retry)
-func calcDelay(o RetryOptions, try int32) time.Duration {
- if try == 0 {
- return 0
- }
-
- pow := func(number int64, exponent int32) int64 { // pow is nested helper function
- var result int64 = 1
- for n := int32(0); n < exponent; n++ {
- result *= number
- }
- return result
- }
-
- delay := time.Duration(pow(2, try)-1) * o.RetryDelay
-
- // Introduce some jitter: [0.0, 1.0) / 2 = [0.0, 0.5) + 0.8 = [0.8, 1.3)
- delay = time.Duration(delay.Seconds() * (rand.Float64()/2 + 0.8) * float64(time.Second)) // NOTE: We want math/rand; not crypto/rand
- if delay > o.MaxRetryDelay {
- delay = o.MaxRetryDelay
- }
- return delay
-}
diff --git a/sdk/messaging/azservicebus/internal/retrier_test.go b/sdk/messaging/azservicebus/internal/retrier_test.go
deleted file mode 100644
index 4cfdc52818c7..000000000000
--- a/sdk/messaging/azservicebus/internal/retrier_test.go
+++ /dev/null
@@ -1,174 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package internal
-
-import (
- "context"
- "errors"
- "math"
- "testing"
- "time"
-
- "github.com/Azure/go-amqp"
- "github.com/stretchr/testify/require"
-)
-
-func TestRetrier(t *testing.T) {
- t.Run("Basic", func(t *testing.T) {
- retrier := NewBackoffRetrier(struct {
- MaxRetries int
- Factor float64
- Jitter bool
- Min time.Duration
- Max time.Duration
- }{
- MaxRetries: 5,
- Factor: 0,
- })
-
- require := require.New(t)
-
- // first iteration is always free (ie, that's not the
- // retry part)
- require.True(retrier.Try(context.Background()))
-
- // now we're doing retries
- require.True(retrier.Try(context.Background()))
- require.True(retrier.Try(context.Background()))
- require.True(retrier.Try(context.Background()))
- require.True(retrier.Try(context.Background()))
- require.True(retrier.Try(context.Background()))
-
- // and it's the 6th retry that fails since we've exhausted
- // the retries we're allotted.
- require.False(retrier.Try(context.Background()))
- require.True(retrier.Exhausted())
- })
-
- t.Run("Cancellation", func(t *testing.T) {
- retrier := NewBackoffRetrier(struct {
- MaxRetries int
- Factor float64
- Jitter bool
- Min time.Duration
- Max time.Duration
- }{
- MaxRetries: 5,
- Factor: 0,
- })
-
- // first iteration is always free (ie, that's not the
- // retry part)
- cancelledContext, cancel := context.WithCancel(context.Background())
- cancel()
- require.False(t, retrier.Try(cancelledContext))
- })
-}
-
-var fastRetryOptions = RetryOptions{
- // note: omitting MaxRetries just to give a sanity check that
- // we do setDefaults() before we run.
- RetryDelay: time.Millisecond,
- MaxRetryDelay: time.Millisecond,
-}
-
-func TestRetryBasic(t *testing.T) {
- called := 0
-
- err := Retry(context.Background(), "retrytest", func(ctx context.Context, args *RetryFnArgs) error {
- require.NotNil(t, args)
- require.NotNil(t, ctx)
-
- called++
-
- return &amqp.DetachError{}
- }, nil, fastRetryOptions)
-
- var de *amqp.DetachError
- require.ErrorAs(t, err, &de)
- require.EqualValues(t, 4, called)
-}
-
-func TestRetryWithFatalError(t *testing.T) {
- called := 0
-
- err := Retry(context.Background(), "retrytest", func(ctx context.Context, args *RetryFnArgs) error {
- require.NotNil(t, args)
- require.NotNil(t, ctx)
-
- called++
-
- return &amqp.Error{
- // this is just a basic non-recoverable situation - typically happens if the
- // lock period expires.
- Condition: amqp.ErrorCondition("com.microsoft:message-lock-lost"),
- }
- }, nil, fastRetryOptions)
-
- // fatal error so we only called the function once
- require.EqualValues(t, 1, called)
-
- var testErr *amqp.Error
-
- require.ErrorAs(t, err, &testErr)
- require.EqualValues(t, "com.microsoft:message-lock-lost", testErr.Condition)
-}
-
-func TestRetryCustomIsFatal(t *testing.T) {
- called := 0
- var totallyHarmlessErrorAsFatal = errors.New("I'm supposed to be harmless but the custom error handler is going to make me fatal")
- var isFatalErr error
-
- err := Retry(context.Background(), "retrytest", func(ctx context.Context, args *RetryFnArgs) error {
- require.NotNil(t, args)
- require.NotNil(t, ctx)
-
- called++
-
- return totallyHarmlessErrorAsFatal
- }, func(err error) bool {
- require.Nil(t, isFatalErr, "should only get called once")
- isFatalErr = err
- return true
- }, fastRetryOptions)
-
- // fatal error so we only called the function once
- require.EqualValues(t, 1, called)
-
- require.ErrorIs(t, err, totallyHarmlessErrorAsFatal)
- require.ErrorIs(t, isFatalErr, totallyHarmlessErrorAsFatal)
-}
-
-func TestRetryDefaults(t *testing.T) {
- ro := RetryOptions{}
- setDefaults(&ro)
-
- require.EqualValues(t, 3, ro.MaxRetries)
- require.EqualValues(t, 4*time.Second, ro.RetryDelay)
- require.EqualValues(t, 2*time.Minute, ro.MaxRetryDelay)
-
- // this is an interesting default. Anything < 0 basically
- // causes the max delay to be "infinite"
- ro.MaxRetryDelay = -1
- // whereas this just normalizes to '0'
- ro.RetryDelay = -1
- ro.MaxRetries = -1
- setDefaults(&ro)
- require.EqualValues(t, time.Duration(math.MaxInt64), ro.MaxRetryDelay)
- require.EqualValues(t, 0, ro.MaxRetries)
- require.EqualValues(t, time.Duration(0), ro.RetryDelay)
-}
-
-func TestCalcDelay(t *testing.T) {
- // calcDelay introduces some jitter, automatically.
- ro := RetryOptions{}
- setDefaults(&ro)
- d := calcDelay(ro, 0)
- require.EqualValues(t, 0, d)
-
- // by default the first calc is 2^attempt
- d = calcDelay(ro, 1)
- require.LessOrEqual(t, d, 6*time.Second)
- require.GreaterOrEqual(t, d, time.Second)
-}
diff --git a/sdk/messaging/azservicebus/internal/rpc.go b/sdk/messaging/azservicebus/internal/rpc.go
new file mode 100644
index 000000000000..5875dc6b5c48
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/rpc.go
@@ -0,0 +1,445 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package internal
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/devigned/tab"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
+ "github.com/Azure/go-amqp"
+)
+
+const (
+ replyPostfix = "-reply-to-"
+ statusCodeKey = "status-code"
+ descriptionKey = "status-description"
+ defaultReceiverCredits = 1000
+)
+
+type (
+ // rpcLink is the bidirectional communication structure used for CBS negotiation
+ rpcLink struct {
+ session *amqp.Session
+
+ receiver amqpReceiver // *amqp.Receiver
+ sender amqpSender // *amqp.Sender
+
+ clientAddress string
+ sessionID *string
+ id string
+
+ responseMu sync.Mutex
+ startResponseRouterOnce *sync.Once
+ responseMap map[string]chan rpcResponse
+ broadcastErr error // the error that caused the responseMap to be nil'd
+
+ // for unit tests
+ uuidNewV4 func() (uuid.UUID, error)
+ messageAccept func(ctx context.Context, message *amqp.Message) error
+ }
+
+ // RPCResponse is the simplified response structure from an RPC like call
+ RPCResponse struct {
+ Code int
+ Description string
+ Message *amqp.Message
+ }
+
+ // RPCLinkOption provides a way to customize the construction of a Link
+ RPCLinkOption func(link *rpcLink) error
+
+ rpcResponse struct {
+ message *amqp.Message
+ err error
+ }
+
+ // Actually: *amqp.Receiver
+ amqpReceiver interface {
+ Receive(ctx context.Context) (*amqp.Message, error)
+ Close(ctx context.Context) error
+ }
+
+ amqpSender interface {
+ Send(ctx context.Context, msg *amqp.Message) error
+ Close(ctx context.Context) error
+ }
+)
+
+// NewLink will build a new request response link
+func NewRPCLink(conn *amqp.Client, address string, opts ...RPCLinkOption) (*rpcLink, error) {
+ authSession, err := conn.NewSession()
+ if err != nil {
+ return nil, err
+ }
+
+ return newRPCLinkWithSession(authSession, address, opts...)
+}
+
+// NewLinkWithSession will build a new request response link, but will reuse an existing AMQP session
+func newRPCLinkWithSession(session *amqp.Session, address string, opts ...RPCLinkOption) (*rpcLink, error) {
+ linkID, err := uuid.New()
+ if err != nil {
+ return nil, err
+ }
+
+ id := linkID.String()
+ link := &rpcLink{
+ session: session,
+ clientAddress: strings.Replace("$", "", address, -1) + replyPostfix + id,
+ id: id,
+
+ uuidNewV4: uuid.New,
+ responseMap: map[string]chan rpcResponse{},
+ startResponseRouterOnce: &sync.Once{},
+ }
+
+ for _, opt := range opts {
+ if err := opt(link); err != nil {
+ return nil, err
+ }
+ }
+
+ sender, err := session.NewSender(
+ amqp.LinkTargetAddress(address),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ receiverOpts := []amqp.LinkOption{
+ amqp.LinkSourceAddress(address),
+ amqp.LinkTargetAddress(link.clientAddress),
+ amqp.LinkCredit(defaultReceiverCredits),
+ }
+
+ if link.sessionID != nil {
+ const name = "com.microsoft:session-filter"
+ const code = uint64(0x00000137000000C)
+ if link.sessionID == nil {
+ receiverOpts = append(receiverOpts, amqp.LinkSourceFilter(name, code, nil))
+ } else {
+ receiverOpts = append(receiverOpts, amqp.LinkSourceFilter(name, code, link.sessionID))
+ }
+ }
+
+ receiver, err := session.NewReceiver(receiverOpts...)
+ if err != nil {
+ // make sure we close the sender
+ clsCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ _ = sender.Close(clsCtx)
+ return nil, err
+ }
+
+ link.sender = sender
+ link.receiver = receiver
+ link.messageAccept = receiver.AcceptMessage
+
+ return link, nil
+}
+
+// startResponseRouter is responsible for taking any messages received on the 'response'
+// link and forwarding it to the proper channel. The channel is being select'd by the
+// original `RPC` call.
+func (l *rpcLink) startResponseRouter() {
+ for {
+ res, err := l.receiver.Receive(context.Background())
+
+ // You'll see this when the link is shutting down (either
+ // service-initiated via 'detach' or a user-initiated shutdown)
+ if isClosedError(err) {
+ l.broadcastError(err)
+ break
+ }
+
+ // I don't believe this should happen. The JS version of this same code
+ // ignores errors as well since responses should always be correlated
+ // to actual send requests. So this is just here for completeness.
+ if res == nil {
+ continue
+ }
+
+ autogenMessageId, ok := res.Properties.CorrelationID.(string)
+
+ if !ok {
+ // TODO: it'd be good to track these in some way. We don't have a good way to
+ // forward this on at this point.
+ continue
+ }
+
+ ch := l.deleteChannelFromMap(autogenMessageId)
+
+ if ch != nil {
+ ch <- rpcResponse{message: res, err: err}
+ }
+ }
+}
+
+// RPC sends a request and waits on a response for that request
+func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error) {
+ l.startResponseRouterOnce.Do(func() {
+ go l.startResponseRouter()
+ })
+
+ copiedMessage, messageID, err := addMessageID(msg, l.uuidNewV4)
+
+ if err != nil {
+ return nil, err
+ }
+
+ // use the copiedMessage from this point
+ msg = copiedMessage
+
+ const altStatusCodeKey, altDescriptionKey = "statusCode", "statusDescription"
+
+ ctx, span := tab.StartSpan(ctx, "rpc.RPC")
+ tracing.ApplyComponentInfo(span, Version)
+ defer span.End()
+
+ msg.Properties.ReplyTo = &l.clientAddress
+
+ if msg.ApplicationProperties == nil {
+ msg.ApplicationProperties = make(map[string]interface{})
+ }
+
+ if _, ok := msg.ApplicationProperties["server-timeout"]; !ok {
+ if deadline, ok := ctx.Deadline(); ok {
+ msg.ApplicationProperties["server-timeout"] = uint(time.Until(deadline) / time.Millisecond)
+ }
+ }
+
+ responseCh := l.addChannelToMap(messageID)
+
+ if responseCh == nil {
+ return nil, l.broadcastErr
+ }
+
+ err = l.sender.Send(ctx, msg)
+
+ if err != nil {
+ l.deleteChannelFromMap(messageID)
+ tab.For(ctx).Error(err)
+ return nil, err
+ }
+
+ var res *amqp.Message
+
+ select {
+ case <-ctx.Done():
+ l.deleteChannelFromMap(messageID)
+ res, err = nil, ctx.Err()
+ case resp := <-responseCh:
+ // this will get triggered by the loop in 'startReceiverRouter' when it receives
+ // a message with our autoGenMessageID set in the correlation_id property.
+ res, err = resp.message, resp.err
+ }
+
+ if err != nil {
+ tab.For(ctx).Error(err)
+ return nil, err
+ }
+
+ var statusCode int
+ statusCodeCandidates := []string{statusCodeKey, altStatusCodeKey}
+ for i := range statusCodeCandidates {
+ if rawStatusCode, ok := res.ApplicationProperties[statusCodeCandidates[i]]; ok {
+ if cast, ok := rawStatusCode.(int32); ok {
+ statusCode = int(cast)
+ break
+ }
+
+ err := errors.New("status code was not of expected type int32")
+ tab.For(ctx).Error(err)
+ return nil, err
+ }
+ }
+ if statusCode == 0 {
+ err := errors.New("status codes was not found on rpc message")
+ tab.For(ctx).Error(err)
+ return nil, err
+ }
+
+ var description string
+ descriptionCandidates := []string{descriptionKey, altDescriptionKey}
+ for i := range descriptionCandidates {
+ if rawDescription, ok := res.ApplicationProperties[descriptionCandidates[i]]; ok {
+ if description, ok = rawDescription.(string); ok || rawDescription == nil {
+ break
+ } else {
+ return nil, errors.New("status description was not of expected type string")
+ }
+ }
+ }
+
+ span.AddAttributes(tab.StringAttribute("http.status_code", fmt.Sprintf("%d", statusCode)))
+
+ response := &RPCResponse{
+ Code: int(statusCode),
+ Description: description,
+ Message: res,
+ }
+
+ if err := l.messageAccept(ctx, res); err != nil {
+ tab.For(ctx).Error(err)
+ return response, err
+ }
+
+ return response, err
+}
+
+// Close the link receiver, sender and session
+func (l *rpcLink) Close(ctx context.Context) error {
+ ctx, span := startRPCSpan(ctx, "rpc.Close")
+ defer span.End()
+
+ if err := l.closeReceiver(ctx); err != nil {
+ _ = l.closeSender(ctx)
+ _ = l.closeSession(ctx)
+ return err
+ }
+
+ if err := l.closeSender(ctx); err != nil {
+ _ = l.closeSession(ctx)
+ return err
+ }
+
+ return l.closeSession(ctx)
+}
+
+func (l *rpcLink) closeReceiver(ctx context.Context) error {
+ ctx, span := startRPCSpan(ctx, "rpc.closeReceiver")
+ defer span.End()
+
+ if l.receiver != nil {
+ return l.receiver.Close(ctx)
+ }
+ return nil
+}
+
+func (l *rpcLink) closeSender(ctx context.Context) error {
+ ctx, span := startRPCSpan(ctx, "rpc.closeSender")
+ defer span.End()
+
+ if l.sender != nil {
+ return l.sender.Close(ctx)
+ }
+ return nil
+}
+
+func (l *rpcLink) closeSession(ctx context.Context) error {
+ ctx, span := startRPCSpan(ctx, "rpc.closeSession")
+ defer span.End()
+
+ if l.session != nil {
+ return l.session.Close(ctx)
+ }
+ return nil
+}
+
+// addChannelToMap adds a channel which will be used by the response router to
+// notify when there is a response to the request.
+// If l.responseMap is nil (for instance, via broadcastError) this function will
+// return nil.
+func (l *rpcLink) addChannelToMap(messageID string) chan rpcResponse {
+ l.responseMu.Lock()
+ defer l.responseMu.Unlock()
+
+ if l.responseMap == nil {
+ return nil
+ }
+
+ responseCh := make(chan rpcResponse, 1)
+ l.responseMap[messageID] = responseCh
+
+ return responseCh
+}
+
+// deleteChannelFromMap removes the message from our internal map and returns
+// a channel that the corresponding RPC() call is waiting on.
+// If l.responseMap is nil (for instance, via broadcastError) this function will
+// return nil.
+func (l *rpcLink) deleteChannelFromMap(messageID string) chan rpcResponse {
+ l.responseMu.Lock()
+ defer l.responseMu.Unlock()
+
+ if l.responseMap == nil {
+ return nil
+ }
+
+ ch := l.responseMap[messageID]
+ delete(l.responseMap, messageID)
+
+ return ch
+}
+
+// broadcastError notifies the anyone waiting for a response that the link/session/connection
+// has closed.
+func (l *rpcLink) broadcastError(err error) {
+ l.responseMu.Lock()
+ defer l.responseMu.Unlock()
+
+ for _, ch := range l.responseMap {
+ ch <- rpcResponse{err: err}
+ }
+
+ l.broadcastErr = err
+ l.responseMap = nil
+}
+
+// addMessageID generates a unique UUID for the message. When the service
+// responds it will fill out the correlation ID property of the response
+// with this ID, allowing us to link the request and response together.
+//
+// NOTE: this function copies 'message', adding in a 'Properties' object
+// if it does not already exist.
+func addMessageID(message *amqp.Message, uuidNewV4 func() (uuid.UUID, error)) (*amqp.Message, string, error) {
+ uuid, err := uuidNewV4()
+
+ if err != nil {
+ return nil, "", err
+ }
+
+ autoGenMessageID := uuid.String()
+
+ // we need to modify the message so we'll make a copy
+ copiedMessage := *message
+
+ if message.Properties == nil {
+ copiedMessage.Properties = &amqp.MessageProperties{
+ MessageID: autoGenMessageID,
+ }
+ } else {
+ // properties already exist, make a copy and then update
+ // the message ID
+ copiedProperties := *message.Properties
+ copiedProperties.MessageID = autoGenMessageID
+
+ copiedMessage.Properties = &copiedProperties
+ }
+
+ return &copiedMessage, autoGenMessageID, nil
+}
+
+func isClosedError(err error) bool {
+ var detachError *amqp.DetachError
+
+ return errors.Is(err, amqp.ErrLinkClosed) ||
+ errors.As(err, &detachError) ||
+ errors.Is(err, amqp.ErrConnClosed) ||
+ errors.Is(err, amqp.ErrSessionClosed)
+}
+
+func startRPCSpan(ctx context.Context, operation string) (context.Context, tab.Spanner) {
+ ctx, span := tab.StartSpan(ctx, operation)
+ tracing.ApplyComponentInfo(span, Version)
+ return ctx, span
+}
diff --git a/sdk/messaging/azservicebus/internal/sbauth/token_provider.go b/sdk/messaging/azservicebus/internal/sbauth/token_provider.go
index 8aa15a6f1890..89fbf7b229cc 100644
--- a/sdk/messaging/azservicebus/internal/sbauth/token_provider.go
+++ b/sdk/messaging/azservicebus/internal/sbauth/token_provider.go
@@ -58,7 +58,6 @@ func (tp *TokenProvider) GetToken(uri string) (*auth.Token, error) {
// GetToken returns a token (that is compatible as an auth.TokenProvider) and
// the calculated time when you should renew your token.
func (tp *TokenProvider) GetTokenAsTokenProvider(uri string) (*singleUseTokenProvider, time.Time, error) {
-
token, renewAt, err := tp.getTokenImpl(uri)
if err != nil {
diff --git a/sdk/messaging/azservicebus/internal/stress/Chart.lock b/sdk/messaging/azservicebus/internal/stress/Chart.lock
new file mode 100644
index 000000000000..96e9785515bb
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/stress/Chart.lock
@@ -0,0 +1,6 @@
+dependencies:
+- name: stress-test-addons
+ repository: https://stresstestcharts.blob.core.windows.net/helm/
+ version: 0.1.12
+digest: sha256:a2afbc89375d1518cd44194ca861e8a7a1ca1c16cfaea3409cdb48bfc910d878
+generated: "2021-11-16T19:17:10.5545801-05:00"
diff --git a/sdk/messaging/azservicebus/internal/stress/Chart.yaml b/sdk/messaging/azservicebus/internal/stress/Chart.yaml
new file mode 100644
index 000000000000..044770e86ef2
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/stress/Chart.yaml
@@ -0,0 +1,14 @@
+apiVersion: v2
+name: go-sb-stress
+description: Stress tests for Go
+version: 0.1.1
+appVersion: v0.1
+annotations:
+ stressTest: 'true' # enable auto-discovery of this test via `find-all-stress-packages.ps1`
+ namespace: 'go'
+ dockerbuilddir: '../..'
+ dockerfile: './Dockerfile'
+dependencies:
+- name: stress-test-addons
+ version: 0.1.12
+ repository: https://stresstestcharts.blob.core.windows.net/helm/
\ No newline at end of file
diff --git a/sdk/messaging/azservicebus/internal/stress/Dockerfile b/sdk/messaging/azservicebus/internal/stress/Dockerfile
index e553930c2ea3..3159517f3471 100644
--- a/sdk/messaging/azservicebus/internal/stress/Dockerfile
+++ b/sdk/messaging/azservicebus/internal/stress/Dockerfile
@@ -1,4 +1,14 @@
-FROM alpine:latest
-ADD ./stress /
-WORKDIR /
-ENTRYPOINT ["/stress"]
+FROM golang as build
+# you'll need to run this build from the root of the repo
+ENV GOOS=linux
+ENV GOARCH=amd64
+ENV CGO_ENABLED=0
+ADD . /src
+WORKDIR /src/internal/stress
+RUN go build -o stress .
+
+FROM alpine
+RUN mkdir -p /app
+COPY --from=build /src/internal/stress/stress /app/stress
+WORKDIR /app
+ENTRYPOINT ["./stress"]
\ No newline at end of file
diff --git a/sdk/messaging/azservicebus/internal/stress/shared/stats.go b/sdk/messaging/azservicebus/internal/stress/shared/stats.go
index 72b8250960d7..858aff921a0f 100644
--- a/sdk/messaging/azservicebus/internal/stress/shared/stats.go
+++ b/sdk/messaging/azservicebus/internal/stress/shared/stats.go
@@ -80,28 +80,30 @@ func newStatsPrinter(ctx context.Context, prefix string, interval time.Duration,
}
sp.mu.RLock()
+ sp.PrintStats()
+ sp.mu.RUnlock()
+ }
+ }(ctx)
- log.Printf("Stats:")
-
- for _, stats := range sp.all {
- log.Printf(" %s", stats.String())
+ return sp
+}
- if stats.Sent > 0 {
- sp.tc.TrackMetric(fmt.Sprintf("%s.TotalSent", stats.name), float64(stats.Sent))
- }
+func (sp *statsPrinter) PrintStats() {
+ log.Printf("Stats:")
- if stats.Received > 0 {
- sp.tc.TrackMetric(fmt.Sprintf("%s.TotalReceived", stats.name), float64(stats.Received))
- }
+ for _, stats := range sp.all {
+ log.Printf(" %s", stats.String())
- sp.tc.TrackMetric(fmt.Sprintf("%s.TotalErrors", stats.name), float64(stats.Errors))
- }
+ if stats.Sent > 0 {
+ sp.tc.TrackMetric(fmt.Sprintf("%s.TotalSent", stats.name), float64(stats.Sent))
+ }
- sp.mu.RUnlock()
+ if stats.Received > 0 {
+ sp.tc.TrackMetric(fmt.Sprintf("%s.TotalReceived", stats.name), float64(stats.Received))
}
- }(ctx)
- return sp
+ sp.tc.TrackMetric(fmt.Sprintf("%s.TotalErrors", stats.name), float64(stats.Errors))
+ }
}
// NewStat creates a new stat with `name` and adds it to the list of statistics that will
diff --git a/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go b/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go
index 4116f10e4385..3422787c171b 100644
--- a/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go
+++ b/sdk/messaging/azservicebus/internal/stress/shared/stress_context.go
@@ -12,6 +12,7 @@ import (
"sync/atomic"
"time"
+ azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log"
"github.com/microsoft/ApplicationInsights-Go/appinsights"
)
@@ -32,6 +33,8 @@ type StressContext struct {
// ConnectionString represents the value of the environment variable SERVICEBUS_CONNECTION_STRING.
ConnectionString string
+ logMessages chan string
+
cancel context.CancelFunc
}
@@ -61,12 +64,33 @@ func MustCreateStressContext(testName string) *StressContext {
ctx, cancel := NewCtrlCContext()
+ azlog.SetEvents("azsb.Conn", "azsb.Auth", "azsb.Retry", "azsb.Mgmt")
+
+ logMessages := make(chan string, 10000)
+
+ go func() {
+ PrintLoop:
+ for {
+ select {
+ case <-ctx.Done():
+ break PrintLoop
+ case msg := <-logMessages:
+ fmt.Println(msg)
+ }
+ }
+ }()
+
+ azlog.SetListener(func(e azlog.Event, msg string) {
+ logMessages <- fmt.Sprintf("%s %10s %s", time.Now().Format(time.RFC3339), e, msg)
+ })
+
return &StressContext{
TestRunID: testRunID,
Nano: testRunID, // the same for now
ConnectionString: cs,
TelemetryClient: telemetryClient,
statsPrinter: newStatsPrinter(ctx, testName, 5*time.Second, telemetryClient),
+ logMessages: logMessages,
Context: ctx,
cancel: cancel,
}
@@ -90,12 +114,29 @@ func (sc *StressContext) Start(entityName string, attributes map[string]string)
func (sc *StressContext) End() {
log.Printf("Stopping and flushing telemetry")
+ sc.cancel()
+
sc.TrackEvent("End")
+
sc.Channel().Flush()
<-sc.Channel().Close()
time.Sleep(5 * time.Second)
+ // dump out any remaining log messages
+PrintLoop:
+ for {
+ select {
+ case msg := <-sc.logMessages:
+ fmt.Println(msg)
+ default:
+ break PrintLoop
+ }
+ }
+
+ // dump out the last stats.
+ sc.PrintStats()
+
log.Printf("Done")
}
@@ -108,6 +149,14 @@ func (tracker *StressContext) PanicOnError(message string, err error) {
}
}
+func (tracker *StressContext) Assert(condition bool, message string) {
+ tracker.LogIfFailed(message, nil, nil)
+
+ if !condition {
+ panic(message)
+ }
+}
+
func (sc *StressContext) LogIfFailed(message string, err error, stats *Stats) {
if err != nil {
log.Printf("Error: %s: %#v, %T", message, err, err)
diff --git a/sdk/messaging/azservicebus/internal/stress/shared/utils.go b/sdk/messaging/azservicebus/internal/stress/shared/utils.go
index b99f3dda9618..ceeff921d608 100644
--- a/sdk/messaging/azservicebus/internal/stress/shared/utils.go
+++ b/sdk/messaging/azservicebus/internal/stress/shared/utils.go
@@ -50,18 +50,24 @@ func MustGenerateMessages(sc *StressContext, sender *azservicebus.Sender, messag
}
// MustCreateAutoDeletingQueue creates a queue that will auto-delete 10 minutes after activity has ceased.
-func MustCreateAutoDeletingQueue(sc *StressContext, queueName string) {
+func MustCreateAutoDeletingQueue(sc *StressContext, queueName string, qp *admin.QueueProperties) {
adminClient, err := admin.NewClientFromConnectionString(sc.ConnectionString, nil)
sc.PanicOnError("failed to create adminClient", err)
autoDeleteOnIdle := 10 * time.Minute
- _, err = adminClient.CreateQueue(context.Background(), queueName, &admin.QueueProperties{
- AutoDeleteOnIdle: &autoDeleteOnIdle,
+ var newQP admin.QueueProperties
- // mostly useful for tracking backwards in case something goes wrong.
- UserMetadata: &sc.TestRunID,
- }, nil)
+ if qp != nil {
+ newQP = *qp
+ }
+
+ newQP.AutoDeleteOnIdle = &autoDeleteOnIdle
+
+ // mostly useful for tracking backwards in case something goes wrong.
+ newQP.UserMetadata = &sc.TestRunID
+
+ _, err = adminClient.CreateQueue(context.Background(), queueName, &newQP, nil)
sc.PanicOnError("failed to create queue", err)
}
diff --git a/sdk/messaging/azservicebus/internal/stress/stress-test-resources.bicep b/sdk/messaging/azservicebus/internal/stress/stress-test-resources.bicep
new file mode 100644
index 000000000000..7245dfec35df
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/stress/stress-test-resources.bicep
@@ -0,0 +1,107 @@
+@description('The base resource name.')
+param baseName string = resourceGroup().name
+
+@description('The client OID to grant access to test resources.')
+param testApplicationOid string
+
+var apiVersion = '2017-04-01'
+var location = resourceGroup().location
+var authorizationRuleName_var = '${baseName}/RootManageSharedAccessKey'
+var authorizationRuleNameNoManage_var = '${baseName}/NoManage'
+var serviceBusDataOwnerRoleId = '/subscriptions/${subscription().subscriptionId}/providers/Microsoft.Authorization/roleDefinitions/090c5cfd-751d-490a-894a-3ce6f1109419'
+
+resource servicebus 'Microsoft.ServiceBus/namespaces@2018-01-01-preview' = {
+ name: baseName
+ location: location
+ sku: {
+ name: 'Standard'
+ tier: 'Standard'
+ }
+ properties: {
+ zoneRedundant: false
+ }
+}
+
+resource authorizationRuleName 'Microsoft.ServiceBus/namespaces/AuthorizationRules@2015-08-01' = {
+ name: authorizationRuleName_var
+ location: location
+ properties: {
+ rights: [
+ 'Listen'
+ 'Manage'
+ 'Send'
+ ]
+ }
+ dependsOn: [
+ servicebus
+ ]
+}
+
+resource authorizationRuleNameNoManage 'Microsoft.ServiceBus/namespaces/AuthorizationRules@2015-08-01' = {
+ name: authorizationRuleNameNoManage_var
+ location: location
+ properties: {
+ rights: [
+ 'Listen'
+ 'Send'
+ ]
+ }
+ dependsOn: [
+ servicebus
+ ]
+}
+
+
+
+resource dataOwnerRoleId 'Microsoft.Authorization/roleAssignments@2018-01-01-preview' = {
+ name: guid('dataOwnerRoleId${baseName}')
+ properties: {
+ roleDefinitionId: serviceBusDataOwnerRoleId
+ principalId: testApplicationOid
+ }
+ dependsOn: [
+ servicebus
+ ]
+}
+
+resource testQueue 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = {
+ parent: servicebus
+ name: 'testQueue'
+ properties: {
+ lockDuration: 'PT5M'
+ maxSizeInMegabytes: 1024
+ requiresDuplicateDetection: false
+ requiresSession: false
+ defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S'
+ deadLetteringOnMessageExpiration: false
+ duplicateDetectionHistoryTimeWindow: 'PT10M'
+ maxDeliveryCount: 10
+ autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S'
+ enablePartitioning: false
+ enableExpress: false
+ }
+}
+
+resource testQueueWithSessions 'Microsoft.ServiceBus/namespaces/queues@2017-04-01' = {
+ parent: servicebus
+ name: 'testQueueWithSessions'
+ properties: {
+ lockDuration: 'PT5M'
+ maxSizeInMegabytes: 1024
+ requiresDuplicateDetection: false
+ requiresSession: true
+ defaultMessageTimeToLive: 'P10675199DT2H48M5.4775807S'
+ deadLetteringOnMessageExpiration: false
+ duplicateDetectionHistoryTimeWindow: 'PT10M'
+ maxDeliveryCount: 10
+ autoDeleteOnIdle: 'P10675199DT2H48M5.4775807S'
+ enablePartitioning: false
+ enableExpress: false
+ }
+}
+
+output SERVICEBUS_CONNECTION_STRING string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'RootManageSharedAccessKey'), apiVersion).primaryConnectionString
+output SERVICEBUS_CONNECTION_STRING_NO_MANAGE string = listKeys(resourceId('Microsoft.ServiceBus/namespaces/authorizationRules', baseName, 'NoManage'), apiVersion).primaryConnectionString
+output SERVICEBUS_ENDPOINT string = replace(servicebus.properties.serviceBusEndpoint, ':443/', '')
+output QUEUE_NAME string = 'testQueue'
+output QUEUE_NAME_WITH_SESSIONS string = 'testQueueWithSessions'
diff --git a/sdk/messaging/azservicebus/internal/stress/templates/deploy-job.yaml b/sdk/messaging/azservicebus/internal/stress/templates/deploy-job.yaml
new file mode 100644
index 000000000000..37cf472c5c8c
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/stress/templates/deploy-job.yaml
@@ -0,0 +1,17 @@
+{{- include "stress-test-addons.deploy-job-template.from-pod" (list . "stress.deploy-example") -}}
+{{- define "stress.deploy-example" -}}
+metadata:
+ labels:
+ testName: "go-servicebus"
+spec:
+ containers:
+ - name: main
+ # az acr list -g rg-stress-cluster-test --subscription "Azure SDK Developer Playground" --query "[0].loginServer"
+ image: {{ .Values.image }}
+ command: ['/app/stress']
+ args:
+ - "tests"
+ # (this is injected automatically. The full list of scenarios is in `../values.yaml`)
+ - {{ .Scenario }}
+ {{- include "stress-test-addons.container-env" . | nindent 6 }}
+{{- end -}}
diff --git a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go
index 75ad7dd8184c..ad1c532b77d8 100644
--- a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go
+++ b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment.go
@@ -18,7 +18,7 @@ func ConstantDetachment(remainingArgs []string) {
queueName := fmt.Sprintf("detach-tester-%s", sc.Nano)
- shared.MustCreateAutoDeletingQueue(sc, queueName)
+ shared.MustCreateAutoDeletingQueue(sc, queueName, nil)
client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil)
sc.PanicOnError("failed to create client", err)
diff --git a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go
index 4395edda3a5f..371335f43fc9 100644
--- a/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go
+++ b/sdk/messaging/azservicebus/internal/stress/tests/constant_detachment_sender.go
@@ -78,7 +78,7 @@ func ConstantDetachmentSender(remainingArgs []string) {
func createDetachResources(sc *shared.StressContext, name string) (string, *shared.Stats, *azservicebus.Sender) {
queueName := fmt.Sprintf("detach_%s-%s", name, sc.Nano)
- shared.MustCreateAutoDeletingQueue(sc, queueName)
+ shared.MustCreateAutoDeletingQueue(sc, queueName, nil)
client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil)
sc.PanicOnError("failed to create client", err)
diff --git a/sdk/messaging/azservicebus/internal/stress/tests/finite_peeks.go b/sdk/messaging/azservicebus/internal/stress/tests/finite_peeks.go
new file mode 100644
index 000000000000..66a1b365b2c5
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/stress/tests/finite_peeks.go
@@ -0,0 +1,68 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package tests
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "time"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/stress/shared"
+)
+
+func FinitePeeks(remainingArgs []string) {
+ sc := shared.MustCreateStressContext("FinitePeeks")
+ defer sc.Done()
+
+ queueName := fmt.Sprintf("finite-peeks-%s", sc.Nano)
+ shared.MustCreateAutoDeletingQueue(sc, queueName, nil)
+
+ client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil)
+ sc.PanicOnError("failed to create client", err)
+
+ receiverStats := sc.NewStat("peeks")
+
+ sender, err := client.NewSender(queueName, nil)
+ sc.PanicOnError("failed to create sender", err)
+
+ err = sender.SendMessage(sc.Context, &azservicebus.Message{
+ Body: []byte("peekable message"),
+ })
+ sc.PanicOnError("failed to send message", err)
+
+ _ = sender.Close(sc.Context)
+
+ receiver, err := client.NewReceiverForQueue(queueName, nil)
+ sc.PanicOnError("failed to create receiver", err)
+
+ // receiving here just guarantees the message has arrived and is available (sometimes
+ // there's a slight delay)
+ receiveCtx, cancel := context.WithTimeout(sc.Context, time.Minute)
+ defer cancel()
+
+ tmp, err := receiver.ReceiveMessages(receiveCtx, 1, nil)
+ sc.PanicOnError("failed to receive messages", err)
+ sc.Assert(len(tmp) == 1, "message was never available")
+
+ // return the message back from whence it came.
+ sc.PanicOnError("failed to abandon message",
+ receiver.AbandonMessage(sc.Context, tmp[0], nil))
+
+ for i := 0; i < 10000; i++ {
+ log.Printf("Sleeping for 1 second before iteration %d", i)
+ time.Sleep(time.Second)
+
+ seqNum := int64(0)
+
+ messages, err := receiver.PeekMessages(sc.Context, 1, &azservicebus.PeekMessagesOptions{
+ FromSequenceNumber: &seqNum,
+ })
+ sc.PanicOnError("failed to peek messages", err)
+ sc.Assert(len(messages) == 1, "no messages returned in peek")
+
+ receiverStats.AddReceived(int32(1))
+ }
+}
diff --git a/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go b/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go
index 6d21be448476..257fc0833395 100644
--- a/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go
+++ b/sdk/messaging/azservicebus/internal/stress/tests/finite_send_and_receive.go
@@ -5,6 +5,7 @@ package tests
import (
"context"
+ "errors"
"fmt"
"log"
"strings"
@@ -12,6 +13,7 @@ import (
"time"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/stress/shared"
)
@@ -19,23 +21,29 @@ func FiniteSendAndReceiveTest(remainingArgs []string) {
sc := shared.MustCreateStressContext("FiniteSendAndReceiveTest")
sc.TrackEvent("Start")
- defer sc.TrackEvent("End")
+ defer sc.End()
queueName := strings.ToLower(fmt.Sprintf("queue-%X", time.Now().UnixNano()))
log.Printf("Creating queue")
- shared.MustCreateAutoDeletingQueue(sc, queueName)
+
+ lockDuration := 5 * time.Minute
+
+ shared.MustCreateAutoDeletingQueue(sc, queueName, &admin.QueueProperties{
+ LockDuration: &lockDuration,
+ })
client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil)
sc.PanicOnError("failed to create client", err)
sender, err := client.NewSender(queueName, nil)
sc.PanicOnError("failed to create sender", err)
- const messageLimit = 50000
- shared.MustGenerateMessages(sc, sender, messageLimit, 100, sc.NewStat("sender"))
+ const messageLimit = 500
log.Printf("Sending %d messages (all messages will be sent before receiving begins)", messageLimit)
+ shared.MustGenerateMessages(sc, sender, messageLimit, 100, sc.NewStat("sender"))
+
log.Printf("Starting receiving...")
receiver, err := client.NewReceiverForQueue(queueName, nil)
@@ -45,7 +53,7 @@ func FiniteSendAndReceiveTest(remainingArgs []string) {
receiverStats := sc.NewStat("receiver")
- for receiverStats.Received == messageLimit {
+ for receiverStats.Received < messageLimit {
log.Printf("[start] Receiving messages...")
messages, err := receiver.ReceiveMessages(context.Background(), 100, nil)
log.Printf("[done] Receiving messages... %v, %v", len(messages), err)
@@ -53,23 +61,36 @@ func FiniteSendAndReceiveTest(remainingArgs []string) {
wg := sync.WaitGroup{}
+ log.Printf("About to complete %d messages", len(messages))
+ time.Sleep(10 * time.Second)
+
for _, msg := range messages {
wg.Add(1)
+
go func(msg *azservicebus.ReceivedMessage) {
completions <- struct{}{}
defer wg.Done()
defer func() { <-completions }()
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
err := receiver.CompleteMessage(ctx, msg)
+
+ var rpcCodeErr interface{ RPCCode() int }
+
+ if errors.As(err, &rpcCodeErr) {
+ if rpcCodeErr.RPCCode() == 410 {
+ receiverStats.AddError("lock lost", err)
+ return
+ }
+ }
+
sc.PanicOnError("failed to complete message", err)
+ receiverStats.AddReceived(1)
}(msg)
}
wg.Wait()
-
- receiverStats.AddReceived(int32(len(messages)))
}
}
diff --git a/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go b/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go
index 25c1af5d1664..86b06d42195c 100644
--- a/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go
+++ b/sdk/messaging/azservicebus/internal/stress/tests/long_running_renew_lock.go
@@ -17,7 +17,7 @@ func LongRunningRenewLockTest(remainingArgs []string) {
sc := shared.MustCreateStressContext("LongRunningRenewLockTest")
queueName := fmt.Sprintf("renew-lock-test-%s", sc.Nano)
- shared.MustCreateAutoDeletingQueue(sc, queueName)
+ shared.MustCreateAutoDeletingQueue(sc, queueName, nil)
client, err := azservicebus.NewClientFromConnectionString(sc.ConnectionString, nil)
sc.PanicOnError("failed to create admin.Client", err)
diff --git a/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go b/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go
index 5af1fd96639b..1279dfe8dcf3 100644
--- a/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go
+++ b/sdk/messaging/azservicebus/internal/stress/tests/rapid_open_close.go
@@ -17,7 +17,7 @@ func RapidOpenCloseTest(remainingArgs []string) {
sc := shared.MustCreateStressContext("RapidOpenCloseTest")
queueName := fmt.Sprintf("rapid_open_close-%X", time.Now().UnixNano())
- shared.MustCreateAutoDeletingQueue(sc, queueName)
+ shared.MustCreateAutoDeletingQueue(sc, queueName, nil)
for i := 0; i < 100; i++ {
log.Printf("[%d] Open/Close", i)
diff --git a/sdk/messaging/azservicebus/internal/stress/tests/tests.go b/sdk/messaging/azservicebus/internal/stress/tests/tests.go
index 2c50db3a56b8..d4ae280f55c0 100644
--- a/sdk/messaging/azservicebus/internal/stress/tests/tests.go
+++ b/sdk/messaging/azservicebus/internal/stress/tests/tests.go
@@ -11,9 +11,6 @@ import (
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/stress/shared"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
- "github.com/devigned/tab"
)
// Simple query to view some of the stats reported by these stress tests.
@@ -25,22 +22,23 @@ import (
func Run(remainingArgs []string) {
// turn on some simple stderr diagnostics
- tracer := utils.NewSimpleTracer(map[string]bool{
- tracing.SpanRecover: true,
- tracing.SpanNegotiateClaim: true,
- tracing.SpanRecoverClient: true,
- tracing.SpanRecoverLink: true,
- }, nil)
+ // tracer := utils.NewSimpleTracer(map[string]bool{
+ // // tracing.SpanRecover: true,
+ // //tracing.SpanNegotiateClaim: true,
+ // // tracing.SpanRecoverClient: true,
+ // // tracing.SpanRecoverLink: true,
+ // }, nil)
- tab.Register(tracer)
+ // tab.Register(tracer)
allTests := map[string]func(args []string){
- "infiniteSendAndReceive": InfiniteSendAndReceiveRun,
- "finiteSendAndReceive": FiniteSendAndReceiveTest,
- "rapidOpenClose": RapidOpenCloseTest,
- "longRunningRenewLock": LongRunningRenewLockTest,
"constantDetach": ConstantDetachment,
"constantDetachmentSender": ConstantDetachmentSender,
+ "finitePeeks": FinitePeeks,
+ "finiteSendAndReceive": FiniteSendAndReceiveTest,
+ "infiniteSendAndReceive": InfiniteSendAndReceiveRun,
+ "longRunningRenewLock": LongRunningRenewLockTest,
+ "rapidOpenClose": RapidOpenCloseTest,
}
if len(remainingArgs) == 0 {
diff --git a/sdk/messaging/azservicebus/internal/stress/values.yaml b/sdk/messaging/azservicebus/internal/stress/values.yaml
new file mode 100644
index 000000000000..5fbec5a3ab77
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/stress/values.yaml
@@ -0,0 +1,11 @@
+# Optional list of scenarios. If specified multiple stress test jobs will be generated,
+# one for each scenario in the list. The pod spec can then be configured to pass the
+# scenario name down to the test command, e.g. `command: ["node", "{{ .Scenario }}.js"]`
+scenarios:
+- "constantDetach"
+- "constantDetachmentSender"
+- "finitePeeks"
+- "finiteSendAndReceive"
+- "infiniteSendAndReceive"
+- "longRunningRenewLock"
+- "rapidOpenClose"
diff --git a/sdk/messaging/azservicebus/internal/test/test_helpers.go b/sdk/messaging/azservicebus/internal/test/test_helpers.go
index 6ddc3f37eeae..40d0fdceaf8d 100644
--- a/sdk/messaging/azservicebus/internal/test/test_helpers.go
+++ b/sdk/messaging/azservicebus/internal/test/test_helpers.go
@@ -4,10 +4,16 @@
package test
import (
+ "context"
"math/rand"
+ "net/http"
"os"
"testing"
"time"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/atom"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
+ "github.com/stretchr/testify/require"
)
var (
@@ -46,3 +52,30 @@ func GetConnectionStringWithoutManagePerms(t *testing.T) string {
return cs
}
+
+func CreateExpiringQueue(t *testing.T, qd *atom.QueueDescription) (string, func()) {
+ cs := GetConnectionString(t)
+ em, err := atom.NewEntityManagerWithConnectionString(cs, "")
+ require.NoError(t, err)
+
+ queueName := RandomString("queue", 5)
+
+ if qd == nil {
+ qd = &atom.QueueDescription{}
+ }
+
+ deleteAfter := 5 * time.Minute
+ qd.AutoDeleteOnIdle = utils.DurationToStringPtr(&deleteAfter)
+
+ env, _ := atom.WrapWithQueueEnvelope(qd, em.TokenProvider())
+
+ var qe *atom.QueueEnvelope
+ resp, err := em.Put(context.Background(), queueName, env, &qe)
+ require.NoError(t, err)
+ require.EqualValues(t, http.StatusCreated, resp.StatusCode)
+
+ return queueName, func() {
+ _, err := em.Delete(context.Background(), queueName)
+ require.NoError(t, err)
+ }
+}
diff --git a/sdk/messaging/azservicebus/internal/utils/retrier.go b/sdk/messaging/azservicebus/internal/utils/retrier.go
new file mode 100644
index 000000000000..e880c870b4b1
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/utils/retrier.go
@@ -0,0 +1,142 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package utils
+
+import (
+ "context"
+ "math"
+ "math/rand"
+ "time"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
+)
+
+// EventRetry is the name for retry events
+const EventRetry = "azsb.Retry"
+
+type RetryFnArgs struct {
+ I int32
+ // LastErr is the returned error from the previous loop.
+ // If you have potentially expensive
+ LastErr error
+
+ resetAttempts bool
+}
+
+// ResetAttempts causes the current retry attempt number to be reset
+// in time for the next recovery (should we fail).
+// NOTE: Use of this should be pretty rare, it's really only needed when you have
+// a situation like Receiver.ReceiveMessages() that can recovery but intentionally
+// does not return.
+func (rf *RetryFnArgs) ResetAttempts() {
+ rf.resetAttempts = true
+}
+
+// Retry runs a standard retry loop. It executes your passed in fn as the body of the loop.
+// It returns if it exceeds the number of configured retry options or if 'isFatal' returns true.
+func Retry(ctx context.Context, name string, fn func(ctx context.Context, args *RetryFnArgs) error, isFatalFn func(err error) bool, o RetryOptions) error {
+ if isFatalFn == nil {
+ panic("isFatalFn is nil, errors would panic")
+ }
+
+ var ro RetryOptions = o
+ setDefaults(&ro)
+
+ var err error
+
+ for i := int32(0); i <= ro.MaxRetries; i++ {
+ if i > 0 {
+ sleep := calcDelay(ro, i)
+ log.Writef(EventRetry, "(%s) Attempt %d sleeping for %s", name, i, sleep)
+ time.Sleep(sleep)
+ }
+
+ args := RetryFnArgs{
+ I: i,
+ LastErr: err,
+ }
+ err = fn(ctx, &args)
+
+ if args.resetAttempts {
+ log.Writef(EventRetry, "(%s) Resetting attempts", name)
+ i = int32(0)
+ }
+
+ if err != nil {
+ if isFatalFn(err) {
+ log.Writef(EventRetry, "(%s) Attempt %d returned non-retryable error: %s", name, i, err.Error())
+ return err
+ } else {
+ log.Writef(EventRetry, "(%s) Attempt %d returned retryable error: %s", name, i, err.Error())
+ }
+
+ continue
+ }
+
+ return nil
+ }
+
+ return err
+}
+
+// RetryOptions represent the options for retries.
+type RetryOptions struct {
+ // MaxRetries specifies the maximum number of attempts a failed operation will be retried
+ // before producing an error.
+ // The default value is three. A value less than zero means one try and no retries.
+ MaxRetries int32
+
+ // RetryDelay specifies the initial amount of delay to use before retrying an operation.
+ // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay.
+ // The default value is four seconds. A value less than zero means no delay between retries.
+ RetryDelay time.Duration
+
+ // MaxRetryDelay specifies the maximum delay allowed before retrying an operation.
+ // Typically the value is greater than or equal to the value specified in RetryDelay.
+ // The default Value is 120 seconds. A value less than zero means there is no cap.
+ MaxRetryDelay time.Duration
+}
+
+func setDefaults(o *RetryOptions) {
+ if o.MaxRetries == 0 {
+ o.MaxRetries = 3
+ } else if o.MaxRetries < 0 {
+ o.MaxRetries = 0
+ }
+ if o.MaxRetryDelay == 0 {
+ o.MaxRetryDelay = 120 * time.Second
+ } else if o.MaxRetryDelay < 0 {
+ // not really an unlimited cap, but sufficiently large enough to be considered as such
+ o.MaxRetryDelay = math.MaxInt64
+ }
+ if o.RetryDelay == 0 {
+ o.RetryDelay = 4 * time.Second
+ } else if o.RetryDelay < 0 {
+ o.RetryDelay = 0
+ }
+}
+
+// (adapted from from azcore/policy_retry)
+func calcDelay(o RetryOptions, try int32) time.Duration {
+ if try == 0 {
+ return 0
+ }
+
+ pow := func(number int64, exponent int32) int64 { // pow is nested helper function
+ var result int64 = 1
+ for n := int32(0); n < exponent; n++ {
+ result *= number
+ }
+ return result
+ }
+
+ delay := time.Duration(pow(2, try)-1) * o.RetryDelay
+
+ // Introduce some jitter: [0.0, 1.0) / 2 = [0.0, 0.5) + 0.8 = [0.8, 1.3)
+ delay = time.Duration(delay.Seconds() * (rand.Float64()/2 + 0.8) * float64(time.Second)) // NOTE: We want math/rand; not crypto/rand
+ if delay > o.MaxRetryDelay {
+ delay = o.MaxRetryDelay
+ }
+ return delay
+}
diff --git a/sdk/messaging/azservicebus/internal/utils/retrier_test.go b/sdk/messaging/azservicebus/internal/utils/retrier_test.go
new file mode 100644
index 000000000000..fc5826f58860
--- /dev/null
+++ b/sdk/messaging/azservicebus/internal/utils/retrier_test.go
@@ -0,0 +1,235 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package utils
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRetrier(t *testing.T) {
+ t.Run("Succeeds", func(t *testing.T) {
+ ctx := context.Background()
+
+ called := 0
+
+ err := Retry(ctx, "Retrier", func(ctx context.Context, args *RetryFnArgs) error {
+ called++
+ return nil
+ }, func(err error) bool {
+ panic("won't get called")
+ }, RetryOptions{})
+
+ require.Nil(t, err)
+ require.EqualValues(t, 1, called)
+ })
+
+ t.Run("FailsThenSucceeds", func(t *testing.T) {
+ ctx := context.Background()
+
+ called := 0
+ isFatalCalled := 0
+
+ isFatalFn := func(err error) bool {
+ require.NotNil(t, err)
+ // we'll just keep saying the errors aren't fatal.
+ isFatalCalled++
+ return false
+ }
+
+ err := Retry(ctx, "FailsThenSucceeds", func(ctx context.Context, args *RetryFnArgs) error {
+ called++
+
+ if args.I == 3 {
+ // we're on the last iteration, succeed
+ return nil
+ }
+
+ return fmt.Errorf("Error, iteration %d", args.I)
+ }, isFatalFn, fastRetryOptions)
+
+ require.EqualValues(t, 4, called)
+ require.EqualValues(t, 3, isFatalCalled)
+
+ // if an attempt succeeds then there's no error (despite previous failed tries)
+ require.NoError(t, err)
+ })
+
+ t.Run("FatalFailure", func(t *testing.T) {
+ ctx := context.Background()
+ called := 0
+
+ isFatalFn := func(err error) bool {
+ require.EqualValues(t, "isFatalFn says this is a fatal error", err.Error())
+ return true
+ }
+
+ err := Retry(ctx, "FatalFailure", func(ctx context.Context, args *RetryFnArgs) error {
+ called++
+ return errors.New("isFatalFn says this is a fatal error")
+ }, isFatalFn, RetryOptions{})
+
+ require.EqualValues(t, "isFatalFn says this is a fatal error", err.Error())
+ require.EqualValues(t, 1, called)
+ })
+
+ t.Run("Cancellation", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ isFatalFn := func(err error) bool {
+ return errors.Is(err, context.Canceled)
+ }
+
+ // it's up to
+ err := Retry(ctx, "Cancellation", func(ctx context.Context, args *RetryFnArgs) error {
+ // NOTE: it's up to the underlying function to handle cancellation. `Retry` doesn't
+ // do anything but propagate it.
+ select {
+ case <-ctx.Done():
+ default:
+ require.Fail(t, "Context should have been cancelled")
+ }
+
+ return context.Canceled
+ }, isFatalFn, RetryOptions{})
+
+ require.ErrorIs(t, context.Canceled, err)
+ })
+
+ t.Run("ResetAttempts", func(t *testing.T) {
+ isFatalFn := func(err error) bool {
+ return errors.Is(err, context.Canceled)
+ }
+
+ customRetryOptions := fastRetryOptions
+ customRetryOptions.MaxRetries = 1
+
+ var actualAttempts []int32
+
+ err := Retry(context.Background(), "ResetAttempts", func(ctx context.Context, args *RetryFnArgs) error {
+ actualAttempts = append(actualAttempts, args.I)
+
+ if len(actualAttempts) == 3 {
+ args.ResetAttempts()
+ }
+
+ return errors.New("whatever")
+ }, isFatalFn, RetryOptions{
+ MaxRetries: 2,
+ RetryDelay: time.Millisecond,
+ MaxRetryDelay: time.Millisecond,
+ })
+
+ expectedAttempts := []int32{
+ 0, 1, 2, // we resetted attempts here.
+ 1, 2, // and we start at the first retry attempt again.
+ }
+
+ require.EqualValues(t, "whatever", err.Error())
+ require.EqualValues(t, expectedAttempts, actualAttempts)
+ })
+
+ t.Run("DisableRetries", func(t *testing.T) {
+ isFatalFn := func(err error) bool {
+ return errors.Is(err, context.Canceled)
+ }
+
+ customRetryOptions := fastRetryOptions
+ customRetryOptions.MaxRetries = -1
+
+ called := 0
+
+ err := Retry(context.Background(), "ResetAttempts", func(ctx context.Context, args *RetryFnArgs) error {
+ called++
+ return errors.New("whatever")
+ }, isFatalFn, customRetryOptions)
+
+ require.EqualValues(t, 1, called)
+ require.EqualValues(t, "whatever", err.Error())
+ })
+}
+
+func Test_calcDelay(t *testing.T) {
+ t.Run("can't exceed max retry delay", func(t *testing.T) {
+ duration := calcDelay(RetryOptions{
+ RetryDelay: time.Hour,
+ MaxRetryDelay: time.Minute,
+ }, 1)
+
+ require.EqualValues(t, time.Minute, duration)
+ })
+
+ t.Run("increases with jitter", func(t *testing.T) {
+ duration := calcDelay(RetryOptions{
+ RetryDelay: time.Minute,
+ MaxRetryDelay: time.Hour,
+ }, 1)
+
+ require.GreaterOrEqual(t, duration, time.Duration((2-1)*time.Minute.Seconds()*0.8*float64(time.Second)))
+ require.LessOrEqual(t, duration, time.Duration((2-1)*time.Minute.Seconds()*1.3*float64(time.Second)))
+
+ duration = calcDelay(RetryOptions{
+ RetryDelay: time.Minute,
+ MaxRetryDelay: time.Hour,
+ }, 2)
+
+ require.GreaterOrEqual(t, duration, time.Duration((2*2-1)*time.Minute.Seconds()*0.8*float64(time.Second)))
+ require.LessOrEqual(t, duration, time.Duration((2*2-1)*time.Minute.Seconds()*1.3*float64(time.Second)))
+
+ duration = calcDelay(RetryOptions{
+ RetryDelay: time.Minute,
+ MaxRetryDelay: time.Hour,
+ }, 3)
+
+ require.GreaterOrEqual(t, duration, time.Duration((2*2*2-1)*time.Minute.Seconds()*0.8*float64(time.Second)))
+ require.LessOrEqual(t, duration, time.Duration((2*2*2-1)*time.Minute.Seconds()*1.3*float64(time.Second)))
+ })
+}
+
+var fastRetryOptions = RetryOptions{
+ // note: omitting MaxRetries just to give a sanity check that
+ // we do setDefaults() before we run.
+ RetryDelay: time.Millisecond,
+ MaxRetryDelay: time.Millisecond,
+}
+
+func TestRetryDefaults(t *testing.T) {
+ ro := RetryOptions{}
+ setDefaults(&ro)
+
+ require.EqualValues(t, 3, ro.MaxRetries)
+ require.EqualValues(t, 4*time.Second, ro.RetryDelay)
+ require.EqualValues(t, 2*time.Minute, ro.MaxRetryDelay)
+
+ // this is an interesting default. Anything < 0 basically
+ // causes the max delay to be "infinite"
+ ro.MaxRetryDelay = -1
+ // whereas this just normalizes to '0'
+ ro.RetryDelay = -1
+ ro.MaxRetries = -1
+ setDefaults(&ro)
+ require.EqualValues(t, time.Duration(math.MaxInt64), ro.MaxRetryDelay)
+ require.EqualValues(t, 0, ro.MaxRetries)
+ require.EqualValues(t, time.Duration(0), ro.RetryDelay)
+}
+
+func TestCalcDelay(t *testing.T) {
+ // calcDelay introduces some jitter, automatically.
+ ro := RetryOptions{}
+ setDefaults(&ro)
+ d := calcDelay(ro, 0)
+ require.EqualValues(t, 0, d)
+
+ // by default the first calc is 2^attempt
+ d = calcDelay(ro, 1)
+ require.LessOrEqual(t, d, 6*time.Second)
+ require.GreaterOrEqual(t, d, time.Second)
+}
diff --git a/sdk/messaging/azservicebus/liveTestHelpers_test.go b/sdk/messaging/azservicebus/liveTestHelpers_test.go
index 850fb5afb504..353183ef141f 100644
--- a/sdk/messaging/azservicebus/liveTestHelpers_test.go
+++ b/sdk/messaging/azservicebus/liveTestHelpers_test.go
@@ -45,6 +45,9 @@ func createQueue(t *testing.T, connectionString string, queueProperties *admin.Q
queueProperties = &admin.QueueProperties{}
}
+ autoDeleteOnIdle := 5 * time.Minute
+ queueProperties.AutoDeleteOnIdle = &autoDeleteOnIdle
+
_, err = adminClient.CreateQueue(context.Background(), queueName, queueProperties, nil)
require.NoError(t, err)
diff --git a/sdk/messaging/azservicebus/message.go b/sdk/messaging/azservicebus/message.go
index 44fa4963528f..f71dbc9f14f7 100644
--- a/sdk/messaging/azservicebus/message.go
+++ b/sdk/messaging/azservicebus/message.go
@@ -4,14 +4,14 @@
package azservicebus
import (
- "context"
"errors"
"fmt"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
+ "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
"github.com/Azure/go-amqp"
- "github.com/devigned/tab"
)
// ReceivedMessage is a received message from a Client.NewReceiver().
@@ -193,7 +193,7 @@ func (m *Message) toAMQPMessage() *amqp.Message {
// newReceivedMessage creates a received message from an AMQP message.
// NOTE: this converter assumes that the Body of this message will be the first
// serialized byte array in the Data section of the messsage.
-func newReceivedMessage(ctxForLogging context.Context, amqpMsg *amqp.Message) *ReceivedMessage {
+func newReceivedMessage(amqpMsg *amqp.Message) *ReceivedMessage {
msg := &ReceivedMessage{
rawAMQPMessage: amqpMsg,
}
@@ -302,7 +302,7 @@ func newReceivedMessage(ctxForLogging context.Context, amqpMsg *amqp.Message) *R
if err == nil {
msg.LockToken = *(*amqp.UUID)(lockToken)
} else {
- tab.For(ctxForLogging).Info(fmt.Sprintf("msg.DeliveryTag could not be converted into a UUID: %s", err.Error()))
+ log.Writef(internal.EventReceiver, "msg.DeliveryTag could not be converted into a UUID: %s", err.Error())
}
}
diff --git a/sdk/messaging/azservicebus/messageSettler.go b/sdk/messaging/azservicebus/messageSettler.go
index 18b967bf5e86..cc54c49c1c2b 100644
--- a/sdk/messaging/azservicebus/messageSettler.go
+++ b/sdk/messaging/azservicebus/messageSettler.go
@@ -5,14 +5,12 @@ package azservicebus
import (
"context"
- "errors"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/go-amqp"
)
-var errReceiveAndDeleteReceiver = errors.New("messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable")
-
type settler interface {
CompleteMessage(ctx context.Context, message *ReceivedMessage) error
AbandonMessage(ctx context.Context, message *ReceivedMessage, options *AbandonMessageOptions) error
@@ -21,15 +19,17 @@ type settler interface {
}
type messageSettler struct {
- links internal.AMQPLinks
+ links internal.AMQPLinks
+ retryOptions utils.RetryOptions
+
+ // used only for tests
onlyDoBackupSettlement bool
- baseRetrier internal.Retrier
}
-func newMessageSettler(links internal.AMQPLinks, baseRetrier internal.Retrier) settler {
+func newMessageSettler(links internal.AMQPLinks, retryOptions utils.RetryOptions) settler {
return &messageSettler{
- links: links,
- baseRetrier: baseRetrier,
+ links: links,
+ retryOptions: retryOptions,
}
}
@@ -39,44 +39,27 @@ func (s *messageSettler) useManagementLink(m *ReceivedMessage, receiver internal
m.rawAMQPMessage.LinkName() != receiver.LinkName()
}
-func (s *messageSettler) settleWithRetries(ctx context.Context, message *ReceivedMessage, settleFn func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error) error {
+func (s *messageSettler) settleWithRetries(ctx context.Context, message *ReceivedMessage, settleFn func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error) error {
if s == nil {
- return errReceiveAndDeleteReceiver
+ return internal.ErrNonRetriable{Message: "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable"}
}
- retrier := s.baseRetrier.Copy()
- var lastErr error
-
- for retrier.Try(ctx) {
- var receiver internal.AMQPReceiver
- var mgmt internal.MgmtClient
- var linkRevision uint64
-
- _, receiver, mgmt, linkRevision, lastErr = s.links.Get(ctx)
-
- if lastErr != nil {
- _ = s.links.RecoverIfNeeded(ctx, linkRevision, lastErr)
- continue
+ err := s.links.Retry(ctx, "settle", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ if err := settleFn(lwid.Receiver, lwid.RPC); err != nil {
+ return err
}
- lastErr := settleFn(receiver, mgmt)
-
- if lastErr != nil {
- _ = s.links.RecoverIfNeeded(ctx, linkRevision, lastErr)
- continue
- }
-
- break
- }
+ return nil
+ }, utils.RetryOptions{})
- return lastErr
+ return err
}
// CompleteMessage completes a message, deleting it from the queue or subscription.
func (s *messageSettler) CompleteMessage(ctx context.Context, message *ReceivedMessage) error {
- return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error {
+ return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error {
if s.useManagementLink(message, receiver) {
- return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), internal.Disposition{Status: internal.CompletedDisposition}, nil)
+ return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), internal.Disposition{Status: internal.CompletedDisposition}, nil)
} else {
return receiver.AcceptMessage(ctx, message.rawAMQPMessage)
}
@@ -92,7 +75,7 @@ type AbandonMessageOptions struct {
// This will increment its delivery count, and potentially cause it to be dead lettered
// depending on your queue or subscription's configuration.
func (s *messageSettler) AbandonMessage(ctx context.Context, message *ReceivedMessage, options *AbandonMessageOptions) error {
- return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error {
+ return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error {
if s.useManagementLink(message, receiver) {
d := internal.Disposition{
Status: internal.AbandonedDisposition,
@@ -104,7 +87,7 @@ func (s *messageSettler) AbandonMessage(ctx context.Context, message *ReceivedMe
propertiesToModify = options.PropertiesToModify
}
- return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), d, propertiesToModify)
+ return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), d, propertiesToModify)
}
var annotations amqp.Annotations
@@ -125,7 +108,7 @@ type DeferMessageOptions struct {
// DeferMessage will cause a message to be deferred. Deferred messages
// can be received using `Receiver.ReceiveDeferredMessages`.
func (s *messageSettler) DeferMessage(ctx context.Context, message *ReceivedMessage, options *DeferMessageOptions) error {
- return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error {
+ return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error {
if s.useManagementLink(message, receiver) {
d := internal.Disposition{
Status: internal.DeferredDisposition,
@@ -137,7 +120,7 @@ func (s *messageSettler) DeferMessage(ctx context.Context, message *ReceivedMess
propertiesToModify = options.PropertiesToModify
}
- return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), d, propertiesToModify)
+ return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), d, propertiesToModify)
}
var annotations amqp.Annotations
@@ -167,7 +150,7 @@ type DeadLetterOptions struct {
// queue or subscription. To receive these messages create a receiver with `Client.NewReceiver()`
// using the `SubQueue` option.
func (s *messageSettler) DeadLetterMessage(ctx context.Context, message *ReceivedMessage, options *DeadLetterOptions) error {
- return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, mgmt internal.MgmtClient) error {
+ return s.settleWithRetries(ctx, message, func(receiver internal.AMQPReceiver, rpcLink internal.RPCLink) error {
reason := ""
description := ""
@@ -194,7 +177,7 @@ func (s *messageSettler) DeadLetterMessage(ctx context.Context, message *Receive
propertiesToModify = options.PropertiesToModify
}
- return mgmt.SendDisposition(ctx, bytesToAMQPUUID(message.LockToken), d, propertiesToModify)
+ return internal.SendDisposition(ctx, rpcLink, bytesToAMQPUUID(message.LockToken), d, propertiesToModify)
}
info := map[string]interface{}{
diff --git a/sdk/messaging/azservicebus/messageSettler_test.go b/sdk/messaging/azservicebus/messageSettler_test.go
index a88a577dfece..91b09139c19d 100644
--- a/sdk/messaging/azservicebus/messageSettler_test.go
+++ b/sdk/messaging/azservicebus/messageSettler_test.go
@@ -8,6 +8,7 @@ import (
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
"github.com/stretchr/testify/require"
)
@@ -109,10 +110,12 @@ func TestMessageSettlementUsingReceiverWithReceiveAndDelete(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, messages)
- require.EqualError(t, receiver.AbandonMessage(ctx, messages[0], nil), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable")
- require.EqualError(t, receiver.CompleteMessage(ctx, messages[0]), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable")
+ require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.AbandonMessage(ctx, messages[0], nil)).RecoveryKind)
+ require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.CompleteMessage(ctx, messages[0])).RecoveryKind)
+ require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.DeferMessage(ctx, messages[0], nil)).RecoveryKind)
+ require.EqualValues(t, internal.RecoveryKindFatal, internal.GetSBErrInfo(receiver.DeadLetterMessage(ctx, messages[0], nil)).RecoveryKind)
+
require.EqualError(t, receiver.DeadLetterMessage(ctx, messages[0], nil), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable")
- require.EqualError(t, receiver.DeferMessage(ctx, messages[0], nil), "messages that are received in `ReceiveModeReceiveAndDelete` mode are not settleable")
}
func TestDeferredMessages(t *testing.T) {
diff --git a/sdk/messaging/azservicebus/message_test.go b/sdk/messaging/azservicebus/message_test.go
index 6c65c311cc35..823fae6764c2 100644
--- a/sdk/messaging/azservicebus/message_test.go
+++ b/sdk/messaging/azservicebus/message_test.go
@@ -4,7 +4,6 @@
package azservicebus
import (
- "context"
"testing"
"time"
@@ -52,7 +51,7 @@ func TestMessageUnitTest(t *testing.T) {
func TestAMQPMessageToReceivedMessage(t *testing.T) {
t.Run("empty_message", func(t *testing.T) {
// nothing should blow up.
- rm := newReceivedMessage(context.Background(), &amqp.Message{})
+ rm := newReceivedMessage(&amqp.Message{})
require.NotNil(t, rm)
})
@@ -72,7 +71,7 @@ func TestAMQPMessageToReceivedMessage(t *testing.T) {
},
}
- receivedMessage := newReceivedMessage(context.Background(), amqpMessage)
+ receivedMessage := newReceivedMessage(amqpMessage)
require.EqualValues(t, lockedUntil, *receivedMessage.LockedUntil)
require.EqualValues(t, int64(101), *receivedMessage.SequenceNumber)
@@ -132,7 +131,7 @@ func TestAMQPMessageToMessage(t *testing.T) {
Data: [][]byte{[]byte("foo")},
}
- msg := newReceivedMessage(context.Background(), amqpMsg)
+ msg := newReceivedMessage(amqpMsg)
require.EqualValues(t, msg.MessageID, amqpMsg.Properties.MessageID, "messageID")
require.EqualValues(t, msg.SessionID, amqpMsg.Properties.GroupID, "groupID")
diff --git a/sdk/messaging/azservicebus/processor.go b/sdk/messaging/azservicebus/processor.go
deleted file mode 100644
index a3ec0683afa2..000000000000
--- a/sdk/messaging/azservicebus/processor.go
+++ /dev/null
@@ -1,408 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package azservicebus
-
-import (
- "context"
- "errors"
- "fmt"
- "sync"
- "time"
-
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
- "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
- "github.com/Azure/go-amqp"
- "github.com/devigned/tab"
-)
-
-// NOTE: this type is experimental
-
-// processorOptions contains options for the `Client.NewProcessorForQueue` or
-// `Client.NewProcessorForSubscription` functions.
-type processorOptions struct {
- // ReceiveMode controls when a message is deleted from Service Bus.
- //
- // `azservicebus.PeekLock` is the default. The message is locked, preventing multiple
- // receivers from processing the message at once. You control the lock state of the message
- // using one of the message settlement functions like processor.CompleteMessage(), which removes
- // it from Service Bus, or processor.AbandonMessage(), which makes it available again.
- //
- // `azservicebus.ReceiveAndDelete` causes Service Bus to remove the message as soon
- // as it's received.
- //
- // More information about receive modes:
- // https://docs.microsoft.com/azure/service-bus-messaging/message-transfers-locks-settlement#settling-receive-operations
- ReceiveMode ReceiveMode
-
- // SubQueue should be set to connect to the sub queue (ex: dead letter queue)
- // of the queue or subscription.
- SubQueue SubQueue
-
- // DisableAutoComplete controls whether messages must be settled explicitly via the
- // settlement methods (ie, Complete, Abandon) or if the processor will automatically
- // settle messages.
- //
- // If true, no automatic settlement is done.
- // If false, the return value of your `handleMessage` function will control if the
- // message is abandoned (non-nil error return) or completed (nil error return).
- //
- // This option is false, by default.
- DisableAutoComplete bool
-
- // MaxConcurrentCalls controls the maximum number of message processing
- // goroutines that are active at any time.
- // Default is 1.
- MaxConcurrentCalls int
-}
-
-// processor is a push-based receiver for Service Bus.
-type processor struct {
- receiveMode ReceiveMode
- autoComplete bool
- maxConcurrentCalls int
-
- settler settler
- amqpLinks internal.AMQPLinks
-
- mu *sync.Mutex
-
- userMessageHandler func(message *ReceivedMessage) error
- userErrorHandler func(err error)
-
- receiversCtx context.Context
- cancelReceivers func()
-
- wg sync.WaitGroup
-
- baseRetrier internal.Retrier
- cleanupOnClose func()
-}
-
-func applyProcessorOptions(p *processor, entity *entity, options *processorOptions) error {
- if options == nil {
- p.maxConcurrentCalls = 1
- p.receiveMode = ReceiveModePeekLock
- p.autoComplete = true
-
- return nil
- }
-
- p.autoComplete = !options.DisableAutoComplete
-
- if err := checkReceiverMode(options.ReceiveMode); err != nil {
- return err
- }
-
- p.receiveMode = options.ReceiveMode
-
- if err := entity.SetSubQueue(options.SubQueue); err != nil {
- return err
- }
-
- if options.MaxConcurrentCalls > 0 {
- p.maxConcurrentCalls = options.MaxConcurrentCalls
- }
-
- return nil
-}
-
-func newProcessor(ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupOnClose func(), options *processorOptions) (*processor, error) {
- processor := &processor{
- // TODO: make this configurable
- baseRetrier: internal.NewBackoffRetrier(internal.BackoffRetrierParams{
- Factor: 1.5,
- Min: time.Second,
- Max: time.Minute,
- MaxRetries: 10,
- }),
- cleanupOnClose: cleanupOnClose,
- mu: &sync.Mutex{},
- }
-
- if err := applyProcessorOptions(processor, entity, options); err != nil {
- return nil, err
- }
-
- entityPath, err := entity.String()
-
- if err != nil {
- return nil, err
- }
-
- processor.amqpLinks = ns.NewAMQPLinks(entityPath, func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) {
- linkOptions := createLinkOptions(processor.receiveMode, entityPath)
- _, receiver, err := createReceiverLink(ctx, session, linkOptions)
-
- if err != nil {
- return nil, nil, err
- }
-
- if err := receiver.IssueCredit(uint32(processor.maxConcurrentCalls)); err != nil {
- _ = receiver.Close(ctx)
- return nil, nil, err
- }
-
- return nil, receiver, nil
- })
-
- processor.settler = newMessageSettler(processor.amqpLinks, processor.baseRetrier)
- processor.receiversCtx, processor.cancelReceivers = context.WithCancel(context.Background())
-
- return processor, nil
-}
-
-// Start will start receiving messages from the queue or subscription.
-//
-// if err := processor.Start(context.TODO(), messageHandler, errorHandler); err != nil {
-// log.Fatalf("Processor failed to start: %s", err.Error())
-// }
-//
-// Any errors that occur (such as network disconnects, failures in handleMessage) will be
-// sent to your handleError function. The processor will retry and restart as needed -
-// no user intervention is required.
-func (p *processor) Start(ctx context.Context, handleMessage func(message *ReceivedMessage) error, handleError func(err error)) error {
- ctx, span := tab.StartSpan(ctx, tracing.SpanProcessorLoop)
- defer span.End()
-
- err := func() error {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if p.userMessageHandler != nil {
- return errors.New("processor already started")
- }
-
- p.userMessageHandler = handleMessage
- p.userErrorHandler = handleError
-
- p.receiversCtx, p.cancelReceivers = context.WithCancel(ctx)
-
- return nil
- }()
-
- if err != nil {
- return err
- }
-
- for {
- retrier := p.baseRetrier.Copy()
-
- for retrier.Try(p.receiversCtx) {
- if err := p.subscribe(); err != nil {
- if internal.IsCancelError(err) {
- break
- }
- }
- }
-
- select {
- case <-p.receiversCtx.Done():
- // check, did they cancel or did we cancel?
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- return nil
- }
- default:
- }
- }
-
-}
-
-// Close will wait for any pending callbacks to complete.
-// NOTE: Close() cannot be called synchronously in a message
-// or error handler. You must run it asynchronously using
-// `go processor.Close(ctx)` or similar.
-func (p *processor) Close(ctx context.Context) error {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if p.amqpLinks.ClosedPermanently() {
- return nil
- }
-
- ctx, span := tab.StartSpan(ctx, tracing.SpanProcessorClose)
- defer span.End()
-
- defer func() {
- if err := p.amqpLinks.Close(ctx, true); err != nil {
- span.Logger().Debug(fmt.Sprintf("Error closing amqpLinks on processor.Close(): %s", err.Error()))
- }
- }()
-
- p.cleanupOnClose()
-
- _, receiver, _, _, err := p.amqpLinks.Get(ctx)
-
- if err != nil {
- span.Logger().Error(err)
- return err
- }
-
- if err := receiver.DrainCredit(ctx); err != nil {
- span.Logger().Error(err)
- // fall through for now and just let whatever is going on finish
- // otherwise they might not be able to actually close.
- }
-
- p.cancelReceivers()
- return utils.WaitForGroupOrContext(ctx, &p.wg)
-}
-
-// CompleteMessage completes a message, deleting it from the queue or subscription.
-func (p *processor) CompleteMessage(ctx context.Context, message *ReceivedMessage) error {
- return p.settler.CompleteMessage(ctx, message)
-}
-
-// AbandonMessage will cause a message to be returned to the queue or subscription.
-// This will increment its delivery count, and potentially cause it to be dead lettered
-// depending on your queue or subscription's configuration.
-func (p *processor) AbandonMessage(ctx context.Context, message *ReceivedMessage, options *AbandonMessageOptions) error {
- return p.settler.AbandonMessage(ctx, message, options)
-}
-
-// DeferMessage will cause a message to be deferred. Deferred messages
-// can be received using `Receiver.ReceiveDeferredMessages`.
-func (p *processor) DeferMessage(ctx context.Context, message *ReceivedMessage, options *DeferMessageOptions) error {
- return p.settler.DeferMessage(ctx, message, options)
-}
-
-// DeadLetterMessage settles a message by moving it to the dead letter queue for a
-// queue or subscription. To receive these messages create a processor with `Client.NewProcessorForQueue()`
-// or `Client.NewProcessorForSubscription()` using the `ProcessorOptions.SubQueue` option.
-func (p *processor) DeadLetterMessage(ctx context.Context, message *ReceivedMessage, options *DeadLetterOptions) error {
- return p.settler.DeadLetterMessage(ctx, message, options)
-}
-
-// subscribe continually receives messages from Service Bus, stopping
-// if a fatal link/connection error occurs.
-func (p *processor) subscribe() error {
- p.wg.Add(1)
- defer p.wg.Done()
-
- for {
- _, receiver, _, linkRevision, err := p.amqpLinks.Get(p.receiversCtx)
-
- if err != nil {
- if internal.IsCancelError(err) {
- return err
- }
-
- if err := p.amqpLinks.RecoverIfNeeded(p.receiversCtx, linkRevision, err); err != nil {
- p.userErrorHandler(err)
- return err
- }
- }
-
- amqpMessage, err := receiver.Receive(p.receiversCtx)
-
- if err != nil {
- if internal.IsCancelError(err) {
- return err
- }
-
- if err := p.amqpLinks.RecoverIfNeeded(p.receiversCtx, linkRevision, err); err != nil {
- p.userErrorHandler(err)
- }
-
- return nil
- }
-
- if amqpMessage == nil {
- // amqpMessage shouldn't be nil here, but somehow it is.
- // need to track this down in the AMQP library.
- continue
- }
-
- p.wg.Add(1)
-
- go func() {
- defer p.wg.Done()
-
- // purposefully avoiding using `ctx`. We always let processing complete
- // for message threads to avoid potential message loss.
- _ = p.processMessage(context.Background(), receiver, amqpMessage)
- }()
- }
-}
-
-func (p *processor) processMessage(ctx context.Context, receiver internal.AMQPReceiver, amqpMessage *amqp.Message) error {
- ctx, span := tab.StartSpan(ctx, tracing.SpanProcessorMessage)
- defer span.End()
-
- receivedMessage := newReceivedMessage(ctx, amqpMessage)
- messageHandlerErr := p.userMessageHandler(receivedMessage)
-
- if messageHandlerErr != nil {
- p.userErrorHandler(messageHandlerErr)
- }
-
- if p.autoComplete {
- var settleErr error
-
- if messageHandlerErr != nil {
- settleErr = p.settler.AbandonMessage(ctx, receivedMessage, nil)
- } else {
- settleErr = p.settler.CompleteMessage(ctx, receivedMessage)
- }
-
- if settleErr != nil {
- p.userErrorHandler(fmt.Errorf("failed to settle message with ID '%s': %w", receivedMessage.MessageID, settleErr))
- return settleErr
- }
- }
-
- select {
- case <-p.receiversCtx.Done():
- return nil
- default:
- }
-
- if err := receiver.IssueCredit(1); err != nil {
- if !internal.IsDrainingError(err) {
- p.userErrorHandler(err)
- return fmt.Errorf("failed issuing additional credit, processor will be restarted: %w", err)
- }
- }
-
- return nil
-}
-
-func checkReceiverMode(receiveMode ReceiveMode) error {
- if receiveMode == ReceiveModePeekLock || receiveMode == ReceiveModeReceiveAndDelete {
- return nil
- } else {
- return fmt.Errorf("invalid receive mode %d, must be either azservicebus.PeekLock or azservicebus.ReceiveAndDelete", receiveMode)
- }
-}
-
-// newProcessorForQueue creates a Processor for a queue.
-func newProcessorForQueue(client *Client, queue string, options *processorOptions) (*processor, error) {
- id, cleanupOnClose := client.getCleanupForCloseable()
-
- processor, err := newProcessor(client.namespace, &entity{Queue: queue}, cleanupOnClose, options)
-
- if err != nil {
- return nil, err
- }
-
- client.addCloseable(id, processor)
- return processor, nil
-}
-
-// newProcessorForQueue creates a Processor for a subscription.
-func newProcessorForSubscription(client *Client, topic string, subscription string, options *processorOptions) (*processor, error) {
- id, cleanupOnClose := client.getCleanupForCloseable()
-
- processor, err := newProcessor(client.namespace, &entity{Topic: topic, Subscription: subscription}, cleanupOnClose, options)
-
- if err != nil {
- return nil, err
- }
-
- client.addCloseable(id, processor)
- return processor, nil
-}
diff --git a/sdk/messaging/azservicebus/processor_test.go b/sdk/messaging/azservicebus/processor_test.go
deleted file mode 100644
index 73e85cb4101c..000000000000
--- a/sdk/messaging/azservicebus/processor_test.go
+++ /dev/null
@@ -1,368 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package azservicebus
-
-import (
- "context"
- "errors"
- "fmt"
- "sort"
- "sync"
- "testing"
- "time"
-
- "github.com/stretchr/testify/require"
-)
-
-func TestProcessorReceiveWithDefaults(t *testing.T) {
- serviceBusClient, cleanup, queueName := setupLiveTest(t, nil)
- defer cleanup()
-
- go func() {
- sender, err := serviceBusClient.NewSender(queueName, nil)
- require.NoError(t, err)
-
- defer sender.Close(context.Background())
-
- // it's perfectly fine to have the processor started before the messages
- // have been sent.
- for i := 0; i < 5; i++ {
- err = sender.SendMessage(context.Background(), &Message{
- Body: []byte(fmt.Sprintf("hello world %d", i)),
- })
-
- time.Sleep(time.Second)
- }
- require.NoError(t, err)
- }()
-
- processor, err := newProcessorForQueue(serviceBusClient, queueName, nil)
- require.NoError(t, err)
-
- defer processor.Close(context.Background()) // multiple close is fine
-
- var messages []string
- mu := sync.Mutex{}
-
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
- defer cancel()
-
- err = processor.Start(ctx, func(m *ReceivedMessage) error {
- mu.Lock()
- defer mu.Unlock()
-
- body, err := m.Body()
- require.NoError(t, err)
-
- messages = append(messages, string(body))
-
- if len(messages) == 5 {
- cancel()
- }
-
- return nil
- }, func(err error) {
- if errors.Is(err, context.Canceled) {
- return
- }
-
- require.NoError(t, err)
- })
-
- require.ErrorIs(t, err, context.Canceled, "cancelling the context stops the processor")
-
- sort.Strings(messages)
-
- require.EqualValues(t, []string{
- "hello world 0",
- "hello world 1",
- "hello world 2",
- "hello world 3",
- "hello world 4",
- }, messages)
-
- require.NoError(t, processor.Close(context.Background()))
-}
-
-func TestProcessorReceiveWith100MessagesWithMaxConcurrency(t *testing.T) {
- serviceBusClient, cleanup, queueName := setupLiveTest(t, nil)
- defer cleanup()
-
- const numMessages = 100
- var expectedBodies []string
-
- go func() {
- sender, err := serviceBusClient.NewSender(queueName, nil)
- require.NoError(t, err)
-
- defer sender.Close(context.Background())
-
- batch, err := sender.NewMessageBatch(context.Background(), nil)
- require.NoError(t, err)
-
- // it's perfectly fine to have the processor started before the messages
- // have been sent.
- for i := 0; i < numMessages; i++ {
- expectedBodies = append(expectedBodies, fmt.Sprintf("hello world %03d", i))
- err := batch.AddMessage(&Message{
- Body: []byte(expectedBodies[len(expectedBodies)-1]),
- })
- require.NoError(t, err)
- }
-
- require.NoError(t, sender.SendMessageBatch(context.Background(), batch))
- }()
-
- processor, err := newProcessorForQueue(
- serviceBusClient,
- queueName,
- &processorOptions{
- MaxConcurrentCalls: 20,
- })
-
- require.NoError(t, err)
-
- defer func() {
- require.NoError(t, processor.Close(context.Background())) // multiple close is fine
- }()
-
- var messages []string
- mu := sync.Mutex{}
-
- ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2)
- defer cancel()
-
- err = processor.Start(ctx, func(m *ReceivedMessage) error {
- mu.Lock()
- defer mu.Unlock()
-
- body, err := m.Body()
- require.NoError(t, err)
- messages = append(messages, string(body))
-
- if len(messages) == 100 {
- go processor.Close(context.Background())
- }
-
- return nil
- }, func(err error) {
- if errors.Is(err, context.Canceled) {
- return
- }
-
- require.NoError(t, err)
- })
-
- require.NoError(t, err)
-
- sort.Strings(messages)
- require.EqualValues(t, expectedBodies, messages)
-
- require.NoError(t, processor.Close(ctx))
-}
-
-func TestProcessorUnitTests(t *testing.T) {
- p := &processor{}
- e := &entity{}
-
- require.NoError(t, applyProcessorOptions(p, e, nil))
- require.True(t, p.autoComplete)
- require.EqualValues(t, 1, p.maxConcurrentCalls)
- require.EqualValues(t, ReceiveModePeekLock, p.receiveMode)
-
- p = &processor{}
- e = &entity{
- Queue: "queue",
- }
-
- require.NoError(t, applyProcessorOptions(p, e, &processorOptions{
- ReceiveMode: ReceiveModeReceiveAndDelete,
- SubQueue: SubQueueDeadLetter,
- DisableAutoComplete: true,
- MaxConcurrentCalls: 101,
- }))
-
- require.False(t, p.autoComplete)
- require.EqualValues(t, 101, p.maxConcurrentCalls)
- require.EqualValues(t, ReceiveModeReceiveAndDelete, p.receiveMode)
- fullEntityPath, err := e.String()
- require.NoError(t, err)
- require.EqualValues(t, "queue/$DeadLetterQueue", fullEntityPath)
-}
-
-// func TestProcessorUnitTests(t *testing.T) {
-// t.Run("Processor", func(t *testing.T) {
-// t.Run("StartAndClose", func(t *testing.T) {
-
-// })
-
-// t.Run("CloseWaitsForActiveSubscribersToExit", func(t *testing.T) {
-// })
-
-// t.Run("CloseWithoutStart", func(t *testing.T) {
-
-// })
-
-// t.Run("DoubleClose", func(t *testing.T) {
-
-// })
-// })
-
-// t.Run("subscribe", func(t *testing.T) {
-// t.Run("cancelled by user does not retry", func(t *testing.T) {
-// ctx, cancel := context.WithCancel(context.Background())
-// cancel() // pre-cancel this context
-
-// receiver := internal.NewFakeLegacyReceiver()
-// var cancelledError error
-
-// retry := subscribe(ctx, receiver, true, func(message *ReceivedMessage) error {
-// return nil
-// }, func(err error) {
-// cancelledError = err
-// })
-
-// require.EqualError(t, cancelledError, context.Canceled.Error())
-// require.False(t, retry, "User cancelling the context will not be retried")
-// require.False(t, receiver.CloseCalled) // subscribe() is not responsible for the lifetime of the receiver
-// })
-
-// t.Run("error in the listener is retryable", func(t *testing.T) {
-// receiver := internal.NewFakeLegacyReceiver()
-
-// receiver.ListenImpl = func(ctx context.Context, handler internal.Handler) internal.ListenerHandle {
-// ch := make(chan struct{})
-// close(ch)
-// return &internal.FakeListenerHandle{
-// DoneChan: ch,
-// ErrValue: errors.New("Some AMQP related error"),
-// }
-// }
-
-// var errorFromListener error
-
-// retry := subscribe(context.Background(), receiver, true, func(message *ReceivedMessage) error {
-// return nil
-// }, func(err error) {
-// errorFromListener = err
-// })
-
-// require.EqualError(t, errorFromListener, "Some AMQP related error")
-// require.True(t, retry, "AMQP errors will cause us to retry")
-// require.False(t, receiver.CloseCalled) // subscribe() is not responsible for the lifetime of the receiver
-// })
-
-// })
-
-// t.Run("handleSingleMessage", func(t *testing.T) {
-// fakeMessage := &internal.Message{
-// ID: "fakeID",
-// LockToken: &uuid.UUID{},
-// SystemProperties: &internal.SystemProperties{
-// SequenceNumber: to.Int64Ptr(1),
-// },
-// }
-
-// setup := func() *internal.FakeInternalReceiver {
-// receiver := internal.NewFakeLegacyReceiver()
-// return receiver
-// }
-
-// t.Run("AutoCompleteCompleteMessage", func(t *testing.T) {
-// receiver := setup()
-
-// handleSingleMessage(func(message *ReceivedMessage) error {
-// // successful return
-// return nil
-// }, func(err error) {
-// require.NoError(t, err)
-// }, true, receiver, fakeMessage)
-
-// require.True(t, receiver.CompleteCalled)
-// require.False(t, receiver.AbandonCalled)
-// })
-
-// t.Run("AutoCompleteAbandonMessage", func(t *testing.T) {
-// receiver := setup()
-
-// handleSingleMessage(func(message *ReceivedMessage) error {
-// // error returned will abandon
-// return errors.New("Purposefully reported error")
-// }, func(err error) {
-// require.EqualErrorf(t, err, "Purposefully reported error", "Error from the handler gets forwarded")
-// }, true, receiver, fakeMessage)
-
-// require.True(t, receiver.AbandonCalled)
-// require.False(t, receiver.CompleteCalled)
-// })
-
-// t.Run("AutoCompleteAlreadySettledDoNotSettleTwice)", func(t *testing.T) {
-// receiver := setup()
-
-// handleSingleMessage(func(message *ReceivedMessage) error {
-// // error returned will abandon
-// return errors.New("Purposefully reported error")
-// }, func(err error) {
-// require.EqualErrorf(t, err, "Purposefully reported error", "Error from the handler gets forwarded")
-// }, true, receiver, fakeMessage)
-
-// // TODO: neither should be called - the message was already settled.
-// require.True(t, receiver.AbandonCalled)
-// require.False(t, receiver.CompleteCalled)
-// })
-
-// t.Run("autoComplete (off)", func(t *testing.T) {
-// receiver := setup()
-
-// handleSingleMessage(func(message *ReceivedMessage) error {
-// // successful return
-// return nil
-// }, func(err error) {
-// require.NoError(t, err)
-// }, false, receiver, fakeMessage)
-
-// require.False(t, receiver.CompleteCalled)
-// require.False(t, receiver.AbandonCalled)
-// })
-
-// t.Run("SettlementErrorsAreForwarded(complete)", func(t *testing.T) {
-// receiver := setup()
-
-// receiver.CompleteMessageImpl = func(ctx context.Context, msg *internal.Message) error {
-// return errors.New("Complete failed")
-// }
-
-// var settleError error
-
-// handleSingleMessage(func(message *ReceivedMessage) error {
-// return nil
-// }, func(err error) {
-// settleError = err
-// }, true, receiver, fakeMessage)
-
-// require.EqualError(t, settleError, "Complete failed")
-// })
-
-// t.Run("SettlementErrorsAreForwarded(abandon)", func(t *testing.T) {
-// receiver := setup()
-
-// receiver.AbandonMessageImpl = func(ctx context.Context, msg *internal.Message) error {
-// return errors.New("Abandon failed")
-// }
-
-// var settleErrors []string
-
-// handleSingleMessage(func(message *ReceivedMessage) error {
-// return errors.New("Error that caused the abandon")
-// }, func(err error) {
-// settleErrors = append(settleErrors, err.Error())
-// }, true, receiver, fakeMessage)
-
-// require.EqualValues(t, settleErrors, []string{
-// "Error that caused the abandon",
-// "Abandon failed",
-// })
-// })
-// })
-// }
diff --git a/sdk/messaging/azservicebus/receiver.go b/sdk/messaging/azservicebus/receiver.go
index 55d92bf87d81..a0519406501e 100644
--- a/sdk/messaging/azservicebus/receiver.go
+++ b/sdk/messaging/azservicebus/receiver.go
@@ -10,7 +10,9 @@ import (
"sync"
"time"
+ "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/go-amqp"
"github.com/devigned/tab"
)
@@ -42,9 +44,10 @@ const (
// Receiver receives messages using pull based functions (ReceiveMessages).
type Receiver struct {
receiveMode ReceiveMode
+ entityPath string
settler settler
- baseRetrier internal.Retrier
+ retryOptions utils.RetryOptions
cleanupOnClose func()
lastPeekedSequenceNumber int64
@@ -74,65 +77,68 @@ type ReceiverOptions struct {
// SubQueue should be set to connect to the sub queue (ex: dead letter queue)
// of the queue or subscription.
SubQueue SubQueue
+
+ retryOptions utils.RetryOptions
}
const defaultLinkRxBuffer = 2048
func applyReceiverOptions(receiver *Receiver, entity *entity, options *ReceiverOptions) error {
+
if options == nil {
receiver.receiveMode = ReceiveModePeekLock
- return nil
- }
+ } else {
+ if err := checkReceiverMode(options.ReceiveMode); err != nil {
+ return err
+ }
- if err := checkReceiverMode(options.ReceiveMode); err != nil {
- return err
+ receiver.receiveMode = options.ReceiveMode
+
+ if err := entity.SetSubQueue(options.SubQueue); err != nil {
+ return err
+ }
+
+ receiver.retryOptions = options.retryOptions
}
- receiver.receiveMode = options.ReceiveMode
+ entityPath, err := entity.String()
- if err := entity.SetSubQueue(options.SubQueue); err != nil {
+ if err != nil {
return err
}
+ receiver.entityPath = entityPath
return nil
}
-func newReceiver(ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupOnClose func(), options *ReceiverOptions, newLinksFn func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error)) (*Receiver, error) {
+type newReceiverArgs struct {
+ ns internal.NamespaceWithNewAMQPLinks
+ entity entity
+ cleanupOnClose func()
+ newLinkFn func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error)
+}
+
+func newReceiver(args newReceiverArgs, options *ReceiverOptions) (*Receiver, error) {
receiver := &Receiver{
lastPeekedSequenceNumber: 0,
- // TODO: make this configurable
- baseRetrier: internal.NewBackoffRetrier(internal.BackoffRetrierParams{
- Factor: 1.5,
- Jitter: true,
- Min: time.Second,
- Max: time.Minute,
- MaxRetries: 10,
- }),
- cleanupOnClose: cleanupOnClose,
+ cleanupOnClose: args.cleanupOnClose,
}
- if err := applyReceiverOptions(receiver, entity, options); err != nil {
+ if err := applyReceiverOptions(receiver, &args.entity, options); err != nil {
return nil, err
}
- entityPath, err := entity.String()
+ newLinkFn := receiver.newReceiverLink
- if err != nil {
- return nil, err
- }
-
- if newLinksFn == nil {
- newLinksFn = func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) {
- linkOptions := createLinkOptions(receiver.receiveMode, entityPath)
- return createReceiverLink(ctx, session, linkOptions)
- }
+ if args.newLinkFn != nil {
+ newLinkFn = args.newLinkFn
}
- receiver.amqpLinks = ns.NewAMQPLinks(entityPath, newLinksFn)
+ receiver.amqpLinks = args.ns.NewAMQPLinks(receiver.entityPath, newLinkFn)
// 'nil' settler handles returning an error message for receiveAndDelete links.
if receiver.receiveMode == ReceiveModePeekLock {
- receiver.settler = newMessageSettler(receiver.amqpLinks, receiver.baseRetrier)
+ receiver.settler = newMessageSettler(receiver.amqpLinks, receiver.retryOptions)
} else {
receiver.settler = (*messageSettler)(nil)
}
@@ -140,6 +146,12 @@ func newReceiver(ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupO
return receiver, nil
}
+func (r *Receiver) newReceiverLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) {
+ linkOptions := createLinkOptions(r.receiveMode, r.entityPath)
+ link, err := createReceiverLink(ctx, session, linkOptions)
+ return nil, link, err
+}
+
// ReceiveMessagesOptions are options for the ReceiveMessages function.
type ReceiveMessagesOptions struct {
// For future expansion
@@ -174,25 +186,27 @@ func (r *Receiver) ReceiveMessages(ctx context.Context, maxMessages int, options
// ReceiveDeferredMessages receives messages that were deferred using `Receiver.DeferMessage`.
func (r *Receiver) ReceiveDeferredMessages(ctx context.Context, sequenceNumbers []int64) ([]*ReceivedMessage, error) {
- _, _, mgmt, _, err := r.amqpLinks.Get(ctx)
+ var receivedMessages []*ReceivedMessage
- if err != nil {
- return nil, err
- }
+ err := r.amqpLinks.Retry(ctx, "receiveDeferredMessage", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ amqpMessages, err := internal.ReceiveDeferred(ctx, lwid.RPC, r.receiveMode, sequenceNumbers)
- amqpMessages, err := mgmt.ReceiveDeferred(ctx, r.receiveMode, sequenceNumbers)
+ if err != nil {
+ return err
+ }
- if err != nil {
- return nil, err
- }
+ for _, amqpMsg := range amqpMessages {
+ receivedMsg := newReceivedMessage(amqpMsg)
+ receivedMsg.deferred = true
- var receivedMessages []*ReceivedMessage
+ receivedMessages = append(receivedMessages, receivedMsg)
+ }
- for _, amqpMsg := range amqpMessages {
- receivedMsg := newReceivedMessage(ctx, amqpMsg)
- receivedMsg.deferred = true
+ return nil
+ }, utils.RetryOptions(r.retryOptions))
- receivedMessages = append(receivedMessages, receivedMsg)
+ if err != nil {
+ return nil, err
}
return receivedMessages, nil
@@ -210,35 +224,39 @@ type PeekMessagesOptions struct {
// like CompleteMessage, AbandonMessage, DeferMessage or DeadLetterMessage
// will not work with them.
func (r *Receiver) PeekMessages(ctx context.Context, maxMessageCount int, options *PeekMessagesOptions) ([]*ReceivedMessage, error) {
- _, _, mgmt, _, err := r.amqpLinks.Get(ctx)
+ var receivedMessages []*ReceivedMessage
- if err != nil {
- return nil, err
- }
+ err := r.amqpLinks.Retry(ctx, "peekMessages", func(ctx context.Context, links *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ var sequenceNumber = r.lastPeekedSequenceNumber + 1
+ updateInternalSequenceNumber := true
- var sequenceNumber = r.lastPeekedSequenceNumber + 1
- updateInternalSequenceNumber := true
+ if options != nil && options.FromSequenceNumber != nil {
+ sequenceNumber = *options.FromSequenceNumber
+ updateInternalSequenceNumber = false
+ }
- if options != nil && options.FromSequenceNumber != nil {
- sequenceNumber = *options.FromSequenceNumber
- updateInternalSequenceNumber = false
- }
+ messages, err := internal.PeekMessages(ctx, links.RPC, sequenceNumber, int32(maxMessageCount))
- messages, err := mgmt.PeekMessages(ctx, sequenceNumber, int32(maxMessageCount))
+ if err != nil {
+ return err
+ }
- if err != nil {
- return nil, err
- }
+ receivedMessages = make([]*ReceivedMessage, len(messages))
- receivedMessages := make([]*ReceivedMessage, len(messages))
+ for i := 0; i < len(messages); i++ {
+ receivedMessages[i] = newReceivedMessage(messages[i])
+ }
- for i := 0; i < len(messages); i++ {
- receivedMessages[i] = newReceivedMessage(ctx, messages[i])
- }
+ if len(receivedMessages) > 0 && updateInternalSequenceNumber {
+ // only update this if they're doing the implicit iteration as part of the receiver.
+ r.lastPeekedSequenceNumber = *receivedMessages[len(receivedMessages)-1].SequenceNumber
+ }
+
+ return nil
+ }, r.retryOptions)
- if len(receivedMessages) > 0 && updateInternalSequenceNumber {
- // only update this if they're doing the implicit iteration as part of the receiver.
- r.lastPeekedSequenceNumber = *receivedMessages[len(receivedMessages)-1].SequenceNumber
+ if err != nil {
+ return nil, err
}
return receivedMessages, nil
@@ -246,22 +264,19 @@ func (r *Receiver) PeekMessages(ctx context.Context, maxMessageCount int, option
// RenewLock renews the lock on a message, updating the `LockedUntil` field on `msg`.
func (r *Receiver) RenewMessageLock(ctx context.Context, msg *ReceivedMessage) error {
- _, _, mgmt, _, err := r.amqpLinks.Get(ctx)
+ return r.amqpLinks.Retry(ctx, "renewMessageLock", func(ctx context.Context, linksWithVersion *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ newExpirationTime, err := internal.RenewLocks(ctx, linksWithVersion.RPC, msg.rawAMQPMessage.LinkName(), []amqp.UUID{
+ (amqp.UUID)(msg.LockToken),
+ })
- if err != nil {
- return err
- }
-
- newExpirationTime, err := mgmt.RenewLocks(ctx, msg.rawAMQPMessage.LinkName(), []amqp.UUID{
- (amqp.UUID)(msg.LockToken),
- })
+ if err != nil {
+ return err
+ }
- if err != nil {
- return err
- }
+ msg.LockedUntil = &newExpirationTime[0]
+ return nil
+ }, r.retryOptions)
- msg.LockedUntil = &newExpirationTime[0]
- return nil
}
// Close permanently closes the receiver.
@@ -333,20 +348,10 @@ func (r *Receiver) receiveMessagesImpl(ctx context.Context, maxMessages int, opt
// user isn't actually waiting for anymore. So we make sure that #3 runs if the
// link is still valid.
// Phase 3.
- _, receiver, _, linksRevision, err := r.amqpLinks.Get(ctx)
- if err != nil {
- if err := r.amqpLinks.RecoverIfNeeded(ctx, linksRevision, err); err != nil {
- return nil, err
- }
-
- return nil, err
- }
-
- if err := receiver.IssueCredit(uint32(maxMessages)); err != nil {
- _ = r.amqpLinks.RecoverIfNeeded(ctx, linksRevision, err)
- return nil, err
- }
+ var all []*ReceivedMessage
+ var cancel context.CancelFunc
+ needsDrain := true
defaultTimeAfterFirstMessage := 20 * time.Millisecond
@@ -354,97 +359,103 @@ func (r *Receiver) receiveMessagesImpl(ctx context.Context, maxMessages int, opt
defaultTimeAfterFirstMessage = time.Second
}
- messages, err := r.getMessages(ctx, receiver, maxMessages, defaultTimeAfterFirstMessage)
+ var linksWithID *internal.LinksWithID
+
+ err := r.amqpLinks.Retry(ctx, "receiveMessages", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ if args.LastErr != nil {
+ // we recovered succesfully (amqpLinks does it for us), we can reset our retry attempts.
+ // This fixes a potential problem where something like this happens:
+ // a. amqplink.Receive() returns an error
+ // b. we attempt recovery a few times and recover just in the nick of time at attempt #3.
+ // d. amqplink.Receive() blocks (no messages in queue, for instance)
+ //
+ // e. amqpLink.Receive() returns an error
+ //
+ // We don't want to the # of attempts to already be '3' when we get to
+ // step 5, so we reset here so the next set of retry attempts starts fresh.
+ args.ResetAttempts()
+ }
- if err != nil {
- return nil, err
- }
+ linksWithID = lwid
+ credits := maxMessages - len(all)
- if len(messages) == maxMessages {
- // no drain needed, all messages arrived.
- return messages, nil
- }
+ if err := lwid.Receiver.IssueCredit(uint32(credits)); err != nil {
+ return err
+ }
- return r.drainLink(ctx, receiver, messages)
-}
+ got := 0
-// drainLink initiates a drainLink on the link. Service Bus will send whatever messages it might have still had and
-// set our link credit to 0.
-// ctxForLoggingOnly is literally only used for when we need to extract context for logging. This function will always attempt
-// to complete, ignoring cancellation, otherwise we can leave the link with messages that haven't been returned to the user.
-func (r *Receiver) drainLink(ctxForLoggingOnly context.Context, receiver internal.AMQPReceiver, messages []*ReceivedMessage) ([]*ReceivedMessage, error) {
- // start the drain asynchronously. Note that we ignore the user's context at this point
- // since draining makes sure we don't get messages when nobody is receiving.
- if err := receiver.DrainCredit(context.Background()); err != nil {
- tab.For(ctxForLoggingOnly).Debug(fmt.Sprintf("Draining of credit failed. link will be closed and will re-open on next receive: %s", err.Error()))
-
- // if the drain fails we just close the link so it'll re-open at the next receive.
- if err := r.amqpLinks.Close(context.Background(), false); err != nil {
- tab.For(ctxForLoggingOnly).Debug(fmt.Sprintf("Failed to close links on ReceiveMessages cleanup. Not fatal: %s", err.Error()))
- }
- }
+ for {
+ amqpMessage, err := lwid.Receiver.Receive(ctx)
- // Draining data from the receiver's prefetched queue. This won't wait for new messages to
- // arrive, so it'll only receive messages that arrived prior to the drain.
- for {
- am, err := receiver.Prefetched(context.Background())
+ if err != nil {
+ return err
+ }
- if am == nil || internal.IsCancelError(err) {
- break
- }
+ all = append(all, newReceivedMessage(amqpMessage))
+ got++
- if err != nil {
- // something fatal happened, we will just
- _ = r.amqpLinks.Close(context.TODO(), false)
+ if got == credits {
+ // no excess credits on link
+ needsDrain = false
+ break
+ }
- if len(messages) > 0 {
- return messages, nil
- } else {
- return nil, err
+ if cancel == nil {
+ // replace the context that we're using for everything with a new one that will cancel
+ // after a period of time.
+ ctx, cancel = context.WithTimeout(ctx, defaultTimeAfterFirstMessage)
+ defer cancel()
}
}
- messages = append(messages, newReceivedMessage(ctxForLoggingOnly, am))
+ return nil
+ }, utils.RetryOptions(r.retryOptions))
+
+ ret := func(err error) ([]*ReceivedMessage, error) {
+ if len(all) > 0 {
+ // we don't return the error here because we did retrieve _some_ messages and you can still
+ // use them.
+ return all, nil
+ } else {
+ return nil, err
+ }
}
- return messages, nil
-}
-
-// getMessages receives messages until a link failure, timeout or the user
-// cancels their context.
-func (r *Receiver) getMessages(ctx context.Context, receiver internal.AMQPReceiver, maxMessages int, maxWaitTimeAfterFirstMessage time.Duration) ([]*ReceivedMessage, error) {
- var messages []*ReceivedMessage
-
- for {
- var amqpMessage *amqp.Message
- amqpMessage, err := receiver.Receive(ctx)
+ if err != nil && !internal.IsCancelError(err) {
+ return ret(err)
+ }
- if err != nil {
- if internal.IsCancelError(err) {
- return messages, nil
+ if needsDrain {
+ // start the drain asynchronously. Note that we ignore the user's context at this point
+ // since draining makes sure we don't get messages when nobody is receiving.
+ if err := linksWithID.Receiver.DrainCredit(context.Background()); err != nil {
+ if err := r.amqpLinks.RecoverIfNeeded(context.Background(), linksWithID.ID, err); err != nil {
+ log.Writef(internal.EventReceiver, "Failed to recover links after a failed drain: %s", err.Error())
+ return ret(err)
}
- // we'll close (instead of recovering) since we are holding onto messages
- // and want to get them back to the user ASAP. (recovery will just happen
- // on the next call to receive)
- if err := r.amqpLinks.Close(context.Background(), false); err != nil {
- tab.For(ctx).Debug(fmt.Sprintf("Failed to close links on ReceiveMessages cleanup. Not fatal: %s", err.Error()))
- }
- return nil, err
+ return ret(err)
}
- messages = append(messages, newReceivedMessage(ctx, amqpMessage))
+ // Draining data from the receiver's prefetched queue. This won't wait for new messages to
+ // arrive, so it'll only receive messages that arrived prior to the drain.
+ for {
+ am, err := linksWithID.Receiver.Prefetched(context.Background())
- if len(messages) == maxMessages {
- return messages, nil
- }
+ if am == nil || internal.IsCancelError(err) {
+ return all, nil
+ }
- if len(messages) == 1 {
- var cancel context.CancelFunc
- ctx, cancel = context.WithTimeout(ctx, maxWaitTimeAfterFirstMessage)
- defer cancel()
+ if err != nil {
+ return ret(err)
+ }
+
+ all = append(all, newReceivedMessage(am))
}
}
+
+ return all, nil
}
type entity struct {
@@ -485,15 +496,15 @@ func (e *entity) SetSubQueue(subQueue SubQueue) error {
return fmt.Errorf("unknown SubQueue %d", subQueue)
}
-func createReceiverLink(ctx context.Context, session internal.AMQPSession, linkOptions []amqp.LinkOption) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) {
+func createReceiverLink(ctx context.Context, session internal.AMQPSession, linkOptions []amqp.LinkOption) (internal.AMQPReceiverCloser, error) {
amqpReceiver, err := session.NewReceiver(linkOptions...)
if err != nil {
tab.For(ctx).Error(err)
- return nil, nil, err
+ return nil, err
}
- return nil, amqpReceiver, nil
+ return amqpReceiver, nil
}
func createLinkOptions(mode ReceiveMode, entityPath string) []amqp.LinkOption {
@@ -516,3 +527,11 @@ func createLinkOptions(mode ReceiveMode, entityPath string) []amqp.LinkOption {
return opts
}
+
+func checkReceiverMode(receiveMode ReceiveMode) error {
+ if receiveMode == ReceiveModePeekLock || receiveMode == ReceiveModeReceiveAndDelete {
+ return nil
+ }
+
+ return fmt.Errorf("invalid receive mode %d, must be either azservicebus.PeekLock or azservicebus.ReceiveAndDelete", receiveMode)
+}
diff --git a/sdk/messaging/azservicebus/receiver_test.go b/sdk/messaging/azservicebus/receiver_test.go
index c0363442f460..6cbd067462a2 100644
--- a/sdk/messaging/azservicebus/receiver_test.go
+++ b/sdk/messaging/azservicebus/receiver_test.go
@@ -11,7 +11,10 @@ import (
"testing"
"time"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/require"
)
@@ -343,6 +346,62 @@ func TestReceiverPeek(t *testing.T) {
require.Empty(t, noMessagesExpected)
}
+func TestReceiverDetach(t *testing.T) {
+ // NOTE: uncomment this to see some of the background reconnects
+ // azlog.SetListener(func(e azlog.Event, s string) {
+ // log.Printf("%s %s", e, s)
+ // })
+
+ serviceBusClient, cleanup, queueName := setupLiveTest(t, nil)
+ defer cleanup()
+
+ adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil)
+ require.NoError(t, err)
+
+ receiver, err := serviceBusClient.NewReceiverForQueue(queueName, nil)
+ require.NoError(t, err)
+
+ // make sure the receiver link and connection are live.
+ _, err = receiver.PeekMessages(context.Background(), 1, nil)
+ require.NoError(t, err)
+
+ sender, err := serviceBusClient.NewSender(queueName, nil)
+ require.NoError(t, err)
+
+ err = sender.SendMessage(context.Background(), &Message{
+ Body: []byte("hello world"),
+ })
+ require.NoError(t, err)
+ require.NoError(t, sender.Close(context.Background()))
+
+ // force a detach to happen
+ _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil)
+ require.NoError(t, err)
+
+ messages, err := receiver.ReceiveMessages(context.Background(), 1, nil)
+ require.NoError(t, err)
+ require.EqualValues(t, []string{"hello world"}, getSortedBodies(messages))
+
+ // force a detach to happen
+ _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil)
+ require.NoError(t, err)
+
+ peekedMessages, err := receiver.PeekMessages(context.Background(), 1, nil)
+ require.NoError(t, err)
+ require.EqualValues(t, []string{"hello world"}, getSortedBodies(peekedMessages))
+
+ // force a detach to happen
+ _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil)
+ require.NoError(t, err)
+
+ require.NoError(t, receiver.CompleteMessage(context.Background(), messages[0]))
+
+ // and last, check that the queue is properly empty
+ peekedMessages, err = receiver.PeekMessages(context.Background(), 1, nil)
+ require.NoError(t, err)
+ require.Empty(t, peekedMessages)
+}
+
func TestReceiver_RenewMessageLock(t *testing.T) {
client, cleanup, queueName := setupLiveTest(t, nil)
defer cleanup()
@@ -401,19 +460,23 @@ func TestReceiverOptions(t *testing.T) {
require.NoError(t, applyReceiverOptions(receiver, e, &ReceiverOptions{
ReceiveMode: ReceiveModeReceiveAndDelete,
SubQueue: SubQueueTransfer,
+ retryOptions: utils.RetryOptions{
+ MaxRetries: 101,
+ },
}))
require.EqualValues(t, ReceiveModeReceiveAndDelete, receiver.receiveMode)
path, err = e.String()
require.NoError(t, err)
require.EqualValues(t, "topic/Subscriptions/subscription/$Transfer/$DeadLetterQueue", path)
+ require.EqualValues(t, 101, receiver.retryOptions.MaxRetries)
}
-type badMgmtClient struct {
- internal.MgmtClient
+type badRPCLink struct {
+ internal.RPCLink
}
-func (b badMgmtClient) ReceiveDeferred(ctx context.Context, mode ReceiveMode, sequenceNumbers []int64) ([]*amqp.Message, error) {
+func (br *badRPCLink) RPC(ctx context.Context, msg *amqp.Message) (*internal.RPCResponse, error) {
return nil, errors.New("receive deferred messages failed")
}
@@ -430,7 +493,7 @@ func TestReceiverDeferUnitTests(t *testing.T) {
r = &Receiver{
amqpLinks: &internal.FakeAMQPLinks{
- Mgmt: &badMgmtClient{},
+ RPC: &badRPCLink{},
},
}
diff --git a/sdk/messaging/azservicebus/sender.go b/sdk/messaging/azservicebus/sender.go
index b8613e4c906c..e6457fc685d1 100644
--- a/sdk/messaging/azservicebus/sender.go
+++ b/sdk/messaging/azservicebus/sender.go
@@ -9,6 +9,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/go-amqp"
"github.com/devigned/tab"
)
@@ -19,7 +20,7 @@ type (
queueOrTopic string
cleanupOnClose func()
links internal.AMQPLinks
- retryOptions internal.RetryOptions
+ retryOptions utils.RetryOptions
}
)
@@ -40,25 +41,10 @@ type MessageBatchOptions struct {
// messages. Sending a batch of messages is more efficient than sending the
// messages one at a time.
func (s *Sender) NewMessageBatch(ctx context.Context, options *MessageBatchOptions) (*MessageBatch, error) {
- var lastRevision uint64
var batch *MessageBatch
- err := internal.Retry(ctx, "send", func(ctx context.Context, args *internal.RetryFnArgs) error {
- if args.LastErr != nil {
- if err := s.links.RecoverIfNeeded(ctx, lastRevision, args.LastErr); err != nil {
- return err
- }
- }
-
- sender, _, _, lr, err := s.links.Get(ctx)
-
- if err != nil {
- return err
- }
-
- lastRevision = lr
-
- maxBytes := sender.MaxMessageSize()
+ err := s.links.Retry(ctx, "send", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ maxBytes := lwid.Sender.MaxMessageSize()
if options != nil && options.MaxBytes != 0 {
maxBytes = options.MaxBytes
@@ -66,7 +52,7 @@ func (s *Sender) NewMessageBatch(ctx context.Context, options *MessageBatchOptio
batch = newMessageBatch(maxBytes)
return nil
- }, nil, s.retryOptions)
+ }, s.retryOptions)
if err != nil {
return nil, err
@@ -77,55 +63,23 @@ func (s *Sender) NewMessageBatch(ctx context.Context, options *MessageBatchOptio
// SendMessage sends a Message to a queue or topic.
func (s *Sender) SendMessage(ctx context.Context, message *Message) error {
- ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendMessage)
- defer span.End()
-
- var lastRevision uint64
-
- return internal.Retry(ctx, "send", func(ctx context.Context, args *internal.RetryFnArgs) error {
- if args.LastErr != nil {
- if err := s.links.RecoverIfNeeded(ctx, lastRevision, args.LastErr); err != nil {
- return err
- }
- }
-
- sender, _, _, lr, err := s.links.Get(ctx)
-
- if err != nil {
- return err
- }
-
- lastRevision = lr
+ return s.links.Retry(ctx, "SendMessage", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendMessage)
+ defer span.End()
- return sender.Send(ctx, message.toAMQPMessage())
- }, nil, s.retryOptions)
+ return lwid.Sender.Send(ctx, message.toAMQPMessage())
+ }, utils.RetryOptions(s.retryOptions))
}
// SendMessageBatch sends a MessageBatch to a queue or topic.
// Message batches can be created using `Sender.NewMessageBatch`.
func (s *Sender) SendMessageBatch(ctx context.Context, batch *MessageBatch) error {
- ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendBatch)
- defer span.End()
-
- var lastRevision uint64
-
- return internal.Retry(ctx, "send", func(ctx context.Context, args *internal.RetryFnArgs) error {
- if args.LastErr != nil {
- if err := s.links.RecoverIfNeeded(ctx, lastRevision, args.LastErr); err != nil {
- return err
- }
- }
+ return s.links.Retry(ctx, "SendMessageBatch", func(ctx context.Context, lwid *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ ctx, span := s.startProducerSpanFromContext(ctx, spanNameSendBatch)
+ defer span.End()
- sender, _, _, lr, err := s.links.Get(ctx)
-
- if err != nil {
- return err
- }
-
- lastRevision = lr
-
- return sender.Send(ctx, batch.toAMQPMessage())
- }, nil, s.retryOptions)
+ return lwid.Sender.Send(ctx, batch.toAMQPMessage())
+ }, utils.RetryOptions(s.retryOptions))
}
// ScheduleMessages schedules a slice of Messages to appear on Service Bus Queue/Subscription at a later time.
@@ -144,14 +98,10 @@ func (s *Sender) ScheduleMessages(ctx context.Context, messages []*Message, sche
// MessageBatch changes
// CancelScheduledMessages cancels multiple messages that were scheduled.
-func (s *Sender) CancelScheduledMessages(ctx context.Context, sequenceNumber []int64) error {
- _, _, mgmt, _, err := s.links.Get(ctx)
-
- if err != nil {
- return err
- }
-
- return mgmt.CancelScheduled(ctx, sequenceNumber...)
+func (s *Sender) CancelScheduledMessages(ctx context.Context, sequenceNumbers []int64) error {
+ return s.links.Retry(ctx, "cancelScheduledMessage", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ return internal.CancelScheduledMessages(ctx, lwv.RPC, sequenceNumbers)
+ }, s.retryOptions)
}
// Close permanently closes the Sender.
@@ -161,13 +111,19 @@ func (s *Sender) Close(ctx context.Context) error {
}
func (s *Sender) scheduleAMQPMessages(ctx context.Context, messages []*amqp.Message, scheduledEnqueueTime time.Time) ([]int64, error) {
- _, _, mgmt, _, err := s.links.Get(ctx)
+ var sequenceNumbers []int64
- if err != nil {
- return nil, err
- }
+ err := s.links.Retry(ctx, "scheduleMessages", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ sn, err := internal.ScheduleMessages(ctx, lwv.RPC, scheduledEnqueueTime, messages)
+
+ if err != nil {
+ return err
+ }
+ sequenceNumbers = sn
+ return nil
+ }, s.retryOptions)
- return mgmt.ScheduleMessages(ctx, scheduledEnqueueTime, messages...)
+ return sequenceNumbers, err
}
func (sender *Sender) createSenderLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) {
@@ -184,14 +140,20 @@ func (sender *Sender) createSenderLink(ctx context.Context, session internal.AMQ
return amqpSender, nil, nil
}
-func newSender(ns internal.NamespaceWithNewAMQPLinks, queueOrTopic string, cleanupOnClose func()) (*Sender, error) {
+type newSenderArgs struct {
+ ns internal.NamespaceWithNewAMQPLinks
+ queueOrTopic string
+ cleanupOnClose func()
+}
+
+func newSender(args newSenderArgs, retryOptions RetryOptions) (*Sender, error) {
sender := &Sender{
- queueOrTopic: queueOrTopic,
- cleanupOnClose: cleanupOnClose,
- retryOptions: internal.RetryOptions{},
+ queueOrTopic: args.queueOrTopic,
+ cleanupOnClose: args.cleanupOnClose,
+ retryOptions: utils.RetryOptions(retryOptions),
}
- sender.links = ns.NewAMQPLinks(queueOrTopic, sender.createSenderLink)
+ sender.links = args.ns.NewAMQPLinks(args.queueOrTopic, sender.createSenderLink)
return sender, nil
}
diff --git a/sdk/messaging/azservicebus/sender_test.go b/sdk/messaging/azservicebus/sender_test.go
index 585ef801ee23..801cb52d3a90 100644
--- a/sdk/messaging/azservicebus/sender_test.go
+++ b/sdk/messaging/azservicebus/sender_test.go
@@ -5,12 +5,14 @@ package azservicebus
import (
"context"
+ "fmt"
"sort"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test"
"github.com/stretchr/testify/require"
)
@@ -276,6 +278,118 @@ func Test_Sender_ScheduleMessages(t *testing.T) {
}
}
+func TestSender_SendMessagesDetach(t *testing.T) {
+ // NOTE: uncomment this to see some of the background reconnects
+ // azlog.SetListener(func(e azlog.Event, s string) {
+ // log.Printf("%s %s", e, s)
+ // })
+
+ sbc, cleanup, queueName := setupLiveTest(t, nil)
+ defer cleanup()
+
+ adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil)
+ require.NoError(t, err)
+
+ sender, err := sbc.NewSender(queueName, nil)
+ require.NoError(t, err)
+
+ // make sure the sender link is open and active.
+ err = sender.SendMessage(context.Background(), &Message{
+ Body: []byte("0"),
+ })
+ require.NoError(t, err)
+
+ // now force a detach to happen
+ _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil)
+ require.NoError(t, err)
+
+ for i := 1; i < 5; i++ {
+ err = sender.SendMessage(context.Background(), &Message{
+ Body: []byte(fmt.Sprintf("%d", i)),
+ })
+ require.NoError(t, err)
+ }
+
+ receiver, err := sbc.NewReceiverForQueue(queueName, &ReceiverOptions{
+ ReceiveMode: ReceiveModeReceiveAndDelete,
+ })
+ require.NoError(t, err)
+
+ // get all the messages
+ var all []*ReceivedMessage
+
+ for {
+ messages, err := receiver.ReceiveMessages(context.Background(), 5, nil)
+ require.NoError(t, err)
+
+ all = append(messages, all...)
+
+ if len(all) == 5 {
+ break
+ }
+ }
+
+ require.EqualValues(t, []string{"0", "1", "2", "3", "4"}, getSortedBodies(all))
+}
+
+func TestSender_SendMessageBatchDetach(t *testing.T) {
+ // NOTE: uncomment this to see some of the background reconnects
+ // azlog.SetListener(func(e azlog.Event, s string) {
+ // log.Printf("%s %s", e, s)
+ // })
+
+ sbc, cleanup, queueName := setupLiveTest(t, nil)
+ defer cleanup()
+
+ adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil)
+ require.NoError(t, err)
+
+ sender, err := sbc.NewSender(queueName, nil)
+ require.NoError(t, err)
+
+ // make sure the sender link is open and active.
+ err = sender.SendMessage(context.Background(), &Message{
+ Body: []byte("0"),
+ })
+ require.NoError(t, err)
+
+ // now force a detach to happen
+ _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{}, nil)
+ require.NoError(t, err)
+
+ for i := 1; i < 5; i++ {
+ batch, err := sender.NewMessageBatch(context.Background(), nil)
+ require.NoError(t, err)
+ require.NoError(t, batch.AddMessage(&Message{
+ Body: []byte(fmt.Sprintf("%d", i)),
+ }))
+
+ err = sender.SendMessageBatch(context.Background(), batch)
+ require.NoError(t, err)
+ }
+
+ receiver, err := sbc.NewReceiverForQueue(queueName, &ReceiverOptions{
+ ReceiveMode: ReceiveModeReceiveAndDelete,
+ })
+ require.NoError(t, err)
+
+ // get all the messages
+ var all []*ReceivedMessage
+
+ for {
+ messages, err := receiver.ReceiveMessages(context.Background(), 5, nil)
+ require.NoError(t, err)
+
+ all = append(messages, all...)
+
+ if len(all) == 5 {
+ break
+ }
+ }
+
+ require.EqualValues(t, []string{"0", "1", "2", "3", "4"}, getSortedBodies(all))
+}
+
func getSortedBodies(messages []*ReceivedMessage) []string {
sort.Sort(receivedMessages(messages))
diff --git a/sdk/messaging/azservicebus/session_receiver.go b/sdk/messaging/azservicebus/session_receiver.go
index 95517483734d..9a7af6c11ba2 100644
--- a/sdk/messaging/azservicebus/session_receiver.go
+++ b/sdk/messaging/azservicebus/session_receiver.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/go-amqp"
)
@@ -47,57 +48,64 @@ func toReceiverOptions(sropts *SessionReceiverOptions) *ReceiverOptions {
}
}
-func newSessionReceiver(ctx context.Context, sessionID *string, ns internal.NamespaceWithNewAMQPLinks, entity *entity, cleanupOnClose func(), options *ReceiverOptions) (*SessionReceiver, error) {
- const sessionFilterName = "com.microsoft:session-filter"
- const code = uint64(0x00000137000000C)
-
+func newSessionReceiver(ctx context.Context, sessionID *string, ns internal.NamespaceWithNewAMQPLinks, entity entity, cleanupOnClose func(), options *ReceiverOptions) (*SessionReceiver, error) {
sessionReceiver := &SessionReceiver{
sessionID: sessionID,
lockedUntil: time.Time{},
}
- var err error
+ r, err := newReceiver(newReceiverArgs{
+ ns: ns,
+ entity: entity,
+ cleanupOnClose: cleanupOnClose,
+ newLinkFn: sessionReceiver.newLink,
+ }, options)
- sessionReceiver.inner, err = newReceiver(ns, entity, cleanupOnClose, options, func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) {
- linkOptions := createLinkOptions(sessionReceiver.inner.receiveMode, sessionReceiver.inner.amqpLinks.EntityPath())
+ if err != nil {
+ return nil, err
+ }
- if sessionID == nil {
- linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, nil))
- } else {
- linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, sessionID))
- }
+ sessionReceiver.inner = r
- _, link, err := createReceiverLink(ctx, session, linkOptions)
+ // temp workaround until we expose the session expiration time from the receiver in go-amqp
+ if err := sessionReceiver.RenewSessionLock(ctx); err != nil {
+ _ = sessionReceiver.Close(context.Background())
+ return nil, err
+ }
- if err != nil {
- return nil, nil, err
- }
+ return sessionReceiver, nil
+}
- // check the session ID that came back - if we asked for a named session ID and didn't get it then
- // we failed to get the lock.
- // if we specified nil then we can _set_ our internally held session ID now that we know the value.
- receivedSessionID := link.LinkSourceFilterValue(sessionFilterName)
- asStr, ok := receivedSessionID.(string)
+func (r *SessionReceiver) newLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) {
+ const sessionFilterName = "com.microsoft:session-filter"
+ const code = uint64(0x00000137000000C)
- if !ok || (sessionID != nil && asStr != *sessionID) {
- return nil, nil, fmt.Errorf("invalid type/value for returned sessionID(type:%T, value:%v)", receivedSessionID, receivedSessionID)
- }
+ linkOptions := createLinkOptions(r.inner.receiveMode, r.inner.amqpLinks.EntityPath())
- sessionReceiver.sessionID = &asStr
- return nil, link, nil
- })
+ if r.sessionID == nil {
+ linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, nil))
+ } else {
+ linkOptions = append(linkOptions, amqp.LinkSourceFilter(sessionFilterName, code, r.sessionID))
+ }
+
+ link, err := createReceiverLink(ctx, session, linkOptions)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- // temp workaround until we expose the session expiration time from the receiver in go-amqp
- if err := sessionReceiver.RenewSessionLock(ctx); err != nil {
- _ = sessionReceiver.Close(context.Background())
- return nil, err
+ // check the session ID that came back - if we asked for a named session ID and didn't get it then
+ // we failed to get the lock.
+ // if we specified nil then we can _set_ our internally held session ID now that we know the value.
+ receivedSessionID := link.LinkSourceFilterValue(sessionFilterName)
+ receivedSessionIDStr, ok := receivedSessionID.(string)
+
+ if !ok || (r.sessionID != nil && receivedSessionIDStr != *r.sessionID) {
+ return nil, nil, fmt.Errorf("invalid type/value for returned sessionID(type:%T, value:%v)", receivedSessionID, receivedSessionID)
}
- return sessionReceiver, nil
+ r.sessionID = &receivedSessionIDStr
+ return nil, link, nil
}
// ReceiveMessages receives a fixed number of messages, up to numMessages.
@@ -172,48 +180,47 @@ func (sr *SessionReceiver) LockedUntil() time.Time {
// GetSessionState retrieves state associated with the session.
func (sr *SessionReceiver) GetSessionState(ctx context.Context) ([]byte, error) {
- _, _, mgmt, _, err := sr.inner.amqpLinks.Get(ctx)
+ var sessionState []byte
- if err != nil {
- return nil, err
- }
+ err := sr.inner.amqpLinks.Retry(ctx, "GetSessionState", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ s, err := internal.GetSessionState(ctx, lwv.RPC, sr.SessionID())
- return mgmt.GetSessionState(ctx, sr.SessionID())
+ if err != nil {
+ return err
+ }
+
+ sessionState = s
+ return nil
+ }, sr.inner.retryOptions)
+
+ return sessionState, err
}
// SetSessionState sets the state associated with the session.
func (sr *SessionReceiver) SetSessionState(ctx context.Context, state []byte) error {
- _, _, mgmt, _, err := sr.inner.amqpLinks.Get(ctx)
-
- if err != nil {
- return err
- }
-
- return mgmt.SetSessionState(ctx, sr.SessionID(), state)
+ return sr.inner.amqpLinks.Retry(ctx, "SetSessionState", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ return internal.SetSessionState(ctx, lwv.RPC, sr.SessionID(), state)
+ }, sr.inner.retryOptions)
}
// RenewSessionLock renews this session's lock. The new expiration time is available
// using `LockedUntil`.
func (sr *SessionReceiver) RenewSessionLock(ctx context.Context) error {
- _, _, mgmt, _, err := sr.inner.amqpLinks.Get(ctx)
-
- if err != nil {
- return err
- }
+ return sr.inner.amqpLinks.Retry(ctx, "SetSessionState", func(ctx context.Context, lwv *internal.LinksWithID, args *utils.RetryFnArgs) error {
+ newLockedUntil, err := internal.RenewSessionLock(ctx, lwv.RPC, *sr.sessionID)
- newLockedUntil, err := mgmt.RenewSessionLock(ctx, *sr.sessionID)
-
- if err != nil {
- return err
- }
+ if err != nil {
+ return err
+ }
- sr.lockedUntil = newLockedUntil
- return nil
+ sr.lockedUntil = newLockedUntil
+ return nil
+ }, sr.inner.retryOptions)
}
// init ensures the link was created, guaranteeing that we get our expected session lock.
func (sr *SessionReceiver) init(ctx context.Context) error {
// initialize the links
- _, _, _, _, err := sr.inner.amqpLinks.Get(ctx)
+ _, err := sr.inner.amqpLinks.Get(ctx)
return err
}
diff --git a/sdk/messaging/azservicebus/session_receiver_test.go b/sdk/messaging/azservicebus/session_receiver_test.go
index ffb0b3cb9592..e9baa824ae44 100644
--- a/sdk/messaging/azservicebus/session_receiver_test.go
+++ b/sdk/messaging/azservicebus/session_receiver_test.go
@@ -12,6 +12,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal"
+ "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/require"
)
@@ -132,7 +133,9 @@ func TestSessionReceiver_acceptSessionButAlreadyLocked(t *testing.T) {
// You can address a session by name which makes lock contention possible (unlike
// messages where the lock token is not a predefined value)
receiver, err = client.AcceptSessionForQueue(ctx, queueName, "session-1", nil)
- require.EqualValues(t, internal.RecoveryKindFatal, internal.ToSBE(context.Background(), err).RecoveryKind)
+
+ sbe := internal.GetSBErrInfo(err)
+ require.EqualValues(t, internal.RecoveryKindFatal, sbe.RecoveryKind)
require.Nil(t, receiver)
}
@@ -254,19 +257,59 @@ func TestSessionReceiver_RenewSessionLock(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, messages)
- // surprisingly this works. Not sure what it accomplishes though. C# has a manual check for it.
- // err = sessionReceiver.RenewMessageLock(context.Background(), messages[0])
- // require.NoError(t, err)
-
orig := sessionReceiver.LockedUntil()
require.NoError(t, sessionReceiver.RenewSessionLock(context.Background()))
require.Greater(t, sessionReceiver.LockedUntil().UnixNano(), orig.UnixNano())
+}
+
+func TestSessionReceiver_Detach(t *testing.T) {
+ serviceBusClient, cleanup, queueName := setupLiveTest(t, &admin.QueueProperties{
+ RequiresSession: to.BoolPtr(true),
+ })
+ defer cleanup()
+
+ adminClient, err := admin.NewClientFromConnectionString(test.GetConnectionString(t), nil)
+ require.NoError(t, err)
+
+ receiver, err := serviceBusClient.AcceptSessionForQueue(context.Background(), queueName, "test-session", nil)
+ require.NoError(t, err)
+
+ sender, err := serviceBusClient.NewSender(queueName, nil)
+ require.NoError(t, err)
+
+ err = sender.SendMessage(context.Background(), &Message{
+ Body: []byte("hello"),
+ SessionID: to.StringPtr("test-session"),
+ })
+ require.NoError(t, err)
+ require.NoError(t, sender.Close(context.Background()))
+
+ state, err := receiver.GetSessionState(context.Background())
+ require.NoError(t, err)
+ require.Nil(t, state)
+
+ // force a detach to happen
+ _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{
+ RequiresSession: to.BoolPtr(true),
+ }, nil)
+ require.NoError(t, err)
+
+ state, err = receiver.GetSessionState(context.Background())
+ require.NoError(t, err)
+ require.Nil(t, state)
- // bogus renewal
- sessionReceiver.sessionID = to.StringPtr("bogus")
+ // force a detach to happen
+ _, err = adminClient.UpdateQueue(context.Background(), queueName, admin.QueueProperties{
+ RequiresSession: to.BoolPtr(true),
+ }, nil)
+ require.NoError(t, err)
+
+ messages, err := receiver.ReceiveMessages(context.Background(), 1, nil)
+ require.NoError(t, err)
+ require.NotEmpty(t, messages)
- err = sessionReceiver.RenewSessionLock(context.Background())
- require.Contains(t, err.Error(), "status code 410 and description: The session lock has expired on the MessageSession")
+ require.NoError(t, receiver.CompleteMessage(context.Background(), messages[0]))
+ require.NoError(t, receiver.Close(context.Background()))
}
func Test_toReceiverOptions(t *testing.T) {