diff --git a/EXAMPLES.md b/EXAMPLES.md index a6e796c..b3aba25 100644 --- a/EXAMPLES.md +++ b/EXAMPLES.md @@ -6,6 +6,7 @@ - [Organizations](#organizations) - [Extra parameters](#extra-parameters) - [Roles](#roles) +- [Multiple Custom Domain (MCD) Support](#multiple-custom-domain-mcd-support) - [Backchannel Logout](#backchannel-logout) - [Blazor Server](#blazor-server) @@ -314,6 +315,162 @@ public IActionResult Admin() } ``` +## Multiple Custom Domain (MCD) Support + +Multiple Custom Domains (MCD) lets you resolve the Auth0 domain per request while keeping a single SDK instance. This is useful when one application serves multiple custom domains (for example, `brand-1.my-app.com` and `brand-2.my-app.com`), each mapped to a different `Auth0` custom domain. + +`MCD` is enabled by providing a `DomainResolver` function instead of a static domain string, enabling you to dynamically define the `Auth0` custom domain at run-time. + +Resolver mode is intended for the custom domains of a single `Auth0` tenant. It is not a supported way to connect multiple `Auth0` tenants to one application. + +### Dynamic Domain Resolver + +Provide a resolver function to select the domain at runtime. The resolver should return the `Auth0 Custom Domain` (for example, `brand-1.custom-domain.com`). Returning `null` or an empty value throws `InvalidOperationException`. + +### Configure with a DomainResolver + +Call `WithCustomDomains()` and provide a `DomainResolver` to resolve the domain dynamically based on the incoming request. The domain can be derived from a subdomain, request header, query parameter, or any other request attribute: + +```csharp +services.AddAuth0WebAppAuthentication(options => +{ + options.Domain = Configuration["Auth0:Domain"]; + options.ClientId = Configuration["Auth0:ClientId"]; +}) +.WithCustomDomains(options => +{ + // Example: resolve from a custom header + options.DomainResolver = httpContext => + { + var tenant = httpContext.Request.Headers["X-Tenant-Domain"].FirstOrDefault(); + return Task.FromResult(tenant ?? "default-tenant.auth0.com"); + }; +}); +``` + +### Resolve domain from subdomain + +```csharp +services.AddAuth0WebAppAuthentication(options => +{ + options.Domain = Configuration["Auth0:Domain"]; + options.ClientId = Configuration["Auth0:ClientId"]; +}) +.WithCustomDomains(options => +{ + // e.g., "acme.myapp.com" -> "acme.auth0.com" + options.DomainResolver = httpContext => + { + var host = httpContext.Request.Host.Host; + var subdomain = host.Split('.')[0]; + return Task.FromResult($"{subdomain}.auth0.com"); + }; +}); +``` + +### Redirect URI requirements + +When using MCD, the `redirectUri` must be an **absolute URL**. In MCD deployments, you will typically resolve the redirect URI per request so each domain uses the correct callback URL: + +```csharp +var authenticationProperties = new LoginAuthenticationPropertiesBuilder() + // Resolve redirect URI based on the incoming request's host + .WithRedirectUri($"{HttpContext.Request.Scheme}://{HttpContext.Request.Host}/callback") + .Build(); + +await HttpContext.ChallengeAsync(Auth0Constants.AuthenticationScheme, authenticationProperties); +``` + +You must validate the host and scheme safely for your deployment to prevent open redirect attacks. + +### Legacy sessions and migration + +When moving from a static domain setup to a `DomainResolver`, existing sessions can continue to work if the resolver returns the same Auth0 custom domain that was used for those legacy sessions. + +If the resolver returns a different domain, the SDK treats the session as missing and requires the user to sign in again. This is intentional to keep sessions isolated per domain. + +### Security requirements + +When configuring the `DomainResolver`, you are responsible for ensuring that all resolved domains are trusted. Mis-configuring the domain resolver is a critical security risk that can lead to authentication bypass on the relying party (RP) or expose the application to Server-Side Request Forgery (SSRF). + +**Single tenant limitation:** +The `DomainResolver` is intended solely for multiple custom domains belonging to the same Auth0 tenant. It is not a supported mechanism for connecting multiple Auth0 tenants to a single application. + +**Secure proxy requirement:** +When using MCD, your application must be deployed behind a secure edge or reverse proxy (e.g., Cloudflare, Nginx, or AWS ALB). The proxy must be configured to sanitize and overwrite `Host` and `X-Forwarded-Host` headers before they reach your application. + +Without a trusted proxy layer to validate these headers, an attacker can manipulate the domain resolution process. This can result in malicious redirects, where users are sent to unauthorized or fraudulent endpoints during the login and logout flows. + +### Configuration Manager Cache + +You can control how OpenID Connect configuration managers are cached per domain with `ConfigurationManagerCache`. + +By default, the SDK uses an in-memory cache with: +- `maxSize: 100` entries +- No expiration (entries remain until evicted by size pressure) + +The cache is keyed by the OIDC metadata endpoint URL (e.g., `https://brand-1.custom-domain.com/.well-known/openid-configuration`). Each distinct domain resolved by `DomainResolver` occupies one cache entry. + +Most applications can keep the defaults, but you may want to adjust them in the following cases: +- Increase `maxSize` if one process may verify tokens for more than 100 distinct domains during its lifetime. +- Decrease `maxSize` if memory usage matters more than avoiding repeated OIDC discovery setup. +- Set `slidingExpiration` if you want entries that haven't been accessed within a given duration to be evicted automatically. +- Use `NullConfigurationManagerCache` to disable caching entirely (not recommended for production). + +Rule of thumb: set `maxSize` to cover the number of distinct domains a single process is expected to serve, with some headroom. + +#### MemoryConfigurationManagerCache (Default) + +```csharp +.WithCustomDomains(options => +{ + options.DomainResolver = httpContext => { /* ... */ }; + + options.ConfigurationManagerCache = new MemoryConfigurationManagerCache( + maxSize: 100, // Maximum number of domains to cache + slidingExpiration: TimeSpan.FromHours(1) // Optional: evict entries not accessed within 1 hour + ); +}); +``` + +#### NullConfigurationManagerCache + +Disables caching entirely — a new configuration manager is created on every request (not recommended for production): + +```csharp +.WithCustomDomains(options => +{ + options.DomainResolver = httpContext => { /* ... */ }; + options.ConfigurationManagerCache = new NullConfigurationManagerCache(); +}); +``` + +#### Custom Cache Implementation + +Implement `IConfigurationManagerCache` for custom caching strategies (e.g., a distributed cache): + +```csharp +public class MyCustomConfigurationManagerCache : IConfigurationManagerCache +{ + public IConfigurationManager GetOrCreate( + string metadataAddress, + Func> factory) + { + // Return a cached instance or call factory(metadataAddress) to create one + } + + public void Clear() { /* Evict all entries */ } + public void Dispose() { /* Clean up resources */ } +} + +// Usage +.WithCustomDomains(options => +{ + options.DomainResolver = httpContext => { /* ... */ }; + options.ConfigurationManagerCache = new MyCustomConfigurationManagerCache(); +}); +``` + ## Backchannel Logout Backchannel logout can be configured by calling `WithBackchannelLogout()` when calling `AddAuth0WebAppAuthentication`. diff --git a/README.md b/README.md index ab577fb..20b78a1 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ A library based on `Microsoft.AspNetCore.Authentication.OpenIdConnect` to make i ![Downloads](https://img.shields.io/nuget/dt/auth0.aspnetcore.authentication) [![License](https://img.shields.io/:license-MIT-blue.svg?style=flat)](https://opensource.org/licenses/MIT) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/auth0/auth0-aspnetcore-authentication) -![AzureDevOps](https://img.shields.io/azure-devops/build/Auth0SDK/Auth0.AspNetCore.Authentication/8) +[![Build and Test](https://github.com/auth0/auth0-aspnetcore-authentication/actions/workflows/build.yml/badge.svg)](https://github.com/auth0/auth0-aspnetcore-authentication/actions/workflows/build.yml) :books: [Documentation](#documentation) - :rocket: [Getting Started](#getting-started) - :computer: [API Reference](#api-reference) - :speech_balloon: [Feedback](#feedback) @@ -15,12 +15,12 @@ A library based on `Microsoft.AspNetCore.Authentication.OpenIdConnect` to make i - [Quickstart](https://auth0.com/docs/quickstart/webapp/aspnet-core) - our interactive guide for quickly adding login, logout and user information to an ASP.NET MVC application using Auth0. - [Sample App](https://github.com/auth0-samples/auth0-aspnetcore-mvc-samples/tree/master/Quickstart/Sample) - a full-fledged ASP.NET MVC application integrated with Auth0. - [Examples](https://github.com/auth0/auth0-aspnetcore-authentication/blob/main/EXAMPLES.md) - code samples for common ASP.NET MVC authentication scenario's. -- [Docs site](https://www.auth0.com/docs) - explore our docs site and learn more about +- [Docs site](https://www.auth0.com/docs) - explore our docs site and learn more about Auth0. ## Getting started ### Requirements -This library supports .NET 6.0 and above. +This library supports .NET 6.0, 7.0, 8.0, and 10.0. ### Installation @@ -114,6 +114,33 @@ For more code samples on how to integrate the **auth0-aspnetcore-authentication* > This SDK also works with Blazor Server, for more info see [the Blazor Server section in our examples](https://github.com/auth0/auth0-aspnetcore-authentication/blob/main/EXAMPLES.md#blazor-server). +## Multiple Custom Domain (MCD) Support + +Multiple Custom Domains (MCD) lets you resolve the Auth0 domain per request while keeping a single SDK instance. This is useful when one application serves multiple custom domains (for example, `brand-1.my-app.com` and `brand-2.my-app.com`), each mapped to a different `Auth0` custom domain. + +Resolver mode is intended for the custom domains of a single `Auth0` tenant. It is not a supported way to connect multiple `Auth0` tenants to one application. + +### Configuration + +```csharp +services.AddAuth0WebAppAuthentication(options => +{ + options.Domain = Configuration["Auth0:Domain"]; + options.ClientId = Configuration["Auth0:ClientId"]; +}) +.WithCustomDomains(options => +{ + // Example: resolve from a custom header + options.DomainResolver = httpContext => + { + var tenant = httpContext.Request.Headers["X-Tenant-Domain"].FirstOrDefault(); + return Task.FromResult(tenant ?? "default-tenant.auth0.com"); + }; +}); +``` + +For detailed configuration options, caching strategies, security requirements, and more examples, see the [Multiple Custom Domain (MCD) Examples](EXAMPLES.md#multiple-custom-domain-mcd-support). + ## API reference Explore public API's available in auth0-aspnetcore-authentication. @@ -152,4 +179,4 @@ Please do not report security vulnerabilities on the public GitHub issue tracker

Auth0 is an easy to implement, adaptable authentication and authorization platform. To learn more checkout Why Auth0?

-This project is licensed under the MIT license. See the LICENSE file for more info.

\ No newline at end of file +This project is licensed under the MIT license. See the LICENSE file for more info.

\ No newline at end of file diff --git a/src/Auth0.AspNetCore.Authentication/Auth0Constants.cs b/src/Auth0.AspNetCore.Authentication/Auth0Constants.cs index 0377961..8d11564 100644 --- a/src/Auth0.AspNetCore.Authentication/Auth0Constants.cs +++ b/src/Auth0.AspNetCore.Authentication/Auth0Constants.cs @@ -14,5 +14,10 @@ public class Auth0Constants /// The callback path to which Auth0 should redirect back, used when configuring OpenIdConnect /// internal static string DefaultCallbackPath = "/callback"; + + /// + /// Key used to store the resolved domain in the authentication properties. + /// + internal static readonly string ResolvedDomainKey = "auth0:resolved-domain"; } } diff --git a/src/Auth0.AspNetCore.Authentication/Auth0WebAppAuthenticationBuilder.cs b/src/Auth0.AspNetCore.Authentication/Auth0WebAppAuthenticationBuilder.cs index b30ad36..c90611b 100644 --- a/src/Auth0.AspNetCore.Authentication/Auth0WebAppAuthenticationBuilder.cs +++ b/src/Auth0.AspNetCore.Authentication/Auth0WebAppAuthenticationBuilder.cs @@ -1,10 +1,13 @@ using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; using Microsoft.IdentityModel.Protocols.OpenIdConnect; using System; using System.Threading.Tasks; using Auth0.AspNetCore.Authentication.BackchannelLogout; +using Auth0.AspNetCore.Authentication.CustomDomains; +using Microsoft.AspNetCore.Hosting; namespace Auth0.AspNetCore.Authentication { @@ -64,6 +67,49 @@ public Auth0WebAppAuthenticationBuilder WithBackchannelLogout() return this; } + /// + /// Configures support for multiple Auth0 custom domains with dynamic domain resolution. + /// + /// A delegate used to configure the + /// An instance of + public Auth0WebAppAuthenticationBuilder WithCustomDomains(Action configureOptions) + { + EnableCustomDomains(configureOptions); + return this; + } + + private void EnableCustomDomains(Action configureOptions) + { + var customDomainsOptions = new Auth0CustomDomainsOptions(); + configureOptions(customDomainsOptions); + + // Validate that DomainResolver is configured + if (customDomainsOptions.DomainResolver == null) + { + throw new InvalidOperationException( + $"DomainResolver must be configured when using {nameof(WithCustomDomains)}. " + + $"Set the {nameof(Auth0CustomDomainsOptions.DomainResolver)} property to provide a function that resolves the Auth0 domain for each request."); + } + + // Register the options for this authentication scheme + _services.Configure(_authenticationScheme, configureOptions); + + // Register HttpContextAccessor - required for domain resolution + _services.AddHttpContextAccessor(); + + // Register HttpClient - required for fetching OIDC configuration per domain + _services.AddHttpClient(); + + // Register the startup filter to resolve domain early in the request pipeline + _services.TryAddEnumerable( + ServiceDescriptor.Singleton( + _ => new Auth0CustomDomainStartupFilter(_authenticationScheme))); + + // Register the post-configure options to set up custom ConfigurationManager + _services.TryAddEnumerable( + ServiceDescriptor.Singleton, Auth0CustomDomainsOpenIdConnectPostConfigureOptions>()); + } + private void EnableWithAccessToken(Action configureOptions) { var auth0WithAccessTokensOptions = new Auth0WebAppWithAccessTokenOptions(); diff --git a/src/Auth0.AspNetCore.Authentication/AuthenticationBuilderExtensions.cs b/src/Auth0.AspNetCore.Authentication/AuthenticationBuilderExtensions.cs index 3ebff6a..6058bf6 100644 --- a/src/Auth0.AspNetCore.Authentication/AuthenticationBuilderExtensions.cs +++ b/src/Auth0.AspNetCore.Authentication/AuthenticationBuilderExtensions.cs @@ -156,7 +156,16 @@ private static Func CreateOnValidatePrinci { await VerifyBackchannelLogoutSupport(context.HttpContext, oidcOptions); - var issuer = $"https://{options.Domain}/"; + // Prefer issuer from the authenticated principal + var resolvedIssuer = context.HttpContext.User?.FindFirst("iss")?.Value; + + // Fall back to the domain resolved by StartupFilter (cached in HttpContext.Items) + if (string.IsNullOrWhiteSpace(resolvedIssuer)) + { + resolvedIssuer = context.HttpContext.GetResolvedDomain(); + } + + var issuer = Utils.ToAuthority(resolvedIssuer ?? $"https://{options.Domain}/"); var sid = context.Principal?.FindFirst("sid")?.Value; var isLoggedOut = await logoutTokenHandler.IsLoggedOutAsync(issuer, sid); @@ -196,7 +205,7 @@ private static async Task RefreshTokenIfNeccesary(CookieValidatePrincipalContext if (isExpired && !string.IsNullOrWhiteSpace(refreshToken)) { - var result = await RefreshTokens(options, refreshToken, oidcOptions.Backchannel); + var result = await RefreshTokens(context.HttpContext, options, refreshToken, oidcOptions.Backchannel); if (result != null) { @@ -239,10 +248,14 @@ private static async Task RefreshTokenIfNeccesary(CookieValidatePrincipalContext } } - private static async Task RefreshTokens(Auth0WebAppOptions options, string refreshToken, HttpClient httpClient) + private static async Task RefreshTokens(HttpContext httpContext, Auth0WebAppOptions options, string refreshToken, HttpClient httpClient) { var tokenClient = new TokenClient(httpClient); - return await tokenClient.Refresh(options, refreshToken); + + // Get the resolved domain from HttpContext if available (for multiple custom domains) + var resolvedDomain = httpContext.GetResolvedDomain(); + + return await tokenClient.Refresh(options, refreshToken, resolvedDomain); } private static async Task VerifyBackchannelLogoutSupport(HttpContext context, OpenIdConnectOptions oidcOptions) diff --git a/src/Auth0.AspNetCore.Authentication/BackchannelLogout/BackchannelLogoutHandler.cs b/src/Auth0.AspNetCore.Authentication/BackchannelLogout/BackchannelLogoutHandler.cs index 0cf44e1..4f8a8e3 100644 --- a/src/Auth0.AspNetCore.Authentication/BackchannelLogout/BackchannelLogoutHandler.cs +++ b/src/Auth0.AspNetCore.Authentication/BackchannelLogout/BackchannelLogoutHandler.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Security.Claims; using System.Threading.Tasks; +using Auth0.AspNetCore.Authentication.CustomDomains; using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Http; @@ -18,7 +19,7 @@ public class BackchannelLogoutHandler private readonly ILogoutTokenHandler _tokenHandler; private readonly string _authenticationScheme; - public BackchannelLogoutHandler(ILogoutTokenHandler tokenHandler) + public BackchannelLogoutHandler(ILogoutTokenHandler tokenHandler) : this(tokenHandler, Auth0Constants.AuthenticationScheme) { } @@ -48,7 +49,17 @@ public async Task HandleRequestAsync(HttpContext context) .GetRequiredService>() .Get(_authenticationScheme); - var principal = await ValidateLogoutToken(logoutToken, oidcOptions, context); + var customDomainsOptions = context.RequestServices + .GetService>() + ?.Get(_authenticationScheme); + var isMcdEnabled = customDomainsOptions?.IsMultipleCustomDomainsEnabled == true; + + if (isMcdEnabled) + { + ValidateIssuerMatchesResolvedDomain(logoutToken, context); + } + + var principal = await ValidateLogoutToken(logoutToken, oidcOptions, context, isMcdEnabled); if (principal != null) { @@ -84,7 +95,46 @@ await context.WriteErrorAsync(400, "invalid_request", } } - private async Task ValidateLogoutToken(String token, OpenIdConnectOptions oidcOptions, HttpContext context) + /// + /// When MCD is enabled, extracts the issuer from the unverified token and validates it matches + /// the domain resolved for the current request. This check happens BEFORE full JWT + /// validation so we avoid fetching JWKS for tokens from the wrong tenant. + /// + private static void ValidateIssuerMatchesResolvedDomain(string token, HttpContext context) + { + var unverifiedIssuer = ExtractUnverifiedIssuer(token); + + var resolvedDomain = context.GetResolvedDomain(); + + if (string.IsNullOrWhiteSpace(resolvedDomain)) + { + throw new LogoutTokenValidationException( + "Unable to resolve domain for this request. Ensure DomainResolver is configured."); + } + + var normalizedIssuer = Utils.ToAuthority(unverifiedIssuer); + var normalizedResolved = Utils.ToAuthority(resolvedDomain); + + if (!string.Equals(normalizedIssuer, normalizedResolved)) + { + throw new LogoutTokenValidationException("Logout token issuer does not match the resolved domain."); + } + } + + /// + /// Reads the JWT without signature validation to extract the issuer claim. + /// Throws if the token is malformed. + /// Note: This does not validate the signature or any other JWT claims. + /// + private static string ExtractUnverifiedIssuer(string token) + { + var handler = new JwtSecurityTokenHandler(); + if (!handler.CanReadToken(token)) + throw new LogoutTokenValidationException("Logout token is malformed or not a valid JWT."); + return handler.ReadJwtToken(token).Issuer; + } + + private async Task ValidateLogoutToken(string token, OpenIdConnectOptions oidcOptions, HttpContext context, bool isMcdEnabled) { OpenIdConnectConfiguration? configuration = null; @@ -93,13 +143,19 @@ private async Task ValidateLogoutToken(String token, OpenIdConn configuration = await oidcOptions.ConfigurationManager.GetConfigurationAsync(context.RequestAborted); } + var validIssuer = isMcdEnabled + ? Utils.ToAuthority(context.GetResolvedDomain() + ?? throw new LogoutTokenValidationException( + "Unable to resolve domain for this request. Ensure DomainResolver is configured.")) + : oidcOptions.TokenValidationParameters.ValidIssuer; + var tokenValidationParameters = new TokenValidationParameters { ValidateAudience = true, ValidateIssuer = true, ValidateLifetime = true, RequireExpirationTime = true, - ValidIssuer = oidcOptions.TokenValidationParameters.ValidIssuer, + ValidIssuer = validIssuer, ValidAudience = oidcOptions.TokenValidationParameters.ValidAudience, }; diff --git a/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainStartupFilter.cs b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainStartupFilter.cs new file mode 100644 index 0000000..a477ee9 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainStartupFilter.cs @@ -0,0 +1,67 @@ +using System; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Auth0.AspNetCore.Authentication.CustomDomains; + +/// +/// A startup filter that integrates Auth0 custom domain resolution into the ASP.NET Core pipeline. +/// +internal sealed class Auth0CustomDomainStartupFilter : IStartupFilter +{ + private readonly string _auth0SchemeName; + + /// + /// Initializes a new instance of the class. + /// + /// The name of the Auth0 authentication scheme. + public Auth0CustomDomainStartupFilter(string auth0SchemeName) + => _auth0SchemeName = auth0SchemeName; + + /// + /// Configures the middleware pipeline to resolve and cache the Auth0 custom domain for each request. + /// + /// The next middleware configuration action in the pipeline. + /// An action that configures the application builder. + /// + /// This method registers middleware that: + /// + /// Runs before authentication to pre-resolve the domain + /// Invokes the configured DomainResolver if present + /// Caches the resolved domain in HttpContext.Items for the request lifetime + /// Throws if DomainResolver returns null/empty to fail fast + /// + /// The resolved domain is stored with key + /// and used by OpenIdConnect configuration managers and token endpoints. + /// + public Action Configure(Action next) + { + return app => + { + app.Use(async (ctx, nxt) => + { + // Retrieve the Auth0 custom domain options for the specified scheme. + var monitor = ctx.RequestServices.GetRequiredService>(); + var customDomainsOptions = monitor.Get(_auth0SchemeName); + + // If a DomainResolver is defined, resolve the issuer and cache it in the HttpContext. + if (customDomainsOptions.DomainResolver is not null) + { + var issuer = await customDomainsOptions.DomainResolver(ctx).ConfigureAwait(false); + if (string.IsNullOrWhiteSpace(issuer)) + throw new InvalidOperationException("DomainResolver returned empty issuer."); + + ctx.Items[Auth0Constants.ResolvedDomainKey] = issuer; + } + + // Proceed to the next middleware in the pipeline. + await nxt(); + }); + + // Invoke the next middleware configuration action. + next(app); + }; + } +} diff --git a/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOpenIdConnectConfigurationManager.cs b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOpenIdConnectConfigurationManager.cs new file mode 100644 index 0000000..e179af4 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOpenIdConnectConfigurationManager.cs @@ -0,0 +1,299 @@ +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; + +namespace Auth0.AspNetCore.Authentication.CustomDomains; + +/// +/// A custom implementation of that maintains +/// separate OpenID Connect configurations per Auth0 custom domain. +/// +/// +/// Resolves configurations dynamically based on the domain associated with each request, +/// enabling support for multiple Auth0 custom domains within a single application instance. +/// Each domain's configuration is cached independently using the provided . +/// Is registered as a singleton and maintain its cache throughout the application lifetime. +/// +internal sealed class Auth0CustomDomainsOpenIdConnectConfigurationManager : IConfigurationManager, IDisposable +{ + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly Func> _domainResolver; + private readonly ISecureDataFormat _stateDataFormat; + private readonly HttpClient _httpClient; + private readonly IConfigurationManagerCache _cache; + private readonly bool _ownsCache; + private bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP context accessor for retrieving the current request context. + /// The function to resolve the Auth0 domain from the HTTP context. + /// The secure data format for protecting/unprotecting authentication state. + /// The HTTP client for retrieving OpenID Connect configurations. + /// + /// The cache for configuration managers. If null, a default is used. + /// + /// Thrown when any required parameter is null. + public Auth0CustomDomainsOpenIdConnectConfigurationManager( + IHttpContextAccessor httpContextAccessor, + Func> domainResolver, + ISecureDataFormat stateDataFormat, + HttpClient httpClient, + IConfigurationManagerCache? cache = null) + { + _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); + _domainResolver = domainResolver ?? throw new ArgumentNullException(nameof(domainResolver)); + _stateDataFormat = stateDataFormat ?? throw new ArgumentNullException(nameof(stateDataFormat)); + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + + _cache = cache ?? new MemoryConfigurationManagerCache(); + _ownsCache = cache == null; + } + + /// + /// Retrieves the OpenID Connect configuration for the current request's resolved domain. + /// + /// A cancellation token to observe while waiting for the task to complete. + /// The OpenID Connect configuration for the resolved domain. + /// Thrown when HttpContext is unavailable, or domain resolution fails. + /// Thrown when the manager has been disposed. + public async Task GetConfigurationAsync(CancellationToken cancel) + { + ThrowIfDisposed(); + + var httpContext = _httpContextAccessor.HttpContext; + + if (httpContext == null) + { + throw new InvalidOperationException( + "HttpContext is not available. Ensure this method is called within an active HTTP request context."); + } + + var authority = await ResolveAuthorityAsync(httpContext).ConfigureAwait(false); + var metadataAddress = $"{authority.TrimEnd('/')}/.well-known/openid-configuration"; + + var manager = _cache.GetOrCreate(metadataAddress, CreateConfigurationManager); + + return await manager.GetConfigurationAsync(cancel).ConfigureAwait(false); + } + + /// + /// Requests that all cached configurations be refreshed on their next access. + /// + /// + /// Clears the cache, forcing new configuration managers to be created on subsequent requests. + /// + public void RequestRefresh() + { + if (_disposed) + { + return; + } + + _cache.Clear(); + } + + /// + /// Resolves the Auth0 authority (issuer URL) for the current request. + /// + /// The current HTTP context. + /// The resolved authority URL. + /// Thrown when domain resolution fails. + internal async Task ResolveAuthorityAsync(HttpContext context) + { + var hasState = TryGetState(context, out var state); + + // In case of a callback request, extracts the issuer from the state parameter. + if (hasState && TryGetIssuerFromState(state, out var stateIssuer)) + { + var stateAuthority = Utils.ToAuthority(stateIssuer); + + // Cross-validate: if the StartupFilter already resolved a domain for this request, + // ensure it matches the domain stored in the encrypted state. A mismatch indicates + // the request arrived on a different domain than the one that initiated the flow. + if (context.Items[Auth0Constants.ResolvedDomainKey] is string middlewareDomain && + !string.IsNullOrWhiteSpace(middlewareDomain)) + { + var middlewareAuthority = Utils.ToAuthority(middlewareDomain); + if (!stateAuthority.Equals(middlewareAuthority, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException( + $"Domain mismatch: the callback request arrived on domain '{middlewareDomain}' " + + $"but the authentication transaction was initiated with domain '{stateIssuer}'. " + + "This may indicate a cross-domain replay or misconfigured routing."); + } + } + + return stateAuthority; + } + + // If the request carries a state parameter (i.e. it looks like a callback) but the domain + // could not be extracted from state, fail explicitly rather than falling back to the + // DomainResolver, which could return a different domain than the one that started the flow. + if (hasState) + { + throw new InvalidOperationException( + "The request contains a 'state' parameter but the resolved domain could not be " + + "extracted from it. This may indicate a tampered, expired, or malformed state. " + + "The authentication transaction cannot be safely completed."); + } + + // Check if the domain was already resolved earlier in the request pipeline + if (context.Items[Auth0Constants.ResolvedDomainKey] is string cachedDomain && + !string.IsNullOrWhiteSpace(cachedDomain)) + { + return Utils.ToAuthority(cachedDomain); + } + + // Invoke the domain resolver to determine the domain for this request + var resolved = await _domainResolver(context).ConfigureAwait(false); + + if (string.IsNullOrWhiteSpace(resolved)) + { + throw new InvalidOperationException( + "DomainResolver returned a null or empty value. " + + "Ensure the configured resolver returns a valid Auth0 domain."); + } + + // Cache the resolved domain for subsequent use in this request + context.Items[Auth0Constants.ResolvedDomainKey] = resolved; + return Utils.ToAuthority(resolved); + } + + /// + /// Attempts to extract the state parameter from the incoming request. + /// + /// The HTTP context. + /// The extracted state value, if found. + /// True if state was found; otherwise, false. + /// + /// Checks both query string (GET requests) and form data (POST requests). + /// + internal static bool TryGetState(HttpContext context, out string? state) + { + // Check query string first (most common for OAuth/OIDC callbacks) + if (context.Request.Query.TryGetValue("state", out var queryState) && + !string.IsNullOrWhiteSpace(queryState)) + { + state = queryState.ToString(); + return true; + } + + // Check form data for POST callbacks + if (context.Request.HasFormContentType && + context.Request.Form.TryGetValue("state", out var formState) && + !string.IsNullOrWhiteSpace(formState)) + { + state = formState.ToString(); + return true; + } + + state = null; + return false; + } + + /// + /// Attempts to extract the issuer (domain) from a protected state parameter. + /// + /// The protected state string. + /// The extracted issuer, if found. + /// True if the issuer was successfully extracted; otherwise, false. + /// + /// This method safely handles malformed or tampered state parameters by catching + /// deserialization exceptions. This is expected behavior for invalid/expired state. + /// + internal bool TryGetIssuerFromState(string? state, out string issuer) + { + issuer = string.Empty; + + if (string.IsNullOrWhiteSpace(state)) + { + return false; + } + + AuthenticationProperties? props; + try + { + props = _stateDataFormat.Unprotect(state); + } + catch (Exception ex) when (ex is System.Security.Cryptography.CryptographicException or + FormatException or + ArgumentException) + { + // State parameter is invalid, malformed, or has been tampered with + // This is expected in certain scenarios (e.g., expired/corrupted state) + return false; + } + + if (props?.Items == null) + { + return false; + } + + if (props.Items.TryGetValue(Auth0Constants.ResolvedDomainKey, out var value) && + !string.IsNullOrWhiteSpace(value)) + { + issuer = value; + return true; + } + + return false; + } + + /// + /// Creates a new configuration manager for a specific metadata address. + /// + /// The OpenID Connect metadata endpoint URL. + /// A configured instance of . + internal IConfigurationManager CreateConfigurationManager(string address) + { + var retriever = new HttpDocumentRetriever(_httpClient) + { + RequireHttps = address.StartsWith("https://", StringComparison.OrdinalIgnoreCase) + }; + + return new ConfigurationManager( + address, + new OpenIdConnectConfigurationRetriever(), + retriever); + } + + /// + /// Throws an if this instance has been disposed. + /// + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + + /// + /// Releases all resources used by this instance. + /// + /// + /// Disposes the cache only if it was created internally (not provided by the user). + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + if (_ownsCache) + { + _cache.Dispose(); + } + } +} \ No newline at end of file diff --git a/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOpenIdConnectPostConfigureOptions.cs b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOpenIdConnectPostConfigureOptions.cs new file mode 100644 index 0000000..0da8146 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOpenIdConnectPostConfigureOptions.cs @@ -0,0 +1,106 @@ +using System; +using System.Net.Http; +using Microsoft.AspNetCore.Authentication.OpenIdConnect; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Options; + +namespace Auth0.AspNetCore.Authentication.CustomDomains +{ + /// + /// Post-configures to support Auth0 multiple custom domains. + /// + /// + /// This configurator sets up a custom + /// that maintains separate OpenID Connect configurations per domain, enabling dynamic issuer resolution + /// based on the current request context. + /// + internal sealed class Auth0CustomDomainsOpenIdConnectPostConfigureOptions : IPostConfigureOptions + { + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly IOptionsMonitor _auth0CustomDomainsOptionsMonitor; + private readonly IHttpClientFactory? _httpClientFactory; + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP context accessor for retrieving the current request context. + /// The options monitor for Auth0 custom domains configuration. + /// Optional HTTP client factory for creating HTTP clients. + /// If not provided, the OpenIdConnect backchannel will be used. + /// Thrown when required parameters are null. + public Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + IHttpContextAccessor httpContextAccessor, + IOptionsMonitor auth0CustomDomainsOptionsMonitor, + IHttpClientFactory? httpClientFactory = null) + { + ArgumentNullException.ThrowIfNull(httpContextAccessor); + ArgumentNullException.ThrowIfNull(auth0CustomDomainsOptionsMonitor); + + _httpContextAccessor = httpContextAccessor; + _auth0CustomDomainsOptionsMonitor = auth0CustomDomainsOptionsMonitor; + _httpClientFactory = httpClientFactory; + } + + /// + /// Post-configures the specified with Auth0 custom domains support. + /// + /// The name of the options instance being configured. + /// The options instance to configure. + /// Thrown when StateDataFormat is not configured. + public void PostConfigure(string? name, OpenIdConnectOptions options) + { + if (string.IsNullOrEmpty(name)) + { + return; + } + + var auth0CustomDomainsOptions = _auth0CustomDomainsOptionsMonitor.Get(name); + + if (!auth0CustomDomainsOptions.IsMultipleCustomDomainsEnabled) + { + return; + } + + // Ensure DomainResolver is configured + if (auth0CustomDomainsOptions.DomainResolver is null) + { + throw new InvalidOperationException( + $"DomainResolver must be configured when custom domains are enabled. " + + $"Set the {nameof(Auth0CustomDomainsOptions.DomainResolver)} property in the {nameof(Auth0CustomDomainsOptions)} configuration."); + } + + // Ensure we have a StateDataFormat for extracting the issuer on callback requests. + if (options.StateDataFormat is null) + { + throw new InvalidOperationException( + $"OpenIdConnectOptions.StateDataFormat is not configured. " + + $"This is required for Auth0 custom domains support. " + + $"Ensure the OpenIdConnect authentication scheme is properly configured."); + } + + if (options.Backchannel is null && _httpClientFactory is null) + { + throw new InvalidOperationException( + $"Either OpenIdConnectOptions.Backchannel or IHttpClientFactory must be configured. " + + $"Configure a Backchannel HttpClient on OpenIdConnectOptions or register IHttpClientFactory in the service collection."); + } + + var httpClient = options.Backchannel ?? _httpClientFactory!.CreateClient(); + + options.ConfigurationManager = new Auth0CustomDomainsOpenIdConnectConfigurationManager( + _httpContextAccessor, + auth0CustomDomainsOptions.DomainResolver, + options.StateDataFormat, + httpClient, + auth0CustomDomainsOptions.ConfigurationManagerCache); + + // The issuer varies per request, so we can't validate against a single static issuer string. + // Issuer validation will instead be performed via the OnTokenValidated event. + options.TokenValidationParameters.ValidateIssuer = false; + + // Since Domain Resolver is set, this value will be set dynamically, so we clear it here. + options.Authority = null; + + } + } +} diff --git a/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOptions.cs b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOptions.cs new file mode 100644 index 0000000..40e0371 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/CustomDomains/Auth0CustomDomainsOptions.cs @@ -0,0 +1,64 @@ +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Auth0.AspNetCore.Authentication.CustomDomains; + +/// +/// Options for configuring Auth0 custom domains support. +/// +public class Auth0CustomDomainsOptions +{ + /// + /// Resolves the Domain (issuer) for the current request. + /// + /// + /// This function is called for each authentication request to dynamically determine + /// which Auth0 custom domain should handle the request. The returned value should + /// be just the domain without protocol or paths. + /// + /// + /// + /// options.DomainResolver = async (context) => + /// { + /// var tenant = context.Request.Host.Host.Split('.').First(); + /// return $"{tenant}.auth0.com"; + /// }; + /// + /// + public Func>? DomainResolver { get; set; } + + /// + /// Cache implementation for OpenID Connect configuration managers. + /// + /// + /// + /// If not set, a default is used with + /// 100 entries and no expiration. + /// + /// + /// To customize the default cache settings: + /// + /// options.ConfigurationManagerCache = new MemoryConfigurationManagerCache( + /// maxSize: 50, + /// slidingExpiration: TimeSpan.FromHours(1) + /// ); + /// + /// + /// + /// To disable caching: + /// + /// options.ConfigurationManagerCache = new NullConfigurationManagerCache(); + /// + /// + /// + /// To provide a custom cache implementation, implement . + /// + /// + public IConfigurationManagerCache? ConfigurationManagerCache { get; set; } + + /// + /// Indicates whether multiple custom domains are enabled by checking if is set. + /// + internal bool IsMultipleCustomDomainsEnabled => DomainResolver != null; +} \ No newline at end of file diff --git a/src/Auth0.AspNetCore.Authentication/CustomDomains/IConfigurationManagerCache.cs b/src/Auth0.AspNetCore.Authentication/CustomDomains/IConfigurationManagerCache.cs new file mode 100644 index 0000000..b9ea370 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/CustomDomains/IConfigurationManagerCache.cs @@ -0,0 +1,56 @@ +using System; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; + +namespace Auth0.AspNetCore.Authentication.CustomDomains; + +/// +/// Abstraction for caching OpenID Connect configuration managers. +/// +/// +/// Implement this interface to provide custom caching behavior for configuration managers. +/// The SDK provides two built-in implementations: +/// +/// - Default in-memory cache using MemoryCache +/// - No-op cache that disables caching +/// +/// +/// +/// +/// // Custom implementation example +/// public class CustomCache : IConfigurationManagerCache +/// { +/// public IConfigurationManager<OpenIdConnectConfiguration> GetOrCreate( +/// string metadataAddress, +/// Func<string, IConfigurationManager<OpenIdConnectConfiguration>> factory) +/// { +/// // Your caching logic here +/// // returns a ConfigurationManager Instance from cache or create a new one; +/// } +/// +/// public void Clear() { /* Clear your cache */ } +/// public void Dispose() { /* Cleanup resources */ } +/// } +/// +/// +public interface IConfigurationManagerCache : IDisposable +{ + /// + /// Gets an existing configuration manager from the cache or creates a new one using the factory. + /// + /// The OIDC metadata endpoint URL, used as the cache key. + /// Factory function to create a new configuration manager if not cached. + /// The cached or newly created configuration manager. + IConfigurationManager GetOrCreate( + string metadataAddress, + Func> factory); + + /// + /// Clears all cached entries. + /// + /// + /// This method is called when is invoked + /// on the parent configuration manager. + /// + void Clear(); +} diff --git a/src/Auth0.AspNetCore.Authentication/CustomDomains/MemoryConfigurationManagerCache.cs b/src/Auth0.AspNetCore.Authentication/CustomDomains/MemoryConfigurationManagerCache.cs new file mode 100644 index 0000000..bf8e673 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/CustomDomains/MemoryConfigurationManagerCache.cs @@ -0,0 +1,135 @@ +using System; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; + +namespace Auth0.AspNetCore.Authentication.CustomDomains; + +/// +/// Default in-memory cache implementation using . +/// +/// +/// Default cache used when no custom is provided. +/// Supports configurable size limits and sliding expiration. +/// +/// +/// +/// // With default settings (100 entries, no expiration) +/// options.ConfigurationManagerCache = new MemoryConfigurationManagerCache(); +/// +/// // With custom settings +/// options.ConfigurationManagerCache = new MemoryConfigurationManagerCache( +/// maxSize: 50, +/// slidingExpiration: TimeSpan.FromHours(1) +/// ); +/// +/// +public sealed class MemoryConfigurationManagerCache : IConfigurationManagerCache +{ + /// + /// The default maximum number of entries in the cache. + /// + public const int DefaultMaxSize = 100; + + private readonly MemoryCache _cache; + private readonly TimeSpan? _slidingExpiration; + private readonly object _lock = new(); + private bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// + /// Maximum number of configuration managers to cache. Default is 100. + /// + /// + /// Optional sliding expiration for cached entries. When set, entries that haven't + /// been accessed within this duration will be removed on the next cache operation. + /// + /// Thrown when maxSize is less than 1. + public MemoryConfigurationManagerCache( + int maxSize = DefaultMaxSize, + TimeSpan? slidingExpiration = null) + { + if (maxSize < 1) + { + throw new ArgumentOutOfRangeException(nameof(maxSize), "MaxSize must be at least 1."); + } + + _slidingExpiration = slidingExpiration; + _cache = new MemoryCache(new MemoryCacheOptions + { + SizeLimit = maxSize + }); + } + + /// + public IConfigurationManager GetOrCreate( + string metadataAddress, + Func> factory) + { + ThrowIfDisposed(); + + // Fast path: check cache without lock + if (_cache.TryGetValue(metadataAddress, out IConfigurationManager? cached) && cached != null) + { + return cached; + } + + // Slow path: acquire lock and double-check + lock (_lock) + { + // Double-check after acquiring lock + if (_cache.TryGetValue(metadataAddress, out cached) && cached != null) + { + return cached; + } + + var manager = factory(metadataAddress); + + var cacheOptions = new MemoryCacheEntryOptions { Size = 1 }; + if (_slidingExpiration.HasValue) + { + cacheOptions.SlidingExpiration = _slidingExpiration.Value; + } + + _cache.Set(metadataAddress, manager, cacheOptions); + + return manager; + } + } + + /// + public void Clear() + { + if (_disposed) + { + return; + } + + _cache.Compact(1.0); + } + + /// + /// Throws an if this instance has been disposed. + /// + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + _cache.Dispose(); + } +} diff --git a/src/Auth0.AspNetCore.Authentication/CustomDomains/NullConfigurationManagerCache.cs b/src/Auth0.AspNetCore.Authentication/CustomDomains/NullConfigurationManagerCache.cs new file mode 100644 index 0000000..97f9c48 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/CustomDomains/NullConfigurationManagerCache.cs @@ -0,0 +1,51 @@ +using System; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; + +namespace Auth0.AspNetCore.Authentication.CustomDomains; + +/// +/// A pass-through cache implementation that does not cache configuration managers. +/// +/// +/// Use this implementation when caching should be completely disabled. +/// Every call to will invoke the factory to create a new +/// configuration manager, which may impact performance but ensures fresh configuration manager instances. +/// +/// +/// +/// // Disable caching +/// options.ConfigurationManagerCache = new NullConfigurationManagerCache(); +/// +/// +public sealed class NullConfigurationManagerCache : IConfigurationManagerCache +{ + /// + /// + /// This implementation always invokes the factory and never caches the result. + /// + public IConfigurationManager GetOrCreate( + string metadataAddress, + Func> factory) + { + return factory(metadataAddress); + } + + /// + /// + /// This is a no-op since nothing is cached. + /// + public void Clear() + { + // No-op: nothing to clear + } + + /// + /// + /// This is a no-op since there are no resources to dispose. + /// + public void Dispose() + { + // No-op: nothing to dispose + } +} diff --git a/src/Auth0.AspNetCore.Authentication/Extensions.cs b/src/Auth0.AspNetCore.Authentication/Extensions.cs new file mode 100644 index 0000000..297fa23 --- /dev/null +++ b/src/Auth0.AspNetCore.Authentication/Extensions.cs @@ -0,0 +1,20 @@ +using Microsoft.AspNetCore.Http; + +namespace Auth0.AspNetCore.Authentication; + +internal static class Extensions +{ + /// + /// Retrieves the resolved domain from the collection. + /// + /// The current HTTP context. + /// + /// The resolved domain as a string if present; otherwise, null. + /// + internal static string? GetResolvedDomain(this HttpContext httpContext) + { + return httpContext.Items.TryGetValue(Auth0Constants.ResolvedDomainKey, out var domainObj) + ? domainObj as string + : null; + } +} \ No newline at end of file diff --git a/src/Auth0.AspNetCore.Authentication/OpenIdConnectEventsFactory.cs b/src/Auth0.AspNetCore.Authentication/OpenIdConnectEventsFactory.cs index 21b4697..609c3c8 100644 --- a/src/Auth0.AspNetCore.Authentication/OpenIdConnectEventsFactory.cs +++ b/src/Auth0.AspNetCore.Authentication/OpenIdConnectEventsFactory.cs @@ -1,10 +1,14 @@ using Microsoft.AspNetCore.Authentication.OpenIdConnect; +using Microsoft.AspNetCore.Http; using Microsoft.IdentityModel.Tokens; using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Auth0.AspNetCore.Authentication.PushedAuthorizationRequest; +using Auth0.AspNetCore.Authentication.CustomDomains; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; namespace Auth0.AspNetCore.Authentication { @@ -14,13 +18,19 @@ internal static OpenIdConnectEvents Create(Auth0WebAppOptions auth0Options, Open { return new OpenIdConnectEvents { - OnRedirectToIdentityProvider = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnRedirectToIdentityProvider, CreateOnRedirectToIdentityProvider(auth0Options, oidcOptions)), - OnRedirectToIdentityProviderForSignOut = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnRedirectToIdentityProviderForSignOut, CreateOnRedirectToIdentityProviderForSignOut(auth0Options)), - OnTokenValidated = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnTokenValidated, CreateOnTokenValidated(auth0Options)), + OnRedirectToIdentityProvider = + ProxyEvent(auth0Options.OpenIdConnectEvents?.OnRedirectToIdentityProvider, + CreateOnRedirectToIdentityProvider(auth0Options, oidcOptions)), + OnRedirectToIdentityProviderForSignOut = ProxyEvent( + auth0Options.OpenIdConnectEvents?.OnRedirectToIdentityProviderForSignOut, + CreateOnRedirectToIdentityProviderForSignOut(auth0Options)), + OnTokenValidated = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnTokenValidated, + CreateOnTokenValidated(auth0Options)), OnAccessDenied = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnAccessDenied), OnAuthenticationFailed = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnAuthenticationFailed), - OnAuthorizationCodeReceived = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnAuthorizationCodeReceived, CreateOnAuthorizationCodeReceived(auth0Options)), + OnAuthorizationCodeReceived = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnAuthorizationCodeReceived, + CreateOnAuthorizationCodeReceived(auth0Options)), OnMessageReceived = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnMessageReceived), OnRemoteFailure = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnRemoteFailure), OnRemoteSignOut = ProxyEvent(auth0Options.OpenIdConnectEvents?.OnRemoteSignOut), @@ -47,10 +57,36 @@ private static Func ProxyEvent(Func? originalHandler, Func< }; } - private static Func CreateOnRedirectToIdentityProvider(Auth0WebAppOptions auth0Options, OpenIdConnectOptions oidcOptions) + private static Func CreateOnRedirectToIdentityProvider(Auth0WebAppOptions auth0Options, + OpenIdConnectOptions oidcOptions) { return async (context) => { + // Store the resolved domain in the authentication state (Properties.Items) so it can be validated + // when the token returns. The StartupFilter already resolved it and cached it in HttpContext.Items. + var customDomainsOptions = context.HttpContext.RequestServices + .GetService>() + ?.Get(context.Scheme.Name); + + if (customDomainsOptions is { IsMultipleCustomDomainsEnabled: true }) + { + var resolvedDomain = context.HttpContext.GetResolvedDomain(); + if (string.IsNullOrWhiteSpace(resolvedDomain)) + { + // Cannot proceed without a resolved domain — issuer validation would fail later anyway. + // Fail early with a clear message rather than allowing a round-trip to Auth0. + context.HandleResponse(); + context.Response.StatusCode = 500; + await context.Response.WriteAsync( + "Authentication configuration error: could not resolve the domain for this request. " + + "Ensure the DomainResolver is configured and the Auth0 middleware is registered in the pipeline."); + return; + } + + // Adds to the encrypted state parameter that will be available even in callbacks + context.Properties.Items[Auth0Constants.ResolvedDomainKey] = resolvedDomain; + } + // Set auth0Client querystring parameter for /authorize context.ProtocolMessage.SetParameter("auth0Client", Utils.CreateAgentString()); @@ -59,7 +95,8 @@ private static Func CreateOnRedirectToIdentityProvider(Au context.ProtocolMessage.SetParameter(extraParam.Key, extraParam.Value); } - if (!string.IsNullOrWhiteSpace(auth0Options.Organization) && !context.Properties.Items.ContainsKey(Auth0AuthenticationParameters.Organization)) + if (!string.IsNullOrWhiteSpace(auth0Options.Organization) && + !context.Properties.Items.ContainsKey(Auth0AuthenticationParameters.Organization)) { context.Properties.Items[Auth0AuthenticationParameters.Organization] = auth0Options.Organization; } @@ -71,11 +108,23 @@ private static Func CreateOnRedirectToIdentityProvider(Au }; } - private static Func CreateOnRedirectToIdentityProviderForSignOut(Auth0WebAppOptions auth0Options) + private static Func CreateOnRedirectToIdentityProviderForSignOut( + Auth0WebAppOptions auth0Options) { return (context) => { - var logoutUri = $"https://{auth0Options.Domain}/v2/logout?client_id={auth0Options.ClientId}"; + // Prefer issuer from the authenticated principal + var issuer = context.HttpContext.User?.FindFirst("iss")?.Value; + + // Fall back to the domain resolved by StartupFilter (cached in HttpContext.Items) + if (string.IsNullOrWhiteSpace(issuer)) + { + issuer = context.HttpContext.GetResolvedDomain(); + } + + var authority = Utils.ToAuthority(issuer ?? $"https://{auth0Options.Domain}"); + var logoutUri = $"{authority.TrimEnd('/')}/v2/logout?client_id={auth0Options.ClientId}"; + var postLogoutUri = context.Properties.RedirectUri; var parameters = GetExtraParameters(context.Properties.Items); @@ -123,29 +172,91 @@ private static Func CreateOnTokenValidated(Auth0Web context.Fail(ex.Message); } + // When the issuer is resolved per request, validate it against the issuer stored in the protected state. + // This is important because we would have skipped issuer validation in the case of Multiple Custom Domains. + var customDomainsOptions = context.HttpContext.RequestServices + .GetService>() + ?.Get(context.Scheme.Name); + + if (customDomainsOptions is { IsMultipleCustomDomainsEnabled: true }) + { + if (context.Properties?.Items == null || + !context.Properties.Items.TryGetValue(Auth0Constants.ResolvedDomainKey, out var expectedIssuer) || + string.IsNullOrWhiteSpace(expectedIssuer)) + { + // In multi-domain mode, static issuer validation is disabled (ValidateIssuer = false). + // The domain MUST be present in state to validate the token issuer. + // If it's missing, we cannot verify the token came from the expected authority. + context.Fail( + "Token validation failed: the resolved domain was not found in the authentication state. " + + "In multi-domain mode, the domain must be stored in state during the authorization request " + + "to validate the token issuer on callback."); + } + else + { + var tokenIssuer = context.SecurityToken.Issuer; + var expectedAuthority = Utils.ToAuthority(expectedIssuer); + + var ok = tokenIssuer.Equals(expectedAuthority, StringComparison.OrdinalIgnoreCase) || + tokenIssuer.Equals(expectedAuthority + "/", StringComparison.OrdinalIgnoreCase); + + if (!ok) + { + context.Fail( + $"Token issuer '{tokenIssuer}' does not match expected issuer '{expectedAuthority}'."); + } + } + } + return Task.CompletedTask; }; } - private static Func CreateOnAuthorizationCodeReceived(Auth0WebAppOptions auth0Options) + private static Func CreateOnAuthorizationCodeReceived( + Auth0WebAppOptions auth0Options) { - return (context) => + return async (context) => { if (auth0Options.ClientAssertionSecurityKey != null) { - context.TokenEndpointRequest?.SetParameter("client_assertion", new JwtTokenFactory(auth0Options.ClientAssertionSecurityKey, auth0Options.ClientAssertionSecurityKeyAlgorithm ?? SecurityAlgorithms.RsaSha256) - .GenerateToken(auth0Options.ClientId, $"https://{auth0Options.Domain}/", auth0Options.ClientId - )); + var issuer = context.Properties?.Items != null && + context.Properties.Items.TryGetValue(Auth0Constants.ResolvedDomainKey, + out var storedIssuer) + ? storedIssuer + : null; - context.TokenEndpointRequest?.SetParameter("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); - } + if (string.IsNullOrWhiteSpace(issuer)) + { + var resolvedDomain = context.HttpContext.GetResolvedDomain(); + var customDomainsOptions = context.HttpContext.RequestServices + .GetService>() + ?.Get(context.Scheme.Name); + if (string.IsNullOrWhiteSpace(resolvedDomain) && customDomainsOptions?.DomainResolver != null) + { + resolvedDomain = await customDomainsOptions.DomainResolver(context.HttpContext) + .ConfigureAwait(false); + } - return Task.CompletedTask; + resolvedDomain ??= auth0Options.Domain; + issuer = $"https://{resolvedDomain}/"; + } + + var audience = Utils.ToAuthority(issuer) + "/"; + context.TokenEndpointRequest?.SetParameter("client_assertion", + new JwtTokenFactory(auth0Options.ClientAssertionSecurityKey, + auth0Options.ClientAssertionSecurityKeyAlgorithm ?? SecurityAlgorithms.RsaSha256) + .GenerateToken(auth0Options.ClientId, audience, auth0Options.ClientId + )); + + context.TokenEndpointRequest?.SetParameter("client_assertion_type", + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); + } }; } - private static IDictionary GetAuthorizeParameters(Auth0WebAppOptions auth0Options, IDictionary authSessionItems) + private static IDictionary GetAuthorizeParameters(Auth0WebAppOptions auth0Options, + IDictionary authSessionItems) { var parameters = new Dictionary(); @@ -190,13 +301,13 @@ private static Func CreateOnAuthorizatio { var parameters = new Dictionary(); - foreach (var (key, value) in authSessionItems.Where(item => item.Key.StartsWith($"{Auth0AuthenticationParameters.Prefix}:"))) + foreach (var (key, value) in authSessionItems.Where(item => + item.Key.StartsWith($"{Auth0AuthenticationParameters.Prefix}:"))) { parameters[key.Replace($"{Auth0AuthenticationParameters.Prefix}:", "")] = value; } return parameters; } - } -} +} \ No newline at end of file diff --git a/src/Auth0.AspNetCore.Authentication/TokenClient.cs b/src/Auth0.AspNetCore.Authentication/TokenClient.cs index 93a2d32..a3316fb 100644 --- a/src/Auth0.AspNetCore.Authentication/TokenClient.cs +++ b/src/Auth0.AspNetCore.Authentication/TokenClient.cs @@ -22,7 +22,7 @@ public TokenClient(HttpClient httpClient) _httpClient = httpClient; } - public async Task Refresh(Auth0WebAppOptions options, string refreshToken) + public async Task Refresh(Auth0WebAppOptions options, string refreshToken, string? domain = null) { var body = new Dictionary { { "grant_type", "refresh_token" }, @@ -30,11 +30,21 @@ public TokenClient(HttpClient httpClient) { "refresh_token", refreshToken } }; - ApplyClientAuthentication(options, body); + // Use provided domain for dynamic resolution, fallback to options.Domain + var tokenEndpointDomain = domain ?? options.Domain; + + if (string.IsNullOrWhiteSpace(tokenEndpointDomain)) + { + throw new InvalidOperationException( + "Cannot determine domain for token endpoint. " + + "Ensure Domain is set or domain resolution is properly configured."); + } + + ApplyClientAuthentication(options, body, tokenEndpointDomain); var requestContent = new FormUrlEncodedContent(body.Select(p => new KeyValuePair(p.Key, p.Value ?? ""))); - using (var request = new HttpRequestMessage(HttpMethod.Post, $"https://{options.Domain}/oauth/token") { Content = requestContent }) + using (var request = new HttpRequestMessage(HttpMethod.Post, $"https://{tokenEndpointDomain}/oauth/token") { Content = requestContent }) { using (var response = await _httpClient.SendAsync(request).ConfigureAwait(false)) { @@ -50,12 +60,12 @@ public TokenClient(HttpClient httpClient) } } - private void ApplyClientAuthentication(Auth0WebAppOptions options, Dictionary body) + private void ApplyClientAuthentication(Auth0WebAppOptions options, Dictionary body, string domain) { if (options.ClientAssertionSecurityKey != null) { body.Add("client_assertion", new JwtTokenFactory(options.ClientAssertionSecurityKey, options.ClientAssertionSecurityKeyAlgorithm ?? SecurityAlgorithms.RsaSha256) - .GenerateToken(options.ClientId, $"https://{options.Domain}/", options.ClientId + .GenerateToken(options.ClientId, $"https://{domain}/", options.ClientId )); body.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); diff --git a/src/Auth0.AspNetCore.Authentication/Utils.cs b/src/Auth0.AspNetCore.Authentication/Utils.cs index 5a1b881..44383a6 100644 --- a/src/Auth0.AspNetCore.Authentication/Utils.cs +++ b/src/Auth0.AspNetCore.Authentication/Utils.cs @@ -45,5 +45,25 @@ public static Func ProxyEvent(Func newHandler, Func + /// Normalizes the given issuer or authority string to a valid HTTPS authority URL. + /// Trims whitespace and trailing slashes. If the input already starts with "http://" or "https://", + /// it is returned as-is (after trimming). Otherwise, "https://" is prepended. + /// + /// The issuer or authority string to normalize. + /// A normalized authority URL string. + internal static string ToAuthority(string issuerOrAuthority) + { + var normalized = issuerOrAuthority.Trim().TrimEnd('/'); + + if (!normalized.StartsWith("http://", StringComparison.OrdinalIgnoreCase) && + !normalized.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + normalized = $"https://{normalized}"; + } + + return normalized + "/"; + } } } diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainStartupFilterTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainStartupFilterTests.cs new file mode 100644 index 0000000..afd7527 --- /dev/null +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainStartupFilterTests.cs @@ -0,0 +1,333 @@ +using System; +using System.Threading.Tasks; +using Auth0.AspNetCore.Authentication.CustomDomains; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Auth0.AspNetCore.Authentication.IntegrationTests; + +public class Auth0CustomDomainStartupFilterTests +{ + [Fact] + public Task Configure_WithValidDomainResolver_StoresResolvedIssuerInHttpContext() + { + var auth0SchemeName = "Auth0"; + var expectedIssuer = "https://custom.domain.com"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => Task.FromResult(expectedIssuer) + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var nextCalled = false; + RequestDelegate next = _ => + { + nextCalled = true; + return Task.CompletedTask; + }; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => middleware(next)(httpContext).Wait()); + + configureAction(appBuilder.Object); + + Assert.Equal(expectedIssuer, httpContext.Items[Auth0Constants.ResolvedDomainKey]); + Assert.True(nextCalled); + return Task.CompletedTask; + } + + [Fact] + public Task Configure_WithNullDomainResolver_DoesNotStoreIssuerInHttpContext() + { + var auth0SchemeName = "Auth0"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = null + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var nextCalled = false; + RequestDelegate next = _ => + { + nextCalled = true; + return Task.CompletedTask; + }; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => middleware(next)(httpContext).Wait()); + + configureAction(appBuilder.Object); + + Assert.False(httpContext.Items.ContainsKey(Auth0Constants.ResolvedDomainKey)); + Assert.True(nextCalled); + return Task.CompletedTask; + } + + [Fact] + public async Task Configure_WithDomainResolverReturningNull_ThrowsInvalidOperationException() + { + var auth0SchemeName = "Auth0"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => Task.FromResult(null) + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + RequestDelegate next = _ => Task.CompletedTask; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + Func capturedMiddleware = null; + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => capturedMiddleware = middleware); + + configureAction(appBuilder.Object); + + await Assert.ThrowsAsync( + async () => await capturedMiddleware(next)(httpContext)); + } + + [Fact] + public async Task Configure_WithDomainResolverReturningEmptyString_ThrowsInvalidOperationException() + { + var auth0SchemeName = "Auth0"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => Task.FromResult(string.Empty) + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + RequestDelegate next = _ => Task.CompletedTask; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + Func capturedMiddleware = null; + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => capturedMiddleware = middleware); + + configureAction(appBuilder.Object); + + await Assert.ThrowsAsync( + async () => await capturedMiddleware(next)(httpContext)); + } + + [Fact] + public async Task Configure_WithDomainResolverReturningWhitespace_ThrowsInvalidOperationException() + { + var auth0SchemeName = "Auth0"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => Task.FromResult(" ") + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + RequestDelegate next = _ => Task.CompletedTask; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + Func capturedMiddleware = null; + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => capturedMiddleware = middleware); + + configureAction(appBuilder.Object); + + await Assert.ThrowsAsync( + async () => await capturedMiddleware(next)(httpContext)); + } + + [Fact] + public Task Configure_InvokesNextMiddlewareConfiguration() + { + var auth0SchemeName = "Auth0"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var nextActionCalled = false; + Action nextAction = _ => { nextActionCalled = true; }; + + var configureAction = startupFilter.Configure(nextAction); + var appBuilder = new Mock(); + appBuilder.Setup(a => a.Use(It.IsAny>())); + + configureAction(appBuilder.Object); + + Assert.True(nextActionCalled); + return Task.CompletedTask; + } + + [Fact] + public Task Configure_WithDifferentSchemeName_UsesCorrectScheme() + { + var auth0SchemeName = "CustomAuth0Scheme"; + var expectedIssuer = "https://custom.auth0.com"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => Task.FromResult(expectedIssuer) + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + RequestDelegate next = _ => Task.CompletedTask; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => middleware(next)(httpContext).Wait()); + + configureAction(appBuilder.Object); + + optionsMonitorMock.Verify(m => m.Get(auth0SchemeName), Times.Once); + return Task.CompletedTask; + } + + [Fact] + public Task Configure_WithDomainResolverAccessingHttpContext_PassesCorrectContext() + { + var auth0SchemeName = "Auth0"; + var expectedPath = "/test-path"; + var capturedPath = string.Empty; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + httpContext.Request.Path = expectedPath; + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = ctx => + { + capturedPath = ctx.Request.Path; + return Task.FromResult("https://domain.com"); + } + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + RequestDelegate next = _ => Task.CompletedTask; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => middleware(next)(httpContext).Wait()); + + configureAction(appBuilder.Object); + + Assert.Equal(expectedPath, capturedPath); + return Task.CompletedTask; + } + + [Fact] + public async Task Configure_WithDomainResolverThrowing_PropagatesException() + { + var auth0SchemeName = "Auth0"; + var expectedException = new InvalidOperationException("Domain resolution failed"); + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => throw expectedException + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + serviceCollection.AddSingleton(optionsMonitorMock.Object); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + RequestDelegate next = _ => Task.CompletedTask; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + Func capturedMiddleware = null; + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => capturedMiddleware = middleware); + + configureAction(appBuilder.Object); + + var exception = await Assert.ThrowsAsync( + async () => await capturedMiddleware(next)(httpContext)); + Assert.Equal("Domain resolution failed", exception.Message); + } + + [Fact] + public async Task Configure_WithConcurrentRequestsWithDifferentDomains_IsolatesResolvedDomains() + { + var auth0SchemeName = "Auth0"; + var startupFilter = new Auth0CustomDomainStartupFilter(auth0SchemeName); + + // Create two HTTP contexts with different host headers + var httpContext1 = new DefaultHttpContext(); + httpContext1.Request.Host = new HostString("tenant1.example.com"); + var serviceCollection1 = new ServiceCollection(); + + var httpContext2 = new DefaultHttpContext(); + httpContext2.Request.Host = new HostString("tenant2.example.com"); + var serviceCollection2 = new ServiceCollection(); + + // Domain resolver that uses the host to determine the domain + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = ctx => Task.FromResult($"https://{ctx.Request.Host.Host}.auth0.com") + }; + + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(auth0SchemeName)).Returns(customDomainsOptions); + + serviceCollection1.AddSingleton(optionsMonitorMock.Object); + serviceCollection2.AddSingleton(optionsMonitorMock.Object); + + httpContext1.RequestServices = serviceCollection1.BuildServiceProvider(); + httpContext2.RequestServices = serviceCollection2.BuildServiceProvider(); + + RequestDelegate next = _ => Task.CompletedTask; + + var configureAction = startupFilter.Configure(_ => { }); + var appBuilder = new Mock(); + Func capturedMiddleware = null; + appBuilder.Setup(a => a.Use(It.IsAny>())) + .Callback>(middleware => capturedMiddleware = middleware); + + configureAction(appBuilder.Object); + + // Execute both requests concurrently + var task1 = capturedMiddleware(next)(httpContext1); + var task2 = capturedMiddleware(next)(httpContext2); + await Task.WhenAll(task1, task2); + + // Verify each context has its own resolved domain + Assert.Equal("https://tenant1.example.com.auth0.com", httpContext1.Items[Auth0Constants.ResolvedDomainKey]); + Assert.Equal("https://tenant2.example.com.auth0.com", httpContext2.Items[Auth0Constants.ResolvedDomainKey]); + } +} \ No newline at end of file diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainsOpenIdConnectConfigurationManagerTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainsOpenIdConnectConfigurationManagerTests.cs new file mode 100644 index 0000000..62da1c6 --- /dev/null +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainsOpenIdConnectConfigurationManagerTests.cs @@ -0,0 +1,856 @@ +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Auth0.AspNetCore.Authentication.CustomDomains; +using FluentAssertions; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; +using Moq; +using Xunit; + +namespace Auth0.AspNetCore.Authentication.IntegrationTests; + +public class Auth0CustomDomainsOpenIdConnectConfigurationManagerTests +{ + private readonly Mock _httpContextAccessorMock; + private readonly Mock>> _domainResolverMock; + private readonly Mock> _stateDataFormatMock; + private readonly HttpClient _httpClient; + private readonly DefaultHttpContext _httpContext; + + public Auth0CustomDomainsOpenIdConnectConfigurationManagerTests() + { + _httpContextAccessorMock = new Mock(); + _domainResolverMock = new Mock>>(); + _stateDataFormatMock = new Mock>(); + _httpClient = new HttpClient(); + _httpContext = new DefaultHttpContext(); + } + + [Fact] + public void Constructor_WithNullHttpContextAccessor_ThrowsArgumentNullException() + { + Assert.Throws(() => new Auth0CustomDomainsOpenIdConnectConfigurationManager( + null!, + _domainResolverMock.Object, + _stateDataFormatMock.Object, + _httpClient)); + } + + [Fact] + public void Constructor_WithNullDomainResolver_ThrowsArgumentNullException() + { + Assert.Throws(() => new Auth0CustomDomainsOpenIdConnectConfigurationManager( + _httpContextAccessorMock.Object, + null!, + _stateDataFormatMock.Object, + _httpClient)); + } + + [Fact] + public void Constructor_WithNullStateDataFormat_ThrowsArgumentNullException() + { + Assert.Throws(() => new Auth0CustomDomainsOpenIdConnectConfigurationManager( + _httpContextAccessorMock.Object, + _domainResolverMock.Object, + null!, + _httpClient)); + } + + [Fact] + public void Constructor_WithNullHttpClient_ThrowsArgumentNullException() + { + Assert.Throws(() => new Auth0CustomDomainsOpenIdConnectConfigurationManager( + _httpContextAccessorMock.Object, + _domainResolverMock.Object, + _stateDataFormatMock.Object, + null!)); + } + + [Fact] + public async Task GetConfigurationAsync_WithNullHttpContext_ThrowsInvalidOperationException() + { + _httpContextAccessorMock.Setup(x => x.HttpContext).Returns((HttpContext)null!); + var manager = CreateManager(); + + var exception = + await Assert.ThrowsAsync(() => + manager.GetConfigurationAsync(CancellationToken.None)); + + Assert.Contains("HttpContext is not available", exception.Message); + } + + [Fact] + public async Task GetConfigurationAsync_WithValidHttpContext_ReturnsConfiguration() + { + var httpContext = new DefaultHttpContext(); + var httpContextAccessor = new Mock(); + httpContextAccessor.Setup(x => x.HttpContext).Returns(httpContext); + + var domainResolver = new Mock>>(); + domainResolver.Setup(x => x(It.IsAny())).ReturnsAsync("example.auth0.com"); + + var stateDataFormat = new Mock>(); + var httpClient = new HttpClient(); + + var manager = new Auth0CustomDomainsOpenIdConnectConfigurationManager( + httpContextAccessor.Object, + domainResolver.Object, + stateDataFormat.Object, + httpClient); + + var configuration = await manager.GetConfigurationAsync(CancellationToken.None); + + Assert.NotNull(configuration); + } + + [Fact] + public async Task GetConfigurationAsync_WithSameDomainMultipleTimes_ReusesCachedConfigurationManager() + { + var httpContext = new DefaultHttpContext(); + var httpContextAccessor = new Mock(); + httpContextAccessor.Setup(x => x.HttpContext).Returns(httpContext); + + var domainResolver = new Mock>>(); + domainResolver.Setup(x => x(It.IsAny())).ReturnsAsync("example.auth0.com"); + + var stateDataFormat = new Mock>(); + var httpClient = new HttpClient(); + + var manager = new Auth0CustomDomainsOpenIdConnectConfigurationManager( + httpContextAccessor.Object, + domainResolver.Object, + stateDataFormat.Object, + httpClient); + + var config1= await manager.GetConfigurationAsync(CancellationToken.None); + var config2 = await manager.GetConfigurationAsync(CancellationToken.None); + + domainResolver.Verify(x => x(It.IsAny()), Times.Once); + config1.Issuer.Should().NotBeNullOrWhiteSpace(); + config2.Issuer.Should().NotBeNullOrWhiteSpace(); + + config1.Issuer?.Should().Be(config2.Issuer); + } + + [Fact] + public async Task GetConfigurationAsync_WithDifferentDomains_CreatesSeparateConfigurationManagers() + { + var httpContext = new DefaultHttpContext(); + var httpContextAccessor = new Mock(); + httpContextAccessor.Setup(x => x.HttpContext).Returns(httpContext); + + var callCount = 0; + var domainResolver = new Mock>>(); + domainResolver.Setup(x => x(It.IsAny())) + .ReturnsAsync(() => callCount++ == 0 ? "domain1.auth0.com" : "domain2.auth0.com"); + + var stateDataFormat = new Mock>(); + var httpClient = new HttpClient(); + + var manager = new Auth0CustomDomainsOpenIdConnectConfigurationManager( + httpContextAccessor.Object, + domainResolver.Object, + stateDataFormat.Object, + httpClient); + + httpContext.Items.Clear(); + var config1 = await manager.GetConfigurationAsync(CancellationToken.None); + + httpContext.Items.Clear(); + var config2 = await manager.GetConfigurationAsync(CancellationToken.None); + + Assert.NotNull(config1); + Assert.True(config1.Issuer?.Contains("domain1.auth0.com")); + Assert.NotNull(config2); + Assert.True(config2.Issuer?.Contains("domain2.auth0.com")); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithValidDomainResolver_ReturnsDomain() + { + var expectedDomain = "tenant.auth0.com"; + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync(expectedDomain); + var manager = CreateManager(); + + var authority = await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal($"https://{expectedDomain}/", authority); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithNullDomain_ThrowsInvalidOperationException() + { + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync((string)null!); + var manager = CreateManager(); + + var exception = + await Assert.ThrowsAsync(() => manager.ResolveAuthorityAsync(_httpContext)); + + Assert.Contains("DomainResolver returned a null or empty value", exception.Message); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithEmptyDomain_ThrowsInvalidOperationException() + { + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync(string.Empty); + var manager = CreateManager(); + + var exception = + await Assert.ThrowsAsync(() => manager.ResolveAuthorityAsync(_httpContext)); + + Assert.Contains("DomainResolver returned a null or empty value", exception.Message); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithWhitespaceDomain_ThrowsInvalidOperationException() + { + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync(" "); + var manager = CreateManager(); + + var exception = + await Assert.ThrowsAsync(() => manager.ResolveAuthorityAsync(_httpContext)); + + Assert.Contains("DomainResolver returned a null or empty value", exception.Message); + } + + [Fact] + public async Task ResolveAuthorityAsync_CachesDomainInHttpContextItems() + { + var expectedDomain = "tenant.auth0.com"; + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync(expectedDomain); + var manager = CreateManager(); + + await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal(expectedDomain, _httpContext.Items[Auth0Constants.ResolvedDomainKey]); + } + + [Fact] + public async Task ResolveAuthorityAsync_UsesCachedDomainFromHttpContextItems() + { + var cachedDomain = "cached.auth0.com"; + _httpContext.Items[Auth0Constants.ResolvedDomainKey] = cachedDomain; + var manager = CreateManager(); + + var authority = await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal($"https://{cachedDomain}/", authority); + _domainResolverMock.Verify(x => x(It.IsAny()), Times.Never); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithStateParameter_ExtractsIssuerFromState() + { + var issuer = "state-domain.auth0.com"; + var state = "protected-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + var props = new AuthenticationProperties(); + props.Items[Auth0Constants.ResolvedDomainKey] = issuer; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + + var manager = CreateManager(); + + var authority = await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal($"https://{issuer}/", authority); + _domainResolverMock.Verify(x => x(It.IsAny()), Times.Never); + } + + [Fact] + public void TryGetState_WithQueryStringState_ReturnsTrue() + { + var expectedState = "test-state"; + _httpContext.Request.QueryString = new QueryString($"?state={expectedState}"); + + var result = Auth0CustomDomainsOpenIdConnectConfigurationManager.TryGetState(_httpContext, out var state); + + Assert.True(result); + Assert.Equal(expectedState, state); + } + + [Fact] + public void TryGetState_WithFormState_ReturnsTrue() + { + var expectedState = "form-state"; + _httpContext.Request.ContentType = "application/x-www-form-urlencoded"; + var formCollection = new FormCollection(new Dictionary + { + { "state", expectedState } + }); + _httpContext.Request.Form = formCollection; + + var result = Auth0CustomDomainsOpenIdConnectConfigurationManager.TryGetState(_httpContext, out var state); + + Assert.True(result); + Assert.Equal(expectedState, state); + } + + [Fact] + public void TryGetState_WithNoState_ReturnsFalse() + { + var result = Auth0CustomDomainsOpenIdConnectConfigurationManager.TryGetState(_httpContext, out var state); + + Assert.False(result); + Assert.Null(state); + } + + [Fact] + public void TryGetState_WithEmptyQueryState_ReturnsFalse() + { + _httpContext.Request.QueryString = new QueryString("?state="); + + var result = Auth0CustomDomainsOpenIdConnectConfigurationManager.TryGetState(_httpContext, out var state); + + Assert.False(result); + Assert.Null(state); + } + + [Fact] + public void TryGetState_PrefersQueryStringOverForm() + { + var queryState = "query-state"; + var formState = "form-state"; + _httpContext.Request.QueryString = new QueryString($"?state={queryState}"); + _httpContext.Request.ContentType = "application/x-www-form-urlencoded"; + var formCollection = new FormCollection(new Dictionary + { + { "state", formState } + }); + _httpContext.Request.Form = formCollection; + + var result = Auth0CustomDomainsOpenIdConnectConfigurationManager.TryGetState(_httpContext, out var state); + + Assert.True(result); + Assert.Equal(queryState, state); + } + + [Fact] + public void TryGetIssuerFromState_WithValidState_ReturnsTrue() + { + var expectedIssuer = "tenant.auth0.com"; + var state = "protected-state"; + var props = new AuthenticationProperties(); + props.Items[Auth0Constants.ResolvedDomainKey] = expectedIssuer; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(state, out var issuer); + + Assert.True(result); + Assert.Equal(expectedIssuer, issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithNullState_ReturnsFalse() + { + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(null, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithEmptyState_ReturnsFalse() + { + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(string.Empty, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithWhitespaceState_ReturnsFalse() + { + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(" ", out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithCryptographicException_ReturnsFalse() + { + var state = "invalid-state"; + _stateDataFormatMock.Setup(x => x.Unprotect(state)) + .Throws(); + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(state, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithFormatException_ReturnsFalse() + { + var state = "invalid-state"; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Throws(); + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(state, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithArgumentException_ReturnsFalse() + { + var state = "invalid-state"; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Throws(); + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(state, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithNullProperties_ReturnsFalse() + { + var state = "protected-state"; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns((AuthenticationProperties)null!); + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(state, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithMissingDomainKey_ReturnsFalse() + { + var state = "protected-state"; + var props = new AuthenticationProperties(); + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(state, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void TryGetIssuerFromState_WithEmptyDomainValue_ReturnsFalse() + { + var state = "protected-state"; + var props = new AuthenticationProperties(); + props.Items[Auth0Constants.ResolvedDomainKey] = string.Empty; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + var manager = CreateManager(); + + var result = manager.TryGetIssuerFromState(state, out var issuer); + + Assert.False(result); + Assert.Empty(issuer); + } + + [Fact] + public void CreateConfigurationManager_WithHttpsAddress_RequiresHttps() + { + var address = "https://tenant.auth0.com/.well-known/openid-configuration"; + var manager = CreateManager(); + + var configManager = manager.CreateConfigurationManager(address); + + Assert.NotNull(configManager); + } + + [Fact] + public void CreateConfigurationManager_WithHttpAddress_DoesNotRequireHttps() + { + var address = "http://localhost/.well-known/openid-configuration"; + var manager = CreateManager(); + + var configManager = manager.CreateConfigurationManager(address); + + Assert.NotNull(configManager); + } + + [Fact] + public void RequestRefresh_CallsRequestRefreshOnAllCachedManagers() + { + _httpContextAccessorMock.Setup(x => x.HttpContext).Returns(_httpContext); + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync("tenant1.auth0.com"); + var manager = CreateManager(); + + manager.RequestRefresh(); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithTrailingSlashInDomain_TrimsSlash() + { + var domainWithSlash = "tenant.auth0.com/"; + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync(domainWithSlash); + var manager = CreateManager(); + + var authority = await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal("https://tenant.auth0.com/", authority); + } + + private Auth0CustomDomainsOpenIdConnectConfigurationManager CreateManager(IConfigurationManagerCache? cache = null) + { + return new Auth0CustomDomainsOpenIdConnectConfigurationManager( + _httpContextAccessorMock.Object, + _domainResolverMock.Object, + _stateDataFormatMock.Object, + _httpClient, + cache); + } + + [Fact] + public void Constructor_WithNullCache_CreatesDefaultCache() + { + var manager = CreateManager(cache: null); + + Assert.NotNull(manager); + } + + [Fact] + public void Constructor_WithCustomCache_UsesProvidedCache() + { + var customCache = new MemoryConfigurationManagerCache(maxSize: 50); + var manager = CreateManager(cache: customCache); + + Assert.NotNull(manager); + } + + [Fact] + public void Constructor_WithNullConfigurationManagerCache_UsesProvidedCache() + { + var nullCache = new NullConfigurationManagerCache(); + var manager = CreateManager(cache: nullCache); + + Assert.NotNull(manager); + } + + [Fact] + public async Task GetConfigurationAsync_WithNullCache_AlwaysCreatesNewManager() + { + var httpContext = new DefaultHttpContext(); + var httpContextAccessor = new Mock(); + httpContextAccessor.Setup(x => x.HttpContext).Returns(httpContext); + + var domainResolver = new Mock>>(); + domainResolver.Setup(x => x(It.IsAny())).ReturnsAsync("example.auth0.com"); + + var stateDataFormat = new Mock>(); + var httpClient = new HttpClient(); + var nullCache = new NullConfigurationManagerCache(); + + var manager = new Auth0CustomDomainsOpenIdConnectConfigurationManager( + httpContextAccessor.Object, + domainResolver.Object, + stateDataFormat.Object, + httpClient, + nullCache); + + httpContext.Items.Clear(); + var config1 = await manager.GetConfigurationAsync(CancellationToken.None); + + httpContext.Items.Clear(); + var config2 = await manager.GetConfigurationAsync(CancellationToken.None); + + // With NullConfigurationManagerCache, the domain resolver should be called each time + // since no caching is performed + domainResolver.Verify(x => x(It.IsAny()), Times.Exactly(2)); + Assert.NotNull(config1); + Assert.NotNull(config2); + } + + [Fact] + public async Task GetConfigurationAsync_WithCustomMemoryCache_ReusesCachedManager() + { + var httpContext = new DefaultHttpContext(); + var httpContextAccessor = new Mock(); + httpContextAccessor.Setup(x => x.HttpContext).Returns(httpContext); + + var domainResolver = new Mock>>(); + domainResolver.Setup(x => x(It.IsAny())).ReturnsAsync("example.auth0.com"); + + var stateDataFormat = new Mock>(); + var httpClient = new HttpClient(); + var memoryCache = new MemoryConfigurationManagerCache(maxSize: 10); + + var manager = new Auth0CustomDomainsOpenIdConnectConfigurationManager( + httpContextAccessor.Object, + domainResolver.Object, + stateDataFormat.Object, + httpClient, + memoryCache); + + var config1 = await manager.GetConfigurationAsync(CancellationToken.None); + var config2 = await manager.GetConfigurationAsync(CancellationToken.None); + + // With MemoryConfigurationManagerCache, the domain resolver should only be called once + // due to caching + domainResolver.Verify(x => x(It.IsAny()), Times.Once); + Assert.NotNull(config1); + Assert.NotNull(config2); + } + + [Fact] + public void Dispose_WithOwnedCache_DisposesCache() + { + // When no cache is provided, the manager creates and owns its own cache + var manager = CreateManager(cache: null); + + var exception = Record.Exception(() => manager.Dispose()); + + Assert.Null(exception); + } + + [Fact] + public void Dispose_WithProvidedCache_DoesNotDisposeCache() + { + // When a cache is provided externally, the manager should not dispose it + var customCache = new MemoryConfigurationManagerCache(maxSize: 50); + var manager = CreateManager(cache: customCache); + + manager.Dispose(); + + // The cache should still be usable after manager disposal + var mockConfigManager = new Mock>(); + var exception = Record.Exception(() => customCache.GetOrCreate("test", _ => mockConfigManager.Object)); + + Assert.Null(exception); + } + + [Fact] + public void Dispose_CanBeCalledMultipleTimes() + { + var manager = CreateManager(); + + var exception = Record.Exception(() => + { + manager.Dispose(); + manager.Dispose(); + manager.Dispose(); + }); + + Assert.Null(exception); + } + + [Fact] + public async Task GetConfigurationAsync_AfterDispose_ThrowsObjectDisposedException() + { + _httpContextAccessorMock.Setup(x => x.HttpContext).Returns(_httpContext); + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync("example.auth0.com"); + var manager = CreateManager(); + + manager.Dispose(); + + await Assert.ThrowsAsync(() => + manager.GetConfigurationAsync(CancellationToken.None)); + } + + [Fact] + public void RequestRefresh_AfterDispose_DoesNotThrow() + { + var manager = CreateManager(); + + manager.Dispose(); + + // RequestRefresh gracefully handles disposal by returning early + var exception = Record.Exception(() => manager.RequestRefresh()); + + Assert.Null(exception); + } + + [Fact] + public async Task GetConfigurationAsync_WithMemoryCacheAndSlidingExpiration_UsesCache() + { + var httpContext = new DefaultHttpContext(); + var httpContextAccessor = new Mock(); + httpContextAccessor.Setup(x => x.HttpContext).Returns(httpContext); + + var domainResolver = new Mock>>(); + domainResolver.Setup(x => x(It.IsAny())).ReturnsAsync("example.auth0.com"); + + var stateDataFormat = new Mock>(); + var httpClient = new HttpClient(); + var memoryCache = new MemoryConfigurationManagerCache(maxSize: 10, slidingExpiration: TimeSpan.FromHours(1)); + + var manager = new Auth0CustomDomainsOpenIdConnectConfigurationManager( + httpContextAccessor.Object, + domainResolver.Object, + stateDataFormat.Object, + httpClient, + memoryCache); + + var config = await manager.GetConfigurationAsync(CancellationToken.None); + + Assert.NotNull(config); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithStateButNoIssuerInState_ThrowsInvalidOperationException() + { + // State parameter is present but doesn't contain the resolved domain key + var state = "protected-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + var props = new AuthenticationProperties(); // No ResolvedDomainKey + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + + var manager = CreateManager(); + + var exception = await Assert.ThrowsAsync( + () => manager.ResolveAuthorityAsync(_httpContext)); + + Assert.Contains("state", exception.Message); + Assert.Contains("resolved domain could not be extracted", exception.Message); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithStateButCorruptedState_ThrowsInvalidOperationException() + { + // State parameter is present but decryption throws (tampered state) + var state = "corrupted-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + _stateDataFormatMock.Setup(x => x.Unprotect(state)) + .Throws(); + + var manager = CreateManager(); + + var exception = await Assert.ThrowsAsync( + () => manager.ResolveAuthorityAsync(_httpContext)); + + Assert.Contains("state", exception.Message); + Assert.Contains("resolved domain could not be extracted", exception.Message); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithStateButNullProperties_ThrowsInvalidOperationException() + { + // State parameter is present but Unprotect returns null + var state = "protected-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns((AuthenticationProperties)null!); + + var manager = CreateManager(); + + var exception = await Assert.ThrowsAsync( + () => manager.ResolveAuthorityAsync(_httpContext)); + + Assert.Contains("state", exception.Message); + } + + [Fact] + public async Task ResolveAuthorityAsync_DoesNotFallBackToDomainResolver_WhenStatePresent() + { + // Verify that when state is present but issuer extraction fails, + // the domain resolver is NOT called (preventing domain swap attacks) + var state = "protected-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns((AuthenticationProperties)null!); + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync("attacker-domain.auth0.com"); + + var manager = CreateManager(); + + await Assert.ThrowsAsync( + () => manager.ResolveAuthorityAsync(_httpContext)); + + _domainResolverMock.Verify(x => x(It.IsAny()), Times.Never); + } + + [Fact] + public async Task ResolveAuthorityAsync_CrossValidation_MatchingDomains_Succeeds() + { + var issuer = "tenant.auth0.com"; + var state = "protected-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + // StartupFilter already resolved the same domain + _httpContext.Items[Auth0Constants.ResolvedDomainKey] = issuer; + + var props = new AuthenticationProperties(); + props.Items[Auth0Constants.ResolvedDomainKey] = issuer; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + + var manager = CreateManager(); + + var authority = await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal($"https://{issuer}/", authority); + } + + [Fact] + public async Task ResolveAuthorityAsync_CrossValidation_MismatchedDomains_ThrowsInvalidOperationException() + { + var stateIssuer = "tenant-a.auth0.com"; + var middlewareIssuer = "tenant-b.auth0.com"; + var state = "protected-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + // StartupFilter resolved a different domain than what's in state + _httpContext.Items[Auth0Constants.ResolvedDomainKey] = middlewareIssuer; + + var props = new AuthenticationProperties(); + props.Items[Auth0Constants.ResolvedDomainKey] = stateIssuer; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + + var manager = CreateManager(); + + var exception = await Assert.ThrowsAsync( + () => manager.ResolveAuthorityAsync(_httpContext)); + + Assert.Contains("Domain mismatch", exception.Message); + Assert.Contains("tenant-a.auth0.com", exception.Message); + Assert.Contains("tenant-b.auth0.com", exception.Message); + } + + [Fact] + public async Task ResolveAuthorityAsync_CrossValidation_NoMiddlewareDomain_StillSucceeds() + { + // When the StartupFilter hasn't cached a domain yet (e.g., no middleware ran), + // the cross-validation is skipped and the state domain is trusted. + var issuer = "tenant.auth0.com"; + var state = "protected-state"; + _httpContext.Request.QueryString = new QueryString($"?state={state}"); + + var props = new AuthenticationProperties(); + props.Items[Auth0Constants.ResolvedDomainKey] = issuer; + _stateDataFormatMock.Setup(x => x.Unprotect(state)).Returns(props); + + var manager = CreateManager(); + + var authority = await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal($"https://{issuer}/", authority); + } + + [Fact] + public async Task ResolveAuthorityAsync_WithoutState_StillUsesResolver() + { + // Non-callback requests (no state parameter) should still work via DomainResolver + var expectedDomain = "tenant.auth0.com"; + _domainResolverMock.Setup(x => x(_httpContext)).ReturnsAsync(expectedDomain); + var manager = CreateManager(); + + var authority = await manager.ResolveAuthorityAsync(_httpContext); + + Assert.Equal($"https://{expectedDomain}/", authority); + _domainResolverMock.Verify(x => x(_httpContext), Times.Once); + } +} \ No newline at end of file diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainsOpenIdConnectPostConfigureOptionsTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainsOpenIdConnectPostConfigureOptionsTests.cs new file mode 100644 index 0000000..313336f --- /dev/null +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Auth0CustomDomainsOpenIdConnectPostConfigureOptionsTests.cs @@ -0,0 +1,414 @@ +using System; +using System.Net.Http; +using System.Threading.Tasks; +using Auth0.AspNetCore.Authentication.CustomDomains; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Authentication.OpenIdConnect; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Auth0.AspNetCore.Authentication.IntegrationTests; + +public class Auth0CustomDomainsOpenIdConnectPostConfigureOptionsTests +{ + [Fact] + public void Constructor_WithNullHttpContextAccessor_ThrowsArgumentNullException() + { + var auth0CustomDomainsOptionsMonitor = new Mock>(); + + var exception = Assert.Throws(() => + new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + null!, + auth0CustomDomainsOptionsMonitor.Object)); + + Assert.Equal("httpContextAccessor", exception.ParamName); + } + + [Fact] + public void Constructor_WithNullAuth0CustomDomainsOptionsMonitor_ThrowsArgumentNullException() + { + var httpContextAccessor = new Mock(); + + var exception = Assert.Throws(() => + new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + null!)); + + Assert.Equal("auth0CustomDomainsOptionsMonitor", exception.ParamName); + } + + [Fact] + public void Constructor_WithValidParameters_DoesNotThrow() + { + var httpContextAccessor = new Mock(); + var auth0CustomDomainsOptionsMonitor = new Mock>(); + + var exception = Record.Exception(() => + new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object)); + + Assert.Null(exception); + } + + [Fact] + public void Constructor_WithHttpClientFactory_DoesNotThrow() + { + var httpContextAccessor = new Mock(); + var auth0CustomDomainsOptionsMonitor = new Mock>(); + var httpClientFactory = new Mock(); + + var exception = Record.Exception(() => + new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object, + httpClientFactory.Object)); + + Assert.Null(exception); + } + + [Fact] + public void PostConfigure_WithNullName_DoesNotModifyOptions() + { + var httpContextAccessor = new Mock(); + var auth0CustomDomainsOptionsMonitor = new Mock>(); + var options = new OpenIdConnectOptions(); + var originalConfigurationManager = options.ConfigurationManager; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure(null, options); + + Assert.Equal(originalConfigurationManager, options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WithEmptyName_DoesNotModifyOptions() + { + var httpContextAccessor = new Mock(); + var auth0CustomDomainsOptionsMonitor = new Mock>(); + var options = new OpenIdConnectOptions(); + var originalConfigurationManager = options.ConfigurationManager; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure(string.Empty, options); + + Assert.Equal(originalConfigurationManager, options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WhenMultipleCustomDomainsDisabled_DoesNotModifyOptions() + { + var httpContextAccessor = new Mock(); + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions(); + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions(); + var originalConfigurationManager = options.ConfigurationManager; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + Assert.Equal(originalConfigurationManager, options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WhenStateDataFormatIsNull_ThrowsInvalidOperationException() + { + var httpContextAccessor = new Mock(); + + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = null + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + var exception = Assert.Throws(() => + postConfigureOptions.PostConfigure("TestScheme", options)); + + Assert.Contains("StateDataFormat is not configured", exception.Message); + } + + [Fact] + public void PostConfigure_WithValidConfiguration_SetsConfigurationManager() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object, + Backchannel = new HttpClient() + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + Assert.IsType(options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WithValidConfiguration_DisablesIssuerValidation() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object, + Backchannel = new HttpClient() + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + Assert.False(options.TokenValidationParameters.ValidateIssuer); + } + + [Fact] + public void PostConfigure_WithValidConfiguration_ClearsAuthority() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object, + Authority = "https://example.auth0.com", + Backchannel = new HttpClient() + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + Assert.Null(options.Authority); + } + + [Fact] + public void PostConfigure_WithHttpClientFactory_UsesFactoryClient() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + var httpClient = new HttpClient(); + var httpClientFactory = new Mock(); + httpClientFactory.Setup(f => f.CreateClient(string.Empty)).Returns(httpClient); + + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object, + httpClientFactory.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + httpClientFactory.Verify(f => f.CreateClient(string.Empty), Times.Once); + Assert.NotNull(options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WithBackchannel_UsesBackchannel() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + var backchannel = new HttpClient(); + var httpClientFactory = new Mock(); + + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object, + Backchannel = backchannel + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object, + httpClientFactory.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + httpClientFactory.Verify(f => f.CreateClient(It.IsAny()), Times.Never); + Assert.NotNull(options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WithoutBackchannelOrFactory_ThrowsInvalidOperationException() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + var exception = Assert.Throws(() => + postConfigureOptions.PostConfigure("TestScheme", options)); + + Assert.Contains("Either OpenIdConnectOptions.Backchannel or IHttpClientFactory must be configured", exception.Message); + } + + [Fact] + public void PostConfigure_WithCustomCache_UsesProvidedCache() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + var customCache = new MemoryConfigurationManagerCache(maxSize: 50); + + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + ConfigurationManagerCache = customCache + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object, + Backchannel = new HttpClient() + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + Assert.IsType(options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WithNullCache_UsesDefaultCache() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + ConfigurationManagerCache = null + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object, + Backchannel = new HttpClient() + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + Assert.IsType(options.ConfigurationManager); + } + + [Fact] + public void PostConfigure_WithNullConfigurationManagerCache_DisablesCaching() + { + var httpContextAccessor = new Mock(); + var stateDataFormat = new Mock>(); + + var auth0CustomDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = context => Task.FromResult(null), + ConfigurationManagerCache = new NullConfigurationManagerCache() + }; + var auth0CustomDomainsOptionsMonitor = new Mock>(); + auth0CustomDomainsOptionsMonitor.Setup(m => m.Get("TestScheme")).Returns(auth0CustomDomainsOptions); + + var options = new OpenIdConnectOptions + { + StateDataFormat = stateDataFormat.Object, + Backchannel = new HttpClient() + }; + + var postConfigureOptions = new Auth0CustomDomainsOpenIdConnectPostConfigureOptions( + httpContextAccessor.Object, + auth0CustomDomainsOptionsMonitor.Object); + + postConfigureOptions.PostConfigure("TestScheme", options); + + Assert.IsType(options.ConfigurationManager); + } +} \ No newline at end of file diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/BackchannelLogoutTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/BackchannelLogoutTests.cs index 0d4229e..2129a45 100644 --- a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/BackchannelLogoutTests.cs +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/BackchannelLogoutTests.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.IdentityModel.Tokens.Jwt; using System.Linq; @@ -31,16 +32,14 @@ public async Task Should_Return_405_If_Not_Post() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var res = await client.SendAsync($"{TestServerBuilder.Host}/backchannel-logout"); res.StatusCode.Should().Be((HttpStatusCode)405); } - + [Fact] public async Task Should_return_400_when_not_form_urlencoded() { @@ -56,21 +55,19 @@ public async Task Should_return_400_when_not_form_urlencoded() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var message = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); var response = await client.SendAsync(message); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Be("Only application/x-www-form-urlencoded is allowed."); } - + [Fact] public async Task Should_return_400_when_no_logout_token() { @@ -86,23 +83,21 @@ public async Task Should_return_400_when_no_logout_token() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var formData = new Dictionary { { "Foo", "Bar" } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Be("Missing logout_token."); } - + [Fact] public async Task Should_Validate_Signature_On_Backchannel_Logout() { @@ -118,31 +113,29 @@ public async Task Should_Validate_Signature_On_Backchannel_Logout() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") .WithAudience(clientId) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }" ) + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") .SignWithRs256("Auth0.AspNetCore.Authentication.IntegrationTests.jwks2.json") .Build(); - + var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Contain("Signature validation failed."); } - + [Fact] public async Task Should_Validate_Issuer_On_Backchannel_Logout() { @@ -158,29 +151,27 @@ public async Task Should_Validate_Issuer_On_Backchannel_Logout() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://bad_issuer/") .WithAudience(clientId) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }" ) + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") .Build(); var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Contain("Issuer validation failed."); } - + [Fact] public async Task Should_Validate_Audience_On_Backchannel_Logout() { @@ -196,30 +187,28 @@ public async Task Should_Validate_Audience_On_Backchannel_Logout() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") .WithAudience("bad_audience") .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }" ) + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") .Build(); - + var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Contain("Audience validation failed."); } - + [Fact] public async Task Should_Validate_Sid_On_Backchannel_Logout() { @@ -235,29 +224,27 @@ public async Task Should_Validate_Sid_On_Backchannel_Logout() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") .WithAudience(clientId) .Build(); - + var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); - + using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Contain("Session Id (sid) claim must be a string present in the logout token."); } - + [Fact] public async Task Should_Validate_Nonce_On_Backchannel_Logout() { @@ -275,10 +262,8 @@ public async Task Should_Validate_Nonce_On_Backchannel_Logout() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") @@ -286,19 +271,19 @@ public async Task Should_Validate_Nonce_On_Backchannel_Logout() .WithClaim(JwtRegisteredClaimNames.Nonce, nonce) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") .Build(); - + var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Contain("Nonce (nonce) claim must not be present in the logout token."); } - + [Fact] public async Task Should_Validate_Events_When_Missing_On_Backchannel_Logout() { @@ -314,25 +299,23 @@ public async Task Should_Validate_Events_When_Missing_On_Backchannel_Logout() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") .WithAudience(clientId) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") .Build(); - + var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); error.Message.Should().Contain("Events (events) claim must be present in the logout token."); } @@ -352,30 +335,30 @@ public async Task Should_Validate_Events_When_Missing_Property_Backchannel_Logou .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") .WithAudience(clientId) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"foo\": {} }" ) + .WithClaim("events", "{ \"foo\": {} }") .Build(); - + var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + var content = await response.Content.ReadAsStringAsync(); var error = ApiError.Parse(content); - + response.StatusCode.Should().Be((HttpStatusCode)400); - error.Message.Should().Contain("Events (events) claim must contain a 'http://schemas.openid.net/event/backchannel-logout' property in the logout token."); + error.Message.Should() + .Contain( + "Events (events) claim must contain a 'http://schemas.openid.net/event/backchannel-logout' property in the logout token."); } - + [Fact] public async Task Should_Pass_Validation_On_Backchannel_Logout() { @@ -391,23 +374,21 @@ public async Task Should_Pass_Validation_On_Backchannel_Logout() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") .WithAudience(clientId) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }" ) + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") .Build(); - + var formData = new Dictionary { { "logout_token", logoutToken } }; using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - + response.StatusCode.Should().Be((HttpStatusCode)200); } @@ -430,10 +411,8 @@ public async Task Should_Logout_And_Clear_Cookie() .Build()) .Build(); - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, true, false, null, true); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, true, false, null, true); using var client = server.CreateClient(); var loginResponse = (await client.SendAsync($"{TestServerBuilder.Host}/{TestServerBuilder.Login}")); var setCookie = Assert.Single(loginResponse.Headers, h => h.Key == "Set-Cookie"); @@ -449,24 +428,26 @@ public async Task Should_Logout_And_Clear_Cookie() // - Send it to the `/oauth/token` endpoint var state = queryParameters["state"]; - var message = new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Callback}?state={state}&nonce={nonce}&code=123"); + var message = new HttpRequestMessage(HttpMethod.Get, + $"{TestServerBuilder.Host}/{TestServerBuilder.Callback}?state={state}&nonce={nonce}&code=123"); // Pass along the Set-Cookies to ensure `Nonce` and `Correlation` cookies are set. var callbackResponse = (await client.SendAsync(message, setCookie.Value)); var callbackCookies = callbackResponse.Headers.GetValues("Set-Cookie").ToList(); - var protectedMessage = new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); + var protectedMessage = + new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); var protectedResponse = await client.SendAsync(protectedMessage, callbackCookies); - + // Accessing a protected endpoint before logging out should be OK. protectedResponse.StatusCode.Should().Be(HttpStatusCode.OK); protectedResponse.Headers.Location.Should().BeNull(); - + var logoutToken = new JwtTokenBuilder(1) .WithIssuer($"https://{domain}/") .WithAudience(clientId) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }" ) + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") .Build(); var formData = new Dictionary { { "logout_token", logoutToken } }; @@ -474,84 +455,86 @@ public async Task Should_Logout_And_Clear_Cookie() req.Content = new FormUrlEncodedContent(formData); using var response = await client.SendAsync(req); - var protectedMessage2 = new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); + var protectedMessage2 = + new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); var protectedResponse2 = await client.SendAsync(protectedMessage2, callbackCookies); - + // Accessing a protected endpoint after logging out should redirect. protectedResponse2.StatusCode.Should().Be(HttpStatusCode.Found); protectedResponse2.Headers.Location.Should().NotBeNull(); protectedResponse2.Headers.Location!.AbsoluteUri.Should().Contain(TestServerBuilder.Login); } - + [Fact] public async Task Should_Not_Logout_When_Sid_Doesnt_Match() { var nonce = ""; - var configuration = TestConfiguration.GetConfiguration(); - var domain = configuration["Auth0:Domain"]; - var clientId = configuration["Auth0:ClientId"]; - var mockHandler = new OidcMockBuilder() - .MockOpenIdConfig() - .MockJwks() - .MockToken(() => new JwtTokenBuilder(1) - .WithIssuer($"https://{domain}/") - .WithAudience(clientId) - // ReSharper disable once AccessToModifiedClosure - .WithClaim(JwtRegisteredClaimNames.Nonce, nonce) - .WithClaim(JwtRegisteredClaimNames.Sid, "sid2") - .Build()) - .Build(); - - using var server = TestServerBuilder.CreateServer(opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, true, false, null, true); - using var client = server.CreateClient(); - var loginResponse = (await client.SendAsync($"{TestServerBuilder.Host}/{TestServerBuilder.Login}")); - var setCookie = Assert.Single(loginResponse.Headers, h => h.Key == "Set-Cookie"); + var configuration = TestConfiguration.GetConfiguration(); + var domain = configuration["Auth0:Domain"]; + var clientId = configuration["Auth0:ClientId"]; + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .MockToken(() => new JwtTokenBuilder(1) + .WithIssuer($"https://{domain}/") + .WithAudience(clientId) + // ReSharper disable once AccessToModifiedClosure + .WithClaim(JwtRegisteredClaimNames.Nonce, nonce) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid2") + .Build()) + .Build(); - var queryParameters = UriUtils.GetQueryParams(loginResponse.Headers.Location); + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, true, false, null, true); + using var client = server.CreateClient(); + var loginResponse = (await client.SendAsync($"{TestServerBuilder.Host}/{TestServerBuilder.Login}")); + var setCookie = Assert.Single(loginResponse.Headers, h => h.Key == "Set-Cookie"); - // Keep track of the nonce as we need to: - // - Send it to the `/oauth/token` endpoint - // - Include it in the generated ID Token - nonce = queryParameters["nonce"]; + var queryParameters = UriUtils.GetQueryParams(loginResponse.Headers.Location); - // Keep track of the state as we need to: - // - Send it to the `/oauth/token` endpoint - var state = queryParameters["state"]; + // Keep track of the nonce as we need to: + // - Send it to the `/oauth/token` endpoint + // - Include it in the generated ID Token + nonce = queryParameters["nonce"]; - var message = new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Callback}?state={state}&nonce={nonce}&code=123"); + // Keep track of the state as we need to: + // - Send it to the `/oauth/token` endpoint + var state = queryParameters["state"]; - // Pass along the Set-Cookies to ensure `Nonce` and `Correlation` cookies are set. - var callbackResponse = (await client.SendAsync(message, setCookie.Value)); - var callbackCookies = callbackResponse.Headers.GetValues("Set-Cookie").ToList(); + var message = new HttpRequestMessage(HttpMethod.Get, + $"{TestServerBuilder.Host}/{TestServerBuilder.Callback}?state={state}&nonce={nonce}&code=123"); - var protectedMessage = new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); - var protectedResponse = await client.SendAsync(protectedMessage, callbackCookies); + // Pass along the Set-Cookies to ensure `Nonce` and `Correlation` cookies are set. + var callbackResponse = (await client.SendAsync(message, setCookie.Value)); + var callbackCookies = callbackResponse.Headers.GetValues("Set-Cookie").ToList(); - // Accessing a protected endpoint before logging out should be OK. - protectedResponse.StatusCode.Should().Be(HttpStatusCode.OK); - protectedResponse.Headers.Location.Should().BeNull(); + var protectedMessage = + new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); + var protectedResponse = await client.SendAsync(protectedMessage, callbackCookies); - var logoutToken = new JwtTokenBuilder(1) - .WithIssuer($"https://{domain}/") - .WithAudience(clientId) - .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }" ) - .Build(); - - var formData = new Dictionary { { "logout_token", logoutToken } }; - using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); - req.Content = new FormUrlEncodedContent(formData); - using var response = await client.SendAsync(req); - - var protectedMessage2 = new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); - var protectedResponse2 = await client.SendAsync(protectedMessage2, callbackCookies); - - // Accessing a protected endpoint after logging out should be still be OK when the SID didn't match. - protectedResponse2.StatusCode.Should().Be(HttpStatusCode.OK); - protectedResponse2.Headers.Location.Should().BeNull(); + // Accessing a protected endpoint before logging out should be OK. + protectedResponse.StatusCode.Should().Be(HttpStatusCode.OK); + protectedResponse.Headers.Location.Should().BeNull(); + + var logoutToken = new JwtTokenBuilder(1) + .WithIssuer($"https://{domain}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData = new Dictionary { { "logout_token", logoutToken } }; + using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req.Content = new FormUrlEncodedContent(formData); + using var response = await client.SendAsync(req); + + var protectedMessage2 = + new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); + var protectedResponse2 = await client.SendAsync(protectedMessage2, callbackCookies); + + // Accessing a protected endpoint after logging out should be still be OK when the SID didn't match. + protectedResponse2.StatusCode.Should().Be(HttpStatusCode.OK); + protectedResponse2.Headers.Location.Should().BeNull(); } [Fact] @@ -572,10 +555,8 @@ public async Task Should_Support_Custom_Authentication_Scheme() .Build(); // Create a server with a custom authentication scheme - using var server = TestServerBuilder.CreateServerWithCustomScheme(customScheme, opt => - { - opt.Backchannel = new HttpClient(mockHandler.Object); - }, null, false, false, false, null, true); + using var server = TestServerBuilder.CreateServerWithCustomScheme(customScheme, + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); using var client = server.CreateClient(); // Create a valid logout token @@ -583,7 +564,7 @@ public async Task Should_Support_Custom_Authentication_Scheme() .WithIssuer($"https://{domain}/") .WithAudience(clientId) .WithClaim(JwtRegisteredClaimNames.Sid, "sid") - .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }" ) + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") .Build(); var formData = new Dictionary { { "logout_token", logoutToken } }; @@ -593,4 +574,372 @@ public async Task Should_Support_Custom_Authentication_Scheme() response.StatusCode.Should().Be((HttpStatusCode)200); } + + [Fact] + public async Task Should_Reject_Tokens_From_Different_Issuers_Multiple_Custom_Domains() + { + var configuration = TestConfiguration.GetConfiguration(); + var domain1 = "tenant1.auth0.com"; + var domain2 = "tenant2.auth0.com"; + var clientId = configuration["Auth0:ClientId"]; + + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .MockToken(() => new JwtTokenBuilder(1) + .WithIssuer($"https://{domain1}/") + .WithAudience(clientId) + .Build()) + .Build(); + + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, null, false, false, false, null, true); + using var client = server.CreateClient(); + + // Test logout token from domain1 + var logoutToken1 = new JwtTokenBuilder(1) + .WithIssuer($"https://{domain1}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid1") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData1 = new Dictionary { { "logout_token", logoutToken1 } }; + using var req1 = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req1.Content = new FormUrlEncodedContent(formData1); + using var response1 = await client.SendAsync(req1); + + // Should fail because issuer doesn't match the configured domain + response1.StatusCode.Should().Be((HttpStatusCode)400); + var content1 = await response1.Content.ReadAsStringAsync(); + var error1 = ApiError.Parse(content1); + error1.Message.Should().Contain("Issuer validation failed"); + + // Test logout token from domain2 + var logoutToken2 = new JwtTokenBuilder(1) + .WithIssuer($"https://{domain2}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid2") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData2 = new Dictionary { { "logout_token", logoutToken2 } }; + using var req2 = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req2.Content = new FormUrlEncodedContent(formData2); + using var response2 = await client.SendAsync(req2); + + // Should also fail because issuer doesn't match the configured domain + response2.StatusCode.Should().Be((HttpStatusCode)400); + var content2 = await response2.Content.ReadAsStringAsync(); + var error2 = ApiError.Parse(content2); + error2.Message.Should().Contain("Issuer validation failed"); + } + + [Fact] + public async Task Should_Support_Backchannel_Logout_With_Multiple_Custom_Domains() + { + var configuration = TestConfiguration.GetConfiguration(); + var domain1 = "tenant1.auth0.com"; + var domain2 = "tenant2.auth0.com"; + var clientId = configuration["Auth0:ClientId"]; + + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .MockToken(() => new JwtTokenBuilder(1) + .WithIssuer($"https://{domain1}/") + .WithAudience(clientId) + .Build()) + .Build(); + + // MCD-enabled server: DomainResolver always returns domain1 for this test + using var server = TestServerBuilder.CreateServer( + opt => + { + opt.Domain = domain1; + opt.Backchannel = new HttpClient(mockHandler.Object); + }, + null, false, false, false, null, + enableBackchannelLogout: true, + configureCustomDomains: opt => + { + opt.DomainResolver = _ => Task.FromResult(domain1); + }); + using var client = server.CreateClient(); + + // Logout token from domain1 — matches the resolved domain, should succeed + var logoutToken1 = new JwtTokenBuilder(1) + .WithIssuer($"https://{domain1}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid1") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData1 = new Dictionary { { "logout_token", logoutToken1 } }; + using var req1 = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req1.Content = new FormUrlEncodedContent(formData1); + using var response1 = await client.SendAsync(req1); + + response1.StatusCode.Should().Be((HttpStatusCode)200); + + // Logout token from domain2 — does NOT match the resolved domain, should be rejected early + var logoutToken2 = new JwtTokenBuilder(1) + .WithIssuer($"https://{domain2}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid2") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData2 = new Dictionary { { "logout_token", logoutToken2 } }; + using var req2 = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req2.Content = new FormUrlEncodedContent(formData2); + using var response2 = await client.SendAsync(req2); + + response2.StatusCode.Should().Be((HttpStatusCode)400); + var content2 = await response2.Content.ReadAsStringAsync(); + var error2 = ApiError.Parse(content2); + error2.Message.Should().Contain("Logout token issuer does not match the resolved domain"); + } + + [Fact] + public async Task Should_Not_Affect_Single_Domain_When_MCD_Not_Enabled() + { + var configuration = TestConfiguration.GetConfiguration(); + var domain = configuration["Auth0:Domain"]; + var clientId = configuration["Auth0:ClientId"]; + + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .MockToken(() => new JwtTokenBuilder(1) + .WithIssuer($"https://{domain}/") + .WithAudience(clientId) + .Build()) + .Build(); + + // No configureCustomDomains — standard single-domain setup + using var server = TestServerBuilder.CreateServer( + opt => { opt.Backchannel = new HttpClient(mockHandler.Object); }, + null, false, false, false, null, true); + using var client = server.CreateClient(); + + var logoutToken = new JwtTokenBuilder(1) + .WithIssuer($"https://{domain}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData = new Dictionary { { "logout_token", logoutToken } }; + using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req.Content = new FormUrlEncodedContent(formData); + using var response = await client.SendAsync(req); + + response.StatusCode.Should().Be((HttpStatusCode)200); + } + + [Fact] + public async Task Should_Return_400_When_Logout_Token_Issuer_Does_Not_Match_Resolved_Domain() + { + var configuration = TestConfiguration.GetConfiguration(); + var resolvedDomain = "tenant1.auth0.com"; + var otherDomain = "tenant2.auth0.com"; + var clientId = configuration["Auth0:ClientId"]; + + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .Build(); + + // DomainResolver returns resolvedDomain, but token will be from otherDomain + using var server = TestServerBuilder.CreateServer( + opt => + { + opt.Domain = resolvedDomain; + opt.Backchannel = new HttpClient(mockHandler.Object); + }, + null, false, false, false, null, + enableBackchannelLogout: true, + configureCustomDomains: opt => + { + opt.DomainResolver = _ => Task.FromResult(resolvedDomain); + }); + using var client = server.CreateClient(); + + // Token from otherDomain — issuer mismatch should be caught before JWT validation + var logoutToken = new JwtTokenBuilder(1) + .WithIssuer($"https://{otherDomain}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData = new Dictionary { { "logout_token", logoutToken } }; + using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req.Content = new FormUrlEncodedContent(formData); + using var response = await client.SendAsync(req); + + response.StatusCode.Should().Be((HttpStatusCode)400); + var content = await response.Content.ReadAsStringAsync(); + var error = ApiError.Parse(content); + error.Message.Should().Contain("Logout token issuer does not match the resolved domain"); + } + + [Fact] + public async Task Should_Logout_And_Clear_Cookie_With_Multiple_Custom_Domains() + { + var nonce = ""; + var configuration = TestConfiguration.GetConfiguration(); + var domain = "tenant1.auth0.com"; + var clientId = configuration["Auth0:ClientId"]; + + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .MockToken(() => new JwtTokenBuilder(1) + .WithIssuer($"https://{domain}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid") + .WithClaim(JwtRegisteredClaimNames.Nonce, nonce) + .Build()) + .Build(); + + using var server = TestServerBuilder.CreateServer( + opt => + { + opt.Domain = domain; + opt.Backchannel = new HttpClient(mockHandler.Object); + }, + null, false, true, false, null, + enableBackchannelLogout: true, + configureCustomDomains: opt => + { + opt.DomainResolver = _ => Task.FromResult(domain); + }); + using var client = server.CreateClient(); + + var loginResponse = await client.SendAsync($"{TestServerBuilder.Host}/{TestServerBuilder.Login}"); + var setCookie = Assert.Single(loginResponse.Headers, h => h.Key == "Set-Cookie"); + + var queryParameters = UriUtils.GetQueryParams(loginResponse.Headers.Location); + nonce = queryParameters["nonce"]; + var state = queryParameters["state"]; + + var callbackMessage = new HttpRequestMessage(HttpMethod.Get, + $"{TestServerBuilder.Host}/{TestServerBuilder.Callback}?state={state}&nonce={nonce}&code=123"); + var callbackResponse = await client.SendAsync(callbackMessage, setCookie.Value); + var callbackCookies = callbackResponse.Headers.GetValues("Set-Cookie").ToList(); + + var protectedMessage = new HttpRequestMessage(HttpMethod.Get, + $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); + var protectedResponse = await client.SendAsync(protectedMessage, callbackCookies); + + // Accessing a protected endpoint before logging out should be OK. + protectedResponse.StatusCode.Should().Be(HttpStatusCode.OK); + protectedResponse.Headers.Location.Should().BeNull(); + + var logoutToken = new JwtTokenBuilder(1) + .WithIssuer($"https://{domain}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData = new Dictionary { { "logout_token", logoutToken } }; + using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req.Content = new FormUrlEncodedContent(formData); + using var response = await client.SendAsync(req); + + response.StatusCode.Should().Be((HttpStatusCode)200); + + var protectedMessage2 = new HttpRequestMessage(HttpMethod.Get, + $"{TestServerBuilder.Host}/{TestServerBuilder.Protected}"); + var protectedResponse2 = await client.SendAsync(protectedMessage2, callbackCookies); + + // Accessing a protected endpoint after backchannel logout should redirect to login. + protectedResponse2.StatusCode.Should().Be(HttpStatusCode.Found); + protectedResponse2.Headers.Location.Should().NotBeNull(); + protectedResponse2.Headers.Location!.AbsoluteUri.Should().Contain(TestServerBuilder.Login); + } + + [Fact] + public async Task Should_Return_400_When_Logout_Token_Is_Malformed_And_MCD_Enabled() + { + var configuration = TestConfiguration.GetConfiguration(); + var resolvedDomain = "tenant1.auth0.com"; + var clientId = configuration["Auth0:ClientId"]; + + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .Build(); + + using var server = TestServerBuilder.CreateServer( + opt => + { + opt.Domain = resolvedDomain; + opt.Backchannel = new HttpClient(mockHandler.Object); + }, + null, false, false, false, null, + enableBackchannelLogout: true, + configureCustomDomains: opt => + { + opt.DomainResolver = _ => Task.FromResult(resolvedDomain); + }); + using var client = server.CreateClient(); + + var formData = new Dictionary { { "logout_token", "not.a.valid.jwt" } }; + using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req.Content = new FormUrlEncodedContent(formData); + using var response = await client.SendAsync(req); + + response.StatusCode.Should().Be((HttpStatusCode)400); + var content = await response.Content.ReadAsStringAsync(); + var error = ApiError.Parse(content); + error.Message.Should().Contain("Logout token is malformed or not a valid JWT"); + } + + [Fact] + public async Task Should_Return_500_When_Resolved_Domain_Not_Available_And_MCD_Enabled() + { + var configuration = TestConfiguration.GetConfiguration(); + var resolvedDomain = "tenant1.auth0.com"; + var clientId = configuration["Auth0:ClientId"]; + + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .Build(); + + using var server = TestServerBuilder.CreateServer( + opt => + { + opt.Domain = resolvedDomain; + opt.Backchannel = new HttpClient(mockHandler.Object); + }, + null, false, false, false, null, + enableBackchannelLogout: true, + configureCustomDomains: opt => + { + opt.DomainResolver = _ => Task.FromResult(null); + }); + using var client = server.CreateClient(); + + var logoutToken = new JwtTokenBuilder(1) + .WithIssuer($"https://{resolvedDomain}/") + .WithAudience(clientId) + .WithClaim(JwtRegisteredClaimNames.Sid, "sid") + .WithClaim("events", "{ \"http://schemas.openid.net/event/backchannel-logout\": {} }") + .Build(); + + var formData = new Dictionary { { "logout_token", logoutToken } }; + using var req = new HttpRequestMessage(HttpMethod.Post, $"{TestServerBuilder.Host}/backchannel-logout"); + req.Content = new FormUrlEncodedContent(formData); + + // DomainResolver returning null causes the startup filter to throw InvalidOperationException + // before the backchannel logout handler is reached. The test host surfaces this as an exception. + var act = async () => await client.SendAsync(req); + await act.Should().ThrowAsync() + .WithMessage("DomainResolver returned empty issuer."); + } } \ No newline at end of file diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Infrastructure/TestServerBuilder.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Infrastructure/TestServerBuilder.cs index 80291e6..78b1fb1 100644 --- a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Infrastructure/TestServerBuilder.cs +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/Infrastructure/TestServerBuilder.cs @@ -1,6 +1,7 @@ using System; using System.Text.Json; using Auth0.AspNetCore.Authentication.BackchannelLogout; +using Auth0.AspNetCore.Authentication.CustomDomains; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.AspNetCore.Builder; @@ -32,7 +33,7 @@ internal class TestServerBuilder /// Indicated whether or not the authenitcation should be mocked, useful because some tests require an authenticated user while others require no user to exist. /// Optional custom authentication scheme to use. /// The created TestServer instance. - public static TestServer CreateServer(Action configureOptions = null, Action configureWithAccessTokensOptions = null, bool mockAuthentication = false, bool useServiceCollectionExtension = false, bool addExtraProvider = false, Action configureAdditionalOptions = null, bool enableBackchannelLogout = false, string authenticationScheme = null) + public static TestServer CreateServer(Action configureOptions = null, Action configureWithAccessTokensOptions = null, bool mockAuthentication = false, bool useServiceCollectionExtension = false, bool addExtraProvider = false, Action configureAdditionalOptions = null, bool enableBackchannelLogout = false, string authenticationScheme = null, Action configureCustomDomains = null) { var configuration = TestConfiguration.GetConfiguration(); var host = new HostBuilder() @@ -140,12 +141,17 @@ await res.WriteAsync(JsonSerializer.Serialize(new { builder.WithAccessToken(configureWithAccessTokensOptions); } - + if (enableBackchannelLogout) { builder.WithBackchannelLogout(); } + if (configureCustomDomains != null) + { + builder.WithCustomDomains(configureCustomDomains); + } + services.AddControllersWithViews(); }) .ConfigureTestServices(services => diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/MemoryConfigurationManagerCacheTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/MemoryConfigurationManagerCacheTests.cs new file mode 100644 index 0000000..2de7fe6 --- /dev/null +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/MemoryConfigurationManagerCacheTests.cs @@ -0,0 +1,169 @@ +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Auth0.AspNetCore.Authentication.CustomDomains; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; +using Moq; +using Xunit; + +namespace Auth0.AspNetCore.Authentication.IntegrationTests; + +public class MemoryConfigurationManagerCacheTests +{ + [Theory] + [InlineData(0)] + [InlineData(-1)] + public void Constructor_WithInvalidMaxSize_ThrowsArgumentOutOfRangeException(int maxSize) + { + var exception = Assert.Throws(() => + new MemoryConfigurationManagerCache(maxSize: maxSize)); + + Assert.Equal("maxSize", exception.ParamName); + } + + [Fact] + public void GetOrCreate_WithValidParameters_ReturnsConfigurationManager() + { + var cache = new MemoryConfigurationManagerCache(); + var mockManager = new Mock>(); + + var result = cache.GetOrCreate( + "https://example.auth0.com/.well-known/openid-configuration", + _ => mockManager.Object); + + Assert.Same(mockManager.Object, result); + } + + [Fact] + public void GetOrCreate_CalledTwiceWithSameKey_ReturnsSameInstance() + { + var cache = new MemoryConfigurationManagerCache(); + var callCount = 0; + var mockManager = new Mock>(); + + Func> factory = _ => + { + callCount++; + return mockManager.Object; + }; + + var metadataAddress = "https://example.auth0.com/.well-known/openid-configuration"; + var result1 = cache.GetOrCreate(metadataAddress, factory); + var result2 = cache.GetOrCreate(metadataAddress, factory); + + Assert.Same(result1, result2); + // Factory should only be called once + Assert.Equal(1, callCount); + } + + [Fact] + public void GetOrCreate_CalledWithDifferentKeys_ReturnsDifferentInstances() + { + var cache = new MemoryConfigurationManagerCache(); + var mockManager1 = new Mock>(); + var mockManager2 = new Mock>(); + var callCount = 0; + + Func> factory = address => + { + callCount++; + return address.Contains("domain1") ? mockManager1.Object : mockManager2.Object; + }; + + var result1 = cache.GetOrCreate("https://domain1.auth0.com/.well-known/openid-configuration", factory); + var result2 = cache.GetOrCreate("https://domain2.auth0.com/.well-known/openid-configuration", factory); + + Assert.NotSame(result1, result2); + Assert.Equal(2, callCount); + } + + [Fact] + public void GetOrCreate_AfterDispose_ThrowsObjectDisposedException() + { + var cache = new MemoryConfigurationManagerCache(); + var mockManager = new Mock>(); + + cache.Dispose(); + + Assert.Throws(() => + cache.GetOrCreate("https://example.auth0.com/.well-known/openid-configuration", _ => mockManager.Object)); + } + + [Fact] + public void Clear_DoesNotThrow() + { + var cache = new MemoryConfigurationManagerCache(); + var mockManager = new Mock>(); + + // Add an entry + cache.GetOrCreate("https://example.auth0.com/.well-known/openid-configuration", _ => mockManager.Object); + + // Should not throw + var exception = Record.Exception(() => cache.Clear()); + Assert.Null(exception); + } + + [Fact] + public void Clear_AfterDispose_DoesNotThrow() + { + var cache = new MemoryConfigurationManagerCache(); + cache.Dispose(); + + // Should not throw even after disposal + var exception = Record.Exception(() => cache.Clear()); + Assert.Null(exception); + } + + [Fact] + public void Dispose_CanBeCalledMultipleTimes() + { + var cache = new MemoryConfigurationManagerCache(); + + var exception = Record.Exception(() => + { + cache.Dispose(); + cache.Dispose(); + cache.Dispose(); + }); + + Assert.Null(exception); + } + + [Fact] + public void DefaultMaxSize_IsOneHundred() + { + Assert.Equal(100, MemoryConfigurationManagerCache.DefaultMaxSize); + } + + [Fact] + public async Task GetOrCreate_ConcurrentCallsWithSameKey_InvokesFactoryOnlyOnce() + { + // Arrange + var cache = new MemoryConfigurationManagerCache(); + var factoryCallCount = 0; + var mockManager = new Mock>(); + + Func> factory = _ => + { + Interlocked.Increment(ref factoryCallCount); + // Simulate slow factory to increase chance of race condition + Thread.Sleep(500); + return mockManager.Object; + }; + + var metadataAddress = "https://example.auth0.com/.well-known/openid-configuration"; + + // Simulates concurrent calls + var tasks = Enumerable.Range(0, 100) + .Select(_ => Task.Run(() => cache.GetOrCreate(metadataAddress, factory))) + .ToArray(); + + var results = await Task.WhenAll(tasks); + + // Assert - Factory should only be invoked once + Assert.Equal(1, factoryCallCount); + Assert.All(results, result => Assert.Same(mockManager.Object, result)); + } +} diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/NullConfigurationManagerCacheTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/NullConfigurationManagerCacheTests.cs new file mode 100644 index 0000000..1222b7e --- /dev/null +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/NullConfigurationManagerCacheTests.cs @@ -0,0 +1,138 @@ +using System; +using Auth0.AspNetCore.Authentication.CustomDomains; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; +using Moq; +using Xunit; + +namespace Auth0.AspNetCore.Authentication.IntegrationTests; + +public class NullConfigurationManagerCacheTests +{ + [Fact] + public void Constructor_CreatesInstance() + { + var cache = new NullConfigurationManagerCache(); + + Assert.NotNull(cache); + } + + [Fact] + public void GetOrCreate_AlwaysInvokesFactory() + { + var cache = new NullConfigurationManagerCache(); + var callCount = 0; + var mockManager = new Mock>(); + + Func> factory = _ => + { + callCount++; + return mockManager.Object; + }; + + var metadataAddress = "https://example.auth0.com/.well-known/openid-configuration"; + cache.GetOrCreate(metadataAddress, factory); + cache.GetOrCreate(metadataAddress, factory); + cache.GetOrCreate(metadataAddress, factory); + + // Factory should be called every time + Assert.Equal(3, callCount); + } + + [Fact] + public void GetOrCreate_ReturnsFactoryResult() + { + var cache = new NullConfigurationManagerCache(); + var mockManager = new Mock>(); + + var result = cache.GetOrCreate( + "https://example.auth0.com/.well-known/openid-configuration", + _ => mockManager.Object); + + Assert.Same(mockManager.Object, result); + } + + [Fact] + public void GetOrCreate_PassesMetadataAddressToFactory() + { + var cache = new NullConfigurationManagerCache(); + var mockManager = new Mock>(); + string? receivedAddress = null; + + var expectedAddress = "https://example.auth0.com/.well-known/openid-configuration"; + cache.GetOrCreate(expectedAddress, address => + { + receivedAddress = address; + return mockManager.Object; + }); + + Assert.Equal(expectedAddress, receivedAddress); + } + + [Fact] + public void Clear_DoesNotThrow() + { + var cache = new NullConfigurationManagerCache(); + + var exception = Record.Exception(() => cache.Clear()); + + Assert.Null(exception); + } + + [Fact] + public void Clear_CanBeCalledMultipleTimes() + { + var cache = new NullConfigurationManagerCache(); + + var exception = Record.Exception(() => + { + cache.Clear(); + cache.Clear(); + cache.Clear(); + }); + + Assert.Null(exception); + } + + [Fact] + public void Dispose_DoesNotThrow() + { + var cache = new NullConfigurationManagerCache(); + + var exception = Record.Exception(() => cache.Dispose()); + + Assert.Null(exception); + } + + [Fact] + public void Dispose_CanBeCalledMultipleTimes() + { + var cache = new NullConfigurationManagerCache(); + + var exception = Record.Exception(() => + { + cache.Dispose(); + cache.Dispose(); + cache.Dispose(); + }); + + Assert.Null(exception); + } + + [Fact] + public void GetOrCreate_AfterDispose_StillWorks() + { + // NullConfigurationManagerCache should still work after Dispose + // because it has no state to dispose + var cache = new NullConfigurationManagerCache(); + var mockManager = new Mock>(); + + cache.Dispose(); + + var result = cache.GetOrCreate( + "https://example.auth0.com/.well-known/openid-configuration", + _ => mockManager.Object); + + Assert.Same(mockManager.Object, result); + } +} diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/OpenIdConnectEventsFactoryTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/OpenIdConnectEventsFactoryTests.cs new file mode 100644 index 0000000..d80add2 --- /dev/null +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/OpenIdConnectEventsFactoryTests.cs @@ -0,0 +1,235 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Auth0.AspNetCore.Authentication.CustomDomains; +using Auth0.AspNetCore.Authentication.IntegrationTests.Builders; +using Auth0.AspNetCore.Authentication.IntegrationTests.Infrastructure; +using FluentAssertions; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Authentication.OpenIdConnect; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Auth0.AspNetCore.Authentication.IntegrationTests; + +/// +/// Tests for the OnRedirectToIdentityProvider event handler in OpenIdConnectEventsFactory, +/// specifically covering the MCD (Multiple Custom Domains) domain-resolution logic. +/// +public class OpenIdConnectEventsFactoryTests +{ + // ------------------------------------------------------------------------- + // Unit-level tests: invoke the OIDC event delegate directly without a full + // TestServer. This lets us precisely control HttpContext.Items. + // ------------------------------------------------------------------------- + + [Fact] + public async Task OnRedirectToIdentityProvider_WhenMcdEnabled_AndDomainMissingFromHttpContextItems_Returns500() + { + // Arrange — MCD enabled but HttpContext.Items has no ResolvedDomainKey + var auth0Options = new Auth0WebAppOptions { Domain = "test.auth0.com", ClientId = "client1" }; + var oidcOptions = new OpenIdConnectOptions(); + + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => Task.FromResult("tenant.custom.com") + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(It.IsAny())).Returns(customDomainsOptions); + + var services = new ServiceCollection(); + services.AddSingleton(optionsMonitorMock.Object); + var serviceProvider = services.BuildServiceProvider(); + + var httpContext = new DefaultHttpContext { RequestServices = serviceProvider }; + // Deliberately do NOT set Auth0Constants.ResolvedDomainKey in httpContext.Items + + var responseBody = new MemoryStream(); + httpContext.Response.Body = responseBody; + + var events = OpenIdConnectEventsFactory.Create(auth0Options, oidcOptions); + + var scheme = new AuthenticationScheme(Auth0Constants.AuthenticationScheme, null, typeof(OpenIdConnectHandler)); + var properties = new AuthenticationProperties(); + var redirectContext = new RedirectContext(httpContext, scheme, oidcOptions, properties) + { + ProtocolMessage = new Microsoft.IdentityModel.Protocols.OpenIdConnect.OpenIdConnectMessage() + }; + + // Act + await events.OnRedirectToIdentityProvider(redirectContext); + + // Assert — response must be 500, redirect must NOT have occurred + httpContext.Response.StatusCode.Should().Be(500); + + responseBody.Seek(0, SeekOrigin.Begin); + var body = await new StreamReader(responseBody).ReadToEndAsync(); + body.Should().Contain("Authentication configuration error"); + body.Should().Contain("could not resolve the domain"); + + // The properties items must NOT contain the resolved domain key (no domain was stored) + properties.Items.Should().NotContainKey(Auth0Constants.ResolvedDomainKey); + } + + [Fact] + public async Task OnRedirectToIdentityProvider_WhenMcdEnabled_AndDomainPresentInHttpContextItems_StoresDomainInState() + { + // Arrange — MCD enabled and HttpContext.Items contains the resolved domain + var auth0Options = new Auth0WebAppOptions { Domain = "test.auth0.com", ClientId = "client1" }; + var oidcOptions = new OpenIdConnectOptions(); + + var customDomainsOptions = new Auth0CustomDomainsOptions + { + DomainResolver = _ => Task.FromResult("tenant.custom.com") + }; + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(It.IsAny())).Returns(customDomainsOptions); + + var services = new ServiceCollection(); + services.AddSingleton(optionsMonitorMock.Object); + var serviceProvider = services.BuildServiceProvider(); + + var httpContext = new DefaultHttpContext { RequestServices = serviceProvider }; + httpContext.Items[Auth0Constants.ResolvedDomainKey] = "tenant.custom.com"; + + var events = OpenIdConnectEventsFactory.Create(auth0Options, oidcOptions); + + var scheme = new AuthenticationScheme(Auth0Constants.AuthenticationScheme, null, typeof(OpenIdConnectHandler)); + var properties = new AuthenticationProperties(); + var redirectContext = new RedirectContext(httpContext, scheme, oidcOptions, properties) + { + ProtocolMessage = new Microsoft.IdentityModel.Protocols.OpenIdConnect.OpenIdConnectMessage() + }; + + // Act + await events.OnRedirectToIdentityProvider(redirectContext); + + // Assert — domain must be stored in state, response must NOT be 500 + httpContext.Response.StatusCode.Should().NotBe(500); + properties.Items.Should().ContainKey(Auth0Constants.ResolvedDomainKey); + properties.Items[Auth0Constants.ResolvedDomainKey].Should().Be("tenant.custom.com"); + } + + [Fact] + public async Task OnRedirectToIdentityProvider_WhenMcdDisabled_DoesNotTouchResolvedDomainKey() + { + // Arrange — MCD disabled (no DomainResolver configured) + var auth0Options = new Auth0WebAppOptions { Domain = "test.auth0.com", ClientId = "client1" }; + var oidcOptions = new OpenIdConnectOptions(); + + var customDomainsOptions = new Auth0CustomDomainsOptions(); // DomainResolver = null → MCD disabled + var optionsMonitorMock = new Mock>(); + optionsMonitorMock.Setup(m => m.Get(It.IsAny())).Returns(customDomainsOptions); + + var services = new ServiceCollection(); + services.AddSingleton(optionsMonitorMock.Object); + var serviceProvider = services.BuildServiceProvider(); + + var httpContext = new DefaultHttpContext { RequestServices = serviceProvider }; + + var events = OpenIdConnectEventsFactory.Create(auth0Options, oidcOptions); + + var scheme = new AuthenticationScheme(Auth0Constants.AuthenticationScheme, null, typeof(OpenIdConnectHandler)); + var properties = new AuthenticationProperties(); + var redirectContext = new RedirectContext(httpContext, scheme, oidcOptions, properties) + { + ProtocolMessage = new Microsoft.IdentityModel.Protocols.OpenIdConnect.OpenIdConnectMessage() + }; + + // Act + await events.OnRedirectToIdentityProvider(redirectContext); + + // Assert — MCD logic not entered; no domain stored, response not affected + httpContext.Response.StatusCode.Should().NotBe(500); + properties.Items.Should().NotContainKey(Auth0Constants.ResolvedDomainKey); + } + + // ------------------------------------------------------------------------- + // Integration-level test: full TestServer with MCD enabled and a working + // DomainResolver + mocked OIDC discovery. + // ------------------------------------------------------------------------- + + [Fact] + public async Task Should_Redirect_To_CustomDomain_Authorize_WhenMcdEnabled() + { + // Arrange — mock OIDC discovery for the custom domain + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .Build(); + + var customDomain = "tenant.custom.com"; + IDictionary capturedItems = null; + + using var server = TestServerBuilder.CreateServer( + configureOptions: opts => + { + opts.Backchannel = new HttpClient(mockHandler.Object); + opts.OpenIdConnectEvents = new OpenIdConnectEvents + { + OnRedirectToIdentityProvider = ctx => + { + // Runs AFTER the SDK's handler, so ResolvedDomainKey should already be set + capturedItems = new Dictionary(ctx.Properties.Items); + return Task.CompletedTask; + } + }; + }, + configureCustomDomains: cdOpts => + { + cdOpts.DomainResolver = _ => Task.FromResult(customDomain); + }); + + using var client = server.CreateClient(); + + // Act + var response = await client.GetAsync($"{TestServerBuilder.Host}/{TestServerBuilder.Login}"); + + // Assert — the MCD flow resolved the domain, fetched OIDC config for it, and redirected to + // the /authorize endpoint returned by the mocked discovery (wellknownconfig.json uses + // "tenant.eu.auth0.com" as the authorization_endpoint host). + response.StatusCode.Should().Be(HttpStatusCode.Redirect); + response.Headers.Location.Should().NotBeNull(); + response.Headers.Location.AbsolutePath.Should().Be("/authorize"); + + // Verify the resolved domain was stored in the authentication state + capturedItems.Should().NotBeNull(); + capturedItems.Should().ContainKey(Auth0Constants.ResolvedDomainKey); + capturedItems[Auth0Constants.ResolvedDomainKey].Should().Be(customDomain); + } + + [Fact] + public async Task Should_Throw_WhenDomainResolver_ReturnsEmpty() + { + // Arrange — DomainResolver returns empty, which should cause the startup filter to fail fast + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .Build(); + + using var server = TestServerBuilder.CreateServer( + configureOptions: opts => + { + opts.Backchannel = new HttpClient(mockHandler.Object); + }, + configureCustomDomains: cdOpts => + { + cdOpts.DomainResolver = _ => Task.FromResult(string.Empty); + }); + + using var client = server.CreateClient(); + + // Act & Assert — the startup filter middleware throws InvalidOperationException + // before the OIDC redirect is ever reached + var ex = await Assert.ThrowsAsync( + () => client.GetAsync($"{TestServerBuilder.Host}/{TestServerBuilder.Login}")); + ex.Message.Should().Contain("DomainResolver returned empty issuer"); + } +} diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenClientTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenClientTests.cs index 4f3f2dc..9446b06 100644 --- a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenClientTests.cs +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenClientTests.cs @@ -30,9 +30,144 @@ public async Task Returns_Null_When_No_Success_StatusCode() var client = new TokenClient(new HttpClient(mockHandler.Object)); - var result = await client.Refresh(new Auth0WebAppOptions { Domain = "local.auth0.com" }, "123"); + var result = await client.Refresh(new Auth0WebAppOptions { Domain = "local.auth0.com", ClientId = "cid", ClientSecret = "secret" }, "123"); result.Should().BeNull(); } + + [Fact] + public async Task Refresh_WithCustomDomain_UsesCorrectTokenEndpoint() + { + var customDomain = "custom.auth0.com"; + var requestedDomain = string.Empty; + + var mockHandler = new Mock(); + mockHandler + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.RequestUri != null && + req.RequestUri.Host == customDomain && + req.RequestUri.AbsolutePath == "/oauth/token" + ), + ItExpr.IsAny() + ) + .Callback((req, _) => + { + if (req.RequestUri != null) + requestedDomain = req.RequestUri.Host; + }) + .ReturnsAsync(new HttpResponseMessage() + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("{\"access_token\":\"new_token\",\"token_type\":\"Bearer\",\"expires_in\":86400}") + }); + + var client = new TokenClient(new HttpClient(mockHandler.Object)); + var result = await client.Refresh( + new Auth0WebAppOptions + { + Domain = "default.auth0.com", + ClientId = "cid", + ClientSecret = "secret" + }, + "refresh_123", + customDomain // Pass custom domain + ); + + result.Should().NotBeNull(); + result?.AccessToken.Should().Be("new_token"); + requestedDomain.Should().Be(customDomain); + } + + [Fact] + public async Task Refresh_WithoutCustomDomain_UsesDefaultDomain() + { + var defaultDomain = "default.auth0.com"; + var requestedDomain = string.Empty; + + var mockHandler = new Mock(); + mockHandler + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.RequestUri != null && + req.RequestUri.Host == defaultDomain && + req.RequestUri.AbsolutePath == "/oauth/token" + ), + ItExpr.IsAny() + ) + .Callback((req, _) => + { + if (req.RequestUri != null) + requestedDomain = req.RequestUri.Host; + }) + .ReturnsAsync(new HttpResponseMessage() + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("{\"access_token\":\"new_token\",\"token_type\":\"Bearer\",\"expires_in\":86400}") + }); + + var client = new TokenClient(new HttpClient(mockHandler.Object)); + var result = await client.Refresh( + new Auth0WebAppOptions + { + Domain = defaultDomain, + ClientId = "cid", + ClientSecret = "secret" + }, + "refresh_123" + // No custom domain passed, should use default + ); + + result.Should().NotBeNull(); + result?.AccessToken.Should().Be("new_token"); + requestedDomain.Should().Be(defaultDomain); + } + + [Fact] + public async Task Refresh_WithNullDomain_ThrowsInvalidOperationException() + { + var mockHandler = new Mock(); + + var client = new TokenClient(new HttpClient(mockHandler.Object)); + + Func act = async () => await client.Refresh( + new Auth0WebAppOptions + { + Domain = null!, // Null domain + ClientId = "cid", + ClientSecret = "secret" + }, + "refresh_123" + ); + + await act.Should().ThrowAsync() + .WithMessage("Cannot determine domain for token endpoint*"); + } + + [Fact] + public async Task Refresh_WithEmptyCustomDomain_ThrowsInvalidOperationException() + { + var mockHandler = new Mock(); + + var client = new TokenClient(new HttpClient(mockHandler.Object)); + + Func act = async () => await client.Refresh( + new Auth0WebAppOptions + { + Domain = "default.auth0.com", + ClientId = "cid", + ClientSecret = "secret" + }, + "refresh_123", + string.Empty // Empty custom domain + ); + + await act.Should().ThrowAsync() + .WithMessage("Cannot determine domain for token endpoint*"); + } } } diff --git a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenValidationTests.cs b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenValidationTests.cs index 6e6c6e4..b0d9c43 100644 --- a/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenValidationTests.cs +++ b/tests/Auth0.AspNetCore.Authentication.IntegrationTests/TokenValidationTests.cs @@ -970,6 +970,58 @@ private string GenerateToken(int userId, string issuer, string audience, string return tokenHandler.WriteToken(token); } + [Fact] + public async Task Should_Throw_When_Custom_Domain_Issuer_Mismatch() + { + var nonce = ""; + var configuration = TestConfiguration.GetConfiguration(); + var domain = configuration["Auth0:Domain"]; + var clientId = configuration["Auth0:ClientId"]; + var customDomain = "custom.auth0.com"; + + // Mock handler expects the default domain but we'll use a different issuer in the token + var mockHandler = new OidcMockBuilder() + .MockOpenIdConfig() + .MockJwks() + .MockToken(() => GenerateToken(1, $"https://{customDomain}/", clientId, nonce, "1"), (me) => me.HasAuth0ClientHeader()) + .Build(); + + using (var server = TestServerBuilder.CreateServer(opt => + { + opt.Backchannel = new HttpClient(mockHandler.Object); + // Server is configured with default domain + })) + { + using (var client = server.CreateClient()) + { + var loginResponse = (await client.SendAsync($"{TestServerBuilder.Host}/{TestServerBuilder.Login}")); + var setCookie = Assert.Single(loginResponse.Headers, h => h.Key == "Set-Cookie"); + + var queryParameters = UriUtils.GetQueryParams(loginResponse.Headers.Location); + + nonce = queryParameters["nonce"]; + var state = queryParameters["state"]; + + var message = new HttpRequestMessage(HttpMethod.Get, $"{TestServerBuilder.Host}/{TestServerBuilder.Callback}?state={state}&nonce={nonce}&code=123"); + + Func act = async () => + { + await client.SendAsync(message, setCookie.Value); + }; + + var innerException = act + .Should() + .ThrowAsync() + .Result + .And.InnerException; + + innerException + .Should() + .BeOfType(); + } + } + } + }