diff --git a/NATS.Client.sln b/NATS.Client.sln index bbe8f8e36..1170e089a 100644 --- a/NATS.Client.sln +++ b/NATS.Client.sln @@ -49,6 +49,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.Core.PublishModel", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NATS.Client.Core.MemoryTests", "tests\NATS.Client.Core.MemoryTests\NATS.Client.Core.MemoryTests.csproj", "{B26DE6AC-A4D5-4427-8453-EE3514E4B513}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.Core.PublishHeaders", "sandbox\Example.Core.PublishHeaders\Example.Core.PublishHeaders.csproj", "{B0C82F24-BDEC-4420-A02A-F74E2423D755}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.Core.SubscribeHeaders", "sandbox\Example.Core.SubscribeHeaders\Example.Core.SubscribeHeaders.csproj", "{A96660DB-DAEB-4C57-8096-F236AC4FA927}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -115,6 +119,14 @@ Global {B26DE6AC-A4D5-4427-8453-EE3514E4B513}.Debug|Any CPU.Build.0 = Debug|Any CPU {B26DE6AC-A4D5-4427-8453-EE3514E4B513}.Release|Any CPU.ActiveCfg = Release|Any CPU {B26DE6AC-A4D5-4427-8453-EE3514E4B513}.Release|Any CPU.Build.0 = Release|Any CPU + {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B0C82F24-BDEC-4420-A02A-F74E2423D755}.Release|Any CPU.Build.0 = Release|Any CPU + {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A96660DB-DAEB-4C57-8096-F236AC4FA927}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -136,6 +148,8 @@ Global {C85BA135-3C21-4027-BE5A-849E1011DD0A} = {95A69671-16CA-4133-981C-CC381B7AAA30} {29F96D05-D02F-4610-A8FB-3527BF83C4A5} = {95A69671-16CA-4133-981C-CC381B7AAA30} {B26DE6AC-A4D5-4427-8453-EE3514E4B513} = {C526E8AB-739A-48D7-8FC4-048978C9B650} + {B0C82F24-BDEC-4420-A02A-F74E2423D755} = {95A69671-16CA-4133-981C-CC381B7AAA30} + {A96660DB-DAEB-4C57-8096-F236AC4FA927} = {95A69671-16CA-4133-981C-CC381B7AAA30} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {8CBB7278-D093-448E-B3DE-B5991209A1AA} diff --git a/NATS.Client.sln.DotSettings b/NATS.Client.sln.DotSettings new file mode 100644 index 000000000..bbcea8912 --- /dev/null +++ b/NATS.Client.sln.DotSettings @@ -0,0 +1,6 @@ + + ASCII + CR + LF + True + True \ No newline at end of file diff --git a/sandbox/Example.Core.PublishHeaders/Example.Core.PublishHeaders.csproj b/sandbox/Example.Core.PublishHeaders/Example.Core.PublishHeaders.csproj new file mode 100644 index 000000000..f6dc3f6fe --- /dev/null +++ b/sandbox/Example.Core.PublishHeaders/Example.Core.PublishHeaders.csproj @@ -0,0 +1,14 @@ + + + + Exe + net6.0 + enable + enable + + + + + + + diff --git a/sandbox/Example.Core.PublishHeaders/Program.cs b/sandbox/Example.Core.PublishHeaders/Program.cs new file mode 100644 index 000000000..a795cef4c --- /dev/null +++ b/sandbox/Example.Core.PublishHeaders/Program.cs @@ -0,0 +1,31 @@ +// > nats sub bar.* +using Microsoft.Extensions.Logging; +using NATS.Client.Core; + +var subject = "bar.xyz"; +var options = NatsOptions.Default with { LoggerFactory = new MinimumConsoleLoggerFactory(LogLevel.Error) }; + +Print("[CON] Connecting...\n"); + +await using var connection = new NatsConnection(options); + +for (int i = 0; i < 10; i++) +{ + Print($"[PUB] Publishing to subject ({i}) '{subject}'...\n"); + await connection.PublishAsync( + subject, + new Bar { Id = i, Name = "Baz" }, + new NatsPubOpts { Headers = new NatsHeaders { ["XFoo"] = $"bar{i}" } }); +} + +void Print(string message) +{ + Console.Write($"{DateTime.Now:HH:mm:ss} {message}"); +} + +public record Bar +{ + public int Id { get; set; } + + public string? Name { get; set; } +} diff --git a/sandbox/Example.Core.SubscribeHeaders/Example.Core.SubscribeHeaders.csproj b/sandbox/Example.Core.SubscribeHeaders/Example.Core.SubscribeHeaders.csproj new file mode 100644 index 000000000..f6dc3f6fe --- /dev/null +++ b/sandbox/Example.Core.SubscribeHeaders/Example.Core.SubscribeHeaders.csproj @@ -0,0 +1,14 @@ + + + + Exe + net6.0 + enable + enable + + + + + + + diff --git a/sandbox/Example.Core.SubscribeHeaders/Program.cs b/sandbox/Example.Core.SubscribeHeaders/Program.cs new file mode 100644 index 000000000..940595f55 --- /dev/null +++ b/sandbox/Example.Core.SubscribeHeaders/Program.cs @@ -0,0 +1,41 @@ +// > nats pub bar.xyz --count=10 "my_message_{{ Count }}" -H X-Foo:Baz + +using System.Text; +using Microsoft.Extensions.Logging; +using NATS.Client.Core; + +var subject = "bar.*"; +var options = NatsOptions.Default with { LoggerFactory = new MinimumConsoleLoggerFactory(LogLevel.Error) }; + +Print("[CON] Connecting...\n"); + +await using var connection = new NatsConnection(options); + +Print($"[SUB] Subscribing to subject '{subject}'...\n"); + +NatsSub sub = await connection.SubscribeAsync(subject); + +await foreach (var msg in sub.Msgs.ReadAllAsync()) +{ + Print($"[RCV] {msg.Subject}: {Encoding.UTF8.GetString(msg.Data.Span)}\n"); + if (msg.Headers != null) + { + foreach (var (key, values) in msg.Headers) + { + foreach (var value in values) + Print($" {key}: {value}\n"); + } + } +} + +void Print(string message) +{ + Console.Write($"{DateTime.Now:HH:mm:ss} {message}"); +} + +public record Bar +{ + public int Id { get; set; } + + public string? Name { get; set; } +} diff --git a/src/NATS.Client.Core/Commands/CommandConstants.cs b/src/NATS.Client.Core/Commands/CommandConstants.cs index 18e3a1611..bdd62fca6 100644 --- a/src/NATS.Client.Core/Commands/CommandConstants.cs +++ b/src/NATS.Client.Core/Commands/CommandConstants.cs @@ -13,6 +13,9 @@ internal static class CommandConstants // string.Join(",", Encoding.ASCII.GetBytes("PUB ")) public static ReadOnlySpan PubWithPadding => new byte[] { 80, 85, 66, 32 }; + // string.Join(",", Encoding.ASCII.GetBytes("HPUB ")) + public static ReadOnlySpan HPubWithPadding => new byte[] { 72, 80, 85, 66, 32 }; + // string.Join(",", Encoding.ASCII.GetBytes("SUB ")) public static ReadOnlySpan SubWithPadding => new byte[] { 83, 85, 66, 32 }; @@ -24,4 +27,7 @@ internal static class CommandConstants // string.Join(",", Encoding.ASCII.GetBytes("PONG\r\n")) public static ReadOnlySpan PongNewLine => new byte[] { 80, 79, 78, 71, 13, 10 }; + + // string.Join(",", Encoding.ASCII.GetBytes("NATS/1.0\r\n")) + public static ReadOnlySpan NatsHeaders10NewLine => new byte[] { 78, 65, 84, 83, 47, 49, 46, 48, 13, 10 }; } diff --git a/src/NATS.Client.Core/Commands/ProtocolWriter.cs b/src/NATS.Client.Core/Commands/ProtocolWriter.cs index 2c3541daa..f0f75e835 100644 --- a/src/NATS.Client.Core/Commands/ProtocolWriter.cs +++ b/src/NATS.Client.Core/Commands/ProtocolWriter.cs @@ -1,5 +1,6 @@ using System.Buffers; using System.Buffers.Text; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using NATS.Client.Core.Internal; @@ -12,6 +13,9 @@ internal sealed class ProtocolWriter private const int NewLineLength = 2; // \r\n private readonly FixedArrayBufferWriter _writer; // where T : IBufferWriter + private readonly FixedArrayBufferWriter _bufferHeaders = new(); + private readonly FixedArrayBufferWriter _bufferPayload = new(); + private readonly HeaderWriter _headerWriter = new(Encoding.UTF8); public ProtocolWriter(FixedArrayBufferWriter writer) { @@ -46,159 +50,63 @@ public void WritePong() } // https://docs.nats.io/reference/reference-protocols/nats-protocol#pub - // PUB [reply-to] <#bytes>\r\n[payload] - // To omit the payload, set the payload size to 0, but the second CRLF is still required. - public void WritePublish(string subject, string? replyTo, ReadOnlySequence payload) + // PUB [reply-to] <#bytes>\r\n[payload]\r\n + public void WritePublish(string subject, string? replyTo, NatsHeaders? headers, ReadOnlySequence payload) { - var offset = 0; - var maxLength = CommandConstants.PubWithPadding.Length - + subject.Length + 1 // with space padding - + (replyTo == null ? 0 : replyTo.Length + 1) - + MaxIntStringLength - + NewLineLength - + (int)payload.Length - + NewLineLength; - - var writableSpan = _writer.GetSpan(maxLength); - - CommandConstants.PubWithPadding.CopyTo(writableSpan); - offset += CommandConstants.PubWithPadding.Length; + // We use a separate buffer to write the headers so that we can calculate the + // size before we write to the output buffer '_writer'. + if (headers != null) + { + _bufferHeaders.Reset(); + _headerWriter.Write(_bufferHeaders, headers); + } - subject.WriteASCIIBytes(writableSpan.Slice(offset)); - offset += subject.Length; - writableSpan.Slice(offset)[0] = (byte)' '; - offset += 1; + // Start writing the message to buffer: + // PUP / HPUB + _writer.WriteSpan(headers == null ? CommandConstants.PubWithPadding : CommandConstants.HPubWithPadding); + _writer.WriteASCIIAndSpace(subject); if (replyTo != null) { - replyTo.WriteASCIIBytes(writableSpan.Slice(offset)); - offset += replyTo.Length; - writableSpan.Slice(offset)[0] = (byte)' '; - offset += 1; + _writer.WriteASCIIAndSpace(replyTo); } - if (!Utf8Formatter.TryFormat(payload.Length, writableSpan.Slice(offset), out var written)) + if (headers == null) { - throw new NatsException("Can not format integer."); + _writer.WriteNumber(payload.Length); } - - offset += written; - - CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset)); - offset += CommandConstants.NewLine.Length; - - if (payload.Length != 0) + else { - payload.CopyTo(writableSpan.Slice(offset)); - offset += (int)payload.Length; + var headersLength = _bufferHeaders.WrittenSpan.Length; + _writer.WriteNumber(CommandConstants.NatsHeaders10NewLine.Length + headersLength); + _writer.WriteSpace(); + var total = CommandConstants.NatsHeaders10NewLine.Length + headersLength + payload.Length; + _writer.WriteNumber(total); } - CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset)); - offset += CommandConstants.NewLine.Length; - - _writer.Advance(offset); - } - - public void WritePublish(string subject, string? replyTo, T? value, INatsSerializer serializer) - { - var offset = 0; - var maxLengthWithoutPayload = CommandConstants.PubWithPadding.Length - + subject.Length + 1 - + (replyTo == null ? 0 : replyTo.Length + 1) - + MaxIntStringLength - + NewLineLength; - - var writableSpan = _writer.GetSpan(maxLengthWithoutPayload); - - CommandConstants.PubWithPadding.CopyTo(writableSpan); - offset += CommandConstants.PubWithPadding.Length; - - subject.WriteASCIIBytes(writableSpan.Slice(offset)); - offset += subject.Length; - writableSpan.Slice(offset)[0] = (byte)' '; - offset += 1; + // End of message first line + _writer.WriteNewLine(); - if (replyTo != null) + if (headers != null) { - replyTo.WriteASCIIBytes(writableSpan.Slice(offset)); - offset += replyTo.Length; - writableSpan.Slice(offset)[0] = (byte)' '; - offset += 1; + _writer.WriteSpan(CommandConstants.NatsHeaders10NewLine); + _writer.WriteSpan(_bufferHeaders.WrittenSpan); } - // Advance for written. - _writer.Advance(offset); - - // preallocate range for write #bytes(write after serialized) - var preallocatedRange = _writer.PreAllocate(MaxIntStringLength); - offset += MaxIntStringLength; - - CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset)); - _writer.Advance(CommandConstants.NewLine.Length); - - var payloadLength = serializer.Serialize(_writer, value); - var payloadLengthSpan = _writer.GetSpanInPreAllocated(preallocatedRange); - payloadLengthSpan.Fill((byte)' '); - if (!Utf8Formatter.TryFormat(payloadLength, payloadLengthSpan, out var written)) + if (payload.Length != 0) { - throw new NatsException("Can not format integer."); + _writer.WriteSequence(payload); } - WriteConstant(CommandConstants.NewLine); + _writer.WriteNewLine(); } - public void WritePublish(string subject, ReadOnlyMemory inboxPrefix, int id, T? value, INatsSerializer serializer) + public void WritePublish(string subject, string? replyTo, NatsHeaders? headers, T? value, INatsSerializer serializer) { - Span idBytes = stackalloc byte[10]; - if (Utf8Formatter.TryFormat(id, idBytes, out var written)) - { - idBytes = idBytes.Slice(0, written); - } - - var offset = 0; - var maxLengthWithoutPayload = CommandConstants.PubWithPadding.Length - + subject.Length + 1 - + (inboxPrefix.Length + idBytes.Length + 1) // with space - + MaxIntStringLength - + NewLineLength; - - var writableSpan = _writer.GetSpan(maxLengthWithoutPayload); - - CommandConstants.PubWithPadding.CopyTo(writableSpan); - offset += CommandConstants.PubWithPadding.Length; - - subject.WriteASCIIBytes(writableSpan.Slice(offset)); - offset += subject.Length; - writableSpan.Slice(offset)[0] = (byte)' '; - offset += 1; - - // build reply-to - inboxPrefix.Span.CopyTo(writableSpan.Slice(offset)); - offset += inboxPrefix.Length; - idBytes.CopyTo(writableSpan.Slice(offset)); - offset += idBytes.Length; - writableSpan.Slice(offset)[0] = (byte)' '; - offset += 1; - - // Advance for written. - _writer.Advance(offset); - - // preallocate range for write #bytes(write after serialized) - var preallocatedRange = _writer.PreAllocate(MaxIntStringLength); - offset += MaxIntStringLength; - - CommandConstants.NewLine.CopyTo(writableSpan.Slice(offset)); - _writer.Advance(CommandConstants.NewLine.Length); - - var payloadLength = serializer.Serialize(_writer, value); - var payloadLengthSpan = _writer.GetSpanInPreAllocated(preallocatedRange); - payloadLengthSpan.Fill((byte)' '); - if (!Utf8Formatter.TryFormat(payloadLength, payloadLengthSpan, out written)) - { - throw new NatsException("Can not format integer."); - } - - WriteConstant(CommandConstants.NewLine); + _bufferPayload.Reset(); + serializer.Serialize(_bufferPayload, value); + var payload = new ReadOnlySequence(_bufferPayload.WrittenMemory); + WritePublish(subject, replyTo, headers, payload); } // https://docs.nats.io/reference/reference-protocols/nats-protocol#sub diff --git a/src/NATS.Client.Core/Commands/PublishCommand.cs b/src/NATS.Client.Core/Commands/PublishCommand.cs index c56fc6afd..0d345dcd5 100644 --- a/src/NATS.Client.Core/Commands/PublishCommand.cs +++ b/src/NATS.Client.Core/Commands/PublishCommand.cs @@ -7,6 +7,7 @@ internal sealed class AsyncPublishCommand : AsyncCommandBase Create(ObjectPool pool, CancellationTimer timer, string subject, string? replyTo, T? value, INatsSerializer serializer) + public static AsyncPublishCommand Create(ObjectPool pool, CancellationTimer timer, string subject, string? replyTo, NatsHeaders? headers, T? value, INatsSerializer serializer) { if (!TryRent(pool, out var result)) { @@ -23,6 +24,7 @@ public static AsyncPublishCommand Create(ObjectPool pool, CancellationTimer t result._subject = subject; result._replyTo = replyTo; + result._headers = headers; result._value = value; result._serializer = serializer; result.SetCancellationTimer(timer); @@ -32,12 +34,13 @@ public static AsyncPublishCommand Create(ObjectPool pool, CancellationTimer t public override void Write(ProtocolWriter writer) { - writer.WritePublish(_subject!, _replyTo, _value, _serializer!); + writer.WritePublish(_subject!, _replyTo, _headers, _value, _serializer!); } protected override void Reset() { _subject = default; + _headers = default; _value = default; _serializer = null; } @@ -46,13 +49,15 @@ protected override void Reset() internal sealed class AsyncPublishBytesCommand : AsyncCommandBase { private string? _subject; - private ReadOnlySequence _value; + private string? _replyTo; + private NatsHeaders? _headers; + private ReadOnlySequence _payload; private AsyncPublishBytesCommand() { } - public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer timer, string subject, ReadOnlySequence value) + public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer timer, string subject, string? replyTo, NatsHeaders? headers, ReadOnlySequence payload) { if (!TryRent(pool, out var result)) { @@ -60,7 +65,9 @@ public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer } result._subject = subject; - result._value = value; + result._replyTo = replyTo; + result._headers = headers; + result._payload = payload; result.SetCancellationTimer(timer); return result; @@ -68,12 +75,14 @@ public static AsyncPublishBytesCommand Create(ObjectPool pool, CancellationTimer public override void Write(ProtocolWriter writer) { - writer.WritePublish(_subject!, null, _value); + writer.WritePublish(_subject!, _replyTo, _headers, _payload); } protected override void Reset() { _subject = default; - _value = default; + _replyTo = default; + _headers = default; + _payload = default; } } diff --git a/src/NATS.Client.Core/Internal/BufferExtensions.cs b/src/NATS.Client.Core/Internal/BufferExtensions.cs new file mode 100644 index 000000000..c618dc8b4 --- /dev/null +++ b/src/NATS.Client.Core/Internal/BufferExtensions.cs @@ -0,0 +1,68 @@ +// Adapted from https://github.com/dotnet/aspnetcore/blob/v6.0.18/src/Shared/ServerInfrastructure/BufferExtensions.cs + +#nullable enable + +using System.Buffers; +using System.Runtime.CompilerServices; + +namespace NATS.Client.Core.Internal; + +internal static class BufferExtensions +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ReadOnlySpan ToSpan(this ReadOnlySequence buffer) + { + if (buffer.IsSingleSegment) + { + return buffer.FirstSpan; + } + + return buffer.ToArray(); + } + + /// + /// Returns position of first occurrence of item in the + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static SequencePosition? PositionOfAny(in this ReadOnlySequence source, T value0, T value1) + where T : IEquatable + { + if (source.IsSingleSegment) + { + int index = source.First.Span.IndexOfAny(value0, value1); + if (index != -1) + { + return source.GetPosition(index); + } + + return null; + } + else + { + return PositionOfAnyMultiSegment(source, value0, value1); + } + } + + private static SequencePosition? PositionOfAnyMultiSegment(in ReadOnlySequence source, T value0, T value1) + where T : IEquatable + { + SequencePosition position = source.Start; + SequencePosition result = position; + while (source.TryGet(ref position, out ReadOnlyMemory memory)) + { + int index = memory.Span.IndexOfAny(value0, value1); + if (index != -1) + { + return source.GetPosition(index, result); + } + else if (position.GetObject() == null) + { + break; + } + + result = position; + } + + return null; + } +} diff --git a/src/NATS.Client.Core/Internal/BufferWriterExtensions.cs b/src/NATS.Client.Core/Internal/BufferWriterExtensions.cs new file mode 100644 index 000000000..797cca547 --- /dev/null +++ b/src/NATS.Client.Core/Internal/BufferWriterExtensions.cs @@ -0,0 +1,66 @@ +using System.Buffers; +using System.Buffers.Text; +using System.Runtime.CompilerServices; +using System.Text; +using NATS.Client.Core.Commands; + +namespace NATS.Client.Core.Internal; + +internal static class BufferWriterExtensions +{ + private const int MaxIntStringLength = 10; // int.MaxValue.ToString().Length + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteNewLine(this FixedArrayBufferWriter writer) + { + var span = writer.GetSpan(CommandConstants.NewLine.Length); + CommandConstants.NewLine.CopyTo(span); + writer.Advance(CommandConstants.NewLine.Length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteNumber(this FixedArrayBufferWriter writer, long number) + { + var span = writer.GetSpan(MaxIntStringLength); + if (!Utf8Formatter.TryFormat(number, span, out var writtenLength)) + { + throw new NatsException("Can not format integer."); + } + + writer.Advance(writtenLength); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteSpace(this FixedArrayBufferWriter writer) + { + var span = writer.GetSpan(1); + span[0] = (byte)' '; + writer.Advance(1); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteSpan(this FixedArrayBufferWriter writer, ReadOnlySpan span) + { + var writerSpan = writer.GetSpan(span.Length); + span.CopyTo(writerSpan); + writer.Advance(span.Length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteSequence(this FixedArrayBufferWriter writer, ReadOnlySequence sequence) + { + var len = (int)sequence.Length; + var span = writer.GetSpan(len); + sequence.CopyTo(span); + writer.Advance(len); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteASCIIAndSpace(this FixedArrayBufferWriter writer, string ascii) + { + var span = writer.GetSpan(ascii.Length + 1); + ascii.WriteASCIIBytes(span); + span[ascii.Length] = (byte)' '; + writer.Advance(ascii.Length + 1); + } +} diff --git a/src/NATS.Client.Core/Internal/ClientOptions.cs b/src/NATS.Client.Core/Internal/ClientOptions.cs index fdd547e62..a3dc7d9ab 100644 --- a/src/NATS.Client.Core/Internal/ClientOptions.cs +++ b/src/NATS.Client.Core/Internal/ClientOptions.cs @@ -12,6 +12,7 @@ public ClientOptions(NatsOptions options) Name = options.Name; Echo = options.Echo; Verbose = options.Verbose; + Headers = options.Headers; Username = options.AuthOptions.Username; Password = options.AuthOptions.Password; AuthToken = options.AuthOptions.Token; diff --git a/src/NATS.Client.Core/Internal/DebuggingExtensions.cs b/src/NATS.Client.Core/Internal/DebuggingExtensions.cs new file mode 100644 index 000000000..4c5845eb8 --- /dev/null +++ b/src/NATS.Client.Core/Internal/DebuggingExtensions.cs @@ -0,0 +1,65 @@ +#if DEBUG + +using System.Buffers; +using System.Diagnostics; +using System.Text; + +namespace NATS.Client.Core.Internal; + +internal static class DebuggingExtensions +{ + public static string Dump(this ReadOnlySequence buffer) + { + var sb = new StringBuilder(); + foreach (var readOnlyMemory in buffer) + { + sb.Append(Dump(readOnlyMemory.Span)); + } + + return sb.ToString(); + } + + public static string Dump(this ReadOnlySpan span) + { + var sb = new StringBuilder(); + foreach (char b in span) + { + switch (b) + { + case >= ' ' and <= '~': + sb.Append(b); + break; + case '\r': + sb.Append('␍'); + break; + case '\n': + sb.Append('␊'); + break; + default: + sb.Append('.'); + break; + } + } + + return sb.ToString(); + } + + public static string Dump(this NatsHeaders? headers) + { + if (headers == null) + return ""; + + var sb = new StringBuilder(); + foreach (var (key, stringValues) in headers) + { + foreach (var value in stringValues) + { + sb.AppendLine($"{key}: {value}"); + } + } + + return sb.ToString(); + } +} + +#endif diff --git a/src/NATS.Client.Core/Internal/FastQueue.cs b/src/NATS.Client.Core/Internal/FastQueue.cs deleted file mode 100644 index d10941757..000000000 --- a/src/NATS.Client.Core/Internal/FastQueue.cs +++ /dev/null @@ -1,97 +0,0 @@ -using System.Runtime.CompilerServices; - -namespace NATS.Client.Core.Internal; - -// fixed size queue. -internal sealed class FastQueue -{ - private T[] _array; - private int _head; - private int _tail; - private int _size; - - public FastQueue(int capacity) - { - if (capacity < 0) - throw new ArgumentOutOfRangeException("capacity"); - _array = new T[capacity]; - _head = _tail = _size = 0; - } - - public int Count - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _size; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void Enqueue(T item) - { - if (_size == _array.Length) - { - ThrowForFullQueue(); - } - - _array[_tail] = item; - MoveNext(ref _tail); - _size++; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public T Dequeue() - { - if (_size == 0) - ThrowForEmptyQueue(); - - var head = _head; - var array = _array; - var removed = array[head]; - array[head] = default!; - MoveNext(ref _head); - _size--; - return removed; - } - - public void EnsureNewCapacity(int capacity) - { - var newarray = new T[capacity]; - if (_size > 0) - { - if (_head < _tail) - { - Array.Copy(_array, _head, newarray, 0, _size); - } - else - { - Array.Copy(_array, _head, newarray, 0, _array.Length - _head); - Array.Copy(_array, 0, newarray, _array.Length - _head, _tail); - } - } - - _array = newarray; - _head = 0; - _tail = _size == capacity ? 0 : _size; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void MoveNext(ref int index) - { - var tmp = index + 1; - if (tmp == _array.Length) - { - tmp = 0; - } - - index = tmp; - } - - private void ThrowForEmptyQueue() - { - throw new InvalidOperationException("Queue is empty."); - } - - private void ThrowForFullQueue() - { - throw new InvalidOperationException("Queue is full."); - } -} diff --git a/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs b/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs index 316d21bd7..bb74aa9eb 100644 --- a/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs +++ b/src/NATS.Client.Core/Internal/FixedArrayBufferWriter.cs @@ -17,6 +17,8 @@ public FixedArrayBufferWriter(int capacity = 65535) public ReadOnlyMemory WrittenMemory => _buffer.AsMemory(0, _written); + public ReadOnlySpan WrittenSpan => _buffer.AsSpan(0, _written); + public int WrittenCount => _written; [MethodImpl(MethodImplOptions.AggressiveInlining)] diff --git a/src/NATS.Client.Core/Internal/HeaderParser.cs b/src/NATS.Client.Core/Internal/HeaderParser.cs new file mode 100644 index 000000000..27afc3f7f --- /dev/null +++ b/src/NATS.Client.Core/Internal/HeaderParser.cs @@ -0,0 +1,335 @@ +// Adapted from https://github.com/dotnet/aspnetcore/blob/v6.0.18/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs + +using System.Buffers; +using System.Diagnostics; +using System.Text; +using Microsoft.Extensions.Primitives; + +namespace NATS.Client.Core.Internal; + +internal class HeaderParser +{ + private const byte ByteCR = (byte)'\r'; + private const byte ByteLF = (byte)'\n'; + private const byte ByteColon = (byte)':'; + private const byte ByteSpace = (byte)' '; + private const byte ByteTab = (byte)'\t'; + + private readonly Encoding _encoding; + + public HeaderParser(Encoding encoding) + { + _encoding = encoding; + } + + public bool ParseHeaders(in SequenceReader reader, NatsHeaders headers) + { + while (!reader.End) + { + var span = reader.UnreadSpan; + while (span.Length > 0) + { + var ch1 = (byte)0; + var ch2 = (byte)0; + var readAhead = 0; + + // Fast path, we're still looking at the same span + if (span.Length >= 2) + { + ch1 = span[0]; + ch2 = span[1]; + } + + // Possibly split across spans + else if (reader.TryRead(out ch1)) + { + // Note if we read ahead by 1 or 2 bytes + readAhead = reader.TryRead(out ch2) ? 2 : 1; + } + + if (ch1 == ByteCR) + { + // Check for final CRLF. + if (ch2 == ByteLF) + { + // If we got 2 bytes from the span directly so skip ahead 2 so that + // the reader's state matches what we expect + if (readAhead == 0) + { + reader.Advance(2); + } + + // Double CRLF found, so end of headers. + return true; + } + else if (readAhead == 1) + { + // Didn't read 2 bytes, reset the reader so we don't consume anything + reader.Rewind(1); + return false; + } + + // Headers don't end in CRLF line. + Debug.Assert(readAhead == 0 || readAhead == 2, "readAhead == 0 || readAhead == 2"); + + throw new NatsException($"Protocol error: invalid headers, no ending CRLFCRLF"); + } + + var length = 0; + + // We only need to look for the end if we didn't read ahead; otherwise there isn't enough in + // in the span to contain a header. + if (readAhead == 0) + { + length = span.IndexOfAny(ByteCR, ByteLF); + + // If not found length with be -1; casting to uint will turn it to uint.MaxValue + // which will be larger than any possible span.Length. This also serves to eliminate + // the bounds check for the next lookup of span[length] + if ((uint)length < (uint)span.Length) + { + // Early memory read to hide latency + var expectedCR = span[length]; + + // Correctly has a CR, move to next + length++; + + if (expectedCR != ByteCR) + { + // Sequence needs to be CRLF not LF first. + RejectRequestHeader(span[..length]); + } + + if ((uint)length < (uint)span.Length) + { + // Early memory read to hide latency + var expectedLF = span[length]; + + // Correctly has a LF, move to next + length++; + + if (expectedLF != ByteLF || + length < 5 || + + // Exclude the CRLF from the headerLine and parse the header name:value pair + !TryTakeSingleHeader(span[..(length - 2)], headers)) + { + // Sequence needs to be CRLF and not contain an inner CR not part of terminator. + // Less than min possible headerSpan of 5 bytes a:b\r\n + // Not parsable as a valid name:value header pair. + RejectRequestHeader(span[..length]); + } + + // Read the header successfully, skip the reader forward past the headerSpan. + span = span.Slice(length); + reader.Advance(length); + } + else + { + // No enough data, set length to 0. + length = 0; + } + } + } + + // End found in current span + if (length > 0) + { + continue; + } + + // We moved the reader to look ahead 2 bytes so rewind the reader + if (readAhead > 0) + { + reader.Rewind(readAhead); + } + + length = ParseMultiSpanHeader(reader, headers); + if (length < 0) + { + // Not there + return false; + } + + reader.Advance(length); + + // As we crossed spans set the current span to default + // so we move to the next span on the next iteration + span = default; + } + } + + return false; + } + + private int ParseMultiSpanHeader(in SequenceReader reader, NatsHeaders headers) + { + var currentSlice = reader.UnreadSequence; + var lineEndPosition = currentSlice.PositionOfAny(ByteCR, ByteLF); + + if (lineEndPosition == null) + { + // Not there. + return -1; + } + + SequencePosition lineEnd; + ReadOnlySpan headerSpan; + if (currentSlice.Slice(reader.Position, lineEndPosition.Value).Length == currentSlice.Length - 1) + { + // No enough data, so CRLF can't currently be there. + // However, we need to check the found char is CR and not LF + + // Advance 1 to include CR/LF in lineEnd + lineEnd = currentSlice.GetPosition(1, lineEndPosition.Value); + headerSpan = currentSlice.Slice(reader.Position, lineEnd).ToSpan(); + if (headerSpan[^1] != ByteCR) + { + RejectRequestHeader(headerSpan); + } + + return -1; + } + + // Advance 2 to include CR{LF?} in lineEnd + lineEnd = currentSlice.GetPosition(2, lineEndPosition.Value); + headerSpan = currentSlice.Slice(reader.Position, lineEnd).ToSpan(); + + if (headerSpan.Length < 5) + { + // Less than min possible headerSpan is 5 bytes a:b\r\n + RejectRequestHeader(headerSpan); + } + + if (headerSpan[^2] != ByteCR) + { + // Sequence needs to be CRLF not LF first. + RejectRequestHeader(headerSpan[..^1]); + } + + if (headerSpan[^1] != ByteLF || + + // Exclude the CRLF from the headerLine and parse the header name:value pair + !TryTakeSingleHeader(headerSpan[..^2], headers)) + { + // Sequence needs to be CRLF and not contain an inner CR not part of terminator. + // Not parsable as a valid name:value header pair. + RejectRequestHeader(headerSpan); + } + + return headerSpan.Length; + } + + private bool TryTakeSingleHeader(ReadOnlySpan headerLine, NatsHeaders headers) + { + // We are looking for a colon to terminate the header name. + // However, the header name cannot contain a space or tab so look for all three + // and see which is found first. + var nameEnd = headerLine.IndexOfAny(ByteColon, ByteSpace, ByteTab); + + // If not found length with be -1; casting to uint will turn it to uint.MaxValue + // which will be larger than any possible headerLine.Length. This also serves to eliminate + // the bounds check for the next lookup of headerLine[nameEnd] + if ((uint)nameEnd >= (uint)headerLine.Length) + { + // Colon not found. + return false; + } + + // Early memory read to hide latency + var expectedColon = headerLine[nameEnd]; + if (nameEnd == 0) + { + // Header name is empty. + return false; + } + + if (expectedColon != ByteColon) + { + // Header name space or tab. + return false; + } + + // Skip colon to get to the value start. + var valueStart = nameEnd + 1; + + // Generally there will only be one space, so we will check it directly + if ((uint)valueStart < (uint)headerLine.Length) + { + var ch = headerLine[valueStart]; + if (ch == ByteSpace || ch == ByteTab) + { + // Ignore first whitespace. + valueStart++; + + // More header chars? + if ((uint)valueStart < (uint)headerLine.Length) + { + ch = headerLine[valueStart]; + + // Do a fast check; as we now expect non-space, before moving into loop. + if (ch <= ByteSpace && (ch == ByteSpace || ch == ByteTab)) + { + valueStart++; + + // Is more whitespace, so we will loop to find the end. This is the slow path. + for (; valueStart < headerLine.Length; valueStart++) + { + ch = headerLine[valueStart]; + if (ch != ByteTab && ch != ByteSpace) + { + // Non-whitespace char found, valueStart is now start of value. + break; + } + } + } + } + } + } + + var valueEnd = headerLine.Length - 1; + + // Ignore end whitespace. Generally there will no spaces + // so we will check the first before moving to a loop. + if (valueEnd > valueStart) + { + var ch = headerLine[valueEnd]; + + // Do a fast check; as we now expect non-space, before moving into loop. + if (ch <= ByteSpace && (ch == ByteSpace || ch == ByteTab)) + { + // Is whitespace so move to loop + valueEnd--; + for (; valueEnd > valueStart; valueEnd--) + { + ch = headerLine[valueEnd]; + if (ch != ByteTab && ch != ByteSpace) + { + // Non-whitespace char found, valueEnd is now start of value. + break; + } + } + } + } + + // Range end is exclusive, so add 1 to valueEnd + valueEnd++; + var key = _encoding.GetString(headerLine[..nameEnd]); + var value = _encoding.GetString(headerLine[valueStart..valueEnd]); + if (headers.TryGetValue(key, out var existing)) + { + headers[key] = StringValues.Concat(existing, value); + } + else + { + headers[key] = value; + } + + return true; + } + + [StackTraceHidden] + private void RejectRequestHeader(ReadOnlySpan headerLine) + => throw new NatsException( + $"Protocol error: invalid request header line '{_encoding.GetString(headerLine)}'"); +} diff --git a/src/NATS.Client.Core/Internal/HeaderWriter.cs b/src/NATS.Client.Core/Internal/HeaderWriter.cs new file mode 100644 index 000000000..c76b806eb --- /dev/null +++ b/src/NATS.Client.Core/Internal/HeaderWriter.cs @@ -0,0 +1,100 @@ +using System.Buffers; +using System.Text; + +namespace NATS.Client.Core.Internal; + +internal class HeaderWriter +{ + private const byte ByteCr = (byte)'\r'; + private const byte ByteLf = (byte)'\n'; + private const byte ByteColon = (byte)':'; + private const byte ByteSpace = (byte)' '; + private const byte ByteDel = 127; + private readonly Encoding _encoding; + + public HeaderWriter(Encoding encoding) => _encoding = encoding; + + private static ReadOnlySpan CrLf => new[] { ByteCr, ByteLf }; + + private static ReadOnlySpan ColonSpace => new[] { ByteColon, ByteSpace }; + + internal int Write(in FixedArrayBufferWriter bufferWriter, NatsHeaders headers) + { + var initialCount = bufferWriter.WrittenCount; + foreach (var kv in headers) + { + foreach (var value in kv.Value) + { + if (value != null) + { + // write key + var keyLength = _encoding.GetByteCount(kv.Key); + var keySpan = bufferWriter.GetSpan(keyLength); + _encoding.GetBytes(kv.Key, keySpan); + if (!ValidateKey(keySpan.Slice(0, keyLength))) + { + throw new NatsException( + $"Invalid header key '{kv.Key}': contains colon, space, or other non-printable ASCII characters"); + } + + bufferWriter.Advance(keyLength); + bufferWriter.Write(ColonSpace); + + // write values + var valueLength = _encoding.GetByteCount(value); + var valueSpan = bufferWriter.GetSpan(valueLength); + _encoding.GetBytes(value, valueSpan); + if (!ValidateValue(valueSpan.Slice(0, valueLength))) + { + throw new NatsException($"Invalid header value for key '{kv.Key}': contains CRLF"); + } + + bufferWriter.Advance(valueLength); + bufferWriter.Write(CrLf); + } + } + } + + // Even empty header needs to terminate. + // We will send NATS/1.0 version line + // even if there are no headers. + bufferWriter.Write(CrLf); + + return bufferWriter.WrittenCount - initialCount; + } + + // cannot contain ASCII Bytes <=32, 58, or 127 + private static bool ValidateKey(ReadOnlySpan span) + { + foreach (var b in span) + { + if (b <= ByteSpace || b == ByteColon || b >= ByteDel) + { + return false; + } + } + + return true; + } + + // cannot contain CRLF + private static bool ValidateValue(ReadOnlySpan span) + { + while (true) + { + var pos = span.IndexOf(ByteCr); + if (pos == -1 || pos == span.Length - 1) + { + return true; + } + + pos += 1; + if (span[pos] == ByteLf) + { + return false; + } + + span = span[pos..]; + } + } +} diff --git a/src/NATS.Client.Core/Internal/StringUtils.cs b/src/NATS.Client.Core/Internal/StringExtensions.cs similarity index 93% rename from src/NATS.Client.Core/Internal/StringUtils.cs rename to src/NATS.Client.Core/Internal/StringExtensions.cs index dbd4add75..b2992e547 100644 --- a/src/NATS.Client.Core/Internal/StringUtils.cs +++ b/src/NATS.Client.Core/Internal/StringExtensions.cs @@ -1,6 +1,6 @@ namespace NATS.Client.Core.Internal; -internal static class StringUtils +internal static class StringExtensions { /// /// Allocation free ASCII buffer writer. diff --git a/src/NATS.Client.Core/MessagePublisher.cs b/src/NATS.Client.Core/MessagePublisher.cs deleted file mode 100644 index d894057c8..000000000 --- a/src/NATS.Client.Core/MessagePublisher.cs +++ /dev/null @@ -1,266 +0,0 @@ -using System.Buffers; -using System.Collections.Concurrent; -using Microsoft.Extensions.Logging; -using NATS.Client.Core.Internal; - -namespace NATS.Client.Core; - -// TODO: Clean up message publisher. -internal delegate Task PublishMessage(string subject, string? replyTo, NatsOptions options, ReadOnlySequence buffer, object?[] callbacks); - -internal static class MessagePublisher -{ - // To avoid boxing, cache generic type and invoke it. - private static readonly Func CreatePublisherValue = CreatePublisher; - private static readonly ConcurrentDictionary PublisherCache = new(); - - public static Task PublishAsync(string subject, string? replyTo, Type type, NatsOptions options, in ReadOnlySequence buffer, object?[] callbacks) - { - return PublisherCache.GetOrAdd(type, CreatePublisherValue).Invoke(subject, replyTo, options, buffer, callbacks); - } - - private static PublishMessage CreatePublisher(Type type) - { - if (type == typeof(byte[])) - { - return new ByteArrayMessagePublisher().Publish; - } - else if (type == typeof(ReadOnlyMemory)) - { - return new ReadOnlyMemoryMessagePublisher().PublishAsync; - } - - var publisher = typeof(MessagePublisher<>).MakeGenericType(type)!; - var instance = Activator.CreateInstance(publisher)!; - return (PublishMessage)Delegate.CreateDelegate(typeof(PublishMessage), instance, "Publish", false); - } -} - -internal sealed class MessagePublisher -{ - public void Publish(NatsOptions options, in ReadOnlySequence buffer, object?[] callbacks) - { - T? value; - try - { - value = options!.Serializer.Deserialize(buffer); - } - catch (Exception ex) - { - try - { - options!.LoggerFactory.CreateLogger>().LogError(ex, "Deserialize error during receive subscribed message. Type:{0}", typeof(T).Name); - } - catch - { - } - - return; - } - - try - { - if (!options.UseThreadPoolCallback) - { - foreach (var callback in callbacks!) - { - if (callback != null) - { - try - { - ((Action)callback).Invoke(value); - } - catch (Exception ex) - { - options!.LoggerFactory.CreateLogger>().LogError(ex, "Error occured during publish callback."); - } - } - } - } - else - { - foreach (var callback in callbacks!) - { - if (callback != null) - { - var item = ThreadPoolWorkItem.Create((Action)callback, value, options!.LoggerFactory); - ThreadPool.UnsafeQueueUserWorkItem(item, preferLocal: false); - } - } - } - } - catch (Exception ex) - { - try - { - options!.LoggerFactory.CreateLogger>().LogError(ex, "Error occured during publish callback."); - } - catch - { - } - } - } -} - -internal sealed class ByteArrayMessagePublisher -{ -#pragma warning disable CA1822 -#pragma warning disable VSTHRD200 -#pragma warning disable CS1998 - public async Task Publish(string subject, string? replyTo, NatsOptions? options, ReadOnlySequence buffer, object?[] callbacks) -#pragma warning restore CS1998 -#pragma warning restore VSTHRD200 -#pragma warning restore CA1822 - { - byte[] value; - try - { - if (buffer.IsEmpty) - { - value = Array.Empty(); - } - else - { - value = buffer.ToArray(); - } - } - catch (Exception ex) - { - try - { - options!.LoggerFactory.CreateLogger().LogError(ex, "Deserialize error during receive subscribed message."); - } - catch - { - } - - return; - } - - try - { - if (options is { UseThreadPoolCallback: false }) - { - foreach (var callback in callbacks!) - { - if (callback != null) - { - try - { - ((Action)callback).Invoke(value); - } - catch (Exception ex) - { - options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback."); - } - } - } - } - else - { - foreach (var callback in callbacks!) - { - if (callback != null) - { - var item = ThreadPoolWorkItem.Create((Action)callback, value, options!.LoggerFactory); - ThreadPool.UnsafeQueueUserWorkItem(item, preferLocal: false); - } - } - } - } - catch (Exception ex) - { - try - { - options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback."); - } - catch - { - } - } - } -} - -internal sealed class ReadOnlyMemoryMessagePublisher -{ - public async Task PublishAsync(string subject, string? replyTo, NatsOptions? options, ReadOnlySequence buffer, object?[] callbacks) - { - ReadOnlyMemory value; - try - { - if (buffer.IsEmpty) - { - value = Array.Empty(); - } - else - { - value = buffer.ToArray(); - } - } - catch (Exception ex) - { - try - { - options!.LoggerFactory.CreateLogger().LogError(ex, "Deserialize error during receive subscribed message."); - } - catch - { - } - - return; - } - - try - { - if (options is { UseThreadPoolCallback: false }) - { - foreach (var callback in callbacks!) - { - if (callback != null) - { - try - { - if (callback is NatsSubBase natsSub) - { - await natsSub.ReceiveAsync(subject, replyTo, buffer).ConfigureAwait(false); - } - else if (callback is Action> action) - { - action.Invoke(value); - } - else - { - throw new NatsException($"Unexpected internal handler type: {callback.GetType().Name}"); - } - } - catch (Exception ex) - { - options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback."); - } - } - } - } - else - { - foreach (var callback in callbacks!) - { - if (callback != null) - { - var item = ThreadPoolWorkItem>.Create((Action>)callback, value, options!.LoggerFactory); - ThreadPool.UnsafeQueueUserWorkItem(item, preferLocal: false); - } - } - } - } - catch (Exception ex) - { - try - { - options!.LoggerFactory.CreateLogger().LogError(ex, "Error occured during publish callback."); - } - catch - { - } - } - } -} diff --git a/src/NATS.Client.Core/NatsConnection.Publish.cs b/src/NATS.Client.Core/NatsConnection.Publish.cs index 4a8c7c80f..07bf1f959 100644 --- a/src/NATS.Client.Core/NatsConnection.Publish.cs +++ b/src/NATS.Client.Core/NatsConnection.Publish.cs @@ -8,9 +8,14 @@ public partial class NatsConnection /// public ValueTask PublishAsync(string subject, ReadOnlySequence payload = default, in NatsPubOpts? opts = default, CancellationToken cancellationToken = default) { + var replyTo = opts?.ReplyTo; + var headers = opts?.Headers; + + headers?.SetReadOnly(); + if (ConnectionState == NatsConnectionState.Open) { - var command = AsyncPublishBytesCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, payload); + var command = AsyncPublishBytesCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, replyTo, headers, payload); if (TryEnqueueCommand(command)) { return command.AsValueTask(); @@ -22,9 +27,9 @@ public ValueTask PublishAsync(string subject, ReadOnlySequence payload = d } else { - return WithConnectAsync(subject, payload, cancellationToken, static (self, k, v, token) => + return WithConnectAsync(subject, replyTo, headers, payload, cancellationToken, static (self, s, r, h, p, token) => { - var command = AsyncPublishBytesCommand.Create(self._pool, self.GetCommandTimer(token), k, v); + var command = AsyncPublishBytesCommand.Create(self._pool, self.GetCommandTimer(token), s, r, h, p); return self.EnqueueAndAwaitCommandAsync(command); }); } @@ -40,10 +45,14 @@ public ValueTask PublishAsync(NatsMsg msg, CancellationToken cancellationToken = public ValueTask PublishAsync(string subject, T data, in NatsPubOpts? opts = default, CancellationToken cancellationToken = default) { var replyTo = opts?.ReplyTo; + var serializer = opts?.Serializer ?? Options.Serializer; + var headers = opts?.Headers; + + headers?.SetReadOnly(); if (ConnectionState == NatsConnectionState.Open) { - var command = AsyncPublishCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, replyTo, data, opts?.Serializer ?? Options.Serializer); + var command = AsyncPublishCommand.Create(_pool, GetCommandTimer(cancellationToken), subject, replyTo, headers, data, serializer); if (TryEnqueueCommand(command)) { return command.AsValueTask(); @@ -55,9 +64,9 @@ public ValueTask PublishAsync(string subject, T data, in NatsPubOpts? opts = } else { - return WithConnectAsync(subject, replyTo, data, cancellationToken, static (self, s, r, v, token) => + return WithConnectAsync(subject, replyTo, headers, data, serializer, cancellationToken, static (self, s, r, h, v, ser, token) => { - var command = AsyncPublishCommand.Create(self._pool, self.GetCommandTimer(token), s, r, v, self.Options.Serializer); + var command = AsyncPublishCommand.Create(self._pool, self.GetCommandTimer(token), s, r, h, v, ser); return self.EnqueueAndAwaitCommandAsync(command); }); } diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs index 75cb7ec3f..af53645bf 100644 --- a/src/NATS.Client.Core/NatsConnection.cs +++ b/src/NATS.Client.Core/NatsConnection.cs @@ -74,6 +74,7 @@ public NatsConnection(NatsOptions options) InboxPrefix = Encoding.ASCII.GetBytes($"{options.InboxPrefix}{Guid.NewGuid()}."); _logger = options.LoggerFactory.CreateLogger(); _clientOptions = new ClientOptions(Options); + HeaderParser = new HeaderParser(options.HeaderEncoding); } // events @@ -89,6 +90,8 @@ public NatsConnection(NatsOptions options) public ServerInfo? ServerInfo { get; internal set; } // server info is set when received INFO + internal HeaderParser HeaderParser { get; } + /// /// Connect socket and write CONNECT command to nats server. /// @@ -162,9 +165,9 @@ internal void EnqueuePing(AsyncPingCommand pingCommand) pingCommand.SetCanceled(); } - internal ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence buffer) + internal ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer) { - return _subscriptionManager.PublishToClientHandlersAsync(subject, replyTo, sid, buffer); + return _subscriptionManager.PublishToClientHandlersAsync(subject, replyTo, sid, headersBuffer, payloadBuffer); } internal void ResetPongCount() @@ -730,6 +733,18 @@ private async ValueTask WithConnectAsync(T1 item1, T2 item2, T3 await coreAsync(this, item1, item2, item3, item4).ConfigureAwait(false); } + private async ValueTask WithConnectAsync(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, Func coreAsync) + { + await ConnectAsync().ConfigureAwait(false); + await coreAsync(this, item1, item2, item3, item4, item5).ConfigureAwait(false); + } + + private async ValueTask WithConnectAsync(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, Func coreAsync) + { + await ConnectAsync().ConfigureAwait(false); + await coreAsync(this, item1, item2, item3, item4, item5, item6).ConfigureAwait(false); + } + private async ValueTask WithConnectAsync(Func> coreAsync) { await ConnectAsync().ConfigureAwait(false); diff --git a/src/NATS.Client.Core/NatsHeaders.cs b/src/NATS.Client.Core/NatsHeaders.cs index 82d84b3f4..ff2b4591c 100644 --- a/src/NATS.Client.Core/NatsHeaders.cs +++ b/src/NATS.Client.Core/NatsHeaders.cs @@ -21,6 +21,8 @@ public class NatsHeaders : IDictionary private static readonly IEnumerator> EmptyIEnumeratorType = default(Enumerator); private static readonly IEnumerator EmptyIEnumerator = default(Enumerator); + private int _readonly = 0; + /// /// Initializes a new instance of . /// @@ -117,7 +119,7 @@ StringValues IDictionary.this[string key] /// Gets a value that indicates whether the is in read-only mode. /// /// true if the is in read-only mode; otherwise, false. - public bool IsReadOnly { get; set; } + public bool IsReadOnly => Volatile.Read(ref _readonly) == 1; /// /// Gets the collection of HTTP header names in this instance. @@ -331,6 +333,8 @@ IEnumerator IEnumerable.GetEnumerator() return Store.GetEnumerator(); } + internal void SetReadOnly() => Interlocked.Exchange(ref _readonly, 1); + private void ThrowIfReadOnly() { if (IsReadOnly) diff --git a/src/NATS.Client.Core/NatsMsg.cs b/src/NATS.Client.Core/NatsMsg.cs index 34921197b..f3b84413e 100644 --- a/src/NATS.Client.Core/NatsMsg.cs +++ b/src/NATS.Client.Core/NatsMsg.cs @@ -13,12 +13,8 @@ public abstract record NatsMsgBase(string Subject) public string? ReplyTo { get; init; } - // TODO: Implement headers in NatsMsg - // public NatsHeaders? Headers - // { - // get => throw new NotImplementedException(); - // set => throw new NotImplementedException(); - // } + public NatsHeaders? Headers { get; init; } + public ValueTask ReplyAsync(ReadOnlySequence data = default, in NatsPubOpts? opts = default, CancellationToken cancellationToken = default) { CheckReplyPreconditions(); diff --git a/src/NATS.Client.Core/NatsOptions.cs b/src/NATS.Client.Core/NatsOptions.cs index 776a4ed42..047c05b42 100644 --- a/src/NATS.Client.Core/NatsOptions.cs +++ b/src/NATS.Client.Core/NatsOptions.cs @@ -1,3 +1,4 @@ +using System.Text; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -10,6 +11,7 @@ namespace NATS.Client.Core; /// /// /// +/// /// /// /// @@ -29,12 +31,14 @@ namespace NATS.Client.Core; /// /// /// +/// public sealed record NatsOptions ( string Url, string Name, bool Echo, bool Verbose, + bool Headers, NatsAuthOptions AuthOptions, TlsOptions TlsOptions, INatsSerializer Serializer, @@ -53,13 +57,15 @@ public sealed record NatsOptions TimeSpan RequestTimeout, TimeSpan CommandTimeout, TimeSpan SubscriptionCleanUpInterval, - int? WriterCommandBufferLimit) + int? WriterCommandBufferLimit, + Encoding HeaderEncoding) { public static readonly NatsOptions Default = new( Url: "nats://localhost:4222", Name: "NATS .Net Client", Echo: true, Verbose: false, + Headers: true, AuthOptions: NatsAuthOptions.Default, TlsOptions: TlsOptions.Default, Serializer: JsonNatsSerializer.Default, @@ -78,7 +84,8 @@ public sealed record NatsOptions RequestTimeout: TimeSpan.FromMinutes(1), CommandTimeout: TimeSpan.FromMinutes(1), SubscriptionCleanUpInterval: TimeSpan.FromMinutes(5), - WriterCommandBufferLimit: null); + WriterCommandBufferLimit: null, + HeaderEncoding: Encoding.ASCII); internal NatsUri[] GetSeedUris() { diff --git a/src/NATS.Client.Core/NatsPubOpts.cs b/src/NATS.Client.Core/NatsPubOpts.cs index ffcef27c8..dbac18df0 100644 --- a/src/NATS.Client.Core/NatsPubOpts.cs +++ b/src/NATS.Client.Core/NatsPubOpts.cs @@ -4,11 +4,7 @@ public readonly record struct NatsPubOpts { public string? ReplyTo { get; init; } - // TODO: Implement headers in NatsPubOpts - // public NatsHeaders? Headers - // { - // get => throw new NotImplementedException(); - // init => throw new NotImplementedException(); - // } + public NatsHeaders? Headers { get; init; } + public INatsSerializer? Serializer { get; init; } } diff --git a/src/NATS.Client.Core/NatsReadProtocolProcessor.cs b/src/NATS.Client.Core/NatsReadProtocolProcessor.cs index d776ebbde..a2a788caf 100644 --- a/src/NATS.Client.Core/NatsReadProtocolProcessor.cs +++ b/src/NATS.Client.Core/NatsReadProtocolProcessor.cs @@ -1,6 +1,7 @@ using System.Buffers; using System.Buffers.Text; using System.Collections.Concurrent; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; @@ -200,7 +201,7 @@ private async Task ReadLoopAsync() buffer = buffer.Slice(buffer.GetPosition(3, positionBeforePayload.Value)); } - await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, ReadOnlySequence.Empty).ConfigureAwait(false); + await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, null, ReadOnlySequence.Empty).ConfigureAwait(false); } else { @@ -221,9 +222,62 @@ private async Task ReadLoopAsync() buffer = buffer.Slice(buffer.GetPosition(2, payloadSlice.End)); // payload + \r\n - await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, payloadSlice).ConfigureAwait(false); + await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, null, payloadSlice).ConfigureAwait(false); } } + else if (code == ServerOpCodes.HMsg) + { + // https://docs.nats.io/reference/reference-protocols/nats-protocol#hmsg + // HMSG [reply-to] <#header bytes> <#total bytes>\r\n[headers]\r\n\r\n[payload]\r\n + + // Find the end of 'HMSG' first message line + var positionBeforeNatsHeader = buffer.PositionOf((byte)'\n'); + if (positionBeforeNatsHeader == null) + { + _socketReader.AdvanceTo(buffer.Start); + buffer = await _socketReader.ReadUntilReceiveNewLineAsync().ConfigureAwait(false); + positionBeforeNatsHeader = buffer.PositionOf((byte)'\n')!; + } + + var msgHeader = buffer.Slice(0, positionBeforeNatsHeader.Value); + var (subject, sid, replyTo, headersLength, totalLength) = ParseHMessageHeader(msgHeader); + var payloadLength = totalLength - headersLength; + Debug.Assert(payloadLength >= 0, "Protocol error: illogical header and total lengths"); + + var headerBegin = buffer.GetPosition(1, positionBeforeNatsHeader.Value); + var totalSlice = buffer.Slice(headerBegin); + + // Read rest of the message if it's not already in the buffer + if (totalSlice.Length < totalLength + 2) + { + _socketReader.AdvanceTo(headerBegin); + + // Read headers + payload + \r\n + var size = totalLength - (int)totalSlice.Length + 2; + buffer = await _socketReader.ReadAtLeastAsync(size).ConfigureAwait(false); + totalSlice = buffer.Slice(0, totalLength); + } + else + { + totalSlice = totalSlice.Slice(0, totalLength); + } + + // Prepare buffer for the next message by removing 'headers + payload + \r\n' from it + buffer = buffer.Slice(buffer.GetPosition(2, totalSlice.End)); + + var versionLength = CommandConstants.NatsHeaders10NewLine.Length; + var versionSlice = totalSlice.Slice(0, versionLength); + if (!versionSlice.ToSpan().SequenceEqual(CommandConstants.NatsHeaders10NewLine)) + { + throw new NatsException("Protocol error: header version mismatch"); + } + + var headerSlice = totalSlice.Slice(versionLength, headersLength - versionLength); + var payloadSlice = totalSlice.Slice(headersLength, payloadLength); + + await _connection.PublishToClientHandlersAsync(subject, replyTo, sid, headerSlice, payloadSlice) + .ConfigureAwait(false); + } else { buffer = await DispatchCommandAsync(code, buffer).ConfigureAwait(false); @@ -436,11 +490,65 @@ private async ValueTask> DispatchCommandAsync(int code, R return ParseMessageHeader(buffer); } + // https://docs.nats.io/reference/reference-protocols/nats-protocol#hmsg + // HMSG [reply-to] <#header bytes> <#total bytes>\r\n[headers]\r\n\r\n[payload]\r\n + private (string subject, int sid, string? replyTo, int headersLength, int totalLength) ParseHMessageHeader(ReadOnlySpan msgHeader) + { + // 'HMSG' literal + Split(msgHeader, out _, out msgHeader); + + Split(msgHeader, out var subjectBytes, out msgHeader); + Split(msgHeader, out var sidBytes, out msgHeader); + Split(msgHeader, out var replyToOrHeaderLenBytes, out msgHeader); + Split(msgHeader, out var headerLenOrTotalLenBytes, out msgHeader); + + var subject = Encoding.ASCII.GetString(subjectBytes); + var sid = GetInt32(sidBytes); + + // We don't have the optional reply-to field + if (msgHeader.Length == 0) + { + var headersLength = GetInt32(replyToOrHeaderLenBytes); + var totalLen = GetInt32(headerLenOrTotalLenBytes); + return (subject, sid, null, headersLength, totalLen); + } + + // There is more data because of the reply-to field + else + { + var replyToBytes = replyToOrHeaderLenBytes; + var replyTo = Encoding.ASCII.GetString(replyToBytes); + + var headerLen = GetInt32(headerLenOrTotalLenBytes); + + var lastBytes = msgHeader; + var totalLen = GetInt32(lastBytes); + + return (subject, sid, replyTo, headerLen, totalLen); + } + } + + private (string subject, int sid, string? replyTo, int headersLength, int totalLength) ParseHMessageHeader(in ReadOnlySequence msgHeader) + { + if (msgHeader.IsSingleSegment) + { + return ParseHMessageHeader(msgHeader.FirstSpan); + } + + // header parsing use Slice frequently so ReadOnlySequence is high cost, should use Span. + // msgheader is not too long, ok to use stackalloc. + // TODO: Fix possible stack overflow + Span buffer = stackalloc byte[(int)msgHeader.Length]; + msgHeader.CopyTo(buffer); + return ParseHMessageHeader(buffer); + } + internal static class ServerOpCodes { // All sent by server commands as int(first 4 characters(includes space, newline)). public const int Info = 1330007625; // Encoding.ASCII.GetBytes("INFO") |> MemoryMarshal.Read public const int Msg = 541545293; // Encoding.ASCII.GetBytes("MSG ") |> MemoryMarshal.Read + public const int HMsg = 1196641608; // Encoding.ASCII.GetBytes("HMSG") |> MemoryMarshal.Read public const int Ping = 1196312912; // Encoding.ASCII.GetBytes("PING") |> MemoryMarshal.Read public const int Pong = 1196314448; // Encoding.ASCII.GetBytes("PONG") |> MemoryMarshal.Read public const int Ok = 223039275; // Encoding.ASCII.GetBytes("+OK\r") |> MemoryMarshal.Read diff --git a/src/NATS.Client.Core/NatsSub.cs b/src/NATS.Client.Core/NatsSub.cs index b7be73839..77cd2d1b3 100644 --- a/src/NATS.Client.Core/NatsSub.cs +++ b/src/NATS.Client.Core/NatsSub.cs @@ -1,5 +1,7 @@ using System.Buffers; +using System.Text; using System.Threading.Channels; +using NATS.Client.Core.Internal; namespace NATS.Client.Core; @@ -29,7 +31,7 @@ public virtual ValueTask DisposeAsync() return Manager.RemoveAsync(Sid); } - internal abstract ValueTask ReceiveAsync(string subject, string? replyTo, ReadOnlySequence buffer); + internal abstract ValueTask ReceiveAsync(string subject, string? replyTo, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer); } public sealed class NatsSub : NatsSubBase @@ -55,12 +57,25 @@ public override ValueTask DisposeAsync() return base.DisposeAsync(); } - internal override ValueTask ReceiveAsync(string subject, string? replyTo, ReadOnlySequence buffer) + internal override ValueTask ReceiveAsync(string subject, string? replyTo, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer) { - return _msgs.Writer.WriteAsync(new NatsMsg(subject, buffer.ToArray()) + NatsHeaders? natsHeaders = null; + if (headersBuffer != null) + { + natsHeaders = new NatsHeaders(); + if (!Connection.HeaderParser.ParseHeaders(new SequenceReader(headersBuffer.Value), natsHeaders)) + { + throw new NatsException("Error parsing headers"); + } + + natsHeaders.SetReadOnly(); + } + + return _msgs.Writer.WriteAsync(new NatsMsg(subject, payloadBuffer.ToArray()) { Connection = Connection, ReplyTo = replyTo, + Headers = natsHeaders, }); } } @@ -88,14 +103,28 @@ public override ValueTask DisposeAsync() return base.DisposeAsync(); } - internal override ValueTask ReceiveAsync(string subject, string? replyTo, ReadOnlySequence buffer) + internal override ValueTask ReceiveAsync(string subject, string? replyTo, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer) { var serializer = Serializer; - var data = serializer.Deserialize(buffer); + var data = serializer.Deserialize(payloadBuffer); + + NatsHeaders? natsHeaders = null; + if (headersBuffer != null) + { + natsHeaders = new NatsHeaders(); + if (!Connection.HeaderParser.ParseHeaders(new SequenceReader(headersBuffer.Value), natsHeaders)) + { + throw new NatsException("Error parsing headers"); + } + + natsHeaders.SetReadOnly(); + } + return _msgs.Writer.WriteAsync(new NatsMsg(subject, data!) { Connection = Connection, ReplyTo = replyTo, + Headers = natsHeaders, }); } } diff --git a/src/NATS.Client.Core/SubscriptionManager.cs b/src/NATS.Client.Core/SubscriptionManager.cs index 3049a73d0..50d5b8b97 100644 --- a/src/NATS.Client.Core/SubscriptionManager.cs +++ b/src/NATS.Client.Core/SubscriptionManager.cs @@ -56,7 +56,7 @@ public async ValueTask> AddAsync(string subject, string? queueGrou return sub; } - public ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence buffer) + public ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, int sid, in ReadOnlySequence? headersBuffer, in ReadOnlySequence payloadBuffer) { int? orphanSid = null; lock (_gate) @@ -65,7 +65,7 @@ public ValueTask PublishToClientHandlersAsync(string subject, string? replyTo, i { if (subRef.TryGetTarget(out var sub)) { - return sub.ReceiveAsync(subject, replyTo, buffer); + return sub.ReceiveAsync(subject, replyTo, headersBuffer, payloadBuffer); } else { diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs index 7714d4235..5275bed9d 100644 --- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs +++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Auth.cs @@ -117,7 +117,7 @@ public async Task UserCredentialAuthTest(string name, string serverConfig, NatsO var signalComplete2 = new WaitSignal(); var natsSub = await subConnection.SubscribeAsync(subject); - natsSub.Register(x => + var register = natsSub.Register(x => { _output.WriteLine($"Received: {x}"); if (x.Data == 1) @@ -148,34 +148,35 @@ public async Task UserCredentialAuthTest(string name, string serverConfig, NatsO _output.WriteLine("AUTHENTICATED RE-CONNECTION"); await pubConnection.PublishAsync(subject, 2); await signalComplete2; + + await natsSub.DisposeAsync(); + await register; } } internal static class NatsMsgTestUtils { - internal static NatsSub? Register(this NatsSub? sub, Action> action) + internal static Task Register(this NatsSub? sub, Action> action) { - if (sub == null) return null; - Task.Run(async () => + if (sub == null) return Task.CompletedTask; + return Task.Run(async () => { await foreach (var natsMsg in sub.Msgs.ReadAllAsync()) { action(natsMsg); } }); - return sub; } - internal static NatsSub? Register(this NatsSub? sub, Action action) + internal static Task Register(this NatsSub? sub, Action action) { - if (sub == null) return null; - Task.Run(async () => + if (sub == null) return Task.CompletedTask; + return Task.Run(async () => { await foreach (var natsMsg in sub.Msgs.ReadAllAsync()) { action(natsMsg); } }); - return sub; } } diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Headers.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Headers.cs new file mode 100644 index 000000000..e00a3424e --- /dev/null +++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Headers.cs @@ -0,0 +1,76 @@ +namespace NATS.Client.Core.Tests; + +public abstract partial class NatsConnectionTest +{ + [Fact] + public async Task HeaderParsingTest() + { + await using var server = new NatsServer(_output, _transportType); + + await using var nats = server.CreateClientConnection(); + + var sync = 0; + var signal1 = new WaitSignal>(); + var signal2 = new WaitSignal>(); + var sub = await nats.SubscribeAsync("foo"); + var reg = sub.Register(m => + { + if (m.Data < 10) + { + Interlocked.Exchange(ref sync, m.Data); + return; + } + + if (m.Data == 100) + signal1.Pulse(m); + if (m.Data == 200) + signal2.Pulse(m); + }); + + await Retry.Until( + "subscription is active", + () => Volatile.Read(ref sync) == 1, + async () => await nats.PublishAsync("foo", 1)); + + var headers = new NatsHeaders + { + ["Test-Header-Key"] = "test-header-value", + ["Multi"] = new[] { "multi-value-0", "multi-value-1" }, + }; + Assert.False(headers.IsReadOnly); + + // Send with headers + await nats.PublishAsync("foo", 100, new NatsPubOpts { Headers = headers }); + + Assert.True(headers.IsReadOnly); + Assert.Throws(() => + { + headers["should-not-set"] = "value"; + }); + + var msg1 = await signal1; + Assert.Equal(100, msg1.Data); + Assert.NotNull(msg1.Headers); + Assert.Equal(2, msg1.Headers!.Count); + + Assert.True(msg1.Headers!.ContainsKey("Test-Header-Key")); + Assert.Single(msg1.Headers["Test-Header-Key"].ToArray()); + Assert.Equal("test-header-value", msg1.Headers["Test-Header-Key"]); + + Assert.True(msg1.Headers!.ContainsKey("Multi")); + Assert.Equal(2, msg1.Headers["Multi"].Count); + Assert.Equal("multi-value-0", msg1.Headers["Multi"][0]); + Assert.Equal("multi-value-1", msg1.Headers["Multi"][1]); + + // Send empty headers + await nats.PublishAsync("foo", 200, new NatsPubOpts { Headers = new NatsHeaders() }); + + var msg2 = await signal2; + Assert.Equal(200, msg2.Data); + Assert.NotNull(msg2.Headers); + Assert.Empty(msg2.Headers!); + + await sub.DisposeAsync(); + await reg; + } +} diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs index 22a01c1cb..6a785ff50 100644 --- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs +++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.QueueGroups.cs @@ -5,7 +5,8 @@ public abstract partial class NatsConnectionTest [Fact] public async Task QueueGroupsTest() { - const int messageCount = 20; + // Use high enough count to create some distribution among subscribers. + const int messageCount = 100; await using var server = new NatsServer(_output, _transportType); @@ -21,12 +22,19 @@ public async Task QueueGroupsTest() cts.Token.Register(() => signal.Pulse()); var count = 0; + var sync1 = 0; var messages1 = new List(); var reader1 = Task.Run( async () => { await foreach (var msg in sub1.Msgs.ReadAllAsync(cts.Token)) { + if (msg.Subject == "foo.sync") + { + Interlocked.Exchange(ref sync1, 1); + continue; + } + Assert.Equal($"foo.xyz{msg.Data}", msg.Subject); lock (messages1) messages1.Add(msg.Data); var total = Interlocked.Increment(ref count); @@ -35,12 +43,19 @@ public async Task QueueGroupsTest() }, cts.Token); + var sync2 = 0; var messages2 = new List(); var reader2 = Task.Run( async () => { await foreach (var msg in sub2.Msgs.ReadAllAsync(cts.Token)) { + if (msg.Subject == "foo.sync") + { + Interlocked.Exchange(ref sync2, 1); + continue; + } + Assert.Equal($"foo.xyz{msg.Data}", msg.Subject); lock (messages2) messages2.Add(msg.Data); var total = Interlocked.Increment(ref count); @@ -49,8 +64,10 @@ public async Task QueueGroupsTest() }, cts.Token); - await conn1.PingAsync(); - await conn2.PingAsync(); + await Retry.Until( + "subscriptions are active", + () => Volatile.Read(ref sync1) + Volatile.Read(ref sync2) == 2, + async () => await conn3.PublishAsync("foo.sync", 0)); for (int i = 0; i < messageCount; i++) { diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs index 1b835541a..230036f19 100644 --- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs +++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.Sharding.cs @@ -35,9 +35,12 @@ public async Task ShardingConnectionTest() var l1 = new List(); var l2 = new List(); var l3 = new List(); - (await shardedConnection.GetCommand("foo").SubscribeAsync()).Register(msg => l1.Add(msg.Data)); - (await shardedConnection.GetCommand("bar").SubscribeAsync()).Register(msg => l2.Add(msg.Data)); - (await shardedConnection.GetCommand("baz").SubscribeAsync()).Register(msg => l3.Add(msg.Data)); + var sub1 = await shardedConnection.GetCommand("foo").SubscribeAsync(); + var reg1 = sub1.Register(msg => l1.Add(msg.Data)); + var sub2 = await shardedConnection.GetCommand("bar").SubscribeAsync(); + var reg2 = sub2.Register(msg => l2.Add(msg.Data)); + var sub3 = await shardedConnection.GetCommand("baz").SubscribeAsync(); + var reg3 = sub3.Register(msg => l3.Add(msg.Data)); await shardedConnection.GetCommand("foo").PublishAsync(10); await shardedConnection.GetCommand("bar").PublishAsync(20); @@ -54,5 +57,12 @@ public async Task ShardingConnectionTest() var r = await shardedConnection.GetCommand("foobarbaz").RequestAsync(100); r.ShouldBe(10000); + + await sub1.DisposeAsync(); + await reg1; + await sub2.DisposeAsync(); + await reg2; + await sub3.DisposeAsync(); + await reg3; } } diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs index d3b714faa..6916c930d 100644 --- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs +++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs @@ -28,8 +28,8 @@ public async Task SimplePubSubTest() var signalComplete = new WaitSignal(); var list = new List(); - await using var sub = await subConnection.SubscribeAsync(subject); - sub.Register(x => + var sub = await subConnection.SubscribeAsync(subject); + var register = sub.Register(x => { _output.WriteLine($"Received: {x.Data}"); list.Add(x.Data); @@ -46,6 +46,8 @@ public async Task SimplePubSubTest() } await signalComplete; + await sub.DisposeAsync(); + await register; list.ShouldEqual(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); } @@ -67,7 +69,8 @@ public async Task EncodingTest() var actual = new List(); var signalComplete = new WaitSignal(); - await using var d = (await subConnection.SubscribeAsync(key)).Register(x => + var sub = await subConnection.SubscribeAsync(key); + var register = sub.Register(x => { actual.Add(x.Data); if (x.Data.Id == 30) @@ -83,6 +86,8 @@ public async Task EncodingTest() await pubConnection.PublishAsync(key, three); await signalComplete; + await sub.DisposeAsync(); + await register; actual.ShouldEqual(new[] { one, two, three }); } @@ -102,14 +107,25 @@ public async Task RequestTest(int minSize) var subject = Guid.NewGuid().ToString(); var text = new StringBuilder(minSize).Insert(0, "a", minSize).ToString(); + var sync = 0; await using var replyHandle = await subConnection.ReplyAsync(subject, x => { + if (x < 10) + { + Interlocked.Exchange(ref sync, x); + return "sync"; + } + if (x == 100) throw new Exception(); return text + x; }); - await Task.Delay(1000); + await Retry.Until( + "reply handle is ready", + () => Volatile.Read(ref sync) == 1, + async () => await pubConnection.PublishAsync(subject, 1, new NatsPubOpts { ReplyTo = "ignore" }), + retryDelay: TimeSpan.FromSeconds(1)); var v = await pubConnection.RequestAsync(subject, 9999); v.Should().Be(text + 9999); @@ -137,7 +153,7 @@ public async Task ReconnectSingleTest() ServerDisposeReturnsPorts = false, }; await using var server = new NatsServer(_output, _transportType, options); - var key = Guid.NewGuid().ToString(); + var subject = Guid.NewGuid().ToString(); await using var subConnection = server.CreateClientConnection(); await using var pubConnection = server.CreateClientConnection(); @@ -145,10 +161,18 @@ public async Task ReconnectSingleTest() await pubConnection.ConnectAsync(); // wait open var list = new List(); + var sync = 0; var waitForReceive300 = new WaitSignal(); var waitForReceiveFinish = new WaitSignal(); - var d = (await subConnection.SubscribeAsync(key)).Register(x => + var sub = await subConnection.SubscribeAsync(subject); + var reg = sub.Register(x => { + if (x.Data < 10) + { + Interlocked.Exchange(ref sync, x.Data); + return; + } + _output.WriteLine("RECEIVED: " + x.Data); list.Add(x.Data); if (x.Data == 300) @@ -161,11 +185,15 @@ public async Task ReconnectSingleTest() waitForReceiveFinish.Pulse(); } }); - await subConnection.PingAsync(); // wait for subscribe complete - await pubConnection.PublishAsync(key, 100); - await pubConnection.PublishAsync(key, 200); - await pubConnection.PublishAsync(key, 300); + await Retry.Until( + "subscription is active (1)", + () => Volatile.Read(ref sync) == 1, + async () => await pubConnection.PublishAsync(subject, 1)); + + await pubConnection.PublishAsync(subject, 100); + await pubConnection.PublishAsync(subject, 200); + await pubConnection.PublishAsync(subject, 300); _output.WriteLine("TRY WAIT RECEIVE 300"); await waitForReceive300; @@ -184,11 +212,19 @@ public async Task ReconnectSingleTest() await subConnection.ConnectAsync(); // wait open again await pubConnection.ConnectAsync(); // wait open again + await Retry.Until( + "subscription is active (2)", + () => Volatile.Read(ref sync) == 2, + async () => await pubConnection.PublishAsync(subject, 2)); + _output.WriteLine("RECONNECT COMPLETE, PUBLISH 400 and 500"); - await pubConnection.PublishAsync(key, 400); - await pubConnection.PublishAsync(key, 500); + await pubConnection.PublishAsync(subject, 400); + await pubConnection.PublishAsync(subject, 500); await waitForReceiveFinish; + await sub.DisposeAsync(); + await reg; + list.ShouldEqual(100, 200, 300, 400, 500); } @@ -198,7 +234,7 @@ public async Task ReconnectClusterTest() await using var cluster = new NatsCluster(_output, _transportType); await Task.Delay(TimeSpan.FromSeconds(5)); // wait for cluster completely connected. - var key = Guid.NewGuid().ToString(); + var subject = Guid.NewGuid().ToString(); await using var connection1 = cluster.Server1.CreateClientConnection(); await using var connection2 = cluster.Server2.CreateClientConnection(); @@ -220,10 +256,18 @@ public async Task ReconnectClusterTest() connection3.ServerInfo!.ClientConnectUrls!.Select(x => new NatsUri(x, true).Port).Distinct().Count().ShouldBe(3); var list = new List(); + var sync = 0; var waitForReceive300 = new WaitSignal(); var waitForReceiveFinish = new WaitSignal(); - var d = (await connection1.SubscribeAsync(key)).Register(x => + var sub = await connection1.SubscribeAsync(subject); + var reg = sub.Register(x => { + if (x.Data < 10) + { + Interlocked.Exchange(ref sync, x.Data); + return; + } + _output.WriteLine("RECEIVED: " + x.Data); list.Add(x.Data); if (x.Data == 300) @@ -236,11 +280,16 @@ public async Task ReconnectClusterTest() waitForReceiveFinish.Pulse(); } }); - await connection1.PingAsync(); // wait for subscribe complete - await connection2.PublishAsync(key, 100); - await connection2.PublishAsync(key, 200); - await connection2.PublishAsync(key, 300); + await Retry.Until( + "subscription is active (1)", + () => Volatile.Read(ref sync) == 1, + async () => await connection2.PublishAsync(subject, 1), + retryDelay: TimeSpan.FromSeconds(.5)); + + await connection2.PublishAsync(subject, 100); + await connection2.PublishAsync(subject, 200); + await connection2.PublishAsync(subject, 300); await waitForReceive300; var disconnectSignal = connection1.ConnectionDisconnectedAsAwaitable(); // register disconnect before kill @@ -249,14 +298,25 @@ public async Task ReconnectClusterTest() await cluster.Server1.DisposeAsync(); // process kill await disconnectSignal; + Net.WaitForTcpPortToClose(cluster.Server1.ConnectionPort); + await connection1.ConnectAsync(); // wait for reconnect complete. - connection1.ServerInfo!.Port.Should() - .BeOneOf(cluster.Server2.Options.ServerPort, cluster.Server3.Options.ServerPort); + connection1.ServerInfo!.Port.Should().BeOneOf(cluster.Server2.ConnectionPort, cluster.Server3.ConnectionPort); + + await Retry.Until( + "subscription is active (2)", + () => Volatile.Read(ref sync) == 2, + async () => await connection2.PublishAsync(subject, 2), + retryDelay: TimeSpan.FromSeconds(.5)); - await connection2.PublishAsync(key, 400); - await connection2.PublishAsync(key, 500); + await connection2.PublishAsync(subject, 400); + await connection2.PublishAsync(subject, 500); await waitForReceiveFinish; + + await sub.DisposeAsync(); + await reg; + list.ShouldEqual(100, 200, 300, 400, 500); } } diff --git a/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs b/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs new file mode 100644 index 000000000..63a9b2b20 --- /dev/null +++ b/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs @@ -0,0 +1,69 @@ +using System.Buffers; +using System.Text; + +namespace NATS.Client.Core.Tests; + +public class NatsHeaderTest +{ + private readonly ITestOutputHelper _output; + + public NatsHeaderTest(ITestOutputHelper output) => _output = output; + + [Fact] + public void WriterTests() + { + var headers = new NatsHeaders + { + ["k1"] = "v1", + ["k2"] = new[] { "v2-0", "v2-1" }, + ["a-long-header-key"] = "value", + ["key"] = "a-long-header-value", + }; + var writer = new HeaderWriter(Encoding.UTF8); + var buffer = new FixedArrayBufferWriter(); + var written = writer.Write(buffer, headers); + + var text = "k1: v1\r\nk2: v2-0\r\nk2: v2-1\r\na-long-header-key: value\r\nkey: a-long-header-value\r\n\r\n"; + var expected = new Span(Encoding.UTF8.GetBytes(text)); + + Assert.Equal(expected.Length, written); + Assert.True(expected.SequenceEqual(buffer.WrittenSpan)); + +#if DEBUG + _output.WriteLine($"Buffer:\n{buffer.WrittenSpan.Dump()}"); +#endif + } + + [Fact] + public void ParserTests() + { + var parser = new HeaderParser(Encoding.UTF8); + var text = "k1: v1\r\nk2: v2-0\r\nk2: v2-1\r\na-long-header-key: value\r\nkey: a-long-header-value\r\n\r\n"; + var input = new SequenceReader(new ReadOnlySequence(Encoding.UTF8.GetBytes(text))); + var headers = new NatsHeaders(); + parser.ParseHeaders(input, headers); + +#if DEBUG + _output.WriteLine($"Headers:\n{headers.Dump()}"); +#endif + + Assert.Equal(4, headers.Count); + + Assert.True(headers.ContainsKey("k1")); + Assert.Single(headers["k1"].ToArray()); + Assert.Equal("v1", headers["k1"]); + + Assert.True(headers.ContainsKey("k2")); + Assert.Equal(2, headers["k2"].ToArray().Length); + Assert.Equal("v2-0", headers["k2"][0]); + Assert.Equal("v2-1", headers["k2"][1]); + + Assert.True(headers.ContainsKey("a-long-header-key")); + Assert.Single(headers["a-long-header-key"].ToArray()); + Assert.Equal("value", headers["a-long-header-key"]); + + Assert.True(headers.ContainsKey("key")); + Assert.Single(headers["key"].ToArray()); + Assert.Equal("a-long-header-value", headers["key"]); + } +} diff --git a/tests/NATS.Client.Core.Tests/ProtocolTest.cs b/tests/NATS.Client.Core.Tests/ProtocolTest.cs new file mode 100644 index 000000000..3b948646c --- /dev/null +++ b/tests/NATS.Client.Core.Tests/ProtocolTest.cs @@ -0,0 +1,156 @@ +namespace NATS.Client.Core.Tests; + +public class ProtocolTest +{ + private readonly ITestOutputHelper _output; + + public ProtocolTest(ITestOutputHelper output) => _output = output; + + [Fact] + public async Task Subscription_with_same_subject() + { + await using var server = new NatsServer(_output, TransportType.Tcp); + var nats1 = server.CreateClientConnection(); + var (nats2, proxy) = server.CreateProxiedClientConnection(); + + var sub1 = await nats2.SubscribeAsync("foo.bar"); + var sub2 = await nats2.SubscribeAsync("foo.bar"); + var sub3 = await nats2.SubscribeAsync("foo.baz"); + + var sync1 = 0; + var sync2 = 0; + var sync3 = 0; + var count = new WaitSignal(3); + + var reg1 = sub1.Register(m => + { + if (m.Data == 0) + { + Interlocked.Exchange(ref sync1, 1); + return; + } + + count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}")); + }); + + var reg2 = sub2.Register(m => + { + if (m.Data == 0) + { + Interlocked.Exchange(ref sync2, 1); + return; + } + + count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}")); + }); + + var reg3 = sub3.Register(m => + { + if (m.Data == 0) + { + Interlocked.Exchange(ref sync3, 1); + return; + } + + count.Pulse(m.Subject == "foo.baz" ? null : new Exception($"Subject mismatch {m.Subject}")); + }); + + // Since subscription and publishing are sent through different connections there is + // a race where one or more subscriptions are made after the publishing happens. + // So, we make sure subscribers are accepted by the server before we send any test data. + await Retry.Until( + "all subscriptions are active", + () => Volatile.Read(ref sync1) + Volatile.Read(ref sync2) + Volatile.Read(ref sync3) == 3, + async () => + { + await nats1.PublishAsync("foo.bar", 0); + await nats1.PublishAsync("foo.baz", 0); + }); + + await nats1.PublishAsync("foo.bar", 1); + await nats1.PublishAsync("foo.baz", 1); + + // Wait until we received all test data + await count; + + var frames = proxy.ClientFrames.OrderBy(f => f.Message).ToList(); + + foreach (var frame in frames) + { + _output.WriteLine($"[PROXY] {frame}"); + } + + Assert.Equal(3, frames.Count); + Assert.StartsWith("SUB foo.bar", frames[0].Message); + Assert.StartsWith("SUB foo.bar", frames[1].Message); + Assert.StartsWith("SUB foo.baz", frames[2].Message); + Assert.False(frames[0].Message.Equals(frames[1].Message), "Should have different SIDs"); + + await sub1.DisposeAsync(); + await reg1; + await sub2.DisposeAsync(); + await reg2; + await sub3.DisposeAsync(); + await reg3; + await nats1.DisposeAsync(); + await nats2.DisposeAsync(); + proxy.Dispose(); + } + + [Fact] + public async Task Publish_empty_message_for_notifications() + { + await using var server = new NatsServer(_output, TransportType.Tcp); + var (nats, proxy) = server.CreateProxiedClientConnection(); + + var sync = 0; + var signal1 = new WaitSignal(); + var signal2 = new WaitSignal(); + var sub = await nats.SubscribeAsync("foo.*"); + var reg = sub.Register(m => + { + switch (m.Subject) + { + case "foo.sync": + Interlocked.Exchange(ref sync, 1); + break; + case "foo.signal1": + signal1.Pulse(m); + break; + case "foo.signal2": + signal2.Pulse(m); + break; + } + }); + + await Retry.Until( + "subscription is active", + () => Volatile.Read(ref sync) == 1, + async () => await nats.PublishAsync("foo.sync"), + retryDelay: TimeSpan.FromSeconds(1)); + + // PUB notifications + await nats.PublishAsync("foo.signal1"); + var msg1 = await signal1; + Assert.Equal(0, msg1.Data.Length); + Assert.Null(msg1.Headers); + var pubFrame1 = proxy.Frames.First(f => f.Message.StartsWith("PUB foo.signal1")); + Assert.Equal("PUB foo.signal1 0␍␊", pubFrame1.Message); + var msgFrame1 = proxy.Frames.First(f => f.Message.StartsWith("MSG foo.signal1")); + Assert.Matches(@"^MSG foo.signal1 \w+ 0␍␊$", msgFrame1.Message); + + // HPUB notifications + await nats.PublishAsync("foo.signal2", opts: new NatsPubOpts { Headers = new NatsHeaders() }); + var msg2 = await signal2; + Assert.Equal(0, msg2.Data.Length); + Assert.NotNull(msg2.Headers); + Assert.Empty(msg2.Headers!); + var pubFrame2 = proxy.Frames.First(f => f.Message.StartsWith("HPUB foo.signal2")); + Assert.Equal("HPUB foo.signal2 12 12␍␊NATS/1.0␍␊␍␊", pubFrame2.Message); + var msgFrame2 = proxy.Frames.First(f => f.Message.StartsWith("HMSG foo.signal2")); + Assert.Matches(@"^HMSG foo.signal2 \w+ 12 12␍␊NATS/1.0␍␊␍␊$", msgFrame2.Message); + + await sub.DisposeAsync(); + await reg; + } +} diff --git a/tests/NATS.Client.Core.Tests/SubscriptionTest.cs b/tests/NATS.Client.Core.Tests/SubscriptionTest.cs index 1037e7245..fcaf5cd5a 100644 --- a/tests/NATS.Client.Core.Tests/SubscriptionTest.cs +++ b/tests/NATS.Client.Core.Tests/SubscriptionTest.cs @@ -1,5 +1,3 @@ -using System.Diagnostics; - namespace NATS.Client.Core.Tests; public class SubscriptionTest @@ -8,94 +6,6 @@ public class SubscriptionTest public SubscriptionTest(ITestOutputHelper output) => _output = output; - [Fact] - public async Task Subscription_with_same_subject() - { - await using var server = new NatsServer(_output, TransportType.Tcp); - var nats1 = server.CreateClientConnection(); - var (nats2, proxy) = server.CreateProxiedClientConnection(); - - var sub1 = await nats2.SubscribeAsync("foo.bar"); - var sub2 = await nats2.SubscribeAsync("foo.bar"); - var sub3 = await nats2.SubscribeAsync("foo.baz"); - - var sync1 = 0; - var sync2 = 0; - var sync3 = 0; - var count = new WaitSignal(3); - - sub1.Register(m => - { - if (m.Data == 0) - { - Interlocked.Exchange(ref sync1, 1); - return; - } - - count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}")); - }); - - sub2.Register(m => - { - if (m.Data == 0) - { - Interlocked.Exchange(ref sync2, 1); - return; - } - - count.Pulse(m.Subject == "foo.bar" ? null : new Exception($"Subject mismatch {m.Subject}")); - }); - - sub3.Register(m => - { - if (m.Data == 0) - { - Interlocked.Exchange(ref sync3, 1); - return; - } - - count.Pulse(m.Subject == "foo.baz" ? null : new Exception($"Subject mismatch {m.Subject}")); - }); - - // Since subscription and publishing are sent through different connections there is - // a race where one or more subscriptions are made after the publishing happens. - // So, we make sure subscribers are accepted by the server before we send any test data. - await RetryUntil( - "all subscriptions are active", - () => Volatile.Read(ref sync1) + Volatile.Read(ref sync2) + Volatile.Read(ref sync3) == 3, - async () => - { - await nats1.PublishAsync("foo.bar", 0); - await nats1.PublishAsync("foo.baz", 0); - }); - - await nats1.PublishAsync("foo.bar", 1); - await nats1.PublishAsync("foo.baz", 1); - - // Wait until we received all test data - await count; - - var frames = proxy.ClientFrames.OrderBy(f => f.Message).ToList(); - - foreach (var frame in frames) - { - _output.WriteLine($"[PROXY] {frame}"); - } - - Assert.Equal(3, frames.Count); - Assert.StartsWith("SUB foo.bar", frames[0].Message); - Assert.StartsWith("SUB foo.bar", frames[1].Message); - Assert.StartsWith("SUB foo.baz", frames[2].Message); - Assert.False(frames[0].Message.Equals(frames[1].Message), "Should have different SIDs"); - - await sub1.DisposeAsync(); - await sub2.DisposeAsync(); - await sub3.DisposeAsync(); - await nats1.DisposeAsync(); - await nats2.DisposeAsync(); - proxy.Dispose(); - } - [Fact] public async Task Subscription_periodic_cleanup_test() { @@ -107,7 +17,7 @@ async Task Isolator() { var sub = await nats.SubscribeAsync("foo"); - await RetryUntil( + await Retry.Until( "unsubscribed", () => proxy.ClientFrames.Count(f => f.Message.StartsWith("SUB")) == 1); @@ -119,7 +29,7 @@ await RetryUntil( GC.Collect(); - await RetryUntil( + await Retry.Until( "unsubscribe message received", () => proxy.ClientFrames.Count(f => f.Message.StartsWith("UNSUB")) == 1); } @@ -137,7 +47,7 @@ async Task Isolator() { var sub = await nats.SubscribeAsync("foo"); - await RetryUntil("unsubscribed", () => proxy.ClientFrames.Count(f => f.Message.StartsWith("SUB")) == 1); + await Retry.Until("unsubscribed", () => proxy.ClientFrames.Count(f => f.Message.StartsWith("SUB")) == 1); // subscription object will be eligible for GC after next statement Assert.Equal("foo", sub.Subject); @@ -148,25 +58,9 @@ async Task Isolator() GC.Collect(); // Publish should trigger UNSUB since NatsSub object should be collected by now. - await RetryUntil( + await Retry.Until( "unsubscribe message received", () => proxy.ClientFrames.Count(f => f.Message.StartsWith("UNSUB")) == 1, async () => await nats.PublishAsync("foo", 1)); } - - private async Task RetryUntil(string reason, Func condition, Func? action = null, TimeSpan? timeout = null) - { - timeout ??= TimeSpan.FromSeconds(10); - var stopwatch = Stopwatch.StartNew(); - while (stopwatch.Elapsed < timeout) - { - if (action != null) - await action(); - if (condition()) - return; - await Task.Delay(50); - } - - throw new TimeoutException($"Took too long ({timeout}) waiting for {reason}"); - } } diff --git a/tests/NATS.Client.Core.Tests/_NatsServer.cs b/tests/NATS.Client.Core.Tests/_NatsServer.cs index a4132fd55..2d18abb10 100644 --- a/tests/NATS.Client.Core.Tests/_NatsServer.cs +++ b/tests/NATS.Client.Core.Tests/_NatsServer.cs @@ -8,10 +8,23 @@ namespace NATS.Client.Core.Tests; +public static class ServerVersions +{ +#pragma warning disable SA1310 +#pragma warning disable SA1401 + + // Changed INFO port reporting for WS connections (nats-server #4255) + public static Version V2_9_19 = new("2.9.19"); + +#pragma warning restore SA1401 +#pragma warning restore SA1310 +} + public class NatsServer : IAsyncDisposable { private static readonly string Ext = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? ".exe" : string.Empty; private static readonly string NatsServerPath = $"nats-server{Ext}"; + private static readonly Version Version; private readonly CancellationTokenSource _cancellationTokenSource = new(); private readonly string? _configFileName; @@ -21,6 +34,25 @@ public class NatsServer : IAsyncDisposable private readonly TransportType _transportType; private int _disposed; + static NatsServer() + { + var process = new Process + { + StartInfo = new ProcessStartInfo + { + FileName = NatsServerPath, + Arguments = "-v", + RedirectStandardOutput = true, + UseShellExecute = false, + }, + }; + process.Start(); + process.WaitForExit(); + var output = process.StandardOutput.ReadToEnd(); + var value = Regex.Match(output, @"v(\d+\.\d+\.\d+)").Groups[1].Value; + Version = new Version(value); + } + public NatsServer() : this(new NullOutputHelper(), TransportType.Tcp) { @@ -91,6 +123,21 @@ public NatsServer(ITestOutputHelper outputHelper, TransportType transportType, N _ => throw new ArgumentOutOfRangeException(), }; + public int ConnectionPort + { + get + { + if (_transportType == TransportType.WebSocket && ServerVersions.V2_9_19 <= Version) + { + return Options.WebSocketPort!.Value; + } + else + { + return Options.ServerPort; + } + } + } + public async ValueTask DisposeAsync() { if (Interlocked.Increment(ref _disposed) != 1) @@ -334,7 +381,16 @@ public IReadOnlyList Frames private bool NatsProtoDump(int client, string origin, TextReader sr, TextWriter sw) { - var message = sr.ReadLine(); + string? message; + try + { + message = sr.ReadLine(); + } + catch + { + return false; + } + if (message == null) return false; if (Regex.IsMatch(message, @"^(INFO|CONNECT|PING|PONG|UNSUB|SUB|\+OK|-ERR)")) @@ -369,14 +425,11 @@ private bool NatsProtoDump(int client, string origin, TextReader sr, TextWriter case >= ' ' and <= '~': sb.Append(c); break; - case '\t': - sb.Append("\\t"); - break; case '\n': - sb.Append("\\n"); + sb.Append('␊'); break; case '\r': - sb.Append("\\r"); + sb.Append('␍'); break; default: sb.Append('.'); @@ -389,7 +442,7 @@ private bool NatsProtoDump(int client, string origin, TextReader sr, TextWriter sw.Flush(); if (client > 0) - AddFrame(new Frame(client, origin, Message: $"{message}\\r\\n{sb}")); + AddFrame(new Frame(client, origin, Message: $"{message}␍␊{sb}")); return true; } diff --git a/tests/NATS.Client.Core.Tests/_Utils.cs b/tests/NATS.Client.Core.Tests/_Utils.cs new file mode 100644 index 000000000..e1095c807 --- /dev/null +++ b/tests/NATS.Client.Core.Tests/_Utils.cs @@ -0,0 +1,45 @@ +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; + +namespace NATS.Client.Core.Tests; + +public static class Retry +{ + public static async Task Until(string reason, Func condition, Func? action = null, TimeSpan? timeout = null, TimeSpan? retryDelay = null) + { + timeout ??= TimeSpan.FromSeconds(10); + var delay1 = retryDelay ?? TimeSpan.FromSeconds(.1); + + var stopwatch = Stopwatch.StartNew(); + while (stopwatch.Elapsed < timeout) + { + if (action != null) + await action(); + if (condition()) + return; + await Task.Delay(delay1); + } + + throw new TimeoutException($"Took too long ({timeout}) waiting until {reason}"); + } +} + +public static class Net +{ + public static void WaitForTcpPortToClose(int port) + { + while (true) + { + try + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.Connect(IPAddress.Loopback, port); + } + catch (SocketException) + { + return; + } + } + } +} diff --git a/tests/NATS.Client.Core.Tests/_WaitSignal.cs b/tests/NATS.Client.Core.Tests/_WaitSignal.cs index f6a835c95..8ff2acbbf 100644 --- a/tests/NATS.Client.Core.Tests/_WaitSignal.cs +++ b/tests/NATS.Client.Core.Tests/_WaitSignal.cs @@ -85,3 +85,53 @@ public TaskAwaiter GetAwaiter() return _tcs.Task.WaitAsync(_timeout).GetAwaiter(); } } + +public class WaitSignal +{ + private TimeSpan _timeout; + private int _count; + private TaskCompletionSource _tcs; + + public WaitSignal() + : this(TimeSpan.FromSeconds(10)) + { + } + + public WaitSignal(int count) + : this(TimeSpan.FromSeconds(10), count) + { + } + + public WaitSignal(TimeSpan timeout, int count = 1) + { + _timeout = timeout; + _count = count; + _tcs = new TaskCompletionSource(); + } + + public TimeSpan Timeout => _timeout; + + public Task Task => _tcs.Task; + + public void Pulse(T result, Exception? exception = null) + { + if (exception == null) + { + if (Interlocked.Decrement(ref _count) > 0) + { + return; + } + + _tcs.TrySetResult(result); + } + else + { + _tcs.TrySetException(exception); + } + } + + public TaskAwaiter GetAwaiter() + { + return _tcs.Task.WaitAsync(_timeout).GetAwaiter(); + } +}