Skip to content

Commit fcd6c77

Browse files
authored
server: Rewrite the WebSocket procol handling (#1752)
1 parent b04b286 commit fcd6c77

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1560
-825
lines changed

samples/GraphQL.Samples.SG.Subscription/GraphQL.Samples.SG.Subscription.csproj

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
<ProjectReference Include="..\..\src\GraphQL\GraphQL.csproj" />
1313
</ItemGroup>
1414

15-
<ItemGroup>
16-
<ProjectReference Include="..\..\src\GraphQL.Server.SourceGenerators\GraphQL.Server.SourceGenerators.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
17-
</ItemGroup>
15+
<ItemGroup>
16+
<ProjectReference Include="..\..\src\GraphQL.Server.SourceGenerators\GraphQL.Server.SourceGenerators.csproj"
17+
OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
18+
</ItemGroup>
1819

1920
</Project>

src/GraphQL.Server.SourceGenerators/GraphQL.Server.SourceGenerators.csproj

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
</PropertyGroup>
1919

2020
<ItemGroup>
21-
<PackageReference Include="Polyfill" Version="2.6.5">
21+
<PackageReference Include="Polyfill" Version="3.0.0">
2222
<PrivateAssets>all</PrivateAssets>
2323
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
2424
</PackageReference>
2525
</ItemGroup>
2626

2727
<ItemGroup>
28-
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.8.0" PrivateAssets="all" />
28+
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.9.2" PrivateAssets="all" />
2929
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4" PrivateAssets="all" />
3030
<PackageReference Include="System.Text.Json" Version="8.0.2" PrivateAssets="all" GeneratePathProperty="true" />
3131
<PackageReference Include="Scriban" Version="5.9.1" IncludeAssets="Build" />

src/GraphQL.Server.SourceGenerators/TypeHelper.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System;
2-
using System.Collections.Generic;
1+
using System.Collections.Generic;
32

43
using Microsoft.CodeAnalysis;
54
using System.Linq;

src/GraphQL.Server/GraphQL.Server.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
<ItemGroup>
1818
<FrameworkReference Include="Microsoft.AspNetCore.App" />
1919
<PackageReference Include="Microsoft.Extensions.Options" Version="8.0.2" />
20+
<PackageReference Include="Microsoft.Extensions.Telemetry" Version="8.2.0" />
2021
<PackageReference Include="System.IO.Pipelines" Version="8.0.0" />
2122
<PackageReference Include="System.Net.WebSockets" Version="4.3.0" />
2223
</ItemGroup>

src/GraphQL.Server/GraphQLWSTransport.cs

+19-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class GraphQLWSTransport : IGraphQLTransport
1919
/// Due to historical reasons this actually is the protocol name used
2020
/// by the newer protocol.
2121
/// </summary>
22-
public static string SubProtocol = "graphql-transport-ws";
22+
public const string GraphQLTransportWSProtocol = "graphql-transport-ws";
2323

2424
public IEndpointConventionBuilder Map(string pattern, IEndpointRouteBuilder routes,
2525
GraphQLRequestDelegate requestDelegate)
@@ -36,10 +36,12 @@ private async Task HandleProtocol(
3636
WebSocket webSocket,
3737
GraphQLRequestDelegate requestPipeline)
3838
{
39-
var connection = new GraphQLWSConnection(webSocket, requestPipeline, httpContext);
40-
await connection.Connect(httpContext.RequestAborted);
41-
}
39+
var handler = new WebSocketTransportHandler(
40+
requestPipeline,
41+
httpContext);
4242

43+
await handler.Handle(webSocket);
44+
}
4345

4446
private RequestDelegate ProcessRequest(GraphQLRequestDelegate pipeline)
4547
{
@@ -56,19 +58,28 @@ await httpContext.Response.WriteAsJsonAsync(new ProblemDetails
5658
return;
5759
}
5860

59-
if (httpContext.WebSockets.WebSocketRequestedProtocols?.Contains(SubProtocol) == false)
61+
if (httpContext.WebSockets.WebSocketRequestedProtocols?.Contains(EchoProtocol.Protocol) == true)
62+
{
63+
using WebSocket echoWebSocket = await httpContext.WebSockets
64+
.AcceptWebSocketAsync(EchoProtocol.Protocol);
65+
66+
await EchoProtocol.Run(echoWebSocket);
67+
return;
68+
}
69+
70+
if (httpContext.WebSockets.WebSocketRequestedProtocols?.Contains(GraphQLTransportWSProtocol) == false)
6071
{
6172
httpContext.Response.StatusCode = StatusCodes.Status400BadRequest;
6273
await httpContext.Response.WriteAsJsonAsync(new ProblemDetails
6374
{
64-
Detail = $"Request does not contain sub-protocol '{SubProtocol}'."
75+
Detail = $"Request does not contain sub-protocol '{GraphQLTransportWSProtocol}'."
6576
});
6677

6778
return;
6879
}
6980

70-
WebSocket webSocket = await httpContext.WebSockets
71-
.AcceptWebSocketAsync(SubProtocol);
81+
using WebSocket webSocket = await httpContext.WebSockets
82+
.AcceptWebSocketAsync(GraphQLTransportWSProtocol);
7283

7384
await HandleProtocol(httpContext, webSocket, pipeline);
7485
};

src/GraphQL.Server/WebSockets/ClientMethods.cs

-28
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System.Net.WebSockets;
2+
using System.Text.Json;
3+
4+
namespace Tanka.GraphQL.Server.WebSockets;
5+
6+
public static class EchoProtocol
7+
{
8+
public const string Protocol = "echo-ws";
9+
10+
public static async Task Run(WebSocket webSocket)
11+
{
12+
var channel = new WebSocketChannel(webSocket, new JsonSerializerOptions(JsonSerializerDefaults.Web));
13+
var echo = Echo(channel);
14+
15+
await Task.WhenAll(channel.Run(), echo);
16+
}
17+
18+
private static async Task Echo(WebSocketChannel channel)
19+
{
20+
while (await channel.Reader.WaitToReadAsync())
21+
{
22+
if (channel.Reader.TryRead(out var message))
23+
await channel.Writer.WriteAsync(message);
24+
}
25+
26+
channel.Complete();
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Microsoft.Extensions.Logging;
2+
3+
using Tanka.GraphQL.Server.WebSockets.Results;
4+
5+
namespace Tanka.GraphQL.Server.WebSockets;
6+
7+
public class GraphQLTransportWSProtocol(
8+
SubscriptionManager subscriptions,
9+
ILoggerFactory loggerFactory)
10+
{
11+
public bool ConnectionInitReceived = false;
12+
13+
public IMessageResult Accept(MessageBase message)
14+
{
15+
if (!ConnectionInitReceived)
16+
return new ConnectionAckResult(
17+
this,
18+
loggerFactory.CreateLogger<ConnectionAckResult>()
19+
);
20+
21+
return message.Type switch
22+
{
23+
MessageTypes.ConnectionInit => new WebSocketCloseResult(
24+
CloseCode.TooManyInitialisationRequests,
25+
loggerFactory.CreateLogger<WebSocketCloseResult>()),
26+
MessageTypes.Ping => new PongResult(loggerFactory.CreateLogger<PongResult>()),
27+
MessageTypes.Subscribe => new Results.SubscribeResult(
28+
subscriptions,
29+
loggerFactory.CreateLogger<Results.SubscribeResult>()),
30+
MessageTypes.Complete => new Results.CompleteSubscriptionResult(
31+
subscriptions,
32+
loggerFactory.CreateLogger<Results.CompleteSubscriptionResult>()),
33+
_ => new UnknownMessageResult(loggerFactory.CreateLogger<UnknownMessageResult>())
34+
};
35+
}
36+
}

src/GraphQL.Server/WebSockets/GraphQLWSConnection.cs

-148
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace Tanka.GraphQL.Server.WebSockets;
2+
3+
public interface IMessageContext
4+
{
5+
Task Write<T>(T message) where T: MessageBase;
6+
7+
Task Close(Exception? error = default);
8+
9+
MessageBase Message { get; }
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
namespace Tanka.GraphQL.Server.WebSockets;
2+
3+
public interface IMessageResult
4+
{
5+
Task Execute(IMessageContext context);
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
namespace Tanka.GraphQL.Server.WebSockets;
2+
3+
public class MessageContext(
4+
WebSocketChannel channel,
5+
MessageBase contextMessage,
6+
GraphQLRequestDelegate requestPipeline) : IMessageContext
7+
{
8+
public async Task Write<T>(T message) where T: MessageBase
9+
{
10+
await channel.Writer.WriteAsync(message);
11+
}
12+
13+
public Task Close(Exception? error = default)
14+
{
15+
channel.Complete(error);
16+
return Task.CompletedTask;
17+
}
18+
19+
public MessageBase Message => contextMessage;
20+
21+
public GraphQLRequestDelegate RequestPipeline => requestPipeline;
22+
23+
}

0 commit comments

Comments
 (0)