Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -479,26 +479,43 @@ internal async Task ProcessConnectionDataMessageAsync(ConnectionDataMessage conn
{
var owner = ExactSizeMemoryPool.Shared.Rent((int)connectionDataMessage.Payload.Length);
connectionDataMessage.Payload.CopyTo(owner.Memory.Span);
// make sure there is no await operation before _bufferingMessages.
// make sure there is no await operation before _bufferingMessages.Enqueue.
_bufferedMessages.Enqueue(owner);
}
else
{
long length = 0;
foreach (var owner in _bufferedMessages)
int length = 0;
if (_bufferedMessages.Count > 0)
{
using (owner)
length += (int)connectionDataMessage.Payload.Length;
foreach (var buffered in _bufferedMessages)
{
await WriteToApplicationAsync(new ReadOnlySequence<byte>(owner.Memory));
length += owner.Memory.Length;
length += buffered.Memory.Length;
}
using var memoryOwner = ExactSizeMemoryPool.Shared.Rent(length);
var destination = memoryOwner.Memory.Span;
while (_bufferedMessages.Count > 0)
{
using var owner = _bufferedMessages.Dequeue();
owner.Memory.Span.CopyTo(destination);
destination = destination.Slice(owner.Memory.Length);
}
foreach (var memory in connectionDataMessage.Payload)
{
memory.Span.CopyTo(destination);
destination = destination.Slice(memory.Length);
}
// make sure there is no await operation before WriteToApplicationAsync.
await WriteToApplicationAsync(new ReadOnlySequence<byte>(memoryOwner.Memory));
}
else
{
var payload = connectionDataMessage.Payload;
length += (int)payload.Length;
Log.WriteMessageToApplication(Logger, length, connectionDataMessage.ConnectionId);
// make sure there is no await operation before WriteToApplicationAsync.
await WriteToApplicationAsync(payload);
}
_bufferedMessages.Clear();

var payload = connectionDataMessage.Payload;
length += payload.Length;
Log.WriteMessageToApplication(Logger, length, connectionDataMessage.ConnectionId);
await WriteToApplicationAsync(payload);
}
}
catch (Exception ex)
Expand Down
102 changes: 102 additions & 0 deletions test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,108 @@ await transportConnection.Application.Output.WriteAsync(
}
}

[Fact]
public async Task TestPartialMessagesProcessingShouldBeThreadSafe()
{
var ccm = new TestClientConnectionManager();
var ccf = new ClientConnectionFactory(NullLoggerFactory.Instance, closeTimeOutMilliseconds: 500);
var protocol = new ServiceProtocol();
var hubProtocol = new JsonHubProtocol();
TestConnection transportConnection = null;
var connectionFactory = new TestConnectionFactory(conn =>
{
transportConnection = conn;
return Task.CompletedTask;
});
var services = new ServiceCollection();

var connectionHandler = new TextContentConnectionHandler();
services.AddSingleton(connectionHandler);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TextContentConnectionHandler>();
var handler = builder.Build();
var hubProtocolResolver = new DefaultHubProtocolResolver(new[] { hubProtocol }, NullLogger<DefaultHubProtocolResolver>.Instance);
var connection = new ServiceConnection(protocol,
ccm,
connectionFactory,
NullLoggerFactory.Instance,
handler,
ccf,
"serverId",
Guid.NewGuid().ToString("N"),
null,
null,
null,
new DefaultClientInvocationManager(),
hubProtocolResolver,
null);

var connectionTask = connection.StartAsync();

// completed handshake
await connection.ConnectionInitializedTask.OrTimeout();
Assert.Equal(ServiceConnectionStatus.Connected, connection.Status);
var clientConnectionId = Guid.NewGuid().ToString();

var waitClientTask = ccm.WaitForClientConnectionAsync(clientConnectionId);
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new OpenConnectionMessage(clientConnectionId, Array.Empty<Claim>()) { Protocol = hubProtocol.Name }));

var clientConnection = await waitClientTask.OrTimeout();

const string messageContent = "{\"type\":1,\"target\":\"method\"}\u001e";
var message = Encoding.UTF8.GetBytes(messageContent);
var messageBytes =
(from b in message
select protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, new byte[] { b }) { IsPartial = b != '\u001e' })).ToArray();
var reconnectBytes = protocol.GetMessageBytes(new ConnectionReconnectMessage(clientConnectionId));
var enumerator = connectionHandler.EnumerateContent().GetAsyncEnumerator();
var moveNextTask = enumerator.MoveNextAsync().AsTask();
var sb = new StringBuilder();
for (int i = 0; i < 1000; i++)
{
foreach (var bytes in messageBytes)
{
await transportConnection.Application.Output.WriteAsync(bytes);
}
foreach (var bytes in messageBytes.Take(5))
{
await transportConnection.Application.Output.WriteAsync(bytes);
}
await transportConnection.Application.Output.WriteAsync(reconnectBytes);

moveNextTask = await ValidateRecievedMessage(messageContent, enumerator, moveNextTask, sb);
}

// complete reading to end the connection
transportConnection.Application.Output.Complete();
moveNextTask = await ValidateRecievedMessage(messageContent, enumerator, moveNextTask, sb);

await clientConnection.LifetimeTask.OrTimeout();
moveNextTask = await ValidateRecievedMessage(messageContent, enumerator, moveNextTask, sb);

// 1s for application task to timeout
await connectionTask.OrTimeout(1000);
Assert.Equal(ServiceConnectionStatus.Disconnected, connection.Status);
Assert.Empty(ccm.ClientConnections);
}

private static async Task<Task<bool>> ValidateRecievedMessage(string expectedMessageContent, IAsyncEnumerator<string> enumerator, Task<bool> moveNextTask, StringBuilder sb)
{
while (moveNextTask.IsCompletedSuccessfully && await moveNextTask)
{
sb.Append(enumerator.Current);
moveNextTask = enumerator.MoveNextAsync().AsTask();
if (sb.Length >= expectedMessageContent.Length)
{
Assert.Equal(expectedMessageContent, sb.ToString(0, expectedMessageContent.Length));
sb.Remove(0, expectedMessageContent.Length);
}
}

return moveNextTask;
}

private static async Task<ClientConnectionContext> CreateClientConnectionAsync(ServiceProtocol protocol,
IHubProtocol hubProtocol,
TestClientConnectionManager ccm,
Expand Down