diff --git a/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusMessageBatch.java b/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusMessageBatch.java index a3ddce1cb739..d3a3c4d9176b 100644 --- a/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusMessageBatch.java +++ b/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusMessageBatch.java @@ -11,10 +11,12 @@ import com.azure.core.util.logging.ClientLogger; import java.nio.BufferOverflowException; +import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import static com.azure.messaging.servicebus.implementation.MessageUtils.traceMessageSpan; @@ -30,7 +32,7 @@ public final class ServiceBusMessageBatch { private final MessageSerializer serializer; private final List serviceBusMessageList; private final byte[] eventBytes; - private int sizeInBytes; + private final AtomicInteger sizeInBytes; private final TracerProvider tracerProvider; private final String entityPath; private final String hostname; @@ -40,8 +42,8 @@ public final class ServiceBusMessageBatch { this.maxMessageSize = maxMessageSize; this.contextProvider = contextProvider; this.serializer = serializer; - this.serviceBusMessageList = new LinkedList<>(); - this.sizeInBytes = (maxMessageSize / 65536) * 1024; // reserve 1KB for every 64KB + this.serviceBusMessageList = Collections.synchronizedList(new LinkedList<>()); + this.sizeInBytes = new AtomicInteger((maxMessageSize / 65536) * 1024); // reserve 1KB for every 64KB this.eventBytes = new byte[maxMessageSize]; this.tracerProvider = tracerProvider; this.entityPath = entityPath; @@ -72,7 +74,7 @@ public int getMaxSizeInBytes() { * @return The size of the {@link ServiceBusMessageBatch batch} in bytes. */ public int getSizeInBytes() { - return this.sizeInBytes; + return this.sizeInBytes.get(); } /** @@ -97,9 +99,9 @@ public boolean tryAddMessage(final ServiceBusMessage serviceBusMessage) { tracerProvider) : serviceBusMessage; - final int size; + final AtomicInteger size = new AtomicInteger(); try { - size = getSize(serviceBusMessageUpdated, serviceBusMessageList.isEmpty()); + size.set(getSize(serviceBusMessageUpdated, serviceBusMessageList.isEmpty())); } catch (BufferOverflowException exception) { final RuntimeException ex = new ServiceBusException( new AmqpException(false, AmqpErrorCondition.LINK_PAYLOAD_SIZE_EXCEEDED, @@ -109,12 +111,9 @@ public boolean tryAddMessage(final ServiceBusMessage serviceBusMessage) { throw logger.logExceptionAsWarning(ex); } - synchronized (lock) { - if (this.sizeInBytes + size > this.maxMessageSize) { - return false; - } - - this.sizeInBytes += size; + if (this.sizeInBytes.addAndGet(size.get()) > this.maxMessageSize) { + this.sizeInBytes.addAndGet(-1 * size.get()); + return false; } this.serviceBusMessageList.add(serviceBusMessageUpdated); diff --git a/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusSenderAsyncClient.java b/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusSenderAsyncClient.java index fd058abf6ba3..f387c02e7de1 100644 --- a/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusSenderAsyncClient.java +++ b/sdk/servicebus/azure-messaging-servicebus/src/main/java/com/azure/messaging/servicebus/ServiceBusSenderAsyncClient.java @@ -42,6 +42,7 @@ import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collector; +import java.util.stream.StreamSupport; import static com.azure.core.amqp.implementation.RetryUtil.getRetryPolicy; import static com.azure.core.amqp.implementation.RetryUtil.withRetry; @@ -574,7 +575,8 @@ private Mono sendIterable(Iterable messages, ServiceBus } return createMessageBatch().flatMap(messageBatch -> { - messages.forEach(message -> messageBatch.tryAddMessage(message)); + StreamSupport.stream(messages.spliterator(), true) + .forEach(message -> messageBatch.tryAddMessage(message)); return sendInternal(messageBatch, transaction); }); } @@ -635,32 +637,29 @@ private Mono sendInternal(ServiceBusMessageBatch batch, ServiceBusTransact logger.info("Sending batch with size[{}].", batch.getCount()); - Context sharedContext = null; - final List messages = new ArrayList<>(); - - for (int i = 0; i < batch.getMessages().size(); i++) { - final ServiceBusMessage event = batch.getMessages().get(i); + AtomicReference sharedContext = new AtomicReference<>(Context.NONE); + final List messages = Collections.synchronizedList(new ArrayList<>()); + batch.getMessages().parallelStream().forEach(serviceBusMessage -> { if (isTracingEnabled) { - parentContext.set(event.getContext()); - if (i == 0) { - sharedContext = tracerProvider.getSharedSpanBuilder(SERVICE_BASE_NAME, parentContext.get()); + parentContext.set(serviceBusMessage.getContext()); + if (sharedContext.get().equals(Context.NONE)) { + sharedContext.set(tracerProvider.getSharedSpanBuilder(SERVICE_BASE_NAME, parentContext.get())); } - tracerProvider.addSpanLinks(sharedContext.addData(SPAN_CONTEXT_KEY, event.getContext())); + tracerProvider.addSpanLinks(sharedContext.get().addData(SPAN_CONTEXT_KEY, serviceBusMessage.getContext())); } - final org.apache.qpid.proton.message.Message message = messageSerializer.serialize(event); - + final org.apache.qpid.proton.message.Message message = messageSerializer.serialize(serviceBusMessage); final MessageAnnotations messageAnnotations = message.getMessageAnnotations() == null ? new MessageAnnotations(new HashMap<>()) : message.getMessageAnnotations(); message.setMessageAnnotations(messageAnnotations); messages.add(message); - } + }); if (isTracingEnabled) { - final Context finalSharedContext = sharedContext == null + final Context finalSharedContext = sharedContext.get().equals(Context.NONE) ? Context.NONE - : sharedContext + : sharedContext.get() .addData(ENTITY_PATH_KEY, entityName) .addData(HOST_NAME_KEY, connectionProcessor.getFullyQualifiedNamespace()) .addData(AZ_TRACING_NAMESPACE_KEY, AZ_TRACING_NAMESPACE_VALUE); diff --git a/sdk/servicebus/azure-messaging-servicebus/src/test/java/com/azure/messaging/servicebus/ServiceBusProcessorTest.java b/sdk/servicebus/azure-messaging-servicebus/src/test/java/com/azure/messaging/servicebus/ServiceBusProcessorTest.java index a4a97d9b7053..6c6c1ae46121 100644 --- a/sdk/servicebus/azure-messaging-servicebus/src/test/java/com/azure/messaging/servicebus/ServiceBusProcessorTest.java +++ b/sdk/servicebus/azure-messaging-servicebus/src/test/java/com/azure/messaging/servicebus/ServiceBusProcessorTest.java @@ -194,7 +194,6 @@ public void testStartStopResume() throws InterruptedException { */ @Test public void testErrorRecovery() throws InterruptedException { - List messageList = new ArrayList<>(); for (int i = 0; i < 2; i++) { ServiceBusReceivedMessage serviceBusReceivedMessage = @@ -204,6 +203,7 @@ public void testErrorRecovery() throws InterruptedException { new ServiceBusMessageContext(serviceBusReceivedMessage); messageList.add(serviceBusMessageContext); } + final Flux messageFlux = Flux.generate(() -> 0, (state, sink) -> { ServiceBusReceivedMessage serviceBusReceivedMessage = @@ -219,11 +219,9 @@ public void testErrorRecovery() throws InterruptedException { }); ServiceBusClientBuilder.ServiceBusReceiverClientBuilder receiverBuilder = getBuilder(messageFlux); - AtomicInteger messageId = new AtomicInteger(); AtomicReference countDownLatch = new AtomicReference<>(); countDownLatch.set(new CountDownLatch(4)); - AtomicBoolean assertionFailed = new AtomicBoolean(); ServiceBusProcessorClient serviceBusProcessorClient = new ServiceBusProcessorClient(receiverBuilder, messageContext -> {