Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ public interface IOperationManager
/// </returns>
bool Enqueue(string sessionId, GraphQLRequest request);

/// <summary>
/// Enqueues a batch of requests for execution with the operation manager.
/// </summary>
/// <param name="sessionId">
/// The operation sessionId given by the client. The sessionId must be unique within the
/// <see cref="ISocketSession"/>.
/// </param>
/// <param name="requests">
/// The GraphQL requests that shall be executed as a batch.
/// </param>
/// <returns>
/// Returns <c>true</c> if the <paramref name="requests"/>
/// were accepted and registered for execution.
/// </returns>
bool EnqueueBatch(string sessionId, GraphQLRequest[] requests);

/// <summary>
/// Completes a request that was previously enqueued with the operation manager.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,11 @@ public interface IOperationSession : IDisposable
/// <param name="request">The graphql request.</param>
/// <param name="cancellationToken">The cancellation token.</param>
void BeginExecute(GraphQLRequest request, CancellationToken cancellationToken);

/// <summary>
/// Starts executing a batch of operations.
/// </summary>
/// <param name="requests">The graphql requests to execute as a batch.</param>
/// <param name="cancellationToken">The cancellation token.</param>
void BeginExecuteBatch(GraphQLRequest[] requests, CancellationToken cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,39 @@ public bool Enqueue(string sessionId, GraphQLRequest request)
return false;
}

/// <inheritdoc />
public bool EnqueueBatch(string sessionId, GraphQLRequest[] requests)
{
ArgumentException.ThrowIfNullOrEmpty(sessionId);
ArgumentNullException.ThrowIfNull(requests);
ObjectDisposedException.ThrowIf(_disposed, this);

IOperationSession? session = null;
_lock.EnterWriteLock();

try
{
if (!_subs.ContainsKey(sessionId))
{
session = _createSession(sessionId);
_subs.Add(sessionId, session);
}
}
finally
{
_lock.ExitWriteLock();
}

if (session is not null)
{
session.Completed += (_, _) => Complete(sessionId);
session.BeginExecuteBatch(requests, _cancellationToken);
return true;
}

return false;
}

/// <inheritdoc />
public bool Complete(string sessionId)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public OperationSession(
public void BeginExecute(GraphQLRequest request, CancellationToken cancellationToken)
=> _ = SendResultsAsync(request, cancellationToken);

public void BeginExecuteBatch(GraphQLRequest[] requests, CancellationToken cancellationToken)
=> _ = SendBatchResultsAsync(requests, cancellationToken);

private async Task SendResultsAsync(GraphQLRequest request, CancellationToken cancellationToken)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _ct);
Expand Down Expand Up @@ -129,6 +132,78 @@ await _session.Protocol.SendErrorMessageAsync(
}
}

private async Task SendBatchResultsAsync(GraphQLRequest[] requests, CancellationToken cancellationToken)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _ct);
var ct = cts.Token;
var completeTry = false;

try
{
var operationRequests = new IOperationRequest[requests.Length];

for (var i = 0; i < requests.Length; i++)
{
var requestBuilder = CreateRequestBuilder(requests[i]);
await _interceptor.OnRequestAsync(_session, Id, requestBuilder, ct);
operationRequests[i] = requestBuilder.Build();
}

var batch = new OperationRequestBatch(operationRequests);
await using var responseStream = await _executorSession.ExecuteBatchAsync(batch, ct);

await foreach (var item in responseStream.ReadResultsAsync().WithCancellation(ct))
{
try
{
// use the original cancellation token here to keep the websocket open for other streams.
await SendResultMessageAsync(item, cancellationToken);
}
finally
{
await item.DisposeAsync();
}
}

completeTry = true;

if (!ct.IsCancellationRequested)
{
await _session.Protocol.SendCompleteMessageAsync(_session, Id, ct);
}
}
catch (OperationCanceledException) when (ct.IsCancellationRequested)
{
// the operation was canceled so we do nothing
}
catch (Exception ex)
{
if (!completeTry)
{
await TrySendErrorMessageAsync(ex, ct);
}
}
finally
{
try
{
await _interceptor.OnCompleteAsync(_session, Id, cancellationToken);
}
catch
{
// we will just ignore any user exceptions here so we can graciously close
// the subscription out.
}

Complete();

foreach (var request in requests)
{
request.Dispose();
}
}
}

private static OperationRequestBuilder CreateRequestBuilder(GraphQLRequest request)
{
var requestBuilder = new OperationRequestBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using HotChocolate.AspNetCore.Formatters;
using HotChocolate.Buffers;
using HotChocolate.Language;
using HotChocolate.AspNetCore.Utilities;
using HotChocolate.Text.Json;
using static HotChocolate.AspNetCore.Subscriptions.Protocols.GraphQLOverWebSocket.MessageProperties;
using static HotChocolate.AspNetCore.Subscriptions.Protocols.MessageUtilities;
Expand All @@ -17,7 +18,8 @@ internal sealed class GraphQLOverWebSocketProtocolHandler(
ISocketSessionInterceptor interceptor,
IWebSocketPayloadFormatter formatter,
IDocumentCache documentCache,
IDocumentHashProvider documentHashProvider)
IDocumentHashProvider documentHashProvider,
ParserOptions parserOptions)
: IGraphQLOverWebSocketProtocolHandler
{
public string Name => GraphQL_Transport_WS;
Expand Down Expand Up @@ -134,13 +136,36 @@ await SendConnectionAcceptMessage(
{
try
{
if (!TryParseSubscribeMessage(root, out var subscribeMessage))
if (!TryParseSubscribeMessage(root, out var subscribeId, out var requests))
{
await connection.CloseInvalidSubscribeMessageAsync(cancellationToken);
return;
}

if (!session.Operations.Enqueue(subscribeMessage.Id, subscribeMessage.Payload))
bool success;

if (requests.Length == 1)
{
success = session.Operations.Enqueue(subscribeId, requests[0]);
}
else
{
var options = session.Connection.Features.Get<GraphQLServerOptions>();

if (options?.Batching.HasFlag(AllowedBatching.RequestBatching) == false)
{
throw new GraphQLRequestException(ErrorHelper.InvalidRequest());
}

if (options?.MaxBatchSize > 0 && requests.Length > options.MaxBatchSize)
{
throw new GraphQLRequestException(ErrorHelper.BatchSizeExceeded(options.MaxBatchSize));
}

success = session.Operations.EnqueueBatch(subscribeId, requests);
}

if (!success)
{
await connection.CloseSubscriptionIdNotUniqueAsync(cancellationToken);
}
Expand Down Expand Up @@ -300,37 +325,41 @@ public ValueTask OnConnectionInitTimeoutAsync(

private bool TryParseSubscribeMessage(
JsonElement messageElement,
[NotNullWhen(true)] out SubscribeMessage? message)
[NotNullWhen(true)] out string? id,
[NotNullWhen(true)] out GraphQLRequest[]? requests)
{
if (!messageElement.TryGetProperty(Id, out var idProp)
|| idProp.ValueKind is not JsonValueKind.String
|| string.IsNullOrEmpty(idProp.GetString()))
{
message = null;
id = null;
requests = null;
return false;
}

if (!messageElement.TryGetProperty(Payload, out var payloadProp)
|| payloadProp.ValueKind is not JsonValueKind.Object)
|| payloadProp.ValueKind is not (JsonValueKind.Object or JsonValueKind.Array))
{
message = null;
id = null;
requests = null;
return false;
}

var id = idProp.GetString()!;
id = idProp.GetString()!;
var requestData = JsonMarshal.GetRawUtf8Value(payloadProp);
var request = Parse(
requests = Parse(
requestData,
cache: documentCache,
hashProvider: documentHashProvider);
parserOptions,
documentCache,
documentHashProvider);

if (request.Length == 0)
if (requests.Length == 0)
{
message = null;
id = null;
requests = null;
return false;
}

message = new SubscribeMessage(id, request[0]);
return true;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ public void Dispose()
public static async Task AcceptAsync(
HttpContext context,
ExecutorSession executorSession,
GraphQLSocketOptions socketOptions)
GraphQLServerOptions serverOptions)
{
using var connection = new WebSocketConnection(context, executorSession);
connection.Features.Set(serverOptions);
using var cts = CancellationTokenSource.CreateLinkedTokenSource(
context.RequestAborted,
connection.ApplicationStopping);
Expand All @@ -59,7 +60,7 @@ public static async Task AcceptAsync(
if (protocol is not null)
{
using var session = new WebSocketSession(connection, protocol, executorSession);
var options = socketOptions;
var options = serverOptions.Sockets;

try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private async Task HandleWebSocketSessionAsync(HttpContext context)
{
try
{
await WebSocketSession.AcceptAsync(context, session, options.Sockets);
await WebSocketSession.AcceptAsync(context, session, options);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ private static IRequestExecutorBuilder AddGraphQLOverWebSocketProtocol(
sp.GetRequiredService<ISocketSessionInterceptor>(),
sp.GetRequiredService<IWebSocketPayloadFormatter>(),
sp.GetRequiredService<IDocumentCache>(),
sp.GetRequiredService<IDocumentHashProvider>())));
sp.GetRequiredService<IDocumentHashProvider>(),
sp.GetRequiredService<ParserOptions>())));

/// <summary>
/// Adds a custom WebSocket payload formatter to the DI.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Immutable;
using System.Text.Json;
using HotChocolate.Transport.Serialization;

Expand All @@ -6,19 +7,36 @@ namespace HotChocolate.Transport;
/// <summary>
/// Represents a GraphQL batch request that can be sent over a WebSocket or HTTP connection.
/// </summary>
/// <param name="requests">
/// A list of operation requests to execute.
/// </param>
public readonly struct OperationBatchRequest(
IReadOnlyList<IOperationRequest> requests)
public readonly struct OperationBatchRequest
: IRequestBody
, IEquatable<OperationBatchRequest>
{
/// <summary>
/// Gets the list of operation requests to execute.
/// </summary>
public IReadOnlyList<IOperationRequest> Requests { get; } =
requests ?? throw new ArgumentNullException(nameof(requests));
public ImmutableArray<IOperationRequest> Requests { get; }

/// <summary>
/// Initializes a new instance of <see cref="OperationBatchRequest"/> with the specified
/// immutable array of operation requests.
/// </summary>
/// <param name="requests">
/// The requests of this batch.
/// </param>
/// <exception cref="ArgumentException">
/// Thrown if <paramref name="requests"/> is default or empty.
/// </exception>
public OperationBatchRequest(ImmutableArray<IOperationRequest> requests)
{
if (requests.IsDefaultOrEmpty)
{
throw new ArgumentException(
"The batch request must contain at least one operation.",
nameof(requests));
}

Requests = requests;
}

/// <summary>
/// Writes the request to the specified <paramref name="writer"/>.
Expand Down Expand Up @@ -50,12 +68,12 @@ public void WriteTo(Utf8JsonWriter writer)
/// </returns>
public bool Equals(OperationBatchRequest other)
{
if (Requests.Count != other.Requests.Count)
if (Requests.Length != other.Requests.Length)
{
return false;
}

for (var i = 0; i < Requests.Count; i++)
for (var i = 0; i < Requests.Length; i++)
{
if (!Requests[i].Equals(other.Requests[i]))
{
Expand Down
Loading
Loading