diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Extensions/EndpointRouteBuilderExtensions.cs b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Extensions/EndpointRouteBuilderExtensions.cs index f4d876322de..36c9138d52e 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Extensions/EndpointRouteBuilderExtensions.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Extensions/EndpointRouteBuilderExtensions.cs @@ -84,8 +84,9 @@ public static GraphQLEndpointConventionBuilder MapGraphQL( var schemaNameOrDefault = schemaName ?? ISchemaDefinition.DefaultName; var pattern = Parse(path + "/{**slug}"); var requestPipeline = endpointRouteBuilder.CreateApplicationBuilder(); + var services = endpointRouteBuilder.ServiceProvider; requestPipeline.MapGraphQL(path, schemaNameOrDefault); - var serverOptions = endpointRouteBuilder.ServiceProvider.GetRequiredService>().Get(schemaNameOrDefault); + var serverOptions = services.GetServerOptions(schemaNameOrDefault); return new GraphQLEndpointConventionBuilder( endpointRouteBuilder @@ -123,19 +124,22 @@ public static IApplicationBuilder MapGraphQL( path = path.ToString().TrimEnd('/'); - var executorProvider = applicationBuilder.ApplicationServices.GetRequiredService(); - var executorEvents = applicationBuilder.ApplicationServices.GetRequiredService(); - var formOptions = applicationBuilder.ApplicationServices.GetRequiredService>(); + var services = applicationBuilder.ApplicationServices; + var executorProvider = services.GetRequiredService(); + var executorEvents = services.GetRequiredService(); + var formOptions = services.GetRequiredService>(); var executor = new HttpRequestExecutorProxy(executorProvider, executorEvents, schemaName); - var serverOptions = applicationBuilder.ApplicationServices.GetRequiredService>().Get(schemaName); + var serverOptions = services.GetServerOptions(schemaName); applicationBuilder .Use(MiddlewareFactory.CreateCancellationMiddleware()) + .Use(MiddlewareFactory.CreateConcurrencyGateMiddleware(serverOptions.MaxConcurrentRequests)) .Use(MiddlewareFactory.CreateWebSocketSubscriptionMiddleware(executor, serverOptions)) .Use(MiddlewareFactory.CreateHttpPostMiddleware(executor, serverOptions)) .Use(MiddlewareFactory.CreateHttpMultipartMiddleware(executor, serverOptions, formOptions)) .Use(MiddlewareFactory.CreateHttpGetMiddleware(executor, serverOptions)) - .Use(MiddlewareFactory.CreateHttpGetSchemaMiddleware(executor, serverOptions, path, MiddlewareRoutingType.Integrated)) + .Use(MiddlewareFactory.CreateHttpGetSchemaMiddleware( + executor, serverOptions, path, MiddlewareRoutingType.Integrated)) .UseNitroApp(path, serverOptions.Tool) .Use(_ => context => { @@ -203,14 +207,16 @@ public static GraphQLHttpEndpointConventionBuilder MapGraphQLHttp( var requestPipeline = endpointRouteBuilder.CreateApplicationBuilder(); var schemaNameOrDefault = schemaName ?? ISchemaDefinition.DefaultName; - var executorProvider = endpointRouteBuilder.ServiceProvider.GetRequiredService(); - var executorEvents = endpointRouteBuilder.ServiceProvider.GetRequiredService(); - var formOptions = endpointRouteBuilder.ServiceProvider.GetRequiredService>(); + var services = endpointRouteBuilder.ServiceProvider; + var executorProvider = services.GetRequiredService(); + var executorEvents = services.GetRequiredService(); + var formOptions = services.GetRequiredService>(); var executor = new HttpRequestExecutorProxy(executorProvider, executorEvents, schemaNameOrDefault); - var serverOptions = endpointRouteBuilder.ServiceProvider.GetRequiredService>().Get(schemaNameOrDefault); + var serverOptions = services.GetServerOptions(schemaNameOrDefault); requestPipeline .Use(MiddlewareFactory.CreateCancellationMiddleware()) + .Use(MiddlewareFactory.CreateConcurrencyGateMiddleware(serverOptions.MaxConcurrentRequests)) .Use(MiddlewareFactory.CreateHttpPostMiddleware(executor, serverOptions)) .Use(MiddlewareFactory.CreateHttpMultipartMiddleware(executor, serverOptions, formOptions)) .Use(MiddlewareFactory.CreateHttpGetMiddleware(executor, serverOptions)) @@ -283,10 +289,11 @@ public static GraphQLWebSocketEndpointConventionBuilder MapGraphQLWebSocket( var requestPipeline = endpointRouteBuilder.CreateApplicationBuilder(); var schemaNameOrDefault = schemaName ?? ISchemaDefinition.DefaultName; - var executorProvider = endpointRouteBuilder.ServiceProvider.GetRequiredService(); - var executorEvents = endpointRouteBuilder.ServiceProvider.GetRequiredService(); + var services = endpointRouteBuilder.ServiceProvider; + var executorProvider = services.GetRequiredService(); + var executorEvents = services.GetRequiredService(); var executor = new HttpRequestExecutorProxy(executorProvider, executorEvents, schemaNameOrDefault); - var serverOptions = endpointRouteBuilder.ServiceProvider.GetRequiredService>().Get(schemaNameOrDefault); + var serverOptions = services.GetServerOptions(schemaNameOrDefault); requestPipeline .Use(MiddlewareFactory.CreateCancellationMiddleware()) @@ -362,14 +369,17 @@ public static IEndpointConventionBuilder MapGraphQLSchema( var requestPipeline = endpointRouteBuilder.CreateApplicationBuilder(); var schemaNameOrDefault = schemaName ?? ISchemaDefinition.DefaultName; - var executorProvider = endpointRouteBuilder.ServiceProvider.GetRequiredService(); - var executorEvents = endpointRouteBuilder.ServiceProvider.GetRequiredService(); + var services = endpointRouteBuilder.ServiceProvider; + var executorProvider = services.GetRequiredService(); + var executorEvents = services.GetRequiredService(); var executor = new HttpRequestExecutorProxy(executorProvider, executorEvents, schemaNameOrDefault); - var serverOptions = endpointRouteBuilder.ServiceProvider.GetRequiredService>().Get(schemaNameOrDefault); + var serverOptions = services.GetServerOptions(schemaNameOrDefault); requestPipeline .Use(MiddlewareFactory.CreateCancellationMiddleware()) - .Use(MiddlewareFactory.CreateHttpGetSchemaMiddleware(executor, serverOptions, PathString.Empty, MiddlewareRoutingType.Explicit)) + .Use(MiddlewareFactory.CreateConcurrencyGateMiddleware(serverOptions.MaxConcurrentRequests)) + .Use(MiddlewareFactory.CreateHttpGetSchemaMiddleware( + executor, serverOptions, PathString.Empty, MiddlewareRoutingType.Explicit)) .Use(_ => context => { context.Response.StatusCode = 404; @@ -647,6 +657,9 @@ public static NitroAppEndpointConventionBuilder WithOptions( return builder; } + private static GraphQLServerOptions GetServerOptions(this IServiceProvider services, string schemaName) + => services.GetRequiredService>().Get(schemaName); + private static void TryResolveSchemaName(IServiceProvider services, ref string? schemaName) { if (schemaName is null diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/MiddlewareFactory.cs b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/MiddlewareFactory.cs index 7c23d6a3785..62a1cb6bf59 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/MiddlewareFactory.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/MiddlewareFactory.cs @@ -29,6 +29,37 @@ internal static Func CreateCancellationMiddlew }; } + internal static Func CreateConcurrencyGateMiddleware( + int? maxConcurrentRequests) + { + if (maxConcurrentRequests is null or <= 0) + { + return next => next; + } + + var semaphore = new SemaphoreSlim(maxConcurrentRequests.Value, maxConcurrentRequests.Value); + + return next => async context => + { + if (context.WebSockets.IsWebSocketRequest) + { + await next(context); + return; + } + + await semaphore.WaitAsync(context.RequestAborted); + + try + { + await next(context); + } + finally + { + semaphore.Release(); + } + }; + } + internal static Func CreateWebSocketSubscriptionMiddleware( HttpRequestExecutorProxy executor, GraphQLServerOptions serverOptions) diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Options/GraphQLServerOptions.cs b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Options/GraphQLServerOptions.cs index 80b7b65b5a1..48754d152fe 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Options/GraphQLServerOptions.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Options/GraphQLServerOptions.cs @@ -64,6 +64,12 @@ public sealed class GraphQLServerOptions /// public int MaxBatchSize { get; set; } = 1024; + /// + /// Gets or sets the maximum number of concurrent GraphQL requests that can be + /// processed simultaneously. A value of null means unlimited. Defaults to 64. + /// + public int? MaxConcurrentRequests { get; set; } = 64; + internal GraphQLServerOptions Clone() => new() { @@ -81,6 +87,7 @@ internal GraphQLServerOptions Clone() EnforceMultipartRequestsPreflightHeader = EnforceMultipartRequestsPreflightHeader, EnableSchemaRequests = EnableSchemaRequests, Batching = Batching, - MaxBatchSize = MaxBatchSize + MaxBatchSize = MaxBatchSize, + MaxConcurrentRequests = MaxConcurrentRequests }; }