diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Config/ServiceBusExtensionConfigProvider.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Config/ServiceBusExtensionConfigProvider.cs index 69c329fac75d..e1a66314b88c 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Config/ServiceBusExtensionConfigProvider.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Config/ServiceBusExtensionConfigProvider.cs @@ -7,6 +7,7 @@ using Azure.Messaging.ServiceBus.Primitives; using Microsoft.Azure.WebJobs.Description; using Microsoft.Azure.WebJobs.Extensions.ServiceBus.Config; +using Microsoft.Azure.WebJobs.Host; using Microsoft.Azure.WebJobs.Host.Bindings; using Microsoft.Azure.WebJobs.Host.Config; using Microsoft.Azure.WebJobs.Host.Scale; @@ -33,6 +34,7 @@ internal class ServiceBusExtensionConfigProvider : IExtensionConfigProvider private readonly IConverterManager _converterManager; private readonly ServiceBusClientFactory _clientFactory; private readonly ConcurrencyManager _concurrencyManager; + private readonly IDrainModeManager _drainModeManager; /// /// Creates a new instance. @@ -45,7 +47,8 @@ public ServiceBusExtensionConfigProvider( ILoggerFactory loggerFactory, IConverterManager converterManager, ServiceBusClientFactory clientFactory, - ConcurrencyManager concurrencyManager) + ConcurrencyManager concurrencyManager, + IDrainModeManager drainModeManager) { _options = options.Value; _messagingProvider = messagingProvider; @@ -54,6 +57,7 @@ public ServiceBusExtensionConfigProvider( _converterManager = converterManager; _clientFactory = clientFactory; _concurrencyManager = concurrencyManager; + _drainModeManager = drainModeManager; } /// @@ -101,7 +105,16 @@ public void Initialize(ExtensionConfigContext context) .AddOpenConverter(typeof(MessageToPocoConverter<>), _options.JsonSerializerSettings); // register our trigger binding provider - ServiceBusTriggerAttributeBindingProvider triggerBindingProvider = new ServiceBusTriggerAttributeBindingProvider(_nameResolver, _options, _messagingProvider, _loggerFactory, _converterManager, _clientFactory, _concurrencyManager); + ServiceBusTriggerAttributeBindingProvider triggerBindingProvider = new ServiceBusTriggerAttributeBindingProvider( + _nameResolver, + _options, + _messagingProvider, + _loggerFactory, + _converterManager, + _clientFactory, + _concurrencyManager, + _drainModeManager); + context.AddBindingRule() .BindToTrigger(triggerBindingProvider); diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Listeners/ServiceBusListener.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Listeners/ServiceBusListener.cs index da891766c114..909d81b55889 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Listeners/ServiceBusListener.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Listeners/ServiceBusListener.cs @@ -16,6 +16,7 @@ using Azure.Messaging.ServiceBus.Diagnostics; using Microsoft.Azure.WebJobs.Extensions.ServiceBus.Config; using Microsoft.Azure.WebJobs.Extensions.ServiceBus.Listeners; +using Microsoft.Azure.WebJobs.Host; using Microsoft.Azure.WebJobs.Host.Executors; using Microsoft.Azure.WebJobs.Host.Listeners; using Microsoft.Azure.WebJobs.Host.Scale; @@ -29,7 +30,8 @@ internal sealed class ServiceBusListener : IListener, IScaleMonitorProvider, ITa private readonly string _entityPath; private readonly bool _isSessionsEnabled; private readonly bool _autoCompleteMessages; - private readonly CancellationTokenSource _cancellationTokenSource; + private readonly CancellationTokenSource _stoppingCancellationTokenSource; + private readonly CancellationTokenSource _functionExecutionCancellationTokenSource; private readonly ServiceBusOptions _serviceBusOptions; private readonly bool _singleDispatch; private readonly ILogger _logger; @@ -65,6 +67,7 @@ internal sealed class ServiceBusListener : IListener, IScaleMonitorProvider, ITa private Task _batchLoop; private Lazy _details; private Lazy _clientDiagnostics; + private readonly IDrainModeManager _drainModeManager; public ServiceBusListener( string functionId, @@ -79,15 +82,18 @@ public ServiceBusListener( ILoggerFactory loggerFactory, bool singleDispatch, ServiceBusClientFactory clientFactory, - ConcurrencyManager concurrencyManager) + ConcurrencyManager concurrencyManager, + IDrainModeManager drainModeManager) { _entityPath = entityPath; _isSessionsEnabled = isSessionsEnabled; _autoCompleteMessages = autoCompleteMessages; _triggerExecutor = triggerExecutor; - _cancellationTokenSource = new CancellationTokenSource(); + _stoppingCancellationTokenSource = new CancellationTokenSource(); + _functionExecutionCancellationTokenSource = new CancellationTokenSource(); _logger = loggerFactory.CreateLogger(); _functionId = functionId; + _drainModeManager = drainModeManager; _client = new Lazy( () => clientFactory.CreateClientFromSetting(connection)); @@ -195,7 +201,7 @@ public async Task StartAsync(CancellationToken cancellationToken) } else { - _batchLoop = RunBatchReceiveLoopAsync(_cancellationTokenSource); + _batchLoop = RunBatchReceiveLoopAsync(_stoppingCancellationTokenSource); } } catch @@ -211,6 +217,11 @@ public async Task StartAsync(CancellationToken cancellationToken) public async Task StopAsync(CancellationToken cancellationToken) { + if (!_drainModeManager.IsDrainModeEnabled) + { + _functionExecutionCancellationTokenSource.Cancel(); + } + ThrowIfDisposed(); _logger.LogDebug($"Attempting to stop ServiceBus listener ({_details.Value})"); @@ -226,7 +237,7 @@ public async Task StopAsync(CancellationToken cancellationToken) } // This will also cancel the background monitoring task through the linked cancellation token source. - _cancellationTokenSource.Cancel(); + _stoppingCancellationTokenSource.Cancel(); // CloseAsync method stop new messages from being processed while allowing in-flight messages to be processed. if (_singleDispatch) @@ -273,7 +284,7 @@ public async Task StopAsync(CancellationToken cancellationToken) public void Cancel() { ThrowIfDisposed(); - _cancellationTokenSource.Cancel(); + _stoppingCancellationTokenSource.Cancel(); } [SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_cancellationTokenSource")] @@ -290,7 +301,8 @@ public void Dispose() // Mark it canceled but don't dispose of the source while the callers are running. // Otherwise, callers would receive ObjectDisposedException when calling token.Register. // For now, rely on finalization to clean up _cancellationTokenSource's wait handle (if allocated). - _cancellationTokenSource.Cancel(); + _stoppingCancellationTokenSource.Cancel(); + _functionExecutionCancellationTokenSource.Cancel(); if (_batchReceiver.IsValueCreated) { @@ -321,10 +333,14 @@ public void Dispose() } _stopAsyncSemaphore.Dispose(); - _cancellationTokenSource.Dispose(); + _stoppingCancellationTokenSource.Dispose(); _batchReceiveRegistration.Dispose(); _concurrencyUpdateManager?.Dispose(); + // No need to dispose the _functionExecutionCancellationTokenSource since we don't create it as a linked token and + // it won't use a timer, so the Dispose method is essentially a no-op. The downside to disposing it is that + // any customers who are trying to use it to cancel their own operations would get an ObjectDisposedException. + Disposed = true; _logger.LogDebug($"ServiceBus listener disposed({_details.Value})"); @@ -336,7 +352,7 @@ internal async Task ProcessMessageAsync(ProcessMessageEventArgs args) _concurrencyUpdateManager?.MessageProcessed(); - using (CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(args.CancellationToken, _cancellationTokenSource.Token)) + using (CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(args.CancellationToken, _stoppingCancellationTokenSource.Token)) { var actions = new ServiceBusMessageActions(args); if (!await _messageProcessor.Value.BeginProcessingMessageAsync(actions, args.Message, linkedCts.Token).ConfigureAwait(false)) @@ -349,7 +365,7 @@ internal async Task ProcessMessageAsync(ProcessMessageEventArgs args) TriggeredFunctionData data = input.GetTriggerFunctionData(); - FunctionResult result = await _triggerExecutor.TryExecuteAsync(data, linkedCts.Token).ConfigureAwait(false); + FunctionResult result = await _triggerExecutor.TryExecuteAsync(data, _functionExecutionCancellationTokenSource.Token).ConfigureAwait(false); try { await _messageProcessor.Value.CompleteProcessingMessageAsync(actions, args.Message, result, linkedCts.Token) @@ -369,7 +385,7 @@ internal async Task ProcessSessionMessageAsync(ProcessSessionMessageEventArgs ar _concurrencyUpdateManager?.MessageProcessed(); using (CancellationTokenSource linkedCts = - CancellationTokenSource.CreateLinkedTokenSource(args.CancellationToken, _cancellationTokenSource.Token)) + CancellationTokenSource.CreateLinkedTokenSource(args.CancellationToken, _stoppingCancellationTokenSource.Token)) { var actions = new ServiceBusSessionMessageActions(args); if (!await _sessionMessageProcessor.Value.BeginProcessingMessageAsync(actions, args.Message, linkedCts.Token) @@ -382,7 +398,7 @@ internal async Task ProcessSessionMessageAsync(ProcessSessionMessageEventArgs ar ServiceBusTriggerInput input = ServiceBusTriggerInput.CreateSingle(args.Message, actions, receiveActions, _client.Value); TriggeredFunctionData data = input.GetTriggerFunctionData(); - FunctionResult result = await _triggerExecutor.TryExecuteAsync(data, linkedCts.Token).ConfigureAwait(false); + FunctionResult result = await _triggerExecutor.TryExecuteAsync(data, _functionExecutionCancellationTokenSource.Token).ConfigureAwait(false); if (actions.ShouldReleaseSession) { @@ -643,7 +659,7 @@ private async Task TriggerAndCompleteMessagesInternal(ServiceBusReceivedMessage[ scope.SetMessageData(messagesArray); scope.Start(); - FunctionResult result = await _triggerExecutor.TryExecuteAsync(input.GetTriggerFunctionData(), cancellationToken).ConfigureAwait(false); + FunctionResult result = await _triggerExecutor.TryExecuteAsync(input.GetTriggerFunctionData(), _functionExecutionCancellationTokenSource.Token).ConfigureAwait(false); if (result.Exception != null) { scope.Failed(result.Exception); diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Triggers/ServiceBusTriggerAttributeBindingProvider.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Triggers/ServiceBusTriggerAttributeBindingProvider.cs index 814ac3f0db61..3744f011043d 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Triggers/ServiceBusTriggerAttributeBindingProvider.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/src/Triggers/ServiceBusTriggerAttributeBindingProvider.cs @@ -28,6 +28,7 @@ internal class ServiceBusTriggerAttributeBindingProvider : ITriggerBindingProvid private readonly ServiceBusClientFactory _clientFactory; private readonly ILogger _logger; private readonly ConcurrencyManager _concurrencyManager; + private readonly IDrainModeManager _drainModeManager; public ServiceBusTriggerAttributeBindingProvider( INameResolver nameResolver, @@ -36,7 +37,8 @@ public ServiceBusTriggerAttributeBindingProvider( ILoggerFactory loggerFactory, IConverterManager converterManager, ServiceBusClientFactory clientFactory, - ConcurrencyManager concurrencyManager) + ConcurrencyManager concurrencyManager, + IDrainModeManager drainModeManager) { _nameResolver = nameResolver ?? throw new ArgumentNullException(nameof(nameResolver)); _options = options ?? throw new ArgumentNullException(nameof(options)); @@ -46,6 +48,7 @@ public ServiceBusTriggerAttributeBindingProvider( _clientFactory = clientFactory; _logger = _loggerFactory.CreateLogger(); _concurrencyManager = concurrencyManager; + _drainModeManager = drainModeManager; } public Task TryCreateAsync(TriggerBindingProviderContext context) @@ -84,7 +87,21 @@ public Task TryCreateAsync(TriggerBindingProviderContext contex (factoryContext, singleDispatch) => { var autoCompleteMessagesOptionEvaluatedValue = GetAutoCompleteMessagesOptionToUse(attribute, factoryContext.Descriptor.ShortName); - IListener listener = new ServiceBusListener(factoryContext.Descriptor.Id, serviceBusEntityType, entityPath, attribute.IsSessionsEnabled, autoCompleteMessagesOptionEvaluatedValue, factoryContext.Executor, _options, attribute.Connection, _messagingProvider, _loggerFactory, singleDispatch, _clientFactory, _concurrencyManager); + IListener listener = new ServiceBusListener( + factoryContext.Descriptor.Id, + serviceBusEntityType, + entityPath, + attribute.IsSessionsEnabled, + autoCompleteMessagesOptionEvaluatedValue, + factoryContext.Executor, + _options, + attribute.Connection, + _messagingProvider, + _loggerFactory, + singleDispatch, + _clientFactory, + _concurrencyManager, + _drainModeManager); return Task.FromResult(listener); }; diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Bindings/ServiceBusTriggerAttributeBindingProviderTests.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Bindings/ServiceBusTriggerAttributeBindingProviderTests.cs index 7b33576b61b3..34dcc203fdc4 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Bindings/ServiceBusTriggerAttributeBindingProviderTests.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Bindings/ServiceBusTriggerAttributeBindingProviderTests.cs @@ -46,7 +46,7 @@ public ServiceBusTriggerAttributeBindingProviderTests() Mock convertManager = new Mock(MockBehavior.Default); var provider = new MessagingProvider(new OptionsWrapper(options)); var factory = new ServiceBusClientFactory(configuration, new Mock().Object, provider, new AzureEventSourceLogForwarder(new NullLoggerFactory()), new OptionsWrapper(options)); - _provider = new ServiceBusTriggerAttributeBindingProvider(mockResolver.Object, options, provider, NullLoggerFactory.Instance, convertManager.Object, factory, concurrencyManager); + _provider = new ServiceBusTriggerAttributeBindingProvider(mockResolver.Object, options, provider, NullLoggerFactory.Instance, convertManager.Object, factory, concurrencyManager, default); } [Test] diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusListenerTests.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusListenerTests.cs index 66d06fcc9c83..3fbdb1e26f7f 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusListenerTests.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusListenerTests.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Azure.Messaging.ServiceBus; using Microsoft.Azure.WebJobs.Extensions.ServiceBus.Config; +using Microsoft.Azure.WebJobs.Host; using Microsoft.Azure.WebJobs.Host.Executors; using Microsoft.Azure.WebJobs.Host.Scale; using Microsoft.Azure.WebJobs.Host.TestCommon; @@ -39,6 +40,7 @@ public class ServiceBusListenerTests private readonly Mock _mockConcurrencyThrottleManager; private readonly ServiceBusClient _client; private readonly ConcurrencyManager _concurrencyManager; + private readonly Mock _mockDrainModeManager; public ServiceBusListenerTests() { @@ -58,6 +60,9 @@ public ServiceBusListenerTests() _mockMessagingProvider = new Mock(new OptionsWrapper(config)); _mockClientFactory = new Mock(configuration, Mock.Of(), _mockMessagingProvider.Object, new AzureEventSourceLogForwarder(new NullLoggerFactory()), new OptionsWrapper(new ServiceBusOptions())); + _mockDrainModeManager = new Mock(); + _mockDrainModeManager.Setup(p => p.IsDrainModeEnabled).Returns(true); + _mockMessagingProvider .Setup(p => p.CreateMessageProcessor(It.IsAny(), _entityPath, It.IsAny())) .Returns(_mockMessageProcessor.Object); @@ -87,7 +92,8 @@ public ServiceBusListenerTests() _loggerFactory, false, _mockClientFactory.Object, - _concurrencyManager); + _concurrencyManager, + _mockDrainModeManager.Object); } [SetUp] @@ -246,7 +252,8 @@ public async Task SessionIdleTimeoutRespected() _loggerFactory, false, _mockClientFactory.Object, - _concurrencyManager); + _concurrencyManager, + _mockDrainModeManager.Object); await listener.StartAsync(CancellationToken.None); await listener.StopAsync(CancellationToken.None); @@ -290,7 +297,8 @@ public async Task SessionIdleTimeoutIgnoredWhenNotUsingSessions() _loggerFactory, false, _mockClientFactory.Object, - _concurrencyManager); + _concurrencyManager, + _mockDrainModeManager.Object); await listener.StartAsync(CancellationToken.None); await listener.StopAsync(CancellationToken.None); diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusScaleMonitorTests.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusScaleMonitorTests.cs index 484ca9d024f6..f5d774bc1b06 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusScaleMonitorTests.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Listeners/ServiceBusScaleMonitorTests.cs @@ -120,7 +120,8 @@ public void Setup() _loggerFactory, false, _mockClientFactory.Object, - concurrencyManager); + concurrencyManager, + default); _scaleMonitor = (ServiceBusScaleMonitor)_listener.GetMonitor(); } @@ -539,7 +540,8 @@ private ServiceBusListener CreateListener(bool useDeadletterQueue = false) _loggerFactory, false, _mockClientFactory.Object, - concurrencyManager); + concurrencyManager, + default); } [Test] diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusEndToEndTests.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusEndToEndTests.cs index 4cc580f298fb..e4ebe3bb657c 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusEndToEndTests.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusEndToEndTests.cs @@ -242,6 +242,50 @@ public async Task TestSingle_AutoCompleteEnabledOnTrigger_CompleteInFunction() } } + [Test] + public async Task TestSingle_Dispose() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + host.Dispose(); + } + + [Test] + public async Task TestSingle_StopWithoutDrain() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + await host.StopAsync(); + } + + [Test] + public async Task TestBatch_Dispose() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + await host.StopAsync(); + } + + [Test] + public async Task TestBatch_StopWithoutDrain() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + host.Dispose(); + } + [Test] public async Task TestSingle_AutoCompleteDisabledOnTrigger_AbandonsWhenException() { @@ -1588,6 +1632,34 @@ public static async Task RunAsync( } } + public class TestSingleDispose + { + public static async Task RunAsync( + [ServiceBusTrigger(FirstQueueNameKey)] + ServiceBusReceivedMessage message, + CancellationToken cancellationToken) + { + _waitHandle1.Set(); + // wait a small amount of time for the host to call dispose + await Task.Delay(2000, CancellationToken.None); + Assert.IsTrue(cancellationToken.IsCancellationRequested); + } + } + + public class TestBatchDispose + { + public static async Task RunAsync( + [ServiceBusTrigger(FirstQueueNameKey)] + ServiceBusReceivedMessage[] message, + CancellationToken cancellationToken) + { + _waitHandle1.Set(); + // wait a small amount of time for the host to call dispose + await Task.Delay(2000, CancellationToken.None); + Assert.IsTrue(cancellationToken.IsCancellationRequested); + } + } + public class TestCrossEntityTransaction { public static async Task RunAsync( @@ -1750,8 +1822,7 @@ public static async Task RunAsync( { logger.LogInformation($"DrainModeValidationFunctions.QueueNoSessions: message data {msg.Body}"); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); await messageActions.CompleteMessageAsync(msg); _drainValidationPostDelay.Set(); } @@ -1768,8 +1839,7 @@ public static async Task RunAsync( { logger.LogInformation($"DrainModeValidationFunctions.NoSessions: message data {msg.Body}"); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); await messageActions.CompleteMessageAsync(msg); _drainValidationPostDelay.Set(); } @@ -1787,8 +1857,7 @@ public static async Task RunAsync( Assert.True(array.Length > 0); logger.LogInformation($"DrainModeTestJobBatch.QueueNoSessionsBatch: received {array.Length} messages"); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); foreach (ServiceBusReceivedMessage msg in array) { await messageActions.CompleteMessageAsync(msg); @@ -1808,8 +1877,7 @@ public static async Task RunAsync( Assert.True(array.Length > 0); logger.LogInformation($"DrainModeTestJobBatch.TopicNoSessionsBatch: received {array.Length} messages"); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); foreach (ServiceBusReceivedMessage msg in array) { // validate that manual lock renewal works diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusSessionsEndToEndTests.cs b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusSessionsEndToEndTests.cs index b950b7564991..b4e218359f5b 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusSessionsEndToEndTests.cs +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/ServiceBusSessionsEndToEndTests.cs @@ -677,6 +677,50 @@ public async Task TestBatch_ProcessMessagesSpan_FailedScope() Assert.IsTrue(scope.IsFailed); } + [Test] + public async Task TestSingle_Dispose() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}", "sessionId1"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + host.Dispose(); + } + + [Test] + public async Task TestSingle_StopWithoutDrain() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}", "sessionId1"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + await host.StopAsync(); + } + + [Test] + public async Task TestBatch_Dispose() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}", "sessionId1"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + host.Dispose(); + } + + [Test] + public async Task TestBatch_StopWithoutDrain() + { + await WriteQueueMessage("{'Name': 'Test1', 'Value': 'Value'}", "sessionId1"); + var host = BuildHost(); + + bool result = _waitHandle1.WaitOne(SBTimeoutMills); + Assert.True(result); + await host.StopAsync(); + } + private async Task TestMultiple(bool isXml = false) { if (isXml) @@ -830,8 +874,7 @@ public static async Task QueueWithSessions( Assert.AreEqual(msg.PartitionKey, partitionKey); Assert.AreEqual(msg.TransactionPartitionKey, transactionPartitionKey); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); try { await messageActions.CompleteMessageAsync(msg); @@ -856,8 +899,7 @@ public static async Task TopicWithSessions( $"DrainModeValidationFunctions.TopicWithSessions: message data {msg.Body} with session id {msg.SessionId}"); Assert.AreEqual(_drainModeSessionId, msg.SessionId); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); try { await messageSession.CompleteMessageAsync(msg); @@ -885,8 +927,7 @@ public static async Task QueueWithSessionsBatch( $"DrainModeTestJobBatch.QueueWithSessionsBatch: received {array.Length} messages with session id {array[0].SessionId}"); Assert.AreEqual(_drainModeSessionId, array[0].SessionId); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); for (int i = 0; i < array.Length; i++) { var message = array[i]; @@ -918,8 +959,7 @@ public static async Task TopicWithSessionsBatch( $"DrainModeTestJobBatch.TopicWithSessionsBatch: received {array.Length} messages with session id {array[0].SessionId}"); Assert.AreEqual(_drainModeSessionId, array[0].SessionId); _drainValidationPreDelay.Set(); - await DrainModeHelper.WaitForCancellationAsync(cancellationToken); - Assert.True(cancellationToken.IsCancellationRequested); + Assert.False(cancellationToken.IsCancellationRequested); foreach (ServiceBusReceivedMessage msg in array) { await messageSession.CompleteMessageAsync(msg); @@ -1160,6 +1200,34 @@ public static void RunAsync( } } + public class TestSingleDispose + { + public static async Task RunAsync( + [ServiceBusTrigger(FirstQueueNameKey, IsSessionsEnabled = true)] + ServiceBusReceivedMessage message, + CancellationToken cancellationToken) + { + _waitHandle1.Set(); + // wait a small amount of time for the host to call dispose + await Task.Delay(2000, CancellationToken.None); + Assert.IsTrue(cancellationToken.IsCancellationRequested); + } + } + + public class TestBatchDispose + { + public static async Task RunAsync( + [ServiceBusTrigger(FirstQueueNameKey, IsSessionsEnabled = true)] + ServiceBusReceivedMessage[] message, + CancellationToken cancellationToken) + { + _waitHandle1.Set(); + // wait a small amount of time for the host to call dispose + await Task.Delay(2000, CancellationToken.None); + Assert.IsTrue(cancellationToken.IsCancellationRequested); + } + } + public class CustomMessagingProvider : MessagingProvider { public const string CustomMessagingCategory = "CustomMessagingProvider";