diff --git a/NATS.Client.sln b/NATS.Client.sln
index bbe8f8e36..1170e089a 100644
--- a/NATS.Client.sln
+++ b/NATS.Client.sln
@@ -49,6 +49,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.Core.PublishModel",
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NATS.Client.Core.MemoryTests", "tests\NATS.Client.Core.MemoryTests\NATS.Client.Core.MemoryTests.csproj", "{B26DE6AC-A4D5-4427-8453-EE3514E4B513}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.Core.PublishHeaders", "sandbox\Example.Core.PublishHeaders\Example.Core.PublishHeaders.csproj", "{B0C82F24-BDEC-4420-A02A-F74E2423D755}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.Core.SubscribeHeaders", "sandbox\Example.Core.SubscribeHeaders\Example.Core.SubscribeHeaders.csproj", "{A96660DB-DAEB-4C57-8096-F236AC4FA927}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -115,6 +119,14 @@ Global
{B26DE6AC-A4D5-4427-8453-EE3514E4B513}.Debug|Any CPU.Build.0 = Debug|Any CPU
{B26DE6AC-A4D5-4427-8453-EE3514E4B513}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B26DE6AC-A4D5-4427-8453-EE3514E4B513}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -136,6 +148,8 @@ Global
{C85BA135-3C21-4027-BE5A-849E1011DD0A} = {95A69671-16CA-4133-981C-CC381B7AAA30}
{29F96D05-D02F-4610-A8FB-3527BF83C4A5} = {95A69671-16CA-4133-981C-CC381B7AAA30}
{B26DE6AC-A4D5-4427-8453-EE3514E4B513} = {C526E8AB-739A-48D7-8FC4-048978C9B650}
+ {B0C82F24-BDEC-4420-A02A-F74E2423D755} = {95A69671-16CA-4133-981C-CC381B7AAA30}
+ {A96660DB-DAEB-4C57-8096-F236AC4FA927} = {95A69671-16CA-4133-981C-CC381B7AAA30}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {8CBB7278-D093-448E-B3DE-B5991209A1AA}
diff --git a/NATS.Client.sln.DotSettings b/NATS.Client.sln.DotSettings
new file mode 100644
index 000000000..bbcea8912
--- /dev/null
+++ b/NATS.Client.sln.DotSettings
@@ -0,0 +1,6 @@
+
+ ASCII
+ CR
+ LF
+ True
+ True
\ No newline at end of file
diff --git a/sandbox/Example.Core.PublishHeaders/Example.Core.PublishHeaders.csproj b/sandbox/Example.Core.PublishHeaders/Example.Core.PublishHeaders.csproj
new file mode 100644
index 000000000..f6dc3f6fe
--- /dev/null
+++ b/sandbox/Example.Core.PublishHeaders/Example.Core.PublishHeaders.csproj
@@ -0,0 +1,14 @@
+
+
+
+ Exe
+ net6.0
+ enable
+ enable
+
+
+
+
+
+
+
diff --git a/sandbox/Example.Core.PublishHeaders/Program.cs b/sandbox/Example.Core.PublishHeaders/Program.cs
new file mode 100644
index 000000000..a795cef4c
--- /dev/null
+++ b/sandbox/Example.Core.PublishHeaders/Program.cs
@@ -0,0 +1,31 @@
+// > nats sub bar.*
+using Microsoft.Extensions.Logging;
+using NATS.Client.Core;
+
+var subject = "bar.xyz";
+var options = NatsOptions.Default with { LoggerFactory = new MinimumConsoleLoggerFactory(LogLevel.Error) };
+
+Print("[CON] Connecting...\n");
+
+await using var connection = new NatsConnection(options);
+
+for (int i = 0; i < 10; i++)
+{
+ Print($"[PUB] Publishing to subject ({i}) '{subject}'...\n");
+ await connection.PublishAsync(
+ subject,
+ new Bar { Id = i, Name = "Baz" },
+ new NatsPubOpts { Headers = new NatsHeaders { ["XFoo"] = $"bar{i}" } });
+}
+
+void Print(string message)
+{
+ Console.Write($"{DateTime.Now:HH:mm:ss} {message}");
+}
+
+public record Bar
+{
+ public int Id { get; set; }
+
+ public string? Name { get; set; }
+}
diff --git a/sandbox/Example.Core.SubscribeHeaders/Example.Core.SubscribeHeaders.csproj b/sandbox/Example.Core.SubscribeHeaders/Example.Core.SubscribeHeaders.csproj
new file mode 100644
index 000000000..f6dc3f6fe
--- /dev/null
+++ b/sandbox/Example.Core.SubscribeHeaders/Example.Core.SubscribeHeaders.csproj
@@ -0,0 +1,14 @@
+
+
+
+ Exe
+ net6.0
+ enable
+ enable
+
+
+
+
+
+
+
diff --git a/sandbox/Example.Core.SubscribeHeaders/Program.cs b/sandbox/Example.Core.SubscribeHeaders/Program.cs
new file mode 100644
index 000000000..940595f55
--- /dev/null
+++ b/sandbox/Example.Core.SubscribeHeaders/Program.cs
@@ -0,0 +1,41 @@
+// > nats pub bar.xyz --count=10 "my_message_{{ Count }}" -H X-Foo:Baz
+
+using System.Text;
+using Microsoft.Extensions.Logging;
+using NATS.Client.Core;
+
+var subject = "bar.*";
+var options = NatsOptions.Default with { LoggerFactory = new MinimumConsoleLoggerFactory(LogLevel.Error) };
+
+Print("[CON] Connecting...\n");
+
+await using var connection = new NatsConnection(options);
+
+Print($"[SUB] Subscribing to subject '{subject}'...\n");
+
+NatsSub sub = await connection.SubscribeAsync(subject);
+
+await foreach (var msg in sub.Msgs.ReadAllAsync())
+{
+ Print($"[RCV] {msg.Subject}: {Encoding.UTF8.GetString(msg.Data.Span)}\n");
+ if (msg.Headers != null)
+ {
+ foreach (var (key, values) in msg.Headers)
+ {
+ foreach (var value in values)
+ Print($" {key}: {value}\n");
+ }
+ }
+}
+
+void Print(string message)
+{
+ Console.Write($"{DateTime.Now:HH:mm:ss} {message}");
+}
+
+public record Bar
+{
+ public int Id { get; set; }
+
+ public string? Name { get; set; }
+}
diff --git a/src/NATS.Client.Core/Commands/CommandConstants.cs b/src/NATS.Client.Core/Commands/CommandConstants.cs
index 18e3a1611..bdd62fca6 100644
--- a/src/NATS.Client.Core/Commands/CommandConstants.cs
+++ b/src/NATS.Client.Core/Commands/CommandConstants.cs
@@ -13,6 +13,9 @@ internal static class CommandConstants
// string.Join(",", Encoding.ASCII.GetBytes("PUB "))
public static ReadOnlySpan PubWithPadding => new byte[] { 80, 85, 66, 32 };
+ // string.Join(",", Encoding.ASCII.GetBytes("HPUB "))
+ public static ReadOnlySpan HPubWithPadding => new byte[] { 72, 80, 85, 66, 32 };
+
// string.Join(",", Encoding.ASCII.GetBytes("SUB "))
public static ReadOnlySpan SubWithPadding => new byte[] { 83, 85, 66, 32 };
@@ -24,4 +27,7 @@ internal static class CommandConstants
// string.Join(",", Encoding.ASCII.GetBytes("PONG\r\n"))
public static ReadOnlySpan PongNewLine => new byte[] { 80, 79, 78, 71, 13, 10 };
+
+ // string.Join(",", Encoding.ASCII.GetBytes("NATS/1.0\r\n"))
+ public static ReadOnlySpan NatsHeaders10NewLine => new byte[] { 78, 65, 84, 83, 47, 49, 46, 48, 13, 10 };
}
diff --git a/src/NATS.Client.Core/Commands/ProtocolWriter.cs b/src/NATS.Client.Core/Commands/ProtocolWriter.cs
index 2c3541daa..f0f75e835 100644
--- a/src/NATS.Client.Core/Commands/ProtocolWriter.cs
+++ b/src/NATS.Client.Core/Commands/ProtocolWriter.cs
@@ -1,5 +1,6 @@
using System.Buffers;
using System.Buffers.Text;
+using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using NATS.Client.Core.Internal;
@@ -12,6 +13,9 @@ internal sealed class ProtocolWriter
private const int NewLineLength = 2; // \r\n
private readonly FixedArrayBufferWriter _writer; // where T : IBufferWriter
+ private readonly FixedArrayBufferWriter _bufferHeaders = new();
+ private readonly FixedArrayBufferWriter _bufferPayload = new();
+ private readonly HeaderWriter _headerWriter = new(Encoding.UTF8);
public ProtocolWriter(FixedArrayBufferWriter writer)
{
@@ -46,159 +50,63 @@ public void WritePong()
}
// https://docs.nats.io/reference/reference-protocols/nats-protocol#pub
- // PUB [reply-to] <#bytes>\r\n[payload]
- // To omit the payload, set the payload size to 0, but the second CRLF is still required.
- public void WritePublish(string subject, string? replyTo, ReadOnlySequence payload)
+ // PUB [reply-to] <#bytes>\r\n[payload]\r\n
+ public void WritePublish(string subject, string? replyTo, NatsHeaders? headers, ReadOnlySequence payload)
{
- var offset = 0;
- var maxLength = CommandConstants.PubWithPadding.Length
- + subject.Length + 1 // with space padding
- + (replyTo == null ? 0 : replyTo.Length + 1)
- + MaxIntStringLength
- + NewLineLength
- + (int)payload.Length
- + NewLineLength;
-
- var writableSpan = _writer.GetSpan(maxLength);
-
- CommandConstants.PubWithPadding.CopyTo(writableSpan);
- offset += CommandConstants.PubWithPadding.Length;
+ // We use a separate buffer to write the headers so that we can calculate the
+ // size before we write to the output buffer '_writer'.
+ if (headers != null)
+ {
+ _bufferHeaders.Reset();
+ _headerWriter.Write(_bufferHeaders, headers);
+ }
- subject.WriteASCIIBytes(writableSpan.Slice(offset));
- offset += subject.Length;
- writableSpan.Slice(offset)[0] = (byte)' ';
- offset += 1;
+ // Start writing the message to buffer:
+ // PUP / HPUB
+ _writer.WriteSpan(headers == null ? CommandConstants.PubWithPadding : CommandConstants.HPubWithPadding);
+ _writer.WriteASCIIAndSpace(subject);
if (replyTo != null)
{
- replyTo.WriteASCIIBytes(writableSpan.Slice(offset));
- offset += replyTo.Length;
- writableSpan.Slice(offset)[0] = (byte)' ';
- offset += 1;
+ _writer.WriteASCIIAndSpace(replyTo);
}
- if (!Utf8Formatter.TryFormat(payload.Length, writableSpan.Slice(offset), out var written))
+ if (headers == null)
{
- throw new NatsException("Can not format integer.");
+ _writer.WriteNumber(payload.Length);
}
-
- offset += written;
-
- CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset));
- offset += CommandConstants.NewLine.Length;
-
- if (payload.Length != 0)
+ else
{
- payload.CopyTo(writableSpan.Slice(offset));
- offset += (int)payload.Length;
+ var headersLength = _bufferHeaders.WrittenSpan.Length;
+ _writer.WriteNumber(CommandConstants.NatsHeaders10NewLine.Length + headersLength);
+ _writer.WriteSpace();
+ var total = CommandConstants.NatsHeaders10NewLine.Length + headersLength + payload.Length;
+ _writer.WriteNumber(total);
}
- CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset));
- offset += CommandConstants.NewLine.Length;
-
- _writer.Advance(offset);
- }
-
- public void WritePublish(string subject, string? replyTo, T? value, INatsSerializer serializer)
- {
- var offset = 0;
- var maxLengthWithoutPayload = CommandConstants.PubWithPadding.Length
- + subject.Length + 1
- + (replyTo == null ? 0 : replyTo.Length + 1)
- + MaxIntStringLength
- + NewLineLength;
-
- var writableSpan = _writer.GetSpan(maxLengthWithoutPayload);
-
- CommandConstants.PubWithPadding.CopyTo(writableSpan);
- offset += CommandConstants.PubWithPadding.Length;
-
- subject.WriteASCIIBytes(writableSpan.Slice(offset));
- offset += subject.Length;
- writableSpan.Slice(offset)[0] = (byte)' ';
- offset += 1;
+ // End of message first line
+ _writer.WriteNewLine();
- if (replyTo != null)
+ if (headers != null)
{
- replyTo.WriteASCIIBytes(writableSpan.Slice(offset));
- offset += replyTo.Length;
- writableSpan.Slice(offset)[0] = (byte)' ';
- offset += 1;
+ _writer.WriteSpan(CommandConstants.NatsHeaders10NewLine);
+ _writer.WriteSpan(_bufferHeaders.WrittenSpan);
}
- // Advance for written.
- _writer.Advance(offset);
-
- // preallocate range for write #bytes(write after serialized)
- var preallocatedRange = _writer.PreAllocate(MaxIntStringLength);
- offset += MaxIntStringLength;
-
- CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset));
- _writer.Advance(CommandConstants.NewLine.Length);
-
- var payloadLength = serializer.Serialize(_writer, value);
- var payloadLengthSpan = _writer.GetSpanInPreAllocated(preallocatedRange);
- payloadLengthSpan.Fill((byte)' ');
- if (!Utf8Formatter.TryFormat(payloadLength, payloadLengthSpan, out var written))
+ if (payload.Length != 0)
{
- throw new NatsException("Can not format integer.");
+ _writer.WriteSequence(payload);
}
- WriteConstant(CommandConstants.NewLine);
+ _writer.WriteNewLine();
}
- public void WritePublish(string subject, ReadOnlyMemory inboxPrefix, int id, T? value, INatsSerializer serializer)
+ public void WritePublish(string subject, string? replyTo, NatsHeaders? headers, T? value, INatsSerializer serializer)
{
- Span idBytes = stackalloc byte[10];
- if (Utf8Formatter.TryFormat(id, idBytes, out var written))
- {
- idBytes = idBytes.Slice(0, written);
- }
-
- var offset = 0;
- var maxLengthWithoutPayload = CommandConstants.PubWithPadding.Length
- + subject.Length + 1
- + (inboxPrefix.Length + idBytes.Length + 1) // with space
- + MaxIntStringLength
- + NewLineLength;
-
- var writableSpan = _writer.GetSpan(maxLengthWithoutPayload);
-
- CommandConstants.PubWithPadding.CopyTo(writableSpan);
- offset += CommandConstants.PubWithPadding.Length;
-
- subject.WriteASCIIBytes(writableSpan.Slice(offset));
- offset += subject.Length;
- writableSpan.Slice(offset)[0] = (byte)' ';
- offset += 1;
-
- // build reply-to
- inboxPrefix.Span.CopyTo(writableSpan.Slice(offset));
- offset += inboxPrefix.Length;
- idBytes.CopyTo(writableSpan.Slice(offset));
- offset += idBytes.Length;
- writableSpan.Slice(offset)[0] = (byte)' ';
- offset += 1;
-
- // Advance for written.
- _writer.Advance(offset);
-
- // preallocate range for write #bytes(write after serialized)
- var preallocatedRange = _writer.PreAllocate(MaxIntStringLength);
- offset += MaxIntStringLength;
-
- CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset));
- _writer.Advance(CommandConstants.NewLine.Length);
-
- var payloadLength = serializer.Serialize(_writer, value);
- var payloadLengthSpan = _writer.GetSpanInPreAllocated(preallocatedRange);
- payloadLengthSpan.Fill((byte)' ');
- if (!Utf8Formatter.TryFormat(payloadLength, payloadLengthSpan, out written))
- {
- throw new NatsException("Can not format integer.");
- }
-
- WriteConstant(CommandConstants.NewLine);
+ _bufferPayload.Reset();
+ serializer.Serialize(_bufferPayload, value);
+ var payload = new ReadOnlySequence(_bufferPayload.WrittenMemory);
+ WritePublish(subject, replyTo, headers, payload);
}
// https://docs.nats.io/reference/reference-protocols/nats-protocol#sub
diff --git a/src/NATS.Client.Core/Commands/PublishCommand.cs b/src/NATS.Client.Core/Commands/PublishCommand.cs
index c56fc6afd..0d345dcd5 100644
--- a/src/NATS.Client.Core/Commands/PublishCommand.cs
+++ b/src/NATS.Client.Core/Commands/PublishCommand.cs
@@ -7,6 +7,7 @@ internal sealed class AsyncPublishCommand : AsyncCommandBase Create(ObjectPool pool, CancellationTimer timer, string subject, string? replyTo, T? value, INatsSerializer serializer)
+ public static AsyncPublishCommand Create(ObjectPool pool, CancellationTimer timer, string subject, string? replyTo, NatsHeaders? headers, T? value, INatsSerializer serializer)
{
if (!TryRent(pool, out var result))
{
@@ -23,6 +24,7 @@ public static AsyncPublishCommand Create(ObjectPool pool, CancellationTimer t
result._subject = subject;
result._replyTo = replyTo;
+ result._headers = headers;
result._value = value;
result._serializer = serializer;
result.SetCancellationTimer(timer);
@@ -32,12 +34,13 @@ public static AsyncPublishCommand Create(ObjectPool pool, CancellationTimer t
public override void Write(ProtocolWriter writer)
{
- writer.WritePublish(_subject!, _replyTo, _value, _serializer!);
+ writer.WritePublish(_subject!, _replyTo, _headers, _value, _serializer!);
}
protected override void Reset()
{
_subject = default;
+ _headers = default;
_value = default;
_serializer = null;
}
@@ -46,13 +49,15 @@ protected override void Reset()
internal sealed class AsyncPublishBytesCommand : AsyncCommandBase
{
private string? _subject;
- private ReadOnlySequence _value;
+ private string? _replyTo;
+ private NatsHeaders? _headers;
+ private ReadOnlySequence _payload;
private AsyncPublishBytesCommand()
{
}
- public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer timer, string subject, ReadOnlySequence value)
+ public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer timer, string subject, string? replyTo, NatsHeaders? headers, ReadOnlySequence payload)
{
if (!TryRent(pool, out var result))
{
@@ -60,7 +65,9 @@ public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer
}
result._subject = subject;
- result._value = value;
+ result._replyTo = replyTo;
+ result._headers = headers;
+ result._payload = payload;
result.SetCancellationTimer(timer);
return result;
@@ -68,12 +75,14 @@ public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer
public override void Write(ProtocolWriter writer)
{
- writer.WritePublish(_subject!, null, _value);
+ writer.WritePublish(_subject!, _replyTo, _headers, _payload);
}
protected override void Reset()
{
_subject = default;
- _value = default;
+ _replyTo = default;
+ _headers = default;
+ _payload = default;
}
}
diff --git a/src/NATS.Client.Core/Internal/BufferExtensions.cs b/src/NATS.Client.Core/Internal/BufferExtensions.cs
new file mode 100644
index 000000000..c618dc8b4
--- /dev/null
+++ b/src/NATS.Client.Core/Internal/BufferExtensions.cs
@@ -0,0 +1,68 @@
+// Adapted from https://github.com/dotnet/aspnetcore/blob/v6.0.18/src/Shared/ServerInfrastructure/BufferExtensions.cs
+
+#nullable enable
+
+using System.Buffers;
+using System.Runtime.CompilerServices;
+
+namespace NATS.Client.Core.Internal;
+
+internal static class BufferExtensions
+{
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static ReadOnlySpan ToSpan(this ReadOnlySequence buffer)
+ {
+ if (buffer.IsSingleSegment)
+ {
+ return buffer.FirstSpan;
+ }
+
+ return buffer.ToArray();
+ }
+
+ ///
+ /// Returns position of first occurrence of item in the
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static SequencePosition? PositionOfAny(in this ReadOnlySequence source, T value0, T value1)
+ where T : IEquatable
+ {
+ if (source.IsSingleSegment)
+ {
+ int index = source.First.Span.IndexOfAny(value0, value1);
+ if (index != -1)
+ {
+ return source.GetPosition(index);
+ }
+
+ return null;
+ }
+ else
+ {
+ return PositionOfAnyMultiSegment(source, value0, value1);
+ }
+ }
+
+ private static SequencePosition? PositionOfAnyMultiSegment(in ReadOnlySequence source, T value0, T value1)
+ where T : IEquatable
+ {
+ SequencePosition position = source.Start;
+ SequencePosition result = position;
+ while (source.TryGet(ref position, out ReadOnlyMemory memory))
+ {
+ int index = memory.Span.IndexOfAny(value0, value1);
+ if (index != -1)
+ {
+ return source.GetPosition(index, result);
+ }
+ else if (position.GetObject() == null)
+ {
+ break;
+ }
+
+ result = position;
+ }
+
+ return null;
+ }
+}
diff --git a/src/NATS.Client.Core/Internal/BufferWriterExtensions.cs b/src/NATS.Client.Core/Internal/BufferWriterExtensions.cs
new file mode 100644
index 000000000..797cca547
--- /dev/null
+++ b/src/NATS.Client.Core/Internal/BufferWriterExtensions.cs
@@ -0,0 +1,66 @@
+using System.Buffers;
+using System.Buffers.Text;
+using System.Runtime.CompilerServices;
+using System.Text;
+using NATS.Client.Core.Commands;
+
+namespace NATS.Client.Core.Internal;
+
+internal static class BufferWriterExtensions
+{
+ private const int MaxIntStringLength = 10; // int.MaxValue.ToString().Length
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void WriteNewLine(this FixedArrayBufferWriter writer)
+ {
+ var span = writer.GetSpan(CommandConstants.NewLine.Length);
+ CommandConstants.NewLine.CopyTo(span);
+ writer.Advance(CommandConstants.NewLine.Length);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void WriteNumber(this FixedArrayBufferWriter writer, long number)
+ {
+ var span = writer.GetSpan(MaxIntStringLength);
+ if (!Utf8Formatter.TryFormat(number, span, out var writtenLength))
+ {
+ throw new NatsException("Can not format integer.");
+ }
+
+ writer.Advance(writtenLength);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void WriteSpace(this FixedArrayBufferWriter writer)
+ {
+ var span = writer.GetSpan(1);
+ span[0] = (byte)' ';
+ writer.Advance(1);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void WriteSpan(this FixedArrayBufferWriter writer, ReadOnlySpan span)
+ {
+ var writerSpan = writer.GetSpan(span.Length);
+ span.CopyTo(writerSpan);
+ writer.Advance(span.Length);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void WriteSequence(this FixedArrayBufferWriter writer, ReadOnlySequence sequence)
+ {
+ var len = (int)sequence.Length;
+ var span = writer.GetSpan(len);
+ sequence.CopyTo(span);
+ writer.Advance(len);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void WriteASCIIAndSpace(this FixedArrayBufferWriter writer, string ascii)
+ {
+ var span = writer.GetSpan(ascii.Length + 1);
+ ascii.WriteASCIIBytes(span);
+ span[ascii.Length] = (byte)' ';
+ writer.Advance(ascii.Length + 1);
+ }
+}
diff --git a/src/NATS.Client.Core/Internal/ClientOptions.cs b/src/NATS.Client.Core/Internal/ClientOptions.cs
index fdd547e62..a3dc7d9ab 100644
--- a/src/NATS.Client.Core/Internal/ClientOptions.cs
+++ b/src/NATS.Client.Core/Internal/ClientOptions.cs
@@ -12,6 +12,7 @@ public ClientOptions(NatsOptions options)
Name = options.Name;
Echo = options.Echo;
Verbose = options.Verbose;
+ Headers = options.Headers;
Username = options.AuthOptions.Username;
Password = options.AuthOptions.Password;
AuthToken = options.AuthOptions.Token;
diff --git a/src/NATS.Client.Core/Internal/DebuggingExtensions.cs b/src/NATS.Client.Core/Internal/DebuggingExtensions.cs
new file mode 100644
index 000000000..4c5845eb8
--- /dev/null
+++ b/src/NATS.Client.Core/Internal/DebuggingExtensions.cs
@@ -0,0 +1,65 @@
+#if DEBUG
+
+using System.Buffers;
+using System.Diagnostics;
+using System.Text;
+
+namespace NATS.Client.Core.Internal;
+
+internal static class DebuggingExtensions
+{
+ public static string Dump(this ReadOnlySequence buffer)
+ {
+ var sb = new StringBuilder();
+ foreach (var readOnlyMemory in buffer)
+ {
+ sb.Append(Dump(readOnlyMemory.Span));
+ }
+
+ return sb.ToString();
+ }
+
+ public static string Dump(this ReadOnlySpan span)
+ {
+ var sb = new StringBuilder();
+ foreach (char b in span)
+ {
+ switch (b)
+ {
+ case >= ' ' and <= '~':
+ sb.Append(b);
+ break;
+ case '\r':
+ sb.Append('␍');
+ break;
+ case '\n':
+ sb.Append('␊');
+ break;
+ default:
+ sb.Append('.');
+ break;
+ }
+ }
+
+ return sb.ToString();
+ }
+
+ public static string Dump(this NatsHeaders? headers)
+ {
+ if (headers == null)
+ return "";
+
+ var sb = new StringBuilder();
+ foreach (var (key, stringValues) in headers)
+ {
+ foreach (var value in stringValues)
+ {
+ sb.AppendLine($"{key}: {value}");
+ }
+ }
+
+ return sb.ToString();
+ }
+}
+
+#endif
diff --git a/src/NATS.Client.Core/Internal/FastQueue.cs b/src/NATS.Client.Core/Internal/FastQueue.cs
deleted file mode 100644
index d10941757..000000000
--- a/src/NATS.Client.Core/Internal/FastQueue.cs
+++ /dev/null
@@ -1,97 +0,0 @@
-using System.Runtime.CompilerServices;
-
-namespace NATS.Client.Core.Internal;
-
-// fixed size queue.
-internal sealed class FastQueue
-{
- private T[] _array;
- private int _head;
- private int _tail;
- private int _size;
-
- public FastQueue(int capacity)
- {
- if (capacity < 0)
- throw new ArgumentOutOfRangeException("capacity");
- _array = new T[capacity];
- _head = _tail = _size = 0;
- }
-
- public int Count
- {
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- get { return _size; }
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public void Enqueue(T item)
- {
- if (_size == _array.Length)
- {
- ThrowForFullQueue();
- }
-
- _array[_tail] = item;
- MoveNext(ref _tail);
- _size++;
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public T Dequeue()
- {
- if (_size == 0)
- ThrowForEmptyQueue();
-
- var head = _head;
- var array = _array;
- var removed = array[head];
- array[head] = default!;
- MoveNext(ref _head);
- _size--;
- return removed;
- }
-
- public void EnsureNewCapacity(int capacity)
- {
- var newarray = new T[capacity];
- if (_size > 0)
- {
- if (_head < _tail)
- {
- Array.Copy(_array, _head, newarray, 0, _size);
- }
- else
- {
- Array.Copy(_array, _head, newarray, 0, _array.Length - _head);
- Array.Copy(_array, 0, newarray, _array.Length - _head, _tail);
- }
- }
-
- _array = newarray;
- _head = 0;
- _tail = _size == capacity ? 0 : _size;
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- private void MoveNext(ref int index)
- {
- var tmp = index + 1;
- if (tmp == _array.Length)
- {
- tmp = 0;
- }
-
- index = tmp;
- }
-
- private void ThrowForEmptyQueue()
- {
- throw new InvalidOperationException("Queue is empty.");
- }
-
- private void ThrowForFullQueue()
- {
- throw new InvalidOperationException("Queue is full.");
- }
-}
diff --git a/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs b/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs
index 316d21bd7..bb74aa9eb 100644
--- a/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs
+++ b/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs
@@ -17,6 +17,8 @@ public FixedArrayBufferWriter(int capacity = 65535)
public ReadOnlyMemory WrittenMemory => _buffer.AsMemory(0, _written);
+ public ReadOnlySpan WrittenSpan => _buffer.AsSpan(0, _written);
+
public int WrittenCount => _written;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
diff --git a/src/NATS.Client.Core/Internal/HeaderParser.cs b/src/NATS.Client.Core/Internal/HeaderParser.cs
new file mode 100644
index 000000000..27afc3f7f
--- /dev/null
+++ b/src/NATS.Client.Core/Internal/HeaderParser.cs
@@ -0,0 +1,335 @@
+// Adapted from https://github.com/dotnet/aspnetcore/blob/v6.0.18/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs
+
+using System.Buffers;
+using System.Diagnostics;
+using System.Text;
+using Microsoft.Extensions.Primitives;
+
+namespace NATS.Client.Core.Internal;
+
+internal class HeaderParser
+{
+ private const byte ByteCR = (byte)'\r';
+ private const byte ByteLF = (byte)'\n';
+ private const byte ByteColon = (byte)':';
+ private const byte ByteSpace = (byte)' ';
+ private const byte ByteTab = (byte)'\t';
+
+ private readonly Encoding _encoding;
+
+ public HeaderParser(Encoding encoding)
+ {
+ _encoding = encoding;
+ }
+
+ public bool ParseHeaders(in SequenceReader reader, NatsHeaders headers)
+ {
+ while (!reader.End)
+ {
+ var span = reader.UnreadSpan;
+ while (span.Length > 0)
+ {
+ var ch1 = (byte)0;
+ var ch2 = (byte)0;
+ var readAhead = 0;
+
+ // Fast path, we're still looking at the same span
+ if (span.Length >= 2)
+ {
+ ch1 = span[0];
+ ch2 = span[1];
+ }
+
+ // Possibly split across spans
+ else if (reader.TryRead(out ch1))
+ {
+ // Note if we read ahead by 1 or 2 bytes
+ readAhead = reader.TryRead(out ch2) ? 2 : 1;
+ }
+
+ if (ch1 == ByteCR)
+ {
+ // Check for final CRLF.
+ if (ch2 == ByteLF)
+ {
+ // If we got 2 bytes from the span directly so skip ahead 2 so that
+ // the reader's state matches what we expect
+ if (readAhead == 0)
+ {
+ reader.Advance(2);
+ }
+
+ // Double CRLF found, so end of headers.
+ return true;
+ }
+ else if (readAhead == 1)
+ {
+ // Didn't read 2 bytes, reset the reader so we don't consume anything
+ reader.Rewind(1);
+ return false;
+ }
+
+ // Headers don't end in CRLF line.
+ Debug.Assert(readAhead == 0 || readAhead == 2, "readAhead == 0 || readAhead == 2");
+
+ throw new NatsException($"Protocol error: invalid headers, no ending CRLFCRLF");
+ }
+
+ var length = 0;
+
+ // We only need to look for the end if we didn't read ahead; otherwise there isn't enough in
+ // in the span to contain a header.
+ if (readAhead == 0)
+ {
+ length = span.IndexOfAny(ByteCR, ByteLF);
+
+ // If not found length with be -1; casting to uint will turn it to uint.MaxValue
+ // which will be larger than any possible span.Length. This also serves to eliminate
+ // the bounds check for the next lookup of span[length]
+ if ((uint)length < (uint)span.Length)
+ {
+ // Early memory read to hide latency
+ var expectedCR = span[length];
+
+ // Correctly has a CR, move to next
+ length++;
+
+ if (expectedCR != ByteCR)
+ {
+ // Sequence needs to be CRLF not LF first.
+ RejectRequestHeader(span[..length]);
+ }
+
+ if ((uint)length < (uint)span.Length)
+ {
+ // Early memory read to hide latency
+ var expectedLF = span[length];
+
+ // Correctly has a LF, move to next
+ length++;
+
+ if (expectedLF != ByteLF ||
+ length < 5 ||
+
+ // Exclude the CRLF from the headerLine and parse the header name:value pair
+ !TryTakeSingleHeader(span[..(length - 2)], headers))
+ {
+ // Sequence needs to be CRLF and not contain an inner CR not part of terminator.
+ // Less than min possible headerSpan of 5 bytes a:b\r\n
+ // Not parsable as a valid name:value header pair.
+ RejectRequestHeader(span[..length]);
+ }
+
+ // Read the header successfully, skip the reader forward past the headerSpan.
+ span = span.Slice(length);
+ reader.Advance(length);
+ }
+ else
+ {
+ // No enough data, set length to 0.
+ length = 0;
+ }
+ }
+ }
+
+ // End found in current span
+ if (length > 0)
+ {
+ continue;
+ }
+
+ // We moved the reader to look ahead 2 bytes so rewind the reader
+ if (readAhead > 0)
+ {
+ reader.Rewind(readAhead);
+ }
+
+ length = ParseMultiSpanHeader(reader, headers);
+ if (length < 0)
+ {
+ // Not there
+ return false;
+ }
+
+ reader.Advance(length);
+
+ // As we crossed spans set the current span to default
+ // so we move to the next span on the next iteration
+ span = default;
+ }
+ }
+
+ return false;
+ }
+
+ private int ParseMultiSpanHeader(in SequenceReader reader, NatsHeaders headers)
+ {
+ var currentSlice = reader.UnreadSequence;
+ var lineEndPosition = currentSlice.PositionOfAny(ByteCR, ByteLF);
+
+ if (lineEndPosition == null)
+ {
+ // Not there.
+ return -1;
+ }
+
+ SequencePosition lineEnd;
+ ReadOnlySpan headerSpan;
+ if (currentSlice.Slice(reader.Position, lineEndPosition.Value).Length == currentSlice.Length - 1)
+ {
+ // No enough data, so CRLF can't currently be there.
+ // However, we need to check the found char is CR and not LF
+
+ // Advance 1 to include CR/LF in lineEnd
+ lineEnd = currentSlice.GetPosition(1, lineEndPosition.Value);
+ headerSpan = currentSlice.Slice(reader.Position, lineEnd).ToSpan();
+ if (headerSpan[^1] != ByteCR)
+ {
+ RejectRequestHeader(headerSpan);
+ }
+
+ return -1;
+ }
+
+ // Advance 2 to include CR{LF?} in lineEnd
+ lineEnd = currentSlice.GetPosition(2, lineEndPosition.Value);
+ headerSpan = currentSlice.Slice(reader.Position, lineEnd).ToSpan();
+
+ if (headerSpan.Length < 5)
+ {
+ // Less than min possible headerSpan is 5 bytes a:b\r\n
+ RejectRequestHeader(headerSpan);
+ }
+
+ if (headerSpan[^2] != ByteCR)
+ {
+ // Sequence needs to be CRLF not LF first.
+ RejectRequestHeader(headerSpan[..^1]);
+ }
+
+ if (headerSpan[^1] != ByteLF ||
+
+ // Exclude the CRLF from the headerLine and parse the header name:value pair
+ !TryTakeSingleHeader(headerSpan[..^2], headers))
+ {
+ // Sequence needs to be CRLF and not contain an inner CR not part of terminator.
+ // Not parsable as a valid name:value header pair.
+ RejectRequestHeader(headerSpan);
+ }
+
+ return headerSpan.Length;
+ }
+
+ private bool TryTakeSingleHeader(ReadOnlySpan headerLine, NatsHeaders headers)
+ {
+ // We are looking for a colon to terminate the header name.
+ // However, the header name cannot contain a space or tab so look for all three
+ // and see which is found first.
+ var nameEnd = headerLine.IndexOfAny(ByteColon, ByteSpace, ByteTab);
+
+ // If not found length with be -1; casting to uint will turn it to uint.MaxValue
+ // which will be larger than any possible headerLine.Length. This also serves to eliminate
+ // the bounds check for the next lookup of headerLine[nameEnd]
+ if ((uint)nameEnd >= (uint)headerLine.Length)
+ {
+ // Colon not found.
+ return false;
+ }
+
+ // Early memory read to hide latency
+ var expectedColon = headerLine[nameEnd];
+ if (nameEnd == 0)
+ {
+ // Header name is empty.
+ return false;
+ }
+
+ if (expectedColon != ByteColon)
+ {
+ // Header name space or tab.
+ return false;
+ }
+
+ // Skip colon to get to the value start.
+ var valueStart = nameEnd + 1;
+
+ // Generally there will only be one space, so we will check it directly
+ if ((uint)valueStart < (uint)headerLine.Length)
+ {
+ var ch = headerLine[valueStart];
+ if (ch == ByteSpace || ch == ByteTab)
+ {
+ // Ignore first whitespace.
+ valueStart++;
+
+ // More header chars?
+ if ((uint)valueStart < (uint)headerLine.Length)
+ {
+ ch = headerLine[valueStart];
+
+ // Do a fast check; as we now expect non-space, before moving into loop.
+ if (ch <= ByteSpace && (ch == ByteSpace || ch == ByteTab))
+ {
+ valueStart++;
+
+ // Is more whitespace, so we will loop to find the end. This is the slow path.
+ for (; valueStart < headerLine.Length; valueStart++)
+ {
+ ch = headerLine[valueStart];
+ if (ch != ByteTab && ch != ByteSpace)
+ {
+ // Non-whitespace char found, valueStart is now start of value.
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ var valueEnd = headerLine.Length - 1;
+
+ // Ignore end whitespace. Generally there will no spaces
+ // so we will check the first before moving to a loop.
+ if (valueEnd > valueStart)
+ {
+ var ch = headerLine[valueEnd];
+
+ // Do a fast check; as we now expect non-space, before moving into loop.
+ if (ch <= ByteSpace && (ch == ByteSpace || ch == ByteTab))
+ {
+ // Is whitespace so move to loop
+ valueEnd--;
+ for (; valueEnd > valueStart; valueEnd--)
+ {
+ ch = headerLine[valueEnd];
+ if (ch != ByteTab && ch != ByteSpace)
+ {
+ // Non-whitespace char found, valueEnd is now start of value.
+ break;
+ }
+ }
+ }
+ }
+
+ // Range end is exclusive, so add 1 to valueEnd
+ valueEnd++;
+ var key = _encoding.GetString(headerLine[..nameEnd]);
+ var value = _encoding.GetString(headerLine[valueStart..valueEnd]);
+ if (headers.TryGetValue(key, out var existing))
+ {
+ headers[key] = StringValues.Concat(existing, value);
+ }
+ else
+ {
+ headers[key] = value;
+ }
+
+ return true;
+ }
+
+ [StackTraceHidden]
+ private void RejectRequestHeader(ReadOnlySpan headerLine)
+ => throw new NatsException(
+ $"Protocol error: invalid request header line '{_encoding.GetString(headerLine)}'");
+}
diff --git a/src/NATS.Client.Core/Internal/HeaderWriter.cs b/src/NATS.Client.Core/Internal/HeaderWriter.cs
new file mode 100644
index 000000000..c76b806eb
--- /dev/null
+++ b/src/NATS.Client.Core/Internal/HeaderWriter.cs
@@ -0,0 +1,100 @@
+using System.Buffers;
+using System.Text;
+
+namespace NATS.Client.Core.Internal;
+
+internal class HeaderWriter
+{
+ private const byte ByteCr = (byte)'\r';
+ private const byte ByteLf = (byte)'\n';
+ private const byte ByteColon = (byte)':';
+ private const byte ByteSpace = (byte)' ';
+ private const byte ByteDel = 127;
+ private readonly Encoding _encoding;
+
+ public HeaderWriter(Encoding encoding) => _encoding = encoding;
+
+ private static ReadOnlySpan CrLf => new[] { ByteCr, ByteLf };
+
+ private static ReadOnlySpan ColonSpace => new[] { ByteColon, ByteSpace };
+
+ internal int Write(in FixedArrayBufferWriter bufferWriter, NatsHeaders headers)
+ {
+ var initialCount = bufferWriter.WrittenCount;
+ foreach (var kv in headers)
+ {
+ foreach (var value in kv.Value)
+ {
+ if (value != null)
+ {
+ // write key
+ var keyLength = _encoding.GetByteCount(kv.Key);
+ var keySpan = bufferWriter.GetSpan(keyLength);
+ _encoding.GetBytes(kv.Key, keySpan);
+ if (!ValidateKey(keySpan.Slice(0, keyLength)))
+ {
+ throw new NatsException(
+ $"Invalid header key '{kv.Key}': contains colon, space, or other non-printable ASCII characters");
+ }
+
+ bufferWriter.Advance(keyLength);
+ bufferWriter.Write(ColonSpace);
+
+ // write values
+ var valueLength = _encoding.GetByteCount(value);
+ var valueSpan = bufferWriter.GetSpan(valueLength);
+ _encoding.GetBytes(value, valueSpan);
+ if (!ValidateValue(valueSpan.Slice(0, valueLength)))
+ {
+ throw new NatsException($"Invalid header value for key '{kv.Key}': contains CRLF");
+ }
+
+ bufferWriter.Advance(valueLength);
+ bufferWriter.Write(CrLf);
+ }
+ }
+ }
+
+ // Even empty header needs to terminate.
+ // We will send NATS/1.0 version line
+ // even if there are no headers.
+ bufferWriter.Write(CrLf);
+
+ return bufferWriter.WrittenCount - initialCount;
+ }
+
+ // cannot contain ASCII Bytes <=32, 58, or 127
+ private static bool ValidateKey(ReadOnlySpan span)
+ {
+ foreach (var b in span)
+ {
+ if (b <= ByteSpace || b == ByteColon || b >= ByteDel)
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // cannot contain CRLF
+ private static bool ValidateValue(ReadOnlySpan span)
+ {
+ while (true)
+ {
+ var pos = span.IndexOf(ByteCr);
+ if (pos == -1 || pos == span.Length - 1)
+ {
+ return true;
+ }
+
+ pos += 1;
+ if (span[pos] == ByteLf)
+ {
+ return false;
+ }
+
+ span = span[pos..];
+ }
+ }
+}
diff --git a/src/NATS.Client.Core/Internal/StringUtils.cs b/src/NATS.Client.Core/Internal/StringExtensions.cs
similarity index 93%
rename from src/NATS.Client.Core/Internal/StringUtils.cs
rename to src/NATS.Client.Core/Internal/StringExtensions.cs
index dbd4add75..b2992e547 100644
--- a/src/NATS.Client.Core/Internal/StringUtils.cs
+++ b/src/NATS.Client.Core/Internal/StringExtensions.cs
@@ -1,6 +1,6 @@
namespace NATS.Client.Core.Internal;
-internal static class StringUtils
+internal static class StringExtensions
{
///
/// Allocation free ASCII buffer writer.
diff --git a/src/NATS.Client.Core/MessagePublisher.cs b/src/NATS.Client.Core/MessagePublisher.cs
deleted file mode 100644
index d894057c8..000000000
--- a/src/NATS.Client.Core/MessagePublisher.cs
+++ /dev/null
@@ -1,266 +0,0 @@
-using System.Buffers;
-using System.Collections.Concurrent;
-using Microsoft.Extensions.Logging;
-using NATS.Client.Core.Internal;
-
-namespace NATS.Client.Core;
-
-// TODO: Clean up message publisher.
-internal delegate Task PublishMessage(string subject, string? replyTo, NatsOptions options, ReadOnlySequence buffer, object?[] callbacks);
-
-internal static class MessagePublisher
-{
- // To avoid boxing, cache generic type and invoke it.
- private static readonly Func CreatePublisherValue = CreatePublisher;
- private static readonly ConcurrentDictionary PublisherCache = new();
-
- public static Task PublishAsync(string subject, string? replyTo, Type type, NatsOptions options, in ReadOnlySequence buffer, object?[] callbacks)
- {
- return PublisherCache.GetOrAdd(type, CreatePublisherValue).Invoke(subject, replyTo, options, buffer, callbacks);
- }
-
- private static PublishMessage CreatePublisher(Type type)
- {
- if (type == typeof(byte[]))
- {
- return new ByteArrayMessagePublisher().Publish;
- }
- else if (type == typeof(ReadOnlyMemory))
- {
- return new ReadOnlyMemoryMessagePublisher().PublishAsync;
- }
-
- var publisher = typeof(MessagePublisher<>).MakeGenericType(type)!;
- var instance = Activator.CreateInstance(publisher)!;
- return (PublishMessage)Delegate.CreateDelegate(typeof(PublishMessage), instance, "Publish", false);
- }
-}
-
-internal sealed class MessagePublisher
-{
- public void Publish(NatsOptions options, in ReadOnlySequence buffer, object?[] callbacks)
- {
- T? value;
- try
- {
- value = options!.Serializer.Deserialize(buffer);
- }
- catch (Exception ex)
- {
- try
- {
- options!.LoggerFactory.CreateLogger>().LogError(ex, "Deserialize error during receive subscribed message. Type:{0}", typeof(T).Name);
- }
- catch
- {
- }
-
- return;
- }
-
- try
- {
- if (!options.UseThreadPoolCallback)
- {
- foreach (var callback in callbacks!)
- {
- if (callback != null)
- {
- try
- {
- ((Action)callback).Invoke(value);
- }
- catch (Exception ex)
- {
- options!.LoggerFactory.CreateLogger>().LogError(ex, "Error occured during publish callback.");
- }
- }
- }
- }
- else
- {
- foreach (var callback in callbacks!)
- {
- if (callback != null)
- {
- var item = ThreadPoolWorkItem.Create((Action)callback, value, options!.LoggerFactory);
- ThreadPool.UnsafeQueueUserWorkItem(item, preferLocal: false);
- }
- }
- }
- }
- catch (Exception ex)
- {
- try
- {
- options!.LoggerFactory.CreateLogger>().LogError(ex, "Error occured during publish callback.");
- }
- catch
- {
- }
- }
- }
-}
-
-internal sealed class ByteArrayMessagePublisher
-{
-#pragma warning disable CA1822
-#pragma warning disable VSTHRD200
-#pragma warning disable CS1998
- public async Task Publish(string subject, string? replyTo, NatsOptions? options, ReadOnlySequence buffer, object?[] callbacks)
-#pragma warning restore CS1998
-#pragma warning restore VSTHRD200
-#pragma warning restore CA1822
- {
- byte[] value;
- try
- {
- if (buffer.IsEmpty)
- {
- value = Array.Empty();
- }
- else
- {
- value = buffer.ToArray();
- }
- }
- catch (Exception ex)
- {
- try
- {
- options!.LoggerFactory.CreateLogger().LogError(ex, "Deserialize error during receive subscribed message.");
- }
- catch
- {
- }
-
- return;
- }
-
- try
- {
- if (options is { UseThreadPoolCallback: false })
- {
- foreach (var callback in callbacks!)
- {
- if (callback != null)
- {
- try
- {
- ((Action)callback).Invoke(value);
- }
- catch (Exception ex)
- {
- options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback.");
- }
- }
- }
- }
- else
- {
- foreach (var callback in callbacks!)
- {
- if (callback != null)
- {
- var item = ThreadPoolWorkItem.Create((Action)callback, value, options!.LoggerFactory);
- ThreadPool.UnsafeQueueUserWorkItem(item, preferLocal: false);
- }
- }
- }
- }
- catch (Exception ex)
- {
- try
- {
- options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback.");
- }
- catch
- {
- }
- }
- }
-}
-
-internal sealed class ReadOnlyMemoryMessagePublisher
-{
- public async Task PublishAsync(string subject, string? replyTo, NatsOptions? options, ReadOnlySequence buffer, object?[] callbacks)
- {
- ReadOnlyMemory value;
- try
- {
- if (buffer.IsEmpty)
- {
- value = Array.Empty();
- }
- else
- {
- value = buffer.ToArray();
- }
- }
- catch (Exception ex)
- {
- try
- {
- options!.LoggerFactory.CreateLogger().LogError(ex, "Deserialize error during receive subscribed message.");
- }
- catch
- {
- }
-
- return;
- }
-
- try
- {
- if (options is { UseThreadPoolCallback: false })
- {
- foreach (var callback in callbacks!)
- {
- if (callback != null)
- {
- try
- {
- if (callback is NatsSubBase natsSub)
- {
- await natsSub.ReceiveAsync(subject, replyTo, buffer).ConfigureAwait(false);
- }
- else if (callback is Action> action)
- {
- action.Invoke(value);
- }
- else
- {
- throw new NatsException($"Unexpected internal handler type: {callback.GetType().Name}");
- }
- }
- catch (Exception ex)
- {
- options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback.");
- }
- }
- }
- }
- else
- {
- foreach (var callback in callbacks!)
- {
- if (callback != null)
- {
- var item = ThreadPoolWorkItem>.Create((Action>)callback, value, options!.LoggerFactory);
- ThreadPool.UnsafeQueueUserWorkItem(item, preferLocal: false);
- }
- }
- }
- }
- catch (Exception ex)
- {
- try
- {
- options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback.");
- }
- catch
- {
- }
- }
- }
-}
diff --git a/src/NATS.Client.Core/NatsConnection.Publish.cs b/src/NATS.Client.Core/NatsConnection.Publish.cs
index 4a8c7c80f..07bf1f959 100644
--- a/src/NATS.Client.Core/NatsConnection.Publish.cs
+++ b/src/NATS.Client.Core/NatsConnection.Publish.cs
@@ -8,9 +8,14 @@ public partial class NatsConnection
///
public ValueTask PublishAsync(string subject, ReadOnlySequence payload = default, in NatsPubOpts? opts = default, CancellationToken cancellationToken = default)
{
+ var replyTo = opts?.ReplyTo;
+ var headers = opts?.Headers;
+
+ headers?.SetReadOnly();
+
if (ConnectionState == NatsConnectionState.Open)
{
- var command = AsyncPublishBytesCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, payload);
+ var command = AsyncPublishBytesCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, replyTo, headers, payload);
if (TryEnqueueCommand(command))
{
return command.AsValueTask();
@@ -22,9 +27,9 @@ public ValueTask PublishAsync(string subject, ReadOnlySequence payload = d
}
else
{
- return WithConnectAsync(subject, payload, cancellationToken, static (self, k, v, token) =>
+ return WithConnectAsync(subject, replyTo, headers, payload, cancellationToken, static (self, s, r, h, p, token) =>
{
- var command = AsyncPublishBytesCommand.Create(self._pool, self.GetCommandTimer(token), k, v);
+ var command = AsyncPublishBytesCommand.Create(self._pool, self.GetCommandTimer(token), s, r, h, p);
return self.EnqueueAndAwaitCommandAsync(command);
});
}
@@ -40,10 +45,14 @@ public ValueTask PublishAsync(NatsMsg msg, CancellationToken cancellationToken =
public ValueTask PublishAsync(string subject, T data, in NatsPubOpts? opts = default, CancellationToken cancellationToken = default)
{
var replyTo = opts?.ReplyTo;
+ var serializer = opts?.Serializer ?? Options.Serializer;
+ var headers = opts?.Headers;
+
+ headers?.SetReadOnly();
if (ConnectionState == NatsConnectionState.Open)
{
- var command = AsyncPublishCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, replyTo, data, opts?.Serializer ?? Options.Serializer);
+ var command = AsyncPublishCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, replyTo, headers, data, serializer);
if (TryEnqueueCommand(command))
{
return command.AsValueTask();
@@ -55,9 +64,9 @@ public ValueTask PublishAsync(string subject, T data, in NatsPubOpts? opts =
}
else
{
- return WithConnectAsync(subject, replyTo, data, cancellationToken, static (self, s, r, v, token) =>
+ return WithConnectAsync(subject, replyTo, headers, data, serializer, cancellationToken, static (self, s, r, h, v, ser, token) =>
{
- var command = AsyncPublishCommand.Create(self._pool, self.GetCommandTimer(token), s, r, v, self.Options.Serializer);
+ var command = AsyncPublishCommand.Create(self._pool, self.GetCommandTimer(token), s, r, h, v, ser);
return self.EnqueueAndAwaitCommandAsync(command);
});
}
diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs
index 75cb7ec3f..af53645bf 100644
--- a/src/NATS.Client.Core/NatsConnection.cs
+++ b/src/NATS.Client.Core/NatsConnection.cs
@@ -74,6 +74,7 @@ public NatsConnection(NatsOptions options)
InboxPrefix = Encoding.ASCII.GetBytes($"{options.InboxPrefix}{Guid.NewGuid()}.");
_logger = options.LoggerFactory.CreateLogger();
_clientOptions = new ClientOptions(Options);
+ HeaderParser = new HeaderParser(options.HeaderEncoding);
}
// events
@@ -89,6 +90,8 @@ public NatsConnection(NatsOptions options)
public ServerInfo? ServerInfo { get; internal set; } // server info is set when received INFO
+ internal HeaderParser HeaderParser { get; }
+
///
/// Connect socket and write CONNECT command to nats server.
///
@@ -162,9 +165,9 @@ internal void EnqueuePing(AsyncPingCommand pingCommand)
pingCommand.SetCanceled();
}
- internal ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence buffer)
+ internal ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer)
{
- return _subscriptionManager.PublishToClientHandlersAsync(subject, replyTo, sid, buffer);
+ return _subscriptionManager.PublishToClientHandlersAsync(subject, replyTo, sid, headersBuffer, payloadBuffer);
}
internal void ResetPongCount()
@@ -730,6 +733,18 @@ private async ValueTask WithConnectAsync(T1 item1, T2 item2, T3
await coreAsync(this, item1, item2, item3, item4).ConfigureAwait(false);
}
+ private async ValueTask WithConnectAsync(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, Func coreAsync)
+ {
+ await ConnectAsync().ConfigureAwait(false);
+ await coreAsync(this, item1, item2, item3, item4, item5).ConfigureAwait(false);
+ }
+
+ private async ValueTask WithConnectAsync(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, Func coreAsync)
+ {
+ await ConnectAsync().ConfigureAwait(false);
+ await coreAsync(this, item1, item2, item3, item4, item5, item6).ConfigureAwait(false);
+ }
+
private async ValueTask WithConnectAsync(Func> coreAsync)
{
await ConnectAsync().ConfigureAwait(false);
diff --git a/src/NATS.Client.Core/NatsHeaders.cs b/src/NATS.Client.Core/NatsHeaders.cs
index 82d84b3f4..ff2b4591c 100644
--- a/src/NATS.Client.Core/NatsHeaders.cs
+++ b/src/NATS.Client.Core/NatsHeaders.cs
@@ -21,6 +21,8 @@ public class NatsHeaders : IDictionary
private static readonly IEnumerator> EmptyIEnumeratorType = default(Enumerator);
private static readonly IEnumerator EmptyIEnumerator = default(Enumerator);
+ private int _readonly = 0;
+
///
/// Initializes a new instance of .
///
@@ -117,7 +119,7 @@ StringValues IDictionary.this[string key]
/// Gets a value that indicates whether the is in read-only mode.
///
/// true if the is in read-only mode; otherwise, false.
- public bool IsReadOnly { get; set; }
+ public bool IsReadOnly => Volatile.Read(ref _readonly) == 1;
///
/// Gets the collection of HTTP header names in this instance.
@@ -331,6 +333,8 @@ IEnumerator IEnumerable.GetEnumerator()
return Store.GetEnumerator();
}
+ internal void SetReadOnly() => Interlocked.Exchange(ref _readonly, 1);
+
private void ThrowIfReadOnly()
{
if (IsReadOnly)
diff --git a/src/NATS.Client.Core/NatsMsg.cs b/src/NATS.Client.Core/NatsMsg.cs
index 34921197b..f3b84413e 100644
--- a/src/NATS.Client.Core/NatsMsg.cs
+++ b/src/NATS.Client.Core/NatsMsg.cs
@@ -13,12 +13,8 @@ public abstract record NatsMsgBase(string Subject)
public string? ReplyTo { get; init; }
- // TODO: Implement headers in NatsMsg
- // public NatsHeaders? Headers
- // {
- // get => throw new NotImplementedException();
- // set => throw new NotImplementedException();
- // }
+ public NatsHeaders? Headers { get; init; }
+
public ValueTask ReplyAsync(ReadOnlySequence data = default, in NatsPubOpts? opts = default, CancellationToken cancellationToken = default)
{
CheckReplyPreconditions();
diff --git a/src/NATS.Client.Core/NatsOptions.cs b/src/NATS.Client.Core/NatsOptions.cs
index 776a4ed42..047c05b42 100644
--- a/src/NATS.Client.Core/NatsOptions.cs
+++ b/src/NATS.Client.Core/NatsOptions.cs
@@ -1,3 +1,4 @@
+using System.Text;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
@@ -10,6 +11,7 @@ namespace NATS.Client.Core;
///
///
///
+///
///
///
///
@@ -29,12 +31,14 @@ namespace NATS.Client.Core;
///
///
///
+///
public sealed record NatsOptions
(
string Url,
string Name,
bool Echo,
bool Verbose,
+ bool Headers,
NatsAuthOptions AuthOptions,
TlsOptions TlsOptions,
INatsSerializer Serializer,
@@ -53,13 +57,15 @@ public sealed record NatsOptions
TimeSpan RequestTimeout,
TimeSpan CommandTimeout,
TimeSpan SubscriptionCleanUpInterval,
- int? WriterCommandBufferLimit)
+ int? WriterCommandBufferLimit,
+ Encoding HeaderEncoding)
{
public static readonly NatsOptions Default = new(
Url: "nats://localhost:4222",
Name: "NATS .Net Client",
Echo: true,
Verbose: false,
+ Headers: true,
AuthOptions: NatsAuthOptions.Default,
TlsOptions: TlsOptions.Default,
Serializer: JsonNatsSerializer.Default,
@@ -78,7 +84,8 @@ public sealed record NatsOptions
RequestTimeout: TimeSpan.FromMinutes(1),
CommandTimeout: TimeSpan.FromMinutes(1),
SubscriptionCleanUpInterval: TimeSpan.FromMinutes(5),
- WriterCommandBufferLimit: null);
+ WriterCommandBufferLimit: null,
+ HeaderEncoding: Encoding.ASCII);
internal NatsUri[] GetSeedUris()
{
diff --git a/src/NATS.Client.Core/NatsPubOpts.cs b/src/NATS.Client.Core/NatsPubOpts.cs
index ffcef27c8..dbac18df0 100644
--- a/src/NATS.Client.Core/NatsPubOpts.cs
+++ b/src/NATS.Client.Core/NatsPubOpts.cs
@@ -4,11 +4,7 @@ public readonly record struct NatsPubOpts
{
public string? ReplyTo { get; init; }
- // TODO: Implement headers in NatsPubOpts
- // public NatsHeaders? Headers
- // {
- // get => throw new NotImplementedException();
- // init => throw new NotImplementedException();
- // }
+ public NatsHeaders? Headers { get; init; }
+
public INatsSerializer? Serializer { get; init; }
}
diff --git a/src/NATS.Client.Core/NatsReadProtocolProcessor.cs b/src/NATS.Client.Core/NatsReadProtocolProcessor.cs
index d776ebbde..a2a788caf 100644
--- a/src/NATS.Client.Core/NatsReadProtocolProcessor.cs
+++ b/src/NATS.Client.Core/NatsReadProtocolProcessor.cs
@@ -1,6 +1,7 @@
using System.Buffers;
using System.Buffers.Text;
using System.Collections.Concurrent;
+using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
@@ -200,7 +201,7 @@ private async Task ReadLoopAsync()
buffer = buffer.Slice(buffer.GetPosition(3, positionBeforePayload.Value));
}
- await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, ReadOnlySequence.Empty).ConfigureAwait(false);
+ await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, null, ReadOnlySequence.Empty).ConfigureAwait(false);
}
else
{
@@ -221,9 +222,62 @@ private async Task ReadLoopAsync()
buffer = buffer.Slice(buffer.GetPosition(2, payloadSlice.End)); // payload + \r\n
- await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, payloadSlice).ConfigureAwait(false);
+ await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, null, payloadSlice).ConfigureAwait(false);
}
}
+ else if (code == ServerOpCodes.HMsg)
+ {
+ // https://docs.nats.io/reference/reference-protocols/nats-protocol#hmsg
+ // HMSG [reply-to] <#header bytes> <#total bytes>\r\n[headers]\r\n\r\n[payload]\r\n
+
+ // Find the end of 'HMSG' first message line
+ var positionBeforeNatsHeader = buffer.PositionOf((byte)'\n');
+ if (positionBeforeNatsHeader == null)
+ {
+ _socketReader.AdvanceTo(buffer.Start);
+ buffer = await _socketReader.ReadUntilReceiveNewLineAsync().ConfigureAwait(false);
+ positionBeforeNatsHeader = buffer.PositionOf((byte)'\n')!;
+ }
+
+ var msgHeader = buffer.Slice(0, positionBeforeNatsHeader.Value);
+ var (subject, sid, replyTo, headersLength, totalLength) = ParseHMessageHeader(msgHeader);
+ var payloadLength = totalLength - headersLength;
+ Debug.Assert(payloadLength >= 0, "Protocol error: illogical header and total lengths");
+
+ var headerBegin = buffer.GetPosition(1, positionBeforeNatsHeader.Value);
+ var totalSlice = buffer.Slice(headerBegin);
+
+ // Read rest of the message if it's not already in the buffer
+ if (totalSlice.Length < totalLength + 2)
+ {
+ _socketReader.AdvanceTo(headerBegin);
+
+ // Read headers + payload + \r\n
+ var size = totalLength - (int)totalSlice.Length + 2;
+ buffer = await _socketReader.ReadAtLeastAsync(size).ConfigureAwait(false);
+ totalSlice = buffer.Slice(0, totalLength);
+ }
+ else
+ {
+ totalSlice = totalSlice.Slice(0, totalLength);
+ }
+
+ // Prepare buffer for the next message by removing 'headers + payload + \r\n' from it
+ buffer = buffer.Slice(buffer.GetPosition(2, totalSlice.End));
+
+ var versionLength = CommandConstants.NatsHeaders10NewLine.Length;
+ var versionSlice = totalSlice.Slice(0, versionLength);
+ if (!versionSlice.ToSpan().SequenceEqual(CommandConstants.NatsHeaders10NewLine))
+ {
+ throw new NatsException("Protocol error: header version mismatch");
+ }
+
+ var headerSlice = totalSlice.Slice(versionLength, headersLength - versionLength);
+ var payloadSlice = totalSlice.Slice(headersLength, payloadLength);
+
+ await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, headerSlice, payloadSlice)
+ .ConfigureAwait(false);
+ }
else
{
buffer = await DispatchCommandAsync(code, buffer).ConfigureAwait(false);
@@ -436,11 +490,65 @@ private async ValueTask> DispatchCommandAsync(int code, R
return ParseMessageHeader(buffer);
}
+ // https://docs.nats.io/reference/reference-protocols/nats-protocol#hmsg
+ // HMSG [reply-to] <#header bytes> <#total bytes>\r\n[headers]\r\n\r\n[payload]\r\n
+ private (string subject, int sid, string? replyTo, int headersLength, int totalLength) ParseHMessageHeader(ReadOnlySpan msgHeader)
+ {
+ // 'HMSG' literal
+ Split(msgHeader, out _, out msgHeader);
+
+ Split(msgHeader, out var subjectBytes, out msgHeader);
+ Split(msgHeader, out var sidBytes, out msgHeader);
+ Split(msgHeader, out var replyToOrHeaderLenBytes, out msgHeader);
+ Split(msgHeader, out var headerLenOrTotalLenBytes, out msgHeader);
+
+ var subject = Encoding.ASCII.GetString(subjectBytes);
+ var sid = GetInt32(sidBytes);
+
+ // We don't have the optional reply-to field
+ if (msgHeader.Length == 0)
+ {
+ var headersLength = GetInt32(replyToOrHeaderLenBytes);
+ var totalLen = GetInt32(headerLenOrTotalLenBytes);
+ return (subject, sid, null, headersLength, totalLen);
+ }
+
+ // There is more data because of the reply-to field
+ else
+ {
+ var replyToBytes = replyToOrHeaderLenBytes;
+ var replyTo = Encoding.ASCII.GetString(replyToBytes);
+
+ var headerLen = GetInt32(headerLenOrTotalLenBytes);
+
+ var lastBytes = msgHeader;
+ var totalLen = GetInt32(lastBytes);
+
+ return (subject, sid, replyTo, headerLen, totalLen);
+ }
+ }
+
+ private (string subject, int sid, string? replyTo, int headersLength, int totalLength) ParseHMessageHeader(in ReadOnlySequence msgHeader)
+ {
+ if (msgHeader.IsSingleSegment)
+ {
+ return ParseHMessageHeader(msgHeader.FirstSpan);
+ }
+
+ // header parsing use Slice frequently so ReadOnlySequence is high cost, should use Span.
+ // msgheader is not too long, ok to use stackalloc.
+ // TODO: Fix possible stack overflow
+ Span buffer = stackalloc byte[(int)msgHeader.Length];
+ msgHeader.CopyTo(buffer);
+ return ParseHMessageHeader(buffer);
+ }
+
internal static class ServerOpCodes
{
// All sent by server commands as int(first 4 characters(includes space, newline)).
public const int Info = 1330007625; // Encoding.ASCII.GetBytes("INFO") |> MemoryMarshal.Read
public const int Msg = 541545293; // Encoding.ASCII.GetBytes("MSG ") |> MemoryMarshal.Read
+ public const int HMsg = 1196641608; // Encoding.ASCII.GetBytes("HMSG") |> MemoryMarshal.Read
public const int Ping = 1196312912; // Encoding.ASCII.GetBytes("PING") |> MemoryMarshal.Read
public const int Pong = 1196314448; // Encoding.ASCII.GetBytes("PONG") |> MemoryMarshal.Read
public const int Ok = 223039275; // Encoding.ASCII.GetBytes("+OK\r") |> MemoryMarshal.Read
diff --git a/src/NATS.Client.Core/NatsSub.cs b/src/NATS.Client.Core/NatsSub.cs
index b7be73839..77cd2d1b3 100644
--- a/src/NATS.Client.Core/NatsSub.cs
+++ b/src/NATS.Client.Core/NatsSub.cs
@@ -1,5 +1,7 @@
using System.Buffers;
+using System.Text;
using System.Threading.Channels;
+using NATS.Client.Core.Internal;
namespace NATS.Client.Core;
@@ -29,7 +31,7 @@ public virtual ValueTask DisposeAsync()
return Manager.RemoveAsync(Sid);
}
- internal abstract ValueTask ReceiveAsync(string subject, string? replyTo, ReadOnlySequence buffer);
+ internal abstract ValueTask ReceiveAsync(string subject, string? replyTo, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer);
}
public sealed class NatsSub : NatsSubBase
@@ -55,12 +57,25 @@ public override ValueTask DisposeAsync()
return base.DisposeAsync();
}
- internal override ValueTask ReceiveAsync(string subject, string? replyTo, ReadOnlySequence buffer)
+ internal override ValueTask ReceiveAsync(string subject, string? replyTo, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer)
{
- return _msgs.Writer.WriteAsync(new NatsMsg(subject, buffer.ToArray())
+ NatsHeaders? natsHeaders = null;
+ if (headersBuffer != null)
+ {
+ natsHeaders = new NatsHeaders();
+ if (!Connection.HeaderParser.ParseHeaders(new SequenceReader(headersBuffer.Value), natsHeaders))
+ {
+ throw new NatsException("Error parsing headers");
+ }
+
+ natsHeaders.SetReadOnly();
+ }
+
+ return _msgs.Writer.WriteAsync(new NatsMsg(subject, payloadBuffer.ToArray())
{
Connection = Connection,
ReplyTo = replyTo,
+ Headers = natsHeaders,
});
}
}
@@ -88,14 +103,28 @@ public override ValueTask DisposeAsync()
return base.DisposeAsync();
}
- internal override ValueTask ReceiveAsync(string subject, string? replyTo, ReadOnlySequence buffer)
+ internal override ValueTask ReceiveAsync(string subject, string? replyTo, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer)
{
var serializer = Serializer;
- var data = serializer.Deserialize(buffer);
+ var data = serializer.Deserialize(payloadBuffer);
+
+ NatsHeaders? natsHeaders = null;
+ if (headersBuffer != null)
+ {
+ natsHeaders = new NatsHeaders();
+ if (!Connection.HeaderParser.ParseHeaders(new SequenceReader(headersBuffer.Value), natsHeaders))
+ {
+ throw new NatsException("Error parsing headers");
+ }
+
+ natsHeaders.SetReadOnly();
+ }
+
return _msgs.Writer.WriteAsync(new NatsMsg(subject, data!)
{
Connection = Connection,
ReplyTo = replyTo,
+ Headers = natsHeaders,
});
}
}
diff --git a/src/NATS.Client.Core/SubscriptionManager.cs b/src/NATS.Client.Core/SubscriptionManager.cs
index 3049a73d0..50d5b8b97 100644
--- a/src/NATS.Client.Core/SubscriptionManager.cs
+++ b/src/NATS.Client.Core/SubscriptionManager.cs
@@ -56,7 +56,7 @@ public async ValueTask> AddAsync(string subject, string? queueGrou
return sub;
}
- public ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence buffer)
+ public ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer)
{
int? orphanSid = null;
lock (_gate)
@@ -65,7 +65,7 @@ public ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, i
{
if (subRef.TryGetTarget(out var sub))
{
- return sub.ReceiveAsync(subject, replyTo, buffer);
+ return sub.ReceiveAsync(subject, replyTo, headersBuffer, payloadBuffer);
}
else
{
diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs
index 7714d4235..5275bed9d 100644
--- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs
+++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs
@@ -117,7 +117,7 @@ public async Task UserCredentialAuthTest(string name, string serverConfig, NatsO
var signalComplete2 = new WaitSignal();
var natsSub = await subConnection.SubscribeAsync(subject);
- natsSub.Register(x =>
+ var register = natsSub.Register(x =>
{
_output.WriteLine($"Received: {x}");
if (x.Data == 1)
@@ -148,34 +148,35 @@ public async Task UserCredentialAuthTest(string name, string serverConfig, NatsO
_output.WriteLine("AUTHENTICATED RE-CONNECTION");
await pubConnection.PublishAsync(subject, 2);
await signalComplete2;
+
+ await natsSub.DisposeAsync();
+ await register;
}
}
internal static class NatsMsgTestUtils
{
- internal static NatsSub? Register(this NatsSub? sub, Action> action)
+ internal static Task Register(this NatsSub? sub, Action> action)
{
- if (sub == null) return null;
- Task.Run(async () =>
+ if (sub == null) return Task.CompletedTask;
+ return Task.Run(async () =>
{
await foreach (var natsMsg in sub.Msgs.ReadAllAsync())
{
action(natsMsg);
}
});
- return sub;
}
- internal static NatsSub? Register(this NatsSub? sub, Action action)
+ internal static Task Register(this NatsSub? sub, Action action)
{
- if (sub == null) return null;
- Task.Run(async () =>
+ if (sub == null) return Task.CompletedTask;
+ return Task.Run(async () =>
{
await foreach (var natsMsg in sub.Msgs.ReadAllAsync())
{
action(natsMsg);
}
});
- return sub;
}
}
diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Headers.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Headers.cs
new file mode 100644
index 000000000..e00a3424e
--- /dev/null
+++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Headers.cs
@@ -0,0 +1,76 @@
+namespace NATS.Client.Core.Tests;
+
+public abstract partial class NatsConnectionTest
+{
+ [Fact]
+ public async Task HeaderParsingTest()
+ {
+ await using var server = new NatsServer(_output, _transportType);
+
+ await using var nats = server.CreateClientConnection();
+
+ var sync = 0;
+ var signal1 = new WaitSignal>();
+ var signal2 = new WaitSignal>();
+ var sub = await nats.SubscribeAsync("foo");
+ var reg = sub.Register(m =>
+ {
+ if (m.Data < 10)
+ {
+ Interlocked.Exchange(ref sync, m.Data);
+ return;
+ }
+
+ if (m.Data == 100)
+ signal1.Pulse(m);
+ if (m.Data == 200)
+ signal2.Pulse(m);
+ });
+
+ await Retry.Until(
+ "subscription is active",
+ () => Volatile.Read(ref sync) == 1,
+ async () => await nats.PublishAsync("foo", 1));
+
+ var headers = new NatsHeaders
+ {
+ ["Test-Header-Key"] = "test-header-value",
+ ["Multi"] = new[] { "multi-value-0", "multi-value-1" },
+ };
+ Assert.False(headers.IsReadOnly);
+
+ // Send with headers
+ await nats.PublishAsync("foo", 100, new NatsPubOpts { Headers = headers });
+
+ Assert.True(headers.IsReadOnly);
+ Assert.Throws(() =>
+ {
+ headers["should-not-set"] = "value";
+ });
+
+ var msg1 = await signal1;
+ Assert.Equal(100, msg1.Data);
+ Assert.NotNull(msg1.Headers);
+ Assert.Equal(2, msg1.Headers!.Count);
+
+ Assert.True(msg1.Headers!.ContainsKey("Test-Header-Key"));
+ Assert.Single(msg1.Headers["Test-Header-Key"].ToArray());
+ Assert.Equal("test-header-value", msg1.Headers["Test-Header-Key"]);
+
+ Assert.True(msg1.Headers!.ContainsKey("Multi"));
+ Assert.Equal(2, msg1.Headers["Multi"].Count);
+ Assert.Equal("multi-value-0", msg1.Headers["Multi"][0]);
+ Assert.Equal("multi-value-1", msg1.Headers["Multi"][1]);
+
+ // Send empty headers
+ await nats.PublishAsync("foo", 200, new NatsPubOpts { Headers = new NatsHeaders() });
+
+ var msg2 = await signal2;
+ Assert.Equal(200, msg2.Data);
+ Assert.NotNull(msg2.Headers);
+ Assert.Empty(msg2.Headers!);
+
+ await sub.DisposeAsync();
+ await reg;
+ }
+}
diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs
index 22a01c1cb..6a785ff50 100644
--- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs
+++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs
@@ -5,7 +5,8 @@ public abstract partial class NatsConnectionTest
[Fact]
public async Task QueueGroupsTest()
{
- const int messageCount = 20;
+ // Use high enough count to create some distribution among subscribers.
+ const int messageCount = 100;
await using var server = new NatsServer(_output, _transportType);
@@ -21,12 +22,19 @@ public async Task QueueGroupsTest()
cts.Token.Register(() => signal.Pulse());
var count = 0;
+ var sync1 = 0;
var messages1 = new List();
var reader1 = Task.Run(
async () =>
{
await foreach (var msg in sub1.Msgs.ReadAllAsync(cts.Token))
{
+ if (msg.Subject == "foo.sync")
+ {
+ Interlocked.Exchange(ref sync1, 1);
+ continue;
+ }
+
Assert.Equal($"foo.xyz{msg.Data}", msg.Subject);
lock (messages1) messages1.Add(msg.Data);
var total = Interlocked.Increment(ref count);
@@ -35,12 +43,19 @@ public async Task QueueGroupsTest()
},
cts.Token);
+ var sync2 = 0;
var messages2 = new List();
var reader2 = Task.Run(
async () =>
{
await foreach (var msg in sub2.Msgs.ReadAllAsync(cts.Token))
{
+ if (msg.Subject == "foo.sync")
+ {
+ Interlocked.Exchange(ref sync2, 1);
+ continue;
+ }
+
Assert.Equal($"foo.xyz{msg.Data}", msg.Subject);
lock (messages2) messages2.Add(msg.Data);
var total = Interlocked.Increment(ref count);
@@ -49,8 +64,10 @@ public async Task QueueGroupsTest()
},
cts.Token);
- await conn1.PingAsync();
- await conn2.PingAsync();
+ await Retry.Until(
+ "subscriptions are active",
+ () => Volatile.Read(ref sync1) + Volatile.Read(ref sync2) == 2,
+ async () => await conn3.PublishAsync("foo.sync", 0));
for (int i = 0; i < messageCount; i++)
{
diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs
index 1b835541a..230036f19 100644
--- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs
+++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs
@@ -35,9 +35,12 @@ public async Task ShardingConnectionTest()
var l1 = new List();
var l2 = new List();
var l3 = new List();
- (await shardedConnection.GetCommand("foo").SubscribeAsync()).Register(msg => l1.Add(msg.Data));
- (await shardedConnection.GetCommand("bar").SubscribeAsync()).Register(msg => l2.Add(msg.Data));
- (await shardedConnection.GetCommand("baz").SubscribeAsync()).Register(msg => l3.Add(msg.Data));
+ var sub1 = await shardedConnection.GetCommand("foo").SubscribeAsync();
+ var reg1 = sub1.Register(msg => l1.Add(msg.Data));
+ var sub2 = await shardedConnection.GetCommand("bar").SubscribeAsync();
+ var reg2 = sub2.Register(msg => l2.Add(msg.Data));
+ var sub3 = await shardedConnection.GetCommand("baz").SubscribeAsync();
+ var reg3 = sub3.Register(msg => l3.Add(msg.Data));
await shardedConnection.GetCommand("foo").PublishAsync(10);
await shardedConnection.GetCommand("bar").PublishAsync(20);
@@ -54,5 +57,12 @@ public async Task ShardingConnectionTest()
var r = await shardedConnection.GetCommand("foobarbaz").RequestAsync(100);
r.ShouldBe(10000);
+
+ await sub1.DisposeAsync();
+ await reg1;
+ await sub2.DisposeAsync();
+ await reg2;
+ await sub3.DisposeAsync();
+ await reg3;
}
}
diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs
index d3b714faa..6916c930d 100644
--- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs
+++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs
@@ -28,8 +28,8 @@ public async Task SimplePubSubTest()
var signalComplete = new WaitSignal();
var list = new List();
- await using var sub = await subConnection.SubscribeAsync(subject);
- sub.Register(x =>
+ var sub = await subConnection.SubscribeAsync(subject);
+ var register = sub.Register(x =>
{
_output.WriteLine($"Received: {x.Data}");
list.Add(x.Data);
@@ -46,6 +46,8 @@ public async Task SimplePubSubTest()
}
await signalComplete;
+ await sub.DisposeAsync();
+ await register;
list.ShouldEqual(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
}
@@ -67,7 +69,8 @@ public async Task EncodingTest()
var actual = new List();
var signalComplete = new WaitSignal();
- await using var d = (await subConnection.SubscribeAsync(key)).Register(x =>
+ var sub = await subConnection.SubscribeAsync(key);
+ var register = sub.Register(x =>
{
actual.Add(x.Data);
if (x.Data.Id == 30)
@@ -83,6 +86,8 @@ public async Task EncodingTest()
await pubConnection.PublishAsync(key, three);
await signalComplete;
+ await sub.DisposeAsync();
+ await register;
actual.ShouldEqual(new[] { one, two, three });
}
@@ -102,14 +107,25 @@ public async Task RequestTest(int minSize)
var subject = Guid.NewGuid().ToString();
var text = new StringBuilder(minSize).Insert(0, "a", minSize).ToString();
+ var sync = 0;
await using var replyHandle = await subConnection.ReplyAsync(subject, x =>
{
+ if (x < 10)
+ {
+ Interlocked.Exchange(ref sync, x);
+ return "sync";
+ }
+
if (x == 100)
throw new Exception();
return text + x;
});
- await Task.Delay(1000);
+ await Retry.Until(
+ "reply handle is ready",
+ () => Volatile.Read(ref sync) == 1,
+ async () => await pubConnection.PublishAsync(subject, 1, new NatsPubOpts { ReplyTo = "ignore" }),
+ retryDelay: TimeSpan.FromSeconds(1));
var v = await pubConnection.RequestAsync(subject, 9999);
v.Should().Be(text + 9999);
@@ -137,7 +153,7 @@ public async Task ReconnectSingleTest()
ServerDisposeReturnsPorts = false,
};
await using var server = new NatsServer(_output, _transportType, options);
- var key = Guid.NewGuid().ToString();
+ var subject = Guid.NewGuid().ToString();
await using var subConnection = server.CreateClientConnection();
await using var pubConnection = server.CreateClientConnection();
@@ -145,10 +161,18 @@ public async Task ReconnectSingleTest()
await pubConnection.ConnectAsync(); // wait open
var list = new List();
+ var sync = 0;
var waitForReceive300 = new WaitSignal();
var waitForReceiveFinish = new WaitSignal();
- var d = (await subConnection.SubscribeAsync(key)).Register(x =>
+ var sub = await subConnection.SubscribeAsync(subject);
+ var reg = sub.Register(x =>
{
+ if (x.Data < 10)
+ {
+ Interlocked.Exchange(ref sync, x.Data);
+ return;
+ }
+
_output.WriteLine("RECEIVED: " + x.Data);
list.Add(x.Data);
if (x.Data == 300)
@@ -161,11 +185,15 @@ public async Task ReconnectSingleTest()
waitForReceiveFinish.Pulse();
}
});
- await subConnection.PingAsync(); // wait for subscribe complete
- await pubConnection.PublishAsync(key, 100);
- await pubConnection.PublishAsync(key, 200);
- await pubConnection.PublishAsync(key, 300);
+ await Retry.Until(
+ "subscription is active (1)",
+ () => Volatile.Read(ref sync) == 1,
+ async () => await pubConnection.PublishAsync(subject, 1));
+
+ await pubConnection.PublishAsync(subject, 100);
+ await pubConnection.PublishAsync(subject, 200);
+ await pubConnection.PublishAsync(subject, 300);
_output.WriteLine("TRY WAIT RECEIVE 300");
await waitForReceive300;
@@ -184,11 +212,19 @@ public async Task ReconnectSingleTest()
await subConnection.ConnectAsync(); // wait open again
await pubConnection.ConnectAsync(); // wait open again
+ await Retry.Until(
+ "subscription is active (2)",
+ () => Volatile.Read(ref sync) == 2,
+ async () => await pubConnection.PublishAsync(subject, 2));
+
_output.WriteLine("RECONNECT COMPLETE, PUBLISH 400 and 500");
- await pubConnection.PublishAsync(key, 400);
- await pubConnection.PublishAsync(key, 500);
+ await pubConnection.PublishAsync(subject, 400);
+ await pubConnection.PublishAsync(subject, 500);
await waitForReceiveFinish;
+ await sub.DisposeAsync();
+ await reg;
+
list.ShouldEqual(100, 200, 300, 400, 500);
}
@@ -198,7 +234,7 @@ public async Task ReconnectClusterTest()
await using var cluster = new NatsCluster(_output, _transportType);
await Task.Delay(TimeSpan.FromSeconds(5)); // wait for cluster completely connected.
- var key = Guid.NewGuid().ToString();
+ var subject = Guid.NewGuid().ToString();
await using var connection1 = cluster.Server1.CreateClientConnection();
await using var connection2 = cluster.Server2.CreateClientConnection();
@@ -220,10 +256,18 @@ public async Task ReconnectClusterTest()
connection3.ServerInfo!.ClientConnectUrls!.Select(x => new NatsUri(x, true).Port).Distinct().Count().ShouldBe(3);
var list = new List();
+ var sync = 0;
var waitForReceive300 = new WaitSignal();
var waitForReceiveFinish = new WaitSignal();
- var d = (await connection1.SubscribeAsync(key)).Register(x =>
+ var sub = await connection1.SubscribeAsync(subject);
+ var reg = sub.Register(x =>
{
+ if (x.Data < 10)
+ {
+ Interlocked.Exchange(ref sync, x.Data);
+ return;
+ }
+
_output.WriteLine("RECEIVED: " + x.Data);
list.Add(x.Data);
if (x.Data == 300)
@@ -236,11 +280,16 @@ public async Task ReconnectClusterTest()
waitForReceiveFinish.Pulse();
}
});
- await connection1.PingAsync(); // wait for subscribe complete
- await connection2.PublishAsync(key, 100);
- await connection2.PublishAsync(key, 200);
- await connection2.PublishAsync(key, 300);
+ await Retry.Until(
+ "subscription is active (1)",
+ () => Volatile.Read(ref sync) == 1,
+ async () => await connection2.PublishAsync(subject, 1),
+ retryDelay: TimeSpan.FromSeconds(.5));
+
+ await connection2.PublishAsync(subject, 100);
+ await connection2.PublishAsync(subject, 200);
+ await connection2.PublishAsync(subject, 300);
await waitForReceive300;
var disconnectSignal = connection1.ConnectionDisconnectedAsAwaitable(); // register disconnect before kill
@@ -249,14 +298,25 @@ public async Task ReconnectClusterTest()
await cluster.Server1.DisposeAsync(); // process kill
await disconnectSignal;
+ Net.WaitForTcpPortToClose(cluster.Server1.ConnectionPort);
+
await connection1.ConnectAsync(); // wait for reconnect complete.
- connection1.ServerInfo!.Port.Should()
- .BeOneOf(cluster.Server2.Options.ServerPort, cluster.Server3.Options.ServerPort);
+ connection1.ServerInfo!.Port.Should().BeOneOf(cluster.Server2.ConnectionPort, cluster.Server3.ConnectionPort);
+
+ await Retry.Until(
+ "subscription is active (2)",
+ () => Volatile.Read(ref sync) == 2,
+ async () => await connection2.PublishAsync(subject, 2),
+ retryDelay: TimeSpan.FromSeconds(.5));
- await connection2.PublishAsync(key, 400);
- await connection2.PublishAsync(key, 500);
+ await connection2.PublishAsync(subject, 400);
+ await connection2.PublishAsync(subject, 500);
await waitForReceiveFinish;
+
+ await sub.DisposeAsync();
+ await reg;
+
list.ShouldEqual(100, 200, 300, 400, 500);
}
}
diff --git a/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs b/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs
new file mode 100644
index 000000000..63a9b2b20
--- /dev/null
+++ b/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs
@@ -0,0 +1,69 @@
+using System.Buffers;
+using System.Text;
+
+namespace NATS.Client.Core.Tests;
+
+public class NatsHeaderTest
+{
+ private readonly ITestOutputHelper _output;
+
+ public NatsHeaderTest(ITestOutputHelper output) => _output = output;
+
+ [Fact]
+ public void WriterTests()
+ {
+ var headers = new NatsHeaders
+ {
+ ["k1"] = "v1",
+ ["k2"] = new[] { "v2-0", "v2-1" },
+ ["a-long-header-key"] = "value",
+ ["key"] = "a-long-header-value",
+ };
+ var writer = new HeaderWriter(Encoding.UTF8);
+ var buffer = new FixedArrayBufferWriter();
+ var written = writer.Write(buffer, headers);
+
+ var text = "k1: v1\r\nk2: v2-0\r\nk2: v2-1\r\na-long-header-key: value\r\nkey: a-long-header-value\r\n\r\n";
+ var expected = new Span(Encoding.UTF8.GetBytes(text));
+
+ Assert.Equal(expected.Length, written);
+ Assert.True(expected.SequenceEqual(buffer.WrittenSpan));
+
+#if DEBUG
+ _output.WriteLine($"Buffer:\n{buffer.WrittenSpan.Dump()}");
+#endif
+ }
+
+ [Fact]
+ public void ParserTests()
+ {
+ var parser = new HeaderParser(Encoding.UTF8);
+ var text = "k1: v1\r\nk2: v2-0\r\nk2: v2-1\r\na-long-header-key: value\r\nkey: a-long-header-value\r\n\r\n";
+ var input = new SequenceReader(new ReadOnlySequence(Encoding.UTF8.GetBytes(text)));
+ var headers = new NatsHeaders();
+ parser.ParseHeaders(input, headers);
+
+#if DEBUG
+ _output.WriteLine($"Headers:\n{headers.Dump()}");
+#endif
+
+ Assert.Equal(4, headers.Count);
+
+ Assert.True(headers.ContainsKey("k1"));
+ Assert.Single(headers["k1"].ToArray());
+ Assert.Equal("v1", headers["k1"]);
+
+ Assert.True(headers.ContainsKey("k2"));
+ Assert.Equal(2, headers["k2"].ToArray().Length);
+ Assert.Equal("v2-0", headers["k2"][0]);
+ Assert.Equal("v2-1", headers["k2"][1]);
+
+ Assert.True(headers.ContainsKey("a-long-header-key"));
+ Assert.Single(headers["a-long-header-key"].ToArray());
+ Assert.Equal("value", headers["a-long-header-key"]);
+
+ Assert.True(headers.ContainsKey("key"));
+ Assert.Single(headers["key"].ToArray());
+ Assert.Equal("a-long-header-value", headers["key"]);
+ }
+}
diff --git a/tests/NATS.Client.Core.Tests/ProtocolTest.cs b/tests/NATS.Client.Core.Tests/ProtocolTest.cs
new file mode 100644
index 000000000..3b948646c
--- /dev/null
+++ b/tests/NATS.Client.Core.Tests/ProtocolTest.cs
@@ -0,0 +1,156 @@
+namespace NATS.Client.Core.Tests;
+
+public class ProtocolTest
+{
+ private readonly ITestOutputHelper _output;
+
+ public ProtocolTest(ITestOutputHelper output) => _output = output;
+
+ [Fact]
+ public async Task Subscription_with_same_subject()
+ {
+ await using var server = new NatsServer(_output, TransportType.Tcp);
+ var nats1 = server.CreateClientConnection();
+ var (nats2, proxy) = server.CreateProxiedClientConnection();
+
+ var sub1 = await nats2.SubscribeAsync("foo.bar");
+ var sub2 = await nats2.SubscribeAsync("foo.bar");
+ var sub3 = await nats2.SubscribeAsync("foo.baz");
+
+ var sync1 = 0;
+ var sync2 = 0;
+ var sync3 = 0;
+ var count = new WaitSignal(3);
+
+ var reg1 = sub1.Register(m =>
+ {
+ if (m.Data == 0)
+ {
+ Interlocked.Exchange(ref sync1, 1);
+ return;
+ }
+
+ count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}"));
+ });
+
+ var reg2 = sub2.Register(m =>
+ {
+ if (m.Data == 0)
+ {
+ Interlocked.Exchange(ref sync2, 1);
+ return;
+ }
+
+ count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}"));
+ });
+
+ var reg3 = sub3.Register(m =>
+ {
+ if (m.Data == 0)
+ {
+ Interlocked.Exchange(ref sync3, 1);
+ return;
+ }
+
+ count.Pulse(m.Subject == "foo.baz" ? null : new Exception($"Subject mismatch {m.Subject}"));
+ });
+
+ // Since subscription and publishing are sent through different connections there is
+ // a race where one or more subscriptions are made after the publishing happens.
+ // So, we make sure subscribers are accepted by the server before we send any test data.
+ await Retry.Until(
+ "all subscriptions are active",
+ () => Volatile.Read(ref sync1) + Volatile.Read(ref sync2) + Volatile.Read(ref sync3) == 3,
+ async () =>
+ {
+ await nats1.PublishAsync("foo.bar", 0);
+ await nats1.PublishAsync("foo.baz", 0);
+ });
+
+ await nats1.PublishAsync("foo.bar", 1);
+ await nats1.PublishAsync("foo.baz", 1);
+
+ // Wait until we received all test data
+ await count;
+
+ var frames = proxy.ClientFrames.OrderBy(f => f.Message).ToList();
+
+ foreach (var frame in frames)
+ {
+ _output.WriteLine($"[PROXY] {frame}");
+ }
+
+ Assert.Equal(3, frames.Count);
+ Assert.StartsWith("SUB foo.bar", frames[0].Message);
+ Assert.StartsWith("SUB foo.bar", frames[1].Message);
+ Assert.StartsWith("SUB foo.baz", frames[2].Message);
+ Assert.False(frames[0].Message.Equals(frames[1].Message), "Should have different SIDs");
+
+ await sub1.DisposeAsync();
+ await reg1;
+ await sub2.DisposeAsync();
+ await reg2;
+ await sub3.DisposeAsync();
+ await reg3;
+ await nats1.DisposeAsync();
+ await nats2.DisposeAsync();
+ proxy.Dispose();
+ }
+
+ [Fact]
+ public async Task Publish_empty_message_for_notifications()
+ {
+ await using var server = new NatsServer(_output, TransportType.Tcp);
+ var (nats, proxy) = server.CreateProxiedClientConnection();
+
+ var sync = 0;
+ var signal1 = new WaitSignal();
+ var signal2 = new WaitSignal();
+ var sub = await nats.SubscribeAsync("foo.*");
+ var reg = sub.Register(m =>
+ {
+ switch (m.Subject)
+ {
+ case "foo.sync":
+ Interlocked.Exchange(ref sync, 1);
+ break;
+ case "foo.signal1":
+ signal1.Pulse(m);
+ break;
+ case "foo.signal2":
+ signal2.Pulse(m);
+ break;
+ }
+ });
+
+ await Retry.Until(
+ "subscription is active",
+ () => Volatile.Read(ref sync) == 1,
+ async () => await nats.PublishAsync("foo.sync"),
+ retryDelay: TimeSpan.FromSeconds(1));
+
+ // PUB notifications
+ await nats.PublishAsync("foo.signal1");
+ var msg1 = await signal1;
+ Assert.Equal(0, msg1.Data.Length);
+ Assert.Null(msg1.Headers);
+ var pubFrame1 = proxy.Frames.First(f => f.Message.StartsWith("PUB foo.signal1"));
+ Assert.Equal("PUB foo.signal1 0␍␊", pubFrame1.Message);
+ var msgFrame1 = proxy.Frames.First(f => f.Message.StartsWith("MSG foo.signal1"));
+ Assert.Matches(@"^MSG foo.signal1 \w+ 0␍␊$", msgFrame1.Message);
+
+ // HPUB notifications
+ await nats.PublishAsync("foo.signal2", opts: new NatsPubOpts { Headers = new NatsHeaders() });
+ var msg2 = await signal2;
+ Assert.Equal(0, msg2.Data.Length);
+ Assert.NotNull(msg2.Headers);
+ Assert.Empty(msg2.Headers!);
+ var pubFrame2 = proxy.Frames.First(f => f.Message.StartsWith("HPUB foo.signal2"));
+ Assert.Equal("HPUB foo.signal2 12 12␍␊NATS/1.0␍␊␍␊", pubFrame2.Message);
+ var msgFrame2 = proxy.Frames.First(f => f.Message.StartsWith("HMSG foo.signal2"));
+ Assert.Matches(@"^HMSG foo.signal2 \w+ 12 12␍␊NATS/1.0␍␊␍␊$", msgFrame2.Message);
+
+ await sub.DisposeAsync();
+ await reg;
+ }
+}
diff --git a/tests/NATS.Client.Core.Tests/SubscriptionTest.cs b/tests/NATS.Client.Core.Tests/SubscriptionTest.cs
index 1037e7245..fcaf5cd5a 100644
--- a/tests/NATS.Client.Core.Tests/SubscriptionTest.cs
+++ b/tests/NATS.Client.Core.Tests/SubscriptionTest.cs
@@ -1,5 +1,3 @@
-using System.Diagnostics;
-
namespace NATS.Client.Core.Tests;
public class SubscriptionTest
@@ -8,94 +6,6 @@ public class SubscriptionTest
public SubscriptionTest(ITestOutputHelper output) => _output = output;
- [Fact]
- public async Task Subscription_with_same_subject()
- {
- await using var server = new NatsServer(_output, TransportType.Tcp);
- var nats1 = server.CreateClientConnection();
- var (nats2, proxy) = server.CreateProxiedClientConnection();
-
- var sub1 = await nats2.SubscribeAsync("foo.bar");
- var sub2 = await nats2.SubscribeAsync("foo.bar");
- var sub3 = await nats2.SubscribeAsync("foo.baz");
-
- var sync1 = 0;
- var sync2 = 0;
- var sync3 = 0;
- var count = new WaitSignal(3);
-
- sub1.Register(m =>
- {
- if (m.Data == 0)
- {
- Interlocked.Exchange(ref sync1, 1);
- return;
- }
-
- count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}"));
- });
-
- sub2.Register(m =>
- {
- if (m.Data == 0)
- {
- Interlocked.Exchange(ref sync2, 1);
- return;
- }
-
- count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}"));
- });
-
- sub3.Register(m =>
- {
- if (m.Data == 0)
- {
- Interlocked.Exchange(ref sync3, 1);
- return;
- }
-
- count.Pulse(m.Subject == "foo.baz" ? null : new Exception($"Subject mismatch {m.Subject}"));
- });
-
- // Since subscription and publishing are sent through different connections there is
- // a race where one or more subscriptions are made after the publishing happens.
- // So, we make sure subscribers are accepted by the server before we send any test data.
- await RetryUntil(
- "all subscriptions are active",
- () => Volatile.Read(ref sync1) + Volatile.Read(ref sync2) + Volatile.Read(ref sync3) == 3,
- async () =>
- {
- await nats1.PublishAsync("foo.bar", 0);
- await nats1.PublishAsync("foo.baz", 0);
- });
-
- await nats1.PublishAsync("foo.bar", 1);
- await nats1.PublishAsync("foo.baz", 1);
-
- // Wait until we received all test data
- await count;
-
- var frames = proxy.ClientFrames.OrderBy(f => f.Message).ToList();
-
- foreach (var frame in frames)
- {
- _output.WriteLine($"[PROXY] {frame}");
- }
-
- Assert.Equal(3, frames.Count);
- Assert.StartsWith("SUB foo.bar", frames[0].Message);
- Assert.StartsWith("SUB foo.bar", frames[1].Message);
- Assert.StartsWith("SUB foo.baz", frames[2].Message);
- Assert.False(frames[0].Message.Equals(frames[1].Message), "Should have different SIDs");
-
- await sub1.DisposeAsync();
- await sub2.DisposeAsync();
- await sub3.DisposeAsync();
- await nats1.DisposeAsync();
- await nats2.DisposeAsync();
- proxy.Dispose();
- }
-
[Fact]
public async Task Subscription_periodic_cleanup_test()
{
@@ -107,7 +17,7 @@ async Task Isolator()
{
var sub = await nats.SubscribeAsync("foo");
- await RetryUntil(
+ await Retry.Until(
"unsubscribed",
() => proxy.ClientFrames.Count(f => f.Message.StartsWith("SUB")) == 1);
@@ -119,7 +29,7 @@ await RetryUntil(
GC.Collect();
- await RetryUntil(
+ await Retry.Until(
"unsubscribe message received",
() => proxy.ClientFrames.Count(f => f.Message.StartsWith("UNSUB")) == 1);
}
@@ -137,7 +47,7 @@ async Task Isolator()
{
var sub = await nats.SubscribeAsync("foo");
- await RetryUntil("unsubscribed", () => proxy.ClientFrames.Count(f => f.Message.StartsWith("SUB")) == 1);
+ await Retry.Until("unsubscribed", () => proxy.ClientFrames.Count(f => f.Message.StartsWith("SUB")) == 1);
// subscription object will be eligible for GC after next statement
Assert.Equal("foo", sub.Subject);
@@ -148,25 +58,9 @@ async Task Isolator()
GC.Collect();
// Publish should trigger UNSUB since NatsSub object should be collected by now.
- await RetryUntil(
+ await Retry.Until(
"unsubscribe message received",
() => proxy.ClientFrames.Count(f => f.Message.StartsWith("UNSUB")) == 1,
async () => await nats.PublishAsync("foo", 1));
}
-
- private async Task RetryUntil(string reason, Func condition, Func? action = null, TimeSpan? timeout = null)
- {
- timeout ??= TimeSpan.FromSeconds(10);
- var stopwatch = Stopwatch.StartNew();
- while (stopwatch.Elapsed < timeout)
- {
- if (action != null)
- await action();
- if (condition())
- return;
- await Task.Delay(50);
- }
-
- throw new TimeoutException($"Took too long ({timeout}) waiting for {reason}");
- }
}
diff --git a/tests/NATS.Client.Core.Tests/_NatsServer.cs b/tests/NATS.Client.Core.Tests/_NatsServer.cs
index a4132fd55..2d18abb10 100644
--- a/tests/NATS.Client.Core.Tests/_NatsServer.cs
+++ b/tests/NATS.Client.Core.Tests/_NatsServer.cs
@@ -8,10 +8,23 @@
namespace NATS.Client.Core.Tests;
+public static class ServerVersions
+{
+#pragma warning disable SA1310
+#pragma warning disable SA1401
+
+ // Changed INFO port reporting for WS connections (nats-server #4255)
+ public static Version V2_9_19 = new("2.9.19");
+
+#pragma warning restore SA1401
+#pragma warning restore SA1310
+}
+
public class NatsServer : IAsyncDisposable
{
private static readonly string Ext = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? ".exe" : string.Empty;
private static readonly string NatsServerPath = $"nats-server{Ext}";
+ private static readonly Version Version;
private readonly CancellationTokenSource _cancellationTokenSource = new();
private readonly string? _configFileName;
@@ -21,6 +34,25 @@ public class NatsServer : IAsyncDisposable
private readonly TransportType _transportType;
private int _disposed;
+ static NatsServer()
+ {
+ var process = new Process
+ {
+ StartInfo = new ProcessStartInfo
+ {
+ FileName = NatsServerPath,
+ Arguments = "-v",
+ RedirectStandardOutput = true,
+ UseShellExecute = false,
+ },
+ };
+ process.Start();
+ process.WaitForExit();
+ var output = process.StandardOutput.ReadToEnd();
+ var value = Regex.Match(output, @"v(\d+\.\d+\.\d+)").Groups[1].Value;
+ Version = new Version(value);
+ }
+
public NatsServer()
: this(new NullOutputHelper(), TransportType.Tcp)
{
@@ -91,6 +123,21 @@ public NatsServer(ITestOutputHelper outputHelper, TransportType transportType, N
_ => throw new ArgumentOutOfRangeException(),
};
+ public int ConnectionPort
+ {
+ get
+ {
+ if (_transportType == TransportType.WebSocket && ServerVersions.V2_9_19 <= Version)
+ {
+ return Options.WebSocketPort!.Value;
+ }
+ else
+ {
+ return Options.ServerPort;
+ }
+ }
+ }
+
public async ValueTask DisposeAsync()
{
if (Interlocked.Increment(ref _disposed) != 1)
@@ -334,7 +381,16 @@ public IReadOnlyList Frames
private bool NatsProtoDump(int client, string origin, TextReader sr, TextWriter sw)
{
- var message = sr.ReadLine();
+ string? message;
+ try
+ {
+ message = sr.ReadLine();
+ }
+ catch
+ {
+ return false;
+ }
+
if (message == null) return false;
if (Regex.IsMatch(message, @"^(INFO|CONNECT|PING|PONG|UNSUB|SUB|\+OK|-ERR)"))
@@ -369,14 +425,11 @@ private bool NatsProtoDump(int client, string origin, TextReader sr, TextWriter
case >= ' ' and <= '~':
sb.Append(c);
break;
- case '\t':
- sb.Append("\\t");
- break;
case '\n':
- sb.Append("\\n");
+ sb.Append('␊');
break;
case '\r':
- sb.Append("\\r");
+ sb.Append('␍');
break;
default:
sb.Append('.');
@@ -389,7 +442,7 @@ private bool NatsProtoDump(int client, string origin, TextReader sr, TextWriter
sw.Flush();
if (client > 0)
- AddFrame(new Frame(client, origin, Message: $"{message}\\r\\n{sb}"));
+ AddFrame(new Frame(client, origin, Message: $"{message}␍␊{sb}"));
return true;
}
diff --git a/tests/NATS.Client.Core.Tests/_Utils.cs b/tests/NATS.Client.Core.Tests/_Utils.cs
new file mode 100644
index 000000000..e1095c807
--- /dev/null
+++ b/tests/NATS.Client.Core.Tests/_Utils.cs
@@ -0,0 +1,45 @@
+using System.Diagnostics;
+using System.Net;
+using System.Net.Sockets;
+
+namespace NATS.Client.Core.Tests;
+
+public static class Retry
+{
+ public static async Task Until(string reason, Func condition, Func? action = null, TimeSpan? timeout = null, TimeSpan? retryDelay = null)
+ {
+ timeout ??= TimeSpan.FromSeconds(10);
+ var delay1 = retryDelay ?? TimeSpan.FromSeconds(.1);
+
+ var stopwatch = Stopwatch.StartNew();
+ while (stopwatch.Elapsed < timeout)
+ {
+ if (action != null)
+ await action();
+ if (condition())
+ return;
+ await Task.Delay(delay1);
+ }
+
+ throw new TimeoutException($"Took too long ({timeout}) waiting until {reason}");
+ }
+}
+
+public static class Net
+{
+ public static void WaitForTcpPortToClose(int port)
+ {
+ while (true)
+ {
+ try
+ {
+ using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ socket.Connect(IPAddress.Loopback, port);
+ }
+ catch (SocketException)
+ {
+ return;
+ }
+ }
+ }
+}
diff --git a/tests/NATS.Client.Core.Tests/_WaitSignal.cs b/tests/NATS.Client.Core.Tests/_WaitSignal.cs
index f6a835c95..8ff2acbbf 100644
--- a/tests/NATS.Client.Core.Tests/_WaitSignal.cs
+++ b/tests/NATS.Client.Core.Tests/_WaitSignal.cs
@@ -85,3 +85,53 @@ public TaskAwaiter GetAwaiter()
return _tcs.Task.WaitAsync(_timeout).GetAwaiter();
}
}
+
+public class WaitSignal
+{
+ private TimeSpan _timeout;
+ private int _count;
+ private TaskCompletionSource _tcs;
+
+ public WaitSignal()
+ : this(TimeSpan.FromSeconds(10))
+ {
+ }
+
+ public WaitSignal(int count)
+ : this(TimeSpan.FromSeconds(10), count)
+ {
+ }
+
+ public WaitSignal(TimeSpan timeout, int count = 1)
+ {
+ _timeout = timeout;
+ _count = count;
+ _tcs = new TaskCompletionSource();
+ }
+
+ public TimeSpan Timeout => _timeout;
+
+ public Task Task => _tcs.Task;
+
+ public void Pulse(T result, Exception? exception = null)
+ {
+ if (exception == null)
+ {
+ if (Interlocked.Decrement(ref _count) > 0)
+ {
+ return;
+ }
+
+ _tcs.TrySetResult(result);
+ }
+ else
+ {
+ _tcs.TrySetException(exception);
+ }
+ }
+
+ public TaskAwaiter GetAwaiter()
+ {
+ return _tcs.Task.WaitAsync(_timeout).GetAwaiter();
+ }
+}