Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestApiEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@

using Microsoft.Extensions.Primitives;

namespace Microsoft.Azure.SignalR
#nullable enable

namespace Microsoft.Azure.SignalR;

internal class RestApiEndpoint
{
internal class RestApiEndpoint
{
public string Audience { get; }
public string Audience { get; }

public IDictionary<string, StringValues> Query { get; set; }
public IDictionary<string, StringValues>? Query { get; set; }

public RestApiEndpoint(string endpoint)
{
Audience = endpoint;
}
public RestApiEndpoint(string endpoint)
{
Audience = endpoint;
}
}
15 changes: 12 additions & 3 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder c
_httpClientFactory = httpClientFactory;
_payloadContentBuilder = contentBuilder;
}

// TODO: Test only, will remove later
internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer()))
{
Expand Down Expand Up @@ -64,7 +64,16 @@ public Task SendWithRetryAsync(
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, AsAsync(handleExpectedResponse), cancellationToken);
return SendWithRetryAsync(api, httpMethod, AsAsync(handleExpectedResponse), cancellationToken);
}

public Task SendWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken);
}

public Task SendMessageWithRetryAsync(
Expand Down Expand Up @@ -186,7 +195,7 @@ private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMeth
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, body, typeHint);
}

private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, StringValues> query, HttpMethod httpMethod, HubMessage? body, Type? typeHint)
private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, StringValues>? query, HttpMethod httpMethod, HubMessage? body, Type? typeHint)
{
var request = new HttpRequestMessage(httpMethod, GetUri(url, query));
request.Content = _payloadContentBuilder.Build(body, typeHint);
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ public RestApiEndpoint GetSendStreamCompletionEndpoint(string appName, string hu
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/streams/{Uri.EscapeDataString(streamId)}/:complete");
}

public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string hubName, string groupName)
{
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
104 changes: 88 additions & 16 deletions src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;

using Azure;

using Microsoft.AspNetCore.SignalR;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Primitives;

using static Microsoft.Azure.SignalR.Constants;

namespace Microsoft.Azure.SignalR.Management;

#nullable enable

internal class RestHubLifetimeManager<THub> : HubLifetimeManager<THub>, IServiceHubLifetimeManager<THub> where THub : Hub
{
private const string NullOrEmptyStringErrorMessage = "Argument cannot be null or empty.";
Expand Down Expand Up @@ -91,12 +93,12 @@ public async Task RemoveFromAllGroupsAsync(string connectionId, CancellationToke
await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default)
public override Task SendAllAsync(string methodName, object?[] args, CancellationToken cancellationToken = default)
{
return SendAllExceptAsync(methodName, args, null, cancellationToken);
}

public override async Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedConnectionIds, CancellationToken cancellationToken = default)
public override async Task SendAllExceptAsync(string methodName, object?[] args, IReadOnlyList<string>? excludedConnectionIds, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
Expand All @@ -107,7 +109,7 @@ public override async Task SendAllExceptAsync(string methodName, object[] args,
await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

public override async Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default)
public override async Task SendConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
Expand All @@ -123,17 +125,17 @@ public override async Task SendConnectionAsync(string connectionId, string metho
await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

public override async Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default)
public override async Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
await Task.WhenAll(connectionIds.Select(id => SendConnectionAsync(id, methodName, args, cancellationToken)));
}

public override Task SendGroupAsync(string groupName, string methodName, object[] args, CancellationToken cancellationToken = default)
public override Task SendGroupAsync(string groupName, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
return SendGroupExceptAsync(groupName, methodName, args, null, cancellationToken);
}

public override async Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList<string> excludedConnectionIds, CancellationToken cancellationToken = default)
public override async Task SendGroupExceptAsync(string groupName, string methodName, object?[] args, IReadOnlyList<string>? excludedConnectionIds, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
Expand All @@ -149,9 +151,9 @@ public override async Task SendGroupExceptAsync(string groupName, string methodN
await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

public override async Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object[] args, CancellationToken cancellationToken = default)
public override async Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
Task all = null;
Task? all = null;

try
{
Expand All @@ -160,11 +162,11 @@ public override async Task SendGroupsAsync(IReadOnlyList<string> groupNames, str
}
catch
{
throw all.Exception;
throw all!.Exception!;
}
}

public override async Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default)
public override async Task SendUserAsync(string userId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
Expand All @@ -180,9 +182,9 @@ public override async Task SendUserAsync(string userId, string methodName, objec
await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

public override async Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object[] args, CancellationToken cancellationToken = default)
public override async Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
Task all = null;
Task? all = null;

try
{
Expand All @@ -191,7 +193,7 @@ public override async Task SendUsersAsync(IReadOnlyList<string> userIds, string
}
catch
{
throw all.Exception;
throw all!.Exception!;
}
}

Expand Down Expand Up @@ -348,7 +350,7 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
["error"] = error,
};
}
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, cancellationToken: cancellationToken);
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
Expand All @@ -357,6 +359,76 @@ private static bool FilterExpectedResponse(HttpResponseMessage response, string

public AsyncPageable<SignalRGroupConnection> ListConnectionsInGroup(string groupName, int? top = null, CancellationToken token = default)
{
throw new NotImplementedException();
if (string.IsNullOrWhiteSpace(groupName))
{
throw new ArgumentException($"'{nameof(groupName)}' cannot be null or whitespace.", nameof(groupName));
}

if (top < 0)
{
throw new ArgumentOutOfRangeException(nameof(top), "The value must be greater than or equal to 0.");
}

return new PageableGroupMember(FetchPages, token);

async IAsyncEnumerable<Page<SignalRGroupConnection>> FetchPages(string? continuationToken, int? pageSizeHint)
{
// Calculate the api for the first page
var api = _restApiProvider.GetListConnectionsInGroupEndpoint(_appName, _hubName, groupName);
if (top.HasValue)
{
api.Query = new Dictionary<string, StringValues>
{
["top"] = top.Value.ToString(CultureInfo.InvariantCulture),
};
}
if (pageSizeHint.HasValue)
{
api.Query ??= new Dictionary<string, StringValues>();
api.Query["maxPageSize"] = pageSizeHint.Value.ToString(CultureInfo.InvariantCulture);
}
if (!string.IsNullOrEmpty(continuationToken))
{
api.Query ??= new Dictionary<string, StringValues>();
api.Query["continuationToken"] = continuationToken;
}
do
{
Comment thread
Y-Sindo marked this conversation as resolved.
var page = await FetchSinglePage(api, token);
Comment thread
Y-Sindo marked this conversation as resolved.
continuationToken = page.ContinuationToken;
yield return page;
if (page.ContinuationToken == null)
{
yield break;
}
if (top != null)
{
top -= page.Values.Count;
if (top <= 0)
{
yield break;
}
}
// Actually it's the next link
api = new RestApiEndpoint(page.ContinuationToken);
} while (true);
}

async Task<Page<SignalRGroupConnection>> FetchSinglePage(RestApiEndpoint api, CancellationToken cancellationToken = default)
{
var page = default(Page<SignalRGroupConnection>);

await _restClient.SendWithRetryAsync(api, HttpMethod.Get, async response =>
{
if (!response.IsSuccessStatusCode)
{
return false;
Comment thread
Y-Sindo marked this conversation as resolved.
}
var contentStream = await response.Content.ReadAsStreamAsync();
page = await JsonSerializer.DeserializeAsync<GroupMemberQueryResultPage>(contentStream, cancellationToken: token);
return true;
}, cancellationToken: token);
return page!;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ internal async Task BroadcastExceptTest(ServiceTransportType serviceTransportTyp
var hubContext = await serviceManager.CreateHubContextAsync(HubName) as ServiceHubContextImpl;
var connectionCount = 3;
var tcsDict = new ConcurrentDictionary<string, TaskCompletionSource>();
logger.LogInformation($"Message is {msg}");
TestOutputHelper.WriteLine($"Message is {msg}");
var connections = await Task.WhenAll(Enumerable.Range(0, connectionCount).Select(async _ =>
{
var negotiationResponse = await hubContext.NegotiateAsync(null, default);
Expand All @@ -86,7 +86,7 @@ internal async Task BroadcastExceptTest(ServiceTransportType serviceTransportTyp
tcsDict.TryAdd(connection.ConnectionId, src);
connection.On(method, (string receivedMsg) =>
{
logger.LogInformation($"Connection {connection.ConnectionId} received msg : {receivedMsg}");
TestOutputHelper.WriteLine($"Connection {connection.ConnectionId} received msg : {receivedMsg}");
if (receivedMsg == msg)
{
src.SetResult();
Expand Down Expand Up @@ -206,7 +206,7 @@ internal async Task SendToGroupExceptTest(ServiceTransportType serviceTransportT
var hubContext = await serviceManager.CreateHubContextAsync(HubName) as ServiceHubContextImpl;
var connectionCount = 3;
var tcsDict = new ConcurrentDictionary<string, TaskCompletionSource>();
logger.LogInformation($"Message is {msg}");
TestOutputHelper.WriteLine($"Message is {msg}");
var connections = await Task.WhenAll(Enumerable.Range(0, connectionCount).Select(async _ =>
{
var negotiationResponse = await hubContext.NegotiateAsync(null, default);
Expand All @@ -216,7 +216,7 @@ internal async Task SendToGroupExceptTest(ServiceTransportType serviceTransportT
tcsDict.TryAdd(connection.ConnectionId, src);
connection.On(method, (string receivedMsg) =>
{
logger.LogInformation($"Connection {connection.ConnectionId} received msg : {receivedMsg}");
TestOutputHelper.WriteLine($"Connection {connection.ConnectionId} received msg : {receivedMsg}");
if (receivedMsg == msg)
{
src.SetResult();
Expand Down Expand Up @@ -912,6 +912,65 @@ public async Task RemoveNonexistentUserFromAllGroupsRestApiTest()
await context.UserGroups.RemoveFromAllGroupsAsync(Guid.NewGuid().ToString());
}

private static readonly IEnumerable<object[]> ListConnectionsInGroupTestData =
[
[6, 6, null, 6, 1],
[6, 3, null, 3, 1],
[6, null, 2, 6, 3],
[6, 5, 2, 5, 3],
];
public static readonly IEnumerable<object[]> ListConnectionsInGroupTestDataWithTransport =
from serviceTransportType in ServiceTransportType
from data in ListConnectionsInGroupTestData
select new object[] { serviceTransportType, data[0], data[1], data[2], data[3], data[4] };

[ConditionalTheory]
[SkipIfConnectionStringNotPresent]
[MemberData(nameof(ListConnectionsInGroupTestDataWithTransport))]
public async Task ListConnectionsInGroupTest(ServiceTransportType serviceTransportType, int totalConnectionCount, int? maxCountToList, int? maxPageSize, int expectedTotalCount, int expectedPageCount)
{
using var logger = StartLog(out var loggerFactory, nameof(ListConnectionsInGroupTest));
using var serviceManager = new ServiceManagerBuilder().WithOptions(o =>
{
o.ConnectionString = TestConfiguration.Instance.ConnectionString;
o.ServiceTransportType = serviceTransportType;
o.HttpClientTimeout = TimeSpan.FromHours(1);
})
.WithLoggerFactory(loggerFactory)
.BuildServiceManager();
using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default);
var groupName = nameof(ListConnectionsInGroupTest) + Guid.NewGuid().ToString();
var negotationResponse = await hubContext.NegotiateAsync();
var clientConnections = await CreateAndStartClientConnections(negotationResponse.Url, Enumerable.Repeat(negotationResponse.AccessToken, totalConnectionCount));
foreach (var connection in clientConnections)
{
await hubContext.Groups.AddToGroupAsync(connection.ConnectionId, groupName);
TestOutputHelper.WriteLine("Created connection: " + connection.ConnectionId);
}

var actualPageCount = 0;
var actualConnectionCount = 0;

await foreach (var page in hubContext.Groups.ListConnectionsInGroup(groupName, maxCountToList).AsPages(null, maxPageSize))
{
//actualPageCount++;
actualConnectionCount += page.Values.Count;
actualPageCount++;
TestOutputHelper.WriteLine($"The {actualPageCount} page:");
foreach (var connection in page.Values)
{
TestOutputHelper.WriteLine($"Listed connection: {connection.ConnectionId}");
}
}

Assert.Equal(expectedPageCount, actualPageCount);
Assert.Equal(expectedTotalCount, actualConnectionCount);
foreach (var connection in clientConnections)
{
await connection.StopAsync();
}
}

private static IDictionary<string, List<string>> GenerateUserGroupDict(IList<string> userNames, IList<string> groupNames)
{
return (from i in Enumerable.Range(0, userNames.Count)
Expand Down
Loading