diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs index 798ff20b98c..16f5ab98f2c 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore.Pipeline/Subscriptions/Protocols/GraphQLOverWebSocket/GraphQLOverWebSocketProtocolHandler.cs @@ -15,7 +15,9 @@ namespace HotChocolate.AspNetCore.Subscriptions.Protocols.GraphQLOverWebSocket; internal sealed class GraphQLOverWebSocketProtocolHandler( ISocketSessionInterceptor interceptor, - IWebSocketPayloadFormatter formatter) + IWebSocketPayloadFormatter formatter, + IDocumentCache documentCache, + IDocumentHashProvider documentHashProvider) : IGraphQLOverWebSocketProtocolHandler { public string Name => GraphQL_Transport_WS; @@ -296,7 +298,7 @@ public ValueTask OnConnectionInitTimeoutAsync( CancellationToken cancellationToken) => session.Connection.CloseConnectionInitTimeoutAsync(cancellationToken); - private static bool TryParseSubscribeMessage( + private bool TryParseSubscribeMessage( JsonElement messageElement, [NotNullWhen(true)] out SubscribeMessage? message) { @@ -317,7 +319,10 @@ private static bool TryParseSubscribeMessage( var id = idProp.GetString()!; var requestData = JsonMarshal.GetRawUtf8Value(payloadProp); - var request = Parse(requestData); + var request = Parse( + requestData, + cache: documentCache, + hashProvider: documentHashProvider); if (request.Length == 0) { diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore/Extensions/HotChocolateAspNetCoreServiceCollectionExtensions.Subscriptions.cs b/src/HotChocolate/AspNetCore/src/AspNetCore/Extensions/HotChocolateAspNetCoreServiceCollectionExtensions.Subscriptions.cs index 4e25af4f8c7..db0faa580b8 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore/Extensions/HotChocolateAspNetCoreServiceCollectionExtensions.Subscriptions.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore/Extensions/HotChocolateAspNetCoreServiceCollectionExtensions.Subscriptions.cs @@ -5,6 +5,7 @@ using HotChocolate.AspNetCore.Subscriptions.Protocols.Apollo; using HotChocolate.AspNetCore.Subscriptions.Protocols.GraphQLOverWebSocket; using HotChocolate.Execution.Configuration; +using HotChocolate.Language; // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection; @@ -91,7 +92,9 @@ private static IRequestExecutorBuilder AddGraphQLOverWebSocketProtocol( s => s.AddSingleton( sp => new GraphQLOverWebSocketProtocolHandler( sp.GetRequiredService(), - sp.GetRequiredService()))); + sp.GetRequiredService(), + sp.GetRequiredService(), + sp.GetRequiredService()))); /// /// Adds a custom WebSocket payload formatter to the DI. diff --git a/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/GraphQLOverWebSocket/WebSocketProtocolTests.cs b/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/GraphQLOverWebSocket/WebSocketProtocolTests.cs index fd89d3a7107..655d8d1f026 100644 --- a/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/GraphQLOverWebSocket/WebSocketProtocolTests.cs +++ b/src/HotChocolate/AspNetCore/test/AspNetCore.Tests/Subscriptions/GraphQLOverWebSocket/WebSocketProtocolTests.cs @@ -1,4 +1,5 @@ using System.Diagnostics; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using HotChocolate.AspNetCore.Formatters; @@ -6,6 +7,9 @@ using HotChocolate.AspNetCore.Subscriptions.Protocols.GraphQLOverWebSocket; using HotChocolate.AspNetCore.Tests.Utilities; using HotChocolate.AspNetCore.Tests.Utilities.Subscriptions.GraphQLOverWebSocket; +using HotChocolate.Execution; +using HotChocolate.Language; +using HotChocolate.PersistedOperations; using HotChocolate.Subscriptions.Diagnostics; using HotChocolate.Text.Json; using HotChocolate.Transport.Formatters; @@ -284,6 +288,69 @@ await testServer.SendPostRequestAsync( }); } + [Fact] + public Task Subscribe_With_PersistedQuery_Extension_Only_Works() + => TryTest( + async ct => + { + // arrange + var storage = new OperationStorage(); + var hashProvider = new MD5DocumentHashProvider(HashFormat.Base64); + const string query = "subscription { onReview(episode: NEW_HOPE) { stars } }"; + var hash = hashProvider.ComputeHash(Encoding.UTF8.GetBytes(query)).Value; + storage.AddOperation(hash, query); + + using var testServer = CreateStarWarsServer( + configureServices: services => services + .AddGraphQLServer() + .AddMD5DocumentHashProvider(HashFormat.Base64) + .ConfigureSchemaServices(c => c.AddSingleton(storage)), + output: output); + var client = CreateWebSocketClient(testServer); + using var webSocket = await ConnectToServerAsync(client, ct); + + var subscribeMessage = JsonSerializer.Serialize( + new + { + type = "subscribe", + id = "abc", + payload = new + { + extensions = new Dictionary + { + ["persistedQuery"] = new Dictionary + { + ["version"] = 1, + [hashProvider.Name] = hash + } + } + } + }); + + // act + await webSocket.SendMessageAsync(subscribeMessage, ct); + + await testServer.SendPostRequestAsync( + new ClientQueryRequest + { + Query = + """ + mutation { + createReview(episode: NEW_HOPE review: { + commentary: "foo" + stars: 5 + }) { + stars + } + } + """ + }); + + // assert + var message = await WaitForMessage(webSocket, Messages.Next, ct); + Assert.NotNull(message); + }); + [Fact] public Task Subscribe_Id_Not_Unique() { @@ -956,6 +1023,31 @@ await testServer.SendPostRequestAsync( Assert.False(messageOnReview.TryGetProperty("commentary", out _)); }); + private sealed class OperationStorage : IOperationDocumentStorage + { + private readonly Dictionary _cache = + new(StringComparer.Ordinal); + + public ValueTask TryReadAsync( + OperationDocumentId documentId, + CancellationToken cancellationToken = default) + => _cache.TryGetValue(documentId.Value, out var value) + ? new ValueTask(value) + : new ValueTask(default(IOperationDocument)); + + public ValueTask SaveAsync( + OperationDocumentId documentId, + IOperationDocument document, + CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + public void AddOperation(string key, string sourceText) + { + var doc = new OperationDocument(Utf8GraphQLParser.Parse(sourceText)); + _cache.Add(key, doc); + } + } + private class AuthInterceptor : DefaultSocketSessionInterceptor { public override ValueTask OnConnectAsync( diff --git a/src/HotChocolate/Fusion-vnext/src/Fusion.AspNetCore/DependencyInjection/FusionServerServiceCollectionExtensions.cs b/src/HotChocolate/Fusion-vnext/src/Fusion.AspNetCore/DependencyInjection/FusionServerServiceCollectionExtensions.cs index d7dc7fe86e8..c71c7429bde 100644 --- a/src/HotChocolate/Fusion-vnext/src/Fusion.AspNetCore/DependencyInjection/FusionServerServiceCollectionExtensions.cs +++ b/src/HotChocolate/Fusion-vnext/src/Fusion.AspNetCore/DependencyInjection/FusionServerServiceCollectionExtensions.cs @@ -105,5 +105,7 @@ private static IFusionGatewayBuilder AddGraphQLOverWebSocketProtocol( (_, s) => s.AddSingleton( sp => new GraphQLOverWebSocketProtocolHandler( sp.GetRequiredService(), - sp.GetRequiredService()))); + sp.GetRequiredService(), + sp.GetRequiredService(), + sp.GetRequiredService()))); }