diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiEndpoint.cs index c569c8da0..2f66a1cae 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiEndpoint.cs @@ -5,17 +5,18 @@ using Microsoft.Extensions.Primitives; -namespace Microsoft.Azure.SignalR +#nullable enable + +namespace Microsoft.Azure.SignalR; + +internal class RestApiEndpoint { - internal class RestApiEndpoint - { - public string Audience { get; } + public string Audience { get; } - public IDictionary Query { get; set; } + public IDictionary? Query { get; set; } - public RestApiEndpoint(string endpoint) - { - Audience = endpoint; - } + public RestApiEndpoint(string endpoint) + { + Audience = endpoint; } } diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs index bb76ad482..866aa34a4 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs @@ -30,7 +30,7 @@ public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder c _httpClientFactory = httpClientFactory; _payloadContentBuilder = contentBuilder; } - + // TODO: Test only, will remove later internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer())) { @@ -64,7 +64,16 @@ public Task SendWithRetryAsync( Func? handleExpectedResponse = null, CancellationToken cancellationToken = default) { - return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, AsAsync(handleExpectedResponse), cancellationToken); + return SendWithRetryAsync(api, httpMethod, AsAsync(handleExpectedResponse), cancellationToken); + } + + public Task SendWithRetryAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + Func>? handleExpectedResponseAsync = null, + CancellationToken cancellationToken = default) + { + return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken); } public Task SendMessageWithRetryAsync( @@ -186,7 +195,7 @@ private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMeth return GenerateHttpRequest(api.Audience, api.Query, httpMethod, body, typeHint); } - private HttpRequestMessage GenerateHttpRequest(string url, IDictionary query, HttpMethod httpMethod, HubMessage? body, Type? typeHint) + private HttpRequestMessage GenerateHttpRequest(string url, IDictionary? query, HttpMethod httpMethod, HubMessage? body, Type? typeHint) { var request = new HttpRequestMessage(httpMethod, GetUri(url, query)); request.Content = _payloadContentBuilder.Build(body, typeHint); diff --git a/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs b/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs index e85df9b37..f424fd1be 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs @@ -100,6 +100,11 @@ public RestApiEndpoint GetSendStreamCompletionEndpoint(string appName, string hu return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/streams/{Uri.EscapeDataString(streamId)}/:complete"); } + public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string hubName, string groupName) + { + return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections"); + } + private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary queries = null) { var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}"; diff --git a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs index 10afefca2..68f24204f 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs @@ -7,19 +7,21 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Azure; using Microsoft.AspNetCore.SignalR; -using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Primitives; using static Microsoft.Azure.SignalR.Constants; namespace Microsoft.Azure.SignalR.Management; +#nullable enable + internal class RestHubLifetimeManager : HubLifetimeManager, IServiceHubLifetimeManager where THub : Hub { private const string NullOrEmptyStringErrorMessage = "Argument cannot be null or empty."; @@ -91,12 +93,12 @@ public async Task RemoveFromAllGroupsAsync(string connectionId, CancellationToke await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken); } - public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) + public override Task SendAllAsync(string methodName, object?[] args, CancellationToken cancellationToken = default) { return SendAllExceptAsync(methodName, args, null, cancellationToken); } - public override async Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + public override async Task SendAllExceptAsync(string methodName, object?[] args, IReadOnlyList? excludedConnectionIds, CancellationToken cancellationToken = default) { if (string.IsNullOrEmpty(methodName)) { @@ -107,7 +109,7 @@ public override async Task SendAllExceptAsync(string methodName, object[] args, await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } - public override async Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) + public override async Task SendConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) { if (string.IsNullOrEmpty(methodName)) { @@ -123,17 +125,17 @@ public override async Task SendConnectionAsync(string connectionId, string metho await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } - public override async Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) + public override async Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object?[] args, CancellationToken cancellationToken = default) { await Task.WhenAll(connectionIds.Select(id => SendConnectionAsync(id, methodName, args, cancellationToken))); } - public override Task SendGroupAsync(string groupName, string methodName, object[] args, CancellationToken cancellationToken = default) + public override Task SendGroupAsync(string groupName, string methodName, object?[] args, CancellationToken cancellationToken = default) { return SendGroupExceptAsync(groupName, methodName, args, null, cancellationToken); } - public override async Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + public override async Task SendGroupExceptAsync(string groupName, string methodName, object?[] args, IReadOnlyList? excludedConnectionIds, CancellationToken cancellationToken = default) { if (string.IsNullOrEmpty(methodName)) { @@ -149,9 +151,9 @@ public override async Task SendGroupExceptAsync(string groupName, string methodN await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } - public override async Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args, CancellationToken cancellationToken = default) + public override async Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object?[] args, CancellationToken cancellationToken = default) { - Task all = null; + Task? all = null; try { @@ -160,11 +162,11 @@ public override async Task SendGroupsAsync(IReadOnlyList groupNames, str } catch { - throw all.Exception; + throw all!.Exception!; } } - public override async Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) + public override async Task SendUserAsync(string userId, string methodName, object?[] args, CancellationToken cancellationToken = default) { if (string.IsNullOrEmpty(methodName)) { @@ -180,9 +182,9 @@ public override async Task SendUserAsync(string userId, string methodName, objec await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } - public override async Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) + public override async Task SendUsersAsync(IReadOnlyList userIds, string methodName, object?[] args, CancellationToken cancellationToken = default) { - Task all = null; + Task? all = null; try { @@ -191,7 +193,7 @@ public override async Task SendUsersAsync(IReadOnlyList userIds, string } catch { - throw all.Exception; + throw all!.Exception!; } } @@ -348,7 +350,7 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId ["error"] = error, }; } - await _restClient.SendWithRetryAsync(api, HttpMethod.Post, cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken); } private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) => @@ -357,6 +359,76 @@ private static bool FilterExpectedResponse(HttpResponseMessage response, string public AsyncPageable ListConnectionsInGroup(string groupName, int? top = null, CancellationToken token = default) { - throw new NotImplementedException(); + if (string.IsNullOrWhiteSpace(groupName)) + { + throw new ArgumentException($"'{nameof(groupName)}' cannot be null or whitespace.", nameof(groupName)); + } + + if (top < 0) + { + throw new ArgumentOutOfRangeException(nameof(top), "The value must be greater than or equal to 0."); + } + + return new PageableGroupMember(FetchPages, token); + + async IAsyncEnumerable> FetchPages(string? continuationToken, int? pageSizeHint) + { + // Calculate the api for the first page + var api = _restApiProvider.GetListConnectionsInGroupEndpoint(_appName, _hubName, groupName); + if (top.HasValue) + { + api.Query = new Dictionary + { + ["top"] = top.Value.ToString(CultureInfo.InvariantCulture), + }; + } + if (pageSizeHint.HasValue) + { + api.Query ??= new Dictionary(); + api.Query["maxPageSize"] = pageSizeHint.Value.ToString(CultureInfo.InvariantCulture); + } + if (!string.IsNullOrEmpty(continuationToken)) + { + api.Query ??= new Dictionary(); + api.Query["continuationToken"] = continuationToken; + } + do + { + var page = await FetchSinglePage(api, token); + continuationToken = page.ContinuationToken; + yield return page; + if (page.ContinuationToken == null) + { + yield break; + } + if (top != null) + { + top -= page.Values.Count; + if (top <= 0) + { + yield break; + } + } + // Actually it's the next link + api = new RestApiEndpoint(page.ContinuationToken); + } while (true); + } + + async Task> FetchSinglePage(RestApiEndpoint api, CancellationToken cancellationToken = default) + { + var page = default(Page); + + await _restClient.SendWithRetryAsync(api, HttpMethod.Get, async response => + { + if (!response.IsSuccessStatusCode) + { + return false; + } + var contentStream = await response.Content.ReadAsStreamAsync(); + page = await JsonSerializer.DeserializeAsync(contentStream, cancellationToken: token); + return true; + }, cancellationToken: token); + return page!; + } } } diff --git a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs index 173e7ca79..de1ef2974 100644 --- a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs +++ b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs @@ -76,7 +76,7 @@ internal async Task BroadcastExceptTest(ServiceTransportType serviceTransportTyp var hubContext = await serviceManager.CreateHubContextAsync(HubName) as ServiceHubContextImpl; var connectionCount = 3; var tcsDict = new ConcurrentDictionary(); - logger.LogInformation($"Message is {msg}"); + TestOutputHelper.WriteLine($"Message is {msg}"); var connections = await Task.WhenAll(Enumerable.Range(0, connectionCount).Select(async _ => { var negotiationResponse = await hubContext.NegotiateAsync(null, default); @@ -86,7 +86,7 @@ internal async Task BroadcastExceptTest(ServiceTransportType serviceTransportTyp tcsDict.TryAdd(connection.ConnectionId, src); connection.On(method, (string receivedMsg) => { - logger.LogInformation($"Connection {connection.ConnectionId} received msg : {receivedMsg}"); + TestOutputHelper.WriteLine($"Connection {connection.ConnectionId} received msg : {receivedMsg}"); if (receivedMsg == msg) { src.SetResult(); @@ -206,7 +206,7 @@ internal async Task SendToGroupExceptTest(ServiceTransportType serviceTransportT var hubContext = await serviceManager.CreateHubContextAsync(HubName) as ServiceHubContextImpl; var connectionCount = 3; var tcsDict = new ConcurrentDictionary(); - logger.LogInformation($"Message is {msg}"); + TestOutputHelper.WriteLine($"Message is {msg}"); var connections = await Task.WhenAll(Enumerable.Range(0, connectionCount).Select(async _ => { var negotiationResponse = await hubContext.NegotiateAsync(null, default); @@ -216,7 +216,7 @@ internal async Task SendToGroupExceptTest(ServiceTransportType serviceTransportT tcsDict.TryAdd(connection.ConnectionId, src); connection.On(method, (string receivedMsg) => { - logger.LogInformation($"Connection {connection.ConnectionId} received msg : {receivedMsg}"); + TestOutputHelper.WriteLine($"Connection {connection.ConnectionId} received msg : {receivedMsg}"); if (receivedMsg == msg) { src.SetResult(); @@ -912,6 +912,65 @@ public async Task RemoveNonexistentUserFromAllGroupsRestApiTest() await context.UserGroups.RemoveFromAllGroupsAsync(Guid.NewGuid().ToString()); } + private static readonly IEnumerable ListConnectionsInGroupTestData = + [ + [6, 6, null, 6, 1], + [6, 3, null, 3, 1], + [6, null, 2, 6, 3], + [6, 5, 2, 5, 3], + ]; + public static readonly IEnumerable ListConnectionsInGroupTestDataWithTransport = + from serviceTransportType in ServiceTransportType + from data in ListConnectionsInGroupTestData + select new object[] { serviceTransportType, data[0], data[1], data[2], data[3], data[4] }; + + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [MemberData(nameof(ListConnectionsInGroupTestDataWithTransport))] + public async Task ListConnectionsInGroupTest(ServiceTransportType serviceTransportType, int totalConnectionCount, int? maxCountToList, int? maxPageSize, int expectedTotalCount, int expectedPageCount) + { + using var logger = StartLog(out var loggerFactory, nameof(ListConnectionsInGroupTest)); + using var serviceManager = new ServiceManagerBuilder().WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + o.HttpClientTimeout = TimeSpan.FromHours(1); + }) + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + var groupName = nameof(ListConnectionsInGroupTest) + Guid.NewGuid().ToString(); + var negotationResponse = await hubContext.NegotiateAsync(); + var clientConnections = await CreateAndStartClientConnections(negotationResponse.Url, Enumerable.Repeat(negotationResponse.AccessToken, totalConnectionCount)); + foreach (var connection in clientConnections) + { + await hubContext.Groups.AddToGroupAsync(connection.ConnectionId, groupName); + TestOutputHelper.WriteLine("Created connection: " + connection.ConnectionId); + } + + var actualPageCount = 0; + var actualConnectionCount = 0; + + await foreach (var page in hubContext.Groups.ListConnectionsInGroup(groupName, maxCountToList).AsPages(null, maxPageSize)) + { + //actualPageCount++; + actualConnectionCount += page.Values.Count; + actualPageCount++; + TestOutputHelper.WriteLine($"The {actualPageCount} page:"); + foreach (var connection in page.Values) + { + TestOutputHelper.WriteLine($"Listed connection: {connection.ConnectionId}"); + } + } + + Assert.Equal(expectedPageCount, actualPageCount); + Assert.Equal(expectedTotalCount, actualConnectionCount); + foreach (var connection in clientConnections) + { + await connection.StopAsync(); + } + } + private static IDictionary> GenerateUserGroupDict(IList userNames, IList groupNames) { return (from i in Enumerable.Range(0, userNames.Count)