diff --git a/src/Microsoft.Azure.SignalR/ClientConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ClientConnections/ClientConnectionContext.cs index cca771a7c..6e487beb8 100644 --- a/src/Microsoft.Azure.SignalR/ClientConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ClientConnections/ClientConnectionContext.cs @@ -479,26 +479,43 @@ internal async Task ProcessConnectionDataMessageAsync(ConnectionDataMessage conn { var owner = ExactSizeMemoryPool.Shared.Rent((int)connectionDataMessage.Payload.Length); connectionDataMessage.Payload.CopyTo(owner.Memory.Span); - // make sure there is no await operation before _bufferingMessages. + // make sure there is no await operation before _bufferingMessages.Enqueue. _bufferedMessages.Enqueue(owner); } else { - long length = 0; - foreach (var owner in _bufferedMessages) + int length = 0; + if (_bufferedMessages.Count > 0) { - using (owner) + length += (int)connectionDataMessage.Payload.Length; + foreach (var buffered in _bufferedMessages) { - await WriteToApplicationAsync(new ReadOnlySequence(owner.Memory)); - length += owner.Memory.Length; + length += buffered.Memory.Length; } + using var memoryOwner = ExactSizeMemoryPool.Shared.Rent(length); + var destination = memoryOwner.Memory.Span; + while (_bufferedMessages.Count > 0) + { + using var owner = _bufferedMessages.Dequeue(); + owner.Memory.Span.CopyTo(destination); + destination = destination.Slice(owner.Memory.Length); + } + foreach (var memory in connectionDataMessage.Payload) + { + memory.Span.CopyTo(destination); + destination = destination.Slice(memory.Length); + } + // make sure there is no await operation before WriteToApplicationAsync. + await WriteToApplicationAsync(new ReadOnlySequence(memoryOwner.Memory)); + } + else + { + var payload = connectionDataMessage.Payload; + length += (int)payload.Length; + Log.WriteMessageToApplication(Logger, length, connectionDataMessage.ConnectionId); + // make sure there is no await operation before WriteToApplicationAsync. + await WriteToApplicationAsync(payload); } - _bufferedMessages.Clear(); - - var payload = connectionDataMessage.Payload; - length += payload.Length; - Log.WriteMessageToApplication(Logger, length, connectionDataMessage.ConnectionId); - await WriteToApplicationAsync(payload); } } catch (Exception ex) diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs index 780b9bea4..e1827af86 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs @@ -945,6 +945,108 @@ await transportConnection.Application.Output.WriteAsync( } } + [Fact] + public async Task TestPartialMessagesProcessingShouldBeThreadSafe() + { + var ccm = new TestClientConnectionManager(); + var ccf = new ClientConnectionFactory(NullLoggerFactory.Instance, closeTimeOutMilliseconds: 500); + var protocol = new ServiceProtocol(); + var hubProtocol = new JsonHubProtocol(); + TestConnection transportConnection = null; + var connectionFactory = new TestConnectionFactory(conn => + { + transportConnection = conn; + return Task.CompletedTask; + }); + var services = new ServiceCollection(); + + var connectionHandler = new TextContentConnectionHandler(); + services.AddSingleton(connectionHandler); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var handler = builder.Build(); + var hubProtocolResolver = new DefaultHubProtocolResolver(new[] { hubProtocol }, NullLogger.Instance); + var connection = new ServiceConnection(protocol, + ccm, + connectionFactory, + NullLoggerFactory.Instance, + handler, + ccf, + "serverId", + Guid.NewGuid().ToString("N"), + null, + null, + null, + new DefaultClientInvocationManager(), + hubProtocolResolver, + null); + + var connectionTask = connection.StartAsync(); + + // completed handshake + await connection.ConnectionInitializedTask.OrTimeout(); + Assert.Equal(ServiceConnectionStatus.Connected, connection.Status); + var clientConnectionId = Guid.NewGuid().ToString(); + + var waitClientTask = ccm.WaitForClientConnectionAsync(clientConnectionId); + await transportConnection.Application.Output.WriteAsync( + protocol.GetMessageBytes(new OpenConnectionMessage(clientConnectionId, Array.Empty()) { Protocol = hubProtocol.Name })); + + var clientConnection = await waitClientTask.OrTimeout(); + + const string messageContent = "{\"type\":1,\"target\":\"method\"}\u001e"; + var message = Encoding.UTF8.GetBytes(messageContent); + var messageBytes = + (from b in message + select protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, new byte[] { b }) { IsPartial = b != '\u001e' })).ToArray(); + var reconnectBytes = protocol.GetMessageBytes(new ConnectionReconnectMessage(clientConnectionId)); + var enumerator = connectionHandler.EnumerateContent().GetAsyncEnumerator(); + var moveNextTask = enumerator.MoveNextAsync().AsTask(); + var sb = new StringBuilder(); + for (int i = 0; i < 1000; i++) + { + foreach (var bytes in messageBytes) + { + await transportConnection.Application.Output.WriteAsync(bytes); + } + foreach (var bytes in messageBytes.Take(5)) + { + await transportConnection.Application.Output.WriteAsync(bytes); + } + await transportConnection.Application.Output.WriteAsync(reconnectBytes); + + moveNextTask = await ValidateRecievedMessage(messageContent, enumerator, moveNextTask, sb); + } + + // complete reading to end the connection + transportConnection.Application.Output.Complete(); + moveNextTask = await ValidateRecievedMessage(messageContent, enumerator, moveNextTask, sb); + + await clientConnection.LifetimeTask.OrTimeout(); + moveNextTask = await ValidateRecievedMessage(messageContent, enumerator, moveNextTask, sb); + + // 1s for application task to timeout + await connectionTask.OrTimeout(1000); + Assert.Equal(ServiceConnectionStatus.Disconnected, connection.Status); + Assert.Empty(ccm.ClientConnections); + } + + private static async Task> ValidateRecievedMessage(string expectedMessageContent, IAsyncEnumerator enumerator, Task moveNextTask, StringBuilder sb) + { + while (moveNextTask.IsCompletedSuccessfully && await moveNextTask) + { + sb.Append(enumerator.Current); + moveNextTask = enumerator.MoveNextAsync().AsTask(); + if (sb.Length >= expectedMessageContent.Length) + { + Assert.Equal(expectedMessageContent, sb.ToString(0, expectedMessageContent.Length)); + sb.Remove(0, expectedMessageContent.Length); + } + } + + return moveNextTask; + } + private static async Task CreateClientConnectionAsync(ServiceProtocol protocol, IHubProtocol hubProtocol, TestClientConnectionManager ccm,