Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ public override Task CloseClientConnections(CancellationToken token)
throw new NotSupportedException();
}

protected override Task<ConnectionContext> CreateConnection(string target = null)
protected override Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default)
{
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target);
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target, cancellationToken);
}

protected override Task DisposeConnection(ConnectionContext connection)
Expand Down
5 changes: 3 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR;
Expand Down Expand Up @@ -34,5 +35,5 @@ public LocalTokenProvider(
_tokenLifetime = tokenLifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
}

public Task<string> ProvideAsync() => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm);
public Task<string> ProvideAsync(CancellationToken cancellationToken) => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm, cancellationToken);
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR;
Expand All @@ -15,5 +16,5 @@ public MicrosoftEntraTokenProvider(MicrosoftEntraAccessKey accessKey)
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
}

public Task<string> ProvideAsync() => _accessKey.GetMicrosoftEntraTokenAsync();
public Task<string> ProvideAsync(CancellationToken cancellationToken) => _accessKey.GetMicrosoftEntraTokenAsync(cancellationToken);
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR;

internal interface IAccessTokenProvider
{
Task<string> ProvideAsync();
Task<string> ProvideAsync(CancellationToken cancellationToken = default);
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public async Task StartAsync(Uri url, CancellationToken cancellationToken = defa
// We don't need to capture to a local because we never change this delegate.
if (_accessTokenProvider != null)
{
accessToken = await _accessTokenProvider.ProvideAsync();
accessToken = await _accessTokenProvider.ProvideAsync(cancellationToken);
if (!string.IsNullOrEmpty(accessToken))
{
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ internal abstract partial class ServiceConnectionBase : IServiceConnection

private readonly TaskCompletionSource<object> _serviceConnectionOfflineTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly CancellationTokenSource _connectionStartCts = new();

private readonly ServiceConnectionType _connectionType;

private readonly IServiceMessageHandler _serviceMessageHandler;
Expand Down Expand Up @@ -157,70 +159,75 @@ public async Task StartAsync(string target = null)
}

Status = ServiceConnectionStatus.Connecting;

var connection = await EstablishConnectionAsync(target);
if (connection != null)
try
{
_connectionContext = connection;
Status = ServiceConnectionStatus.Connected;
_serviceConnectionStartTcs.TrySetResult(true);
try
var connection = await EstablishConnectionAsync(target, _connectionStartCts.Token);
if (connection != null)
{
TimerAwaitable syncTimer = null;
_connectionContext = connection;
Status = ServiceConnectionStatus.Connected;
_serviceConnectionStartTcs.TrySetResult(true);
try
{
if (HubEndpoint != null && HubEndpoint.AccessKey is MicrosoftEntraAccessKey key)
TimerAwaitable syncTimer = null;
try
{
syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval);
_ = UpdateAzureIdentityAsync(key, syncTimer);
if (HubEndpoint != null && HubEndpoint.AccessKey is MicrosoftEntraAccessKey key)
{
syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval);
_ = UpdateAzureIdentityAsync(key, syncTimer);
}
await ProcessIncomingAsync(connection);
}
await ProcessIncomingAsync(connection);
}
finally
{
// mark the status as Disconnected so that no one will write to this connection anymore
Status = ServiceConnectionStatus.Disconnected;
syncTimer?.Stop();
finally
{
// mark the status as Disconnected so that no one will write to this connection anymore
Status = ServiceConnectionStatus.Disconnected;
syncTimer?.Stop();

// when ProcessIncoming completes, clean up the connection
// when ProcessIncoming completes, clean up the connection

// TODO: Never cleanup connections unless Service asks us to do that
// Current implementation is based on assumption that Service will drop clients
// if server connection fails.
await CleanupClientConnections();
// TODO: Never cleanup connections unless Service asks us to do that
// Current implementation is based on assumption that Service will drop clients
// if server connection fails.
await CleanupClientConnections();
}
}
}
catch (Exception ex)
{
Log.ConnectionDropped(Logger, _endpointName, ConnectionId, ex);
}
finally
{
// wait until all the connections are cleaned up to close the outgoing pipe
// Don't allow write anymore when the connection is disconnected
await _writeLock.WaitAsync();
try
catch (Exception ex)
{
// close the underlying connection
await DisposeConnection(connection);
Log.ConnectionDropped(Logger, _endpointName, ConnectionId, ex);
}
finally
{
_writeLock.Release();
// wait until all the connections are cleaned up to close the outgoing pipe
// Don't allow write anymore when the connection is disconnected
await _writeLock.WaitAsync();
try
{
// close the underlying connection
await DisposeConnection(connection);
}
finally
{
_writeLock.Release();
}
}
}
}
else
finally
{
Status = ServiceConnectionStatus.Disconnected;
_serviceConnectionStartTcs.TrySetResult(false);
_serviceConnectionOfflineTcs.TrySetResult(false);
}
}

public Task StopAsync()
{
try
{
// to avoid the connection hung in connecting state
_connectionStartCts.Cancel();
_connectionContext?.Transport.Input.CancelPendingRead();
}
catch (Exception ex)
Expand Down Expand Up @@ -277,7 +284,7 @@ public virtual async Task<bool> SafeWriteAsync(ServiceMessage serviceMessage)

public abstract bool TryRemoveClientConnection(string connectionId, out IClientConnection connection);

protected abstract Task<ConnectionContext> CreateConnection(string target = null);
protected abstract Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default);

protected abstract Task DisposeConnection(ConnectionContext connection);

Expand Down Expand Up @@ -492,11 +499,11 @@ private Task OnFlowControlMessageAsync(ConnectionFlowControlMessage flowControlM
throw new NotImplementedException($"Unsupported connection type: {flowControlMessage.ConnectionType}");
}

private async Task<ConnectionContext> EstablishConnectionAsync(string target)
private async Task<ConnectionContext> EstablishConnectionAsync(string target, CancellationToken cancellationToken)
{
try
{
var connectionContext = await CreateConnection(target);
var connectionContext = await CreateConnection(target, cancellationToken);
try
{
if (await HandshakeAsync(connectionContext))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ public async IAsyncEnumerable<Page<SignalRGroupConnection>> ListConnectionsInGro
public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token)
{
_terminated = true;
return Task.WhenAll(ServiceConnections.Select(c => RemoveConnectionAsync(c, mode, token)));
return Task.WhenAll(ServiceConnections.Select(c => RemoveConnectionFromServiceAsync(c, mode, token)));
}

public virtual Task CloseClientConnections(CancellationToken token)
Expand Down Expand Up @@ -476,8 +476,23 @@ protected virtual ServiceConnectionStatus GetStatus()
: ServiceConnectionStatus.Disconnected;
}

protected async Task RemoveConnectionAsync(IServiceConnection c, GracefulShutdownMode mode, CancellationToken token)
/// <summary>
/// TODO: this logic sounds more fit into the serviceConnection class
/// </summary>
/// <param name="c">The service connection instance</param>
/// <param name="mode">The graceful shutdown mode</param>
/// <param name="token">The cancellation token</param>
/// <returns></returns>
protected async Task RemoveConnectionFromServiceAsync(IServiceConnection c, GracefulShutdownMode mode, CancellationToken token)
{
if (c.Status != ServiceConnectionStatus.Connected)
{
// if the connection is not yet connected
// we stop the connection in case it is connecting
// otherwise ConnectionOfflineTask should be set
await c.StopAsync();
return;
}
var retry = 0;
while (retry < MaxRetryRemoveSeverConnection && !token.IsCancellationRequested)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ public override async Task CloseClientConnections(CancellationToken token)
}
}

protected override Task<ConnectionContext> CreateConnection(string target = null)
protected override Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default)
{
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target);
return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target, cancellationToken);
}

protected override Task DisposeConnection(ConnectionContext connection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.AspNetCore.Connections;
Expand Down Expand Up @@ -83,7 +84,7 @@ public override async Task<bool> SafeWriteAsync(ServiceMessage serviceMessage)
return result;
}

protected override async Task<ConnectionContext> CreateConnection(string target = null)
protected override async Task<ConnectionContext> CreateConnection(string target = null, CancellationToken cancellationToken = default)
{
TestConnectionContext = await base.CreateConnection() as TestConnectionContext;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand Down Expand Up @@ -93,7 +93,7 @@ protected override Task CleanupClientConnections(string? fromInstanceId = null)
return Task.CompletedTask;
}

protected override Task<ConnectionContext> CreateConnection(string? target = null)
protected override Task<ConnectionContext> CreateConnection(string? target = null, CancellationToken cancellationToken = default)
{
var pipeOptions = new PipeOptions();
var duplex = DuplexPipe.CreateConnectionPair(pipeOptions, pipeOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,19 @@ public Task StartAsync(string target = null)

public Task StopAsync()
{
_offline.TrySetResult(true);
return Task.CompletedTask;
}

public Task WriteAsync(ServiceMessage serviceMessage)
{
if (RuntimeServicePingMessage.IsFin(serviceMessage))
{
_offline.SetResult(true);
_offline.TrySetResult(true);
}
if (RuntimeServicePingMessage.IsGetServers(serviceMessage))
{
_serversPing.SetResult(true);
_serversPing.TrySetResult(true);
}
return Task.CompletedTask;
}
Expand Down
Loading