diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs index 26676da84..ab333dc78 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs @@ -28,15 +28,23 @@ async Task ITokenAcquirer.GetTokenForUserAsync( { string? authenticationScheme = tokenAcquisitionOptions?.AuthenticationOptionsName ?? _authenticationScheme; + var effectiveOptions = GetEffectiveTokenAcquisitionOptions(tokenAcquisitionOptions, authenticationScheme, cancellationToken); var result = await _tokenAcquisition.GetAuthenticationResultForUserAsync( scopes, authenticationScheme, tokenAcquisitionOptions?.Tenant, tokenAcquisitionOptions?.UserFlow, user, - GetEffectiveTokenAcquisitionOptions(tokenAcquisitionOptions, authenticationScheme, cancellationToken) + effectiveOptions ).ConfigureAwait(false); + // Propagate LongRunningWebApiSessionKey (possibly auto-generated) back to the caller + if (tokenAcquisitionOptions is not null && effectiveOptions is not null + && !string.IsNullOrEmpty(effectiveOptions.LongRunningWebApiSessionKey)) + { + tokenAcquisitionOptions.LongRunningWebApiSessionKey = effectiveOptions.LongRunningWebApiSessionKey; + } + return new AcquireTokenResult( result.AccessToken, result.ExpiresOn, diff --git a/tests/Microsoft.Identity.Web.Test/TokenAcquirerTests.cs b/tests/Microsoft.Identity.Web.Test/TokenAcquirerTests.cs index 3fcab3b21..d9a1c776c 100644 --- a/tests/Microsoft.Identity.Web.Test/TokenAcquirerTests.cs +++ b/tests/Microsoft.Identity.Web.Test/TokenAcquirerTests.cs @@ -2,7 +2,9 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Reflection; +using System.Security.Claims; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -102,6 +104,163 @@ public async Task GetTokenForAppAsync_WithBindingCertificate_ReturnsAcquireToken Assert.Equal(_bindingCertificate.Thumbprint, result.BindingCertificate.Thumbprint); } + [Fact] + public async Task GetTokenForUserAsync_WithAutoSessionKey_PropagatesGeneratedKeyBackToCaller() + { + // Arrange + const string autoGeneratedKey = "test-auto-generated-key"; + var authResult = CreateMockAuthenticationResult(); + var callerOptions = new AcquireTokenOptions + { + LongRunningWebApiSessionKey = AcquireTokenOptions.LongRunningWebApiSessionKeyAuto, + }; + + _tokenAcquisition.GetAuthenticationResultForUserAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(callInfo => + { + // simulate TokenAcquisition writing the auto-generated key back to the options + var options = callInfo.ArgAt(5); + if (options is not null) + { + options.LongRunningWebApiSessionKey = autoGeneratedKey; + } + + return Task.FromResult(authResult); + }); + + var tokenAcquirer = new TokenAcquirer(_tokenAcquisition, _authenticationScheme); + + // Act + await ((ITokenAcquirer)tokenAcquirer).GetTokenForUserAsync( + new[] { _scope }, + callerOptions, + user: null, + CancellationToken.None); + + // Assert + Assert.Equal(autoGeneratedKey, callerOptions.LongRunningWebApiSessionKey); + } + + [Fact] + public async Task GetTokenForUserAsync_WithExplicitSessionKey_PropagatesKeyBackToCaller() + { + // Arrange + const string explicitKey = "test-key"; + var authResult = CreateMockAuthenticationResult(); + var callerOptions = new AcquireTokenOptions + { + LongRunningWebApiSessionKey = explicitKey, + }; + + _tokenAcquisition.GetAuthenticationResultForUserAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(callInfo => + { + // simulate that TokenAcquisition does not modify the key when an explicit key is provided + return Task.FromResult(authResult); + }); + + var tokenAcquirer = new TokenAcquirer(_tokenAcquisition, _authenticationScheme); + + // Act + await ((ITokenAcquirer)tokenAcquirer).GetTokenForUserAsync( + new[] { _scope }, + callerOptions, + user: null, + CancellationToken.None); + + // Assert + Assert.Equal(explicitKey, callerOptions.LongRunningWebApiSessionKey); + } + + [Fact] + public async Task GetTokenForUserAsync_WithNoSessionKey_SessionKeyRemainsNull() + { + // Arrange + var authResult = CreateMockAuthenticationResult(); + var callerOptions = new AcquireTokenOptions + { + // LongRunningWebApiSessionKey is not set + }; + + _tokenAcquisition.GetAuthenticationResultForUserAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(authResult); + + var tokenAcquirer = new TokenAcquirer(_tokenAcquisition, _authenticationScheme); + + // Act + await ((ITokenAcquirer)tokenAcquirer).GetTokenForUserAsync( + new[] { _scope }, + callerOptions, + user: null, + CancellationToken.None); + + // Assert + Assert.Null(callerOptions.LongRunningWebApiSessionKey); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + public async Task GetTokenForUserAsync_WhenEffectiveKeyIsNullOrEmpty_DoesNotOverwriteCallerKey(string? effectiveKeyValue) + { + // Arrange + const string originalKey = "test-key"; + var authResult = CreateMockAuthenticationResult(); + var callerOptions = new AcquireTokenOptions + { + LongRunningWebApiSessionKey = originalKey, + }; + + _tokenAcquisition.GetAuthenticationResultForUserAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(callInfo => + { + // simulate TokenAcquisition setting the key to a null or empty value + var options = callInfo.ArgAt(5); + if (options is not null) + { + options.LongRunningWebApiSessionKey = effectiveKeyValue; + } + + return Task.FromResult(authResult); + }); + + var tokenAcquirer = new TokenAcquirer(_tokenAcquisition, _authenticationScheme); + + // Act + await ((ITokenAcquirer)tokenAcquirer).GetTokenForUserAsync( + new[] { _scope }, + callerOptions, + user: null, + CancellationToken.None); + + // Assert + Assert.Equal(originalKey, callerOptions.LongRunningWebApiSessionKey); + } + private AuthenticationResult CreateMockAuthenticationResult(X509Certificate2? bindingCertificate = null) { var authResult = new AuthenticationResult(