diff --git a/api/Microsoft.Azure.SignalR.Protocols.netstandard2.0.cs b/api/Microsoft.Azure.SignalR.Protocols.netstandard2.0.cs index 998c65c61..c4fb259b4 100644 --- a/api/Microsoft.Azure.SignalR.Protocols.netstandard2.0.cs +++ b/api/Microsoft.Azure.SignalR.Protocols.netstandard2.0.cs @@ -213,6 +213,12 @@ public abstract partial class ExtensibleServiceMessage : Microsoft.Azure.SignalR { protected ExtensibleServiceMessage() { } } + public partial class GetConnectionClaimsMessage : Microsoft.Azure.SignalR.Protocol.ExtensibleServiceMessage, Microsoft.Azure.SignalR.Protocol.IAckableMessage + { + public GetConnectionClaimsMessage(string connectionToken, int ackId) { } + public int AckId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } } + public string ConnectionToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } } + } public partial class GroupBroadcastDataMessage : Microsoft.Azure.SignalR.Protocol.MulticastDataMessage, Microsoft.Azure.SignalR.Protocol.IPartitionableMessage { public GroupBroadcastDataMessage(string groupName, System.Collections.Generic.IDictionary> payloads, ulong? tracingId = default(ulong?)) : base (default(System.Collections.Generic.IDictionary>), default(ulong?)) { } @@ -489,6 +495,7 @@ public static partial class ServiceProtocolConstants public const int ConnectionFlowControlMessageType = 39; public const int ConnectionReconnectMessageType = 38; public const int ErrorCompletionMessageType = 36; + public const int GetConnectionClaimsMessageType = 42; public const int GroupBroadcastDataMessageType = 13; public const int GroupMemberQueryMessageType = 40; public const int HandshakeRequestType = 1; @@ -506,12 +513,19 @@ public static partial class ServiceProtocolConstants public const int ServiceErrorMessageType = 15; public const int ServiceEventMessageType = 22; public const int ServiceMappingMessageType = 37; + public const int UpdateConnectionClaimsMessageType = 43; public const int UserDataMessageType = 8; public const int UserJoinGroupMessageType = 16; public const int UserJoinGroupWithAckMessageType = 26; public const int UserLeaveGroupMessageType = 17; public const int UserLeaveGroupWithAckMessageType = 27; } + public partial class UpdateConnectionClaimsMessage : Microsoft.Azure.SignalR.Protocol.ExtensibleServiceMessage + { + public UpdateConnectionClaimsMessage(string connectionId, System.Security.Claims.Claim[]? claims) { } + public System.Security.Claims.Claim[]? Claims { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } } + public string ConnectionId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } } + } public partial class UserDataMessage : Microsoft.Azure.SignalR.Protocol.MulticastDataMessage, Microsoft.Azure.SignalR.Protocol.IPartitionableMessage { public UserDataMessage(string userId, System.Collections.Generic.IDictionary> payloads, ulong? tracingId = default(ulong?)) : base (default(System.Collections.Generic.IDictionary>), default(ulong?)) { } diff --git a/specs/ServiceProtocol.md b/specs/ServiceProtocol.md index 45b874568..e86e3e6ef 100644 --- a/specs/ServiceProtocol.md +++ b/specs/ServiceProtocol.md @@ -657,3 +657,27 @@ MessagePack uses different formats to encode values. Refer to the [MessagePack F - ExtensionMembers - A MessagePack Map indicates the extensible members. #### Example: TODO + +### GetConnectionClaims Message +`GetConnectionClaims` messages have the following structure: +``` +[42, ConnectionToken, AckId, ExtensionMembers] +``` +- 42 - Message Type, indicating this is a `GetConnectionClaims` message. +- ConnectionToken - A `String` indicating the connection token of the live client connection whose current user claims are being fetched. +- AckId - An `Int32` encoding Id number to identify the corresponding ack message. +- ExtensionMembers - A MessagePack Map indicates the extensible members. + +#### Example: TODO + +### UpdateConnectionClaims Message +`UpdateConnectionClaims` messages have the following structure: +``` +[43, ConnectionId, Claims, ExtensionMembers] +``` +- 43 - Message Type, indicating this is an `UpdateConnectionClaims` message. +- ConnectionId - A `String` indicating the connection id of the live client connection whose user claims are being updated on the owning server. +- Claims - A MessagePack Map of `String` to `String` indicating the refreshed user claims to apply. +- ExtensionMembers - A MessagePack Map indicates the extensible members. + +#### Example: TODO diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs index 7f55fb7a3..c7b8f7946 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs @@ -365,6 +365,60 @@ public RefreshAuthMessage(string connectionToken, System.Security.Claims.Claim[] } } + /// + /// A read-only message to fetch the current user claims of an existing client connection. + /// + public class GetConnectionClaimsMessage : ExtensibleServiceMessage, IAckableMessage + { + /// + /// Gets or sets the connection token that identifies the live client connection whose claims are being fetched. + /// + public string ConnectionToken { get; set; } + + /// + /// Gets or sets the protocol correlation id used to acknowledge this read operation. + /// + public int AckId { get; set; } + + /// + /// Initializes a new instance of the class. + /// + /// The connection token that identifies the live client connection. + /// The protocol correlation id used to acknowledge this read operation. + public GetConnectionClaimsMessage(string connectionToken, int ackId) + { + ConnectionToken = connectionToken ?? throw new ArgumentNullException(nameof(connectionToken)); + AckId = ackId; + } + } + + /// + /// A server-bound message that pushes refreshed user claims of a client connection to its owning app server. + /// + public class UpdateConnectionClaimsMessage : ExtensibleServiceMessage + { + /// + /// Gets or sets the connection id of the live client connection whose claims are being updated. + /// + public string ConnectionId { get; set; } + + /// + /// Gets or sets the refreshed user claims to apply on the owning app server. + /// + public System.Security.Claims.Claim[]? Claims { get; set; } + + /// + /// Initializes a new instance of the class. + /// + /// The connection id of the live client connection. + /// The refreshed user claims to apply on the owning app server. + public UpdateConnectionClaimsMessage(string connectionId, System.Security.Claims.Claim[]? claims) + { + ConnectionId = connectionId ?? throw new ArgumentNullException(nameof(connectionId)); + Claims = claims; + } + } + /// /// A handshake request message. /// diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs index 54f166f08..6a5c005a4 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs @@ -159,6 +159,10 @@ public bool TryParseMessage(ref ReadOnlySequence input, out ServiceMessage return CreateGroupMemberQueryMessage(ref reader, arrayLength); case ServiceProtocolConstants.RefreshAuthMessageType: return CreateRefreshAuthMessage(ref reader, arrayLength); + case ServiceProtocolConstants.GetConnectionClaimsMessageType: + return CreateGetConnectionClaimsMessage(ref reader, arrayLength); + case ServiceProtocolConstants.UpdateConnectionClaimsMessageType: + return CreateUpdateConnectionClaimsMessage(ref reader, arrayLength); default: // Future protocol changes can add message types, old clients can ignore them return null; @@ -348,6 +352,12 @@ private static void WriteMessageCore(ref MessagePackWriter writer, ServiceMessag case RefreshAuthMessage refreshAuthMessage: WriteRefreshAuthMessage(ref writer, refreshAuthMessage); break; + case GetConnectionClaimsMessage getConnectionClaimsMessage: + WriteGetConnectionClaimsMessage(ref writer, getConnectionClaimsMessage); + break; + case UpdateConnectionClaimsMessage updateConnectionClaimsMessage: + WriteUpdateConnectionClaimsMessage(ref writer, updateConnectionClaimsMessage); + break; default: throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); } @@ -812,6 +822,36 @@ private static void WriteRefreshAuthMessage(ref MessagePackWriter writer, Refres message.WriteExtensionMembers(ref writer); } + private static void WriteGetConnectionClaimsMessage(ref MessagePackWriter writer, GetConnectionClaimsMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.GetConnectionClaimsMessageType); + writer.Write(message.ConnectionToken); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } + + private static void WriteUpdateConnectionClaimsMessage(ref MessagePackWriter writer, UpdateConnectionClaimsMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.UpdateConnectionClaimsMessageType); + writer.Write(message.ConnectionId); + if (message.Claims?.Length > 0) + { + writer.WriteMapHeader(message.Claims.Length); + foreach (var claim in message.Claims) + { + writer.Write(claim.Type); + writer.Write(claim.Value); + } + } + else + { + writer.WriteMapHeader(0); + } + message.WriteExtensionMembers(ref writer); + } + private static void WriteStringArray(ref MessagePackWriter writer, IReadOnlyList? array) { if (array?.Count > 0) @@ -1463,4 +1503,22 @@ private static RefreshAuthMessage CreateRefreshAuthMessage(ref MessagePackReader return message; } + private static GetConnectionClaimsMessage CreateGetConnectionClaimsMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionToken = ReadStringNotNull(ref reader, "connectionToken"); + var ackId = ReadInt32(ref reader, "ackId"); + var message = new GetConnectionClaimsMessage(connectionToken, ackId); + message.ReadExtensionMembers(ref reader); + return message; + } + + private static UpdateConnectionClaimsMessage CreateUpdateConnectionClaimsMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadStringNotNull(ref reader, "connectionId"); + var claims = ReadClaims(ref reader); + var message = new UpdateConnectionClaimsMessage(connectionId, claims); + message.ReadExtensionMembers(ref reader); + return message; + } + } diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocolConstants.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocolConstants.cs index 5a12c28a1..0c72ff5c7 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocolConstants.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocolConstants.cs @@ -47,4 +47,6 @@ public static class ServiceProtocolConstants public const int ConnectionFlowControlMessageType = 39; public const int GroupMemberQueryMessageType = 40; public const int RefreshAuthMessageType = 41; + public const int GetConnectionClaimsMessageType = 42; + public const int UpdateConnectionClaimsMessageType = 43; } diff --git a/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceMessageEqualityComparer.cs b/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceMessageEqualityComparer.cs index 5c3852c9a..a1f72e63b 100644 --- a/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceMessageEqualityComparer.cs +++ b/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceMessageEqualityComparer.cs @@ -108,6 +108,10 @@ public bool Equals(ServiceMessage x, ServiceMessage y) return GroupMemberQueryMessageEqual(groupMemberQueryMessage, (GroupMemberQueryMessage)y); case RefreshAuthMessage refreshAuthMessage: return RefreshAuthMessageEqual(refreshAuthMessage, (RefreshAuthMessage)y); + case GetConnectionClaimsMessage getConnectionClaimsMessage: + return GetConnectionClaimsMessageEqual(getConnectionClaimsMessage, (GetConnectionClaimsMessage)y); + case UpdateConnectionClaimsMessage updateConnectionClaimsMessage: + return UpdateConnectionClaimsMessageEqual(updateConnectionClaimsMessage, (UpdateConnectionClaimsMessage)y); default: throw new InvalidOperationException($"Unknown message type: {x.GetType().FullName}"); } @@ -420,6 +424,18 @@ private static bool RefreshAuthMessageEqual(RefreshAuthMessage x, RefreshAuthMes x.ExpireTime.UtcDateTime == y.ExpireTime.UtcDateTime; } + private static bool GetConnectionClaimsMessageEqual(GetConnectionClaimsMessage x, GetConnectionClaimsMessage y) + { + return StringEqual(x.ConnectionToken, y.ConnectionToken) && + x.AckId == y.AckId; + } + + private static bool UpdateConnectionClaimsMessageEqual(UpdateConnectionClaimsMessage x, UpdateConnectionClaimsMessage y) + { + return StringEqual(x.ConnectionId, y.ConnectionId) && + ClaimsEqual(x.Claims, y.Claims); + } + private static bool StringEqual(string x, string y) { return string.Equals(x, y, StringComparison.Ordinal); diff --git a/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs b/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs index e61944a38..99cad3bf0 100644 --- a/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs +++ b/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs @@ -787,6 +787,17 @@ public static IEnumerable TestParseOldData new System.Security.Claims.Claim("role", "reader"), }, new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero), 1), binary: "limlY29ubjGBpHJvbGWmcmVhZGVy1v9lkgCAAYA="), + new ProtocolTestData( + name: "GetConnectionClaimsMessage", + message: new GetConnectionClaimsMessage("conn1", 1), + binary: "lCqlY29ubjEBgA=="), + new ProtocolTestData( + name: "UpdateConnectionClaimsMessage", + message: new UpdateConnectionClaimsMessage("conn1", new[] + { + new System.Security.Claims.Claim("role", "reader"), + }), + binary: "lCulY29ubjGBpHJvbGWmcmVhZGVygA=="), }.ToDictionary(t => t.Name); #pragma warning restore CS0618 // Type or member is obsolete