diff --git a/src/OpenTelemetry.Api/Context/Propagation/TraceContextPropagator.cs b/src/OpenTelemetry.Api/Context/Propagation/TraceContextPropagator.cs index 111f5caec09..7c9ce6fe3c9 100644 --- a/src/OpenTelemetry.Api/Context/Propagation/TraceContextPropagator.cs +++ b/src/OpenTelemetry.Api/Context/Propagation/TraceContextPropagator.cs @@ -56,15 +56,11 @@ public override PropagationContext Extract(PropagationContext context, T carr try { - var traceparentCollection = getter(carrier, TraceParent); - - // There must be a single traceparent - if (traceparentCollection == null || traceparentCollection.Count() != 1) + if (!TryGetSingleValue(getter(carrier, TraceParent), out var traceparent)) { return context; } - var traceparent = traceparentCollection.First(); var traceparentParsed = TryExtractTraceparent(traceparent, out var traceId, out var spanId, out var traceoptions); if (!traceparentParsed) @@ -73,10 +69,10 @@ public override PropagationContext Extract(PropagationContext context, T carr } string? tracestate = null; - var tracestateCollection = getter(carrier, TraceState); - if (tracestateCollection?.Any() ?? false) + TryExtractTracestate(getter(carrier, TraceState), out var extractedTracestate, out var hasTraceState); + if (hasTraceState) { - TryExtractTracestate([.. tracestateCollection], out tracestate); + tracestate = extractedTracestate; } return new PropagationContext( @@ -228,94 +224,322 @@ internal static bool TryExtractTraceparent(string traceparent, out ActivityTrace return true; } - internal static bool TryExtractTracestate(string[] tracestateCollection, out string tracestateResult) + internal static bool TryExtractTracestate(string[]? tracestateCollection, out string tracestateResult) + => TryExtractTracestate((IEnumerable?)tracestateCollection, out tracestateResult); + + internal static bool TryExtractTracestate(IEnumerable? tracestateCollection, out string tracestateResult) + => TryExtractTracestate(tracestateCollection, out tracestateResult, out _); + + private static bool TryExtractTracestate(IEnumerable? tracestateCollection, out string tracestateResult, out bool hasTraceState) { tracestateResult = string.Empty; + hasTraceState = false; + + if (tracestateCollection == null) + { + return true; + } - if (tracestateCollection != null) + if (tracestateCollection is IList list) { - var keySet = new HashSet(); - var result = new StringBuilder(); - for (var i = 0; i < tracestateCollection.Length; ++i) + if (list.Count == 0) { - var tracestate = tracestateCollection[i].AsSpan(); - var begin = 0; - while (begin < tracestate.Length) + return true; + } + + hasTraceState = true; + if (list.Count == 1) + { + return TryExtractSingleTracestate(list[0], out tracestateResult); + } + + return TryExtractMultipleTracestate(list, out tracestateResult); + } + + if (tracestateCollection is IReadOnlyList readOnlyList) + { + if (readOnlyList.Count == 0) + { + return true; + } + + hasTraceState = true; + if (readOnlyList.Count == 1) + { + return TryExtractSingleTracestate(readOnlyList[0], out tracestateResult); + } + + return TryExtractMultipleTracestate(readOnlyList, out tracestateResult); + } + + using var enumerator = tracestateCollection.GetEnumerator(); + if (!enumerator.MoveNext()) + { + return true; + } + + hasTraceState = true; + var singleTraceState = enumerator.Current; + if (!enumerator.MoveNext()) + { + return TryExtractSingleTracestate(singleTraceState, out tracestateResult); + } + + return TryExtractMultipleTracestate(EnumerateFrom(singleTraceState, enumerator), out tracestateResult); + } + + private static IEnumerable EnumerateFrom(string first, IEnumerator enumerator) + { + yield return first; + + do + { + yield return enumerator.Current; + } + while (enumerator.MoveNext()); + } + + private static bool TryExtractMultipleTracestate(IEnumerable tracestateCollection, out string tracestateResult) + { + var keySet = new HashSet(); + var result = new StringBuilder(); + + foreach (var tracestateEntry in tracestateCollection) + { + var tracestate = tracestateEntry.AsSpan(); + var begin = 0; + while (begin < tracestate.Length) + { + ReadOnlySpan listMember; + + var length = tracestate.Slice(begin).IndexOf(','); + if (length != -1) { - var length = tracestate.Slice(begin).IndexOf(','); - ReadOnlySpan listMember; - if (length != -1) - { - listMember = tracestate.Slice(begin, length).Trim(); - begin += length + 1; - } - else - { - listMember = tracestate.Slice(begin).Trim(); - begin = tracestate.Length; - } + listMember = tracestate.Slice(begin, length).Trim(); + begin += length + 1; + } + else + { + listMember = tracestate.Slice(begin).Trim(); + begin = tracestate.Length; + } - // https://github.com/w3c/trace-context/blob/master/spec/20-http_request_header_format.md#tracestate-header-field-values - if (listMember.IsEmpty) - { - // Empty and whitespace - only list members are allowed. - // Vendors MUST accept empty tracestate headers but SHOULD avoid sending them. - continue; - } + // https://github.com/w3c/trace-context/blob/master/spec/20-http_request_header_format.md#tracestate-header-field-values + if (listMember.IsEmpty) + { + // Empty and whitespace - only list members are allowed. + // Vendors MUST accept empty tracestate headers but SHOULD avoid sending them. + continue; + } - if (keySet.Count >= 32) - { - // https://github.com/w3c/trace-context/blob/master/spec/20-http_request_header_format.md#list - // test_tracestate_member_count_limit - return false; - } + if (keySet.Count >= 32) + { + // https://github.com/w3c/trace-context/blob/master/spec/20-http_request_header_format.md#list + // test_tracestate_member_count_limit + tracestateResult = string.Empty; + return false; + } - var keyLength = listMember.IndexOf('='); - if (keyLength == listMember.Length || keyLength == -1) - { - // Missing key or value in tracestate - return false; - } + var keyLength = listMember.IndexOf('='); + if (keyLength == listMember.Length || keyLength == -1) + { + // Missing key or value in tracestate + tracestateResult = string.Empty; + return false; + } - var key = listMember.Slice(0, keyLength); - if (!ValidateKey(key)) - { - // test_tracestate_key_illegal_characters in https://github.com/w3c/trace-context/blob/master/test/test.py - // test_tracestate_key_length_limit - // test_tracestate_key_illegal_vendor_format - return false; - } + var key = listMember.Slice(0, keyLength); + if (!ValidateKey(key)) + { + // test_tracestate_key_illegal_characters in https://github.com/w3c/trace-context/blob/master/test/test.py + // test_tracestate_key_length_limit + // test_tracestate_key_illegal_vendor_format + tracestateResult = string.Empty; + return false; + } - var value = listMember.Slice(keyLength + 1); - if (!ValidateValue(value)) - { - // test_tracestate_value_illegal_characters - return false; - } + var value = listMember.Slice(keyLength + 1); + if (!ValidateValue(value)) + { + // test_tracestate_value_illegal_characters + tracestateResult = string.Empty; + return false; + } - // ValidateKey() call above has ensured the key does not contain upper case letters. - if (!keySet.Add(key.ToString())) - { - // test_tracestate_duplicated_keys - return false; - } + // ValidateKey() call above has ensured the key does not contain upper case letters. + if (!keySet.Add(key.ToString())) + { + // test_tracestate_duplicated_keys + tracestateResult = string.Empty; + return false; + } - if (result.Length > 0) - { - result.Append(','); - } + if (result.Length > 0) + { + result.Append(','); + } #if NET - result.Append(listMember); + result.Append(listMember); #else - result.Append(listMember.ToString()); + result.Append(listMember.ToString()); #endif + } + } + + tracestateResult = result.ToString(); + return true; + } + + private static bool TryExtractSingleTracestate(string tracestate, out string tracestateResult) + { + tracestateResult = string.Empty; + + if (tracestate.Length == 0) + { + return true; + } + + var tracestateSpan = tracestate.AsSpan(); + + const int Limit = 32; + + Span memberStarts = stackalloc int[Limit]; + Span memberLengths = stackalloc int[Limit]; + Span keyLengths = stackalloc int[Limit]; + Span keyHashes = stackalloc int[Limit]; + + var memberCount = 0; + var totalLength = 0; + var normalized = false; + var begin = 0; + + while (begin < tracestateSpan.Length) + { + var end = begin; + while (end < tracestateSpan.Length && tracestateSpan[end] != ',') + { + end++; + } + + var memberStart = begin; + var memberEnd = end; + + while (memberStart < memberEnd && char.IsWhiteSpace(tracestateSpan[memberStart])) + { + memberStart++; + } + + while (memberEnd > memberStart && char.IsWhiteSpace(tracestateSpan[memberEnd - 1])) + { + memberEnd--; + } + + if (memberStart != begin || memberEnd != end) + { + normalized = true; + } + + var memberLength = memberEnd - memberStart; + if (memberLength > 0) + { + if (memberCount >= Limit) + { + return false; + } + + var listMember = tracestateSpan.Slice(memberStart, memberLength); + var keyLength = listMember.IndexOf('='); + if (keyLength == listMember.Length || keyLength == -1) + { + return false; + } + + var key = listMember.Slice(0, keyLength); + if (!ValidateKey(key)) + { + return false; + } + + var value = listMember.Slice(keyLength + 1); + if (!ValidateValue(value)) + { + return false; + } + + var useHashedDuplicateCheck = keyLength <= Limit; + var keyHash = 0; + if (useHashedDuplicateCheck) + { + keyHash = GetKeyHashCode(key); + for (var i = 0; i < memberCount; i++) + { + if (keyHashes[i] != keyHash || keyLengths[i] != keyLength) + { + continue; + } + + if (key.SequenceEqual(tracestateSpan.Slice(memberStarts[i], keyLength))) + { + return false; + } + } + } + else + { + for (var i = 0; i < memberCount; i++) + { + if (keyLengths[i] == keyLength && + key.SequenceEqual(tracestateSpan.Slice(memberStarts[i], keyLength))) + { + return false; + } + } } + + memberStarts[memberCount] = memberStart; + memberLengths[memberCount] = memberLength; + keyLengths[memberCount] = keyLength; + keyHashes[memberCount] = keyHash; + + memberCount++; + totalLength += memberLength; + } + else + { + normalized = true; + } + + begin = end + 1; + } + + if (!normalized && memberCount > 0 && totalLength + memberCount - 1 == tracestate.Length) + { + tracestateResult = tracestate; + return true; + } + + if (memberCount == 0) + { + return true; + } + + var result = new StringBuilder(totalLength + memberCount - 1); + for (var i = 0; i < memberCount; i++) + { + if (i > 0) + { + result.Append(','); } - tracestateResult = result.ToString(); +#if NET + result.Append(tracestateSpan.Slice(memberStarts[i], memberLengths[i])); +#else + result.Append(tracestate.Substring(memberStarts[i], memberLengths[i])); +#endif } + tracestateResult = result.ToString(); return true; } @@ -326,6 +550,72 @@ private static byte HexCharToByte(char c) ? (byte)(c - 'a' + 10) : throw new ArgumentOutOfRangeException(nameof(c), c, "Must be within: [0-9] or [a-f]"); + private static int GetKeyHashCode(ReadOnlySpan key) + { +#if NET + HashCode hash = default; + + for (var i = 0; i < key.Length; i++) + { + hash.Add(key[i]); + } + + return hash.ToHashCode(); +#else + unchecked + { + var hash = (int)2166136261; + for (var i = 0; i < key.Length; i++) + { + hash = (hash ^ key[i]) * 16777619; + } + + return hash; + } +#endif + } + + private static bool TryGetSingleValue(IEnumerable? values, out string value) + { + value = string.Empty; + + if (values == null) + { + return false; + } + + if (values is IList list) + { + if (list.Count != 1) + { + return false; + } + + value = list[0]; + return true; + } + + if (values is IReadOnlyList readOnlyList) + { + if (readOnlyList.Count != 1) + { + return false; + } + + value = readOnlyList[0]; + return true; + } + + using var enumerator = values.GetEnumerator(); + if (!enumerator.MoveNext()) + { + return false; + } + + value = enumerator.Current; + return !enumerator.MoveNext(); + } + private static bool ValidateKey(ReadOnlySpan key) { // This implementation follows Trace Context v1 which has W3C Recommendation. diff --git a/test/OpenTelemetry.Tests/Trace/Propagation/TraceContextPropagatorTests.cs b/test/OpenTelemetry.Tests/Trace/Propagation/TraceContextPropagatorTests.cs index 5e1364989bf..674af043a7f 100644 --- a/test/OpenTelemetry.Tests/Trace/Propagation/TraceContextPropagatorTests.cs +++ b/test/OpenTelemetry.Tests/Trace/Propagation/TraceContextPropagatorTests.cs @@ -1,6 +1,7 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 +using System.Collections; using System.Diagnostics; using Xunit; @@ -13,6 +14,8 @@ public class TraceContextPropagatorTests private const string TraceId = "0af7651916cd43dd8448eb211c80319c"; private const string SpanId = "b9c7c989f97918e1"; + private static readonly IEnumerable Empty = []; + private static readonly Func, string, IEnumerable> Getter = static (headers, name) => headers.TryGetValue(name, out var value) ? [value] : []; @@ -121,6 +124,147 @@ public void TracestateToString() Assert.Equal("k1=v1,k2=v2,k3=v3", ctx.ActivityContext.TraceState); } + [Fact] + public void Extract_SupportsReadOnlyListCarrierValues() + { + var headers = new Dictionary + { + [TraceParent] = new([$"00-{TraceId}-{SpanId}-01"]), + [TraceState] = new(["k1=v1"]), + }; + + var target = new TraceContextPropagator(); + var actual = target.Extract(default, headers, static (carrier, name) => + carrier.TryGetValue(name, out var value) ? value : new ReadOnlyCarrierValues([])); + + Assert.Equal(ActivityTraceId.CreateFromString(TraceId.AsSpan()), actual.ActivityContext.TraceId); + Assert.Equal(ActivitySpanId.CreateFromString(SpanId.AsSpan()), actual.ActivityContext.SpanId); + Assert.Equal("k1=v1", actual.ActivityContext.TraceState); + } + + [Fact] + public void Extract_SupportsEnumerableCarrierValues() + { + var headers = new Dictionary + { + [TraceParent] = new([$"00-{TraceId}-{SpanId}-01"]), + [TraceState] = new([" k1=v1 , k2=v2 "]), + }; + + var target = new TraceContextPropagator(); + var actual = target.Extract(default, headers, static (carrier, name) => + carrier.TryGetValue(name, out var value) ? value : new EnumerableCarrierValues([])); + + Assert.Equal(ActivityTraceId.CreateFromString(TraceId.AsSpan()), actual.ActivityContext.TraceId); + Assert.Equal(ActivitySpanId.CreateFromString(SpanId.AsSpan()), actual.ActivityContext.SpanId); + Assert.Equal("k1=v1,k2=v2", actual.ActivityContext.TraceState); + } + + [Fact] + public void Extract_EnumeratesEnumerableTracestateValuesOnce() + { + var tracestateValues = new SingleUseEnumerableCarrierValues(" k1=v1 , k2=v2 "); + var headers = new Dictionary> + { + [TraceParent] = new EnumerableCarrierValues($"00-{TraceId}-{SpanId}-01"), + [TraceState] = tracestateValues, + }; + + var target = new TraceContextPropagator(); + var actual = target.Extract(default, headers, static (carrier, name) => + carrier.TryGetValue(name, out var value) ? value : Empty); + + Assert.Equal(ActivityTraceId.CreateFromString(TraceId.AsSpan()), actual.ActivityContext.TraceId); + Assert.Equal(ActivitySpanId.CreateFromString(SpanId.AsSpan()), actual.ActivityContext.SpanId); + Assert.Equal("k1=v1,k2=v2", actual.ActivityContext.TraceState); + Assert.Equal(1, tracestateValues.EnumerationCount); + } + + [Fact] + public void Extract_IgnoresMultipleEnumerableTraceparentValues() + { + var headers = new Dictionary + { + [TraceParent] = new([$"00-{TraceId}-{SpanId}-01", $"00-{TraceId}-{SpanId}-00"]), + }; + + var target = new TraceContextPropagator(); + var context = target.Extract(default, headers, static (carrier, name) => + carrier.TryGetValue(name, out var value) ? value : new EnumerableCarrierValues([])); + + Assert.False(context.ActivityContext.IsValid()); + } + + [Fact] + public void Extract_IgnoresEmptyEnumerableTracestateValues() + { + var headers = new Dictionary + { + [TraceParent] = new([$"00-{TraceId}-{SpanId}-01"]), + [TraceState] = new([]), + }; + + var target = new TraceContextPropagator(); + var context = target.Extract(default, headers, static (carrier, name) => + carrier.TryGetValue(name, out var value) ? value : new EnumerableCarrierValues([])); + + Assert.Equal(ActivityTraceId.CreateFromString(TraceId.AsSpan()), context.ActivityContext.TraceId); + Assert.Null(context.ActivityContext.TraceState); + } + + [Fact] + public void TryExtractTracestate_SingleHeaderReturnsOriginalString() + { + Assert.True(TraceContextPropagator.TryExtractTracestate(["k1=v1,k2=v2"], out var actual)); + Assert.Equal("k1=v1,k2=v2", actual); + } + + [Fact] + public void TryExtractTracestate_SingleHeaderReturnsEmptyForWhitespaceOnly() + { + Assert.True(TraceContextPropagator.TryExtractTracestate([" , "], out var actual)); + Assert.Empty(actual); + } + + [Fact] + public void TryExtractTracestate_SingleHeaderRejectsTooManyMembers() + { + var tracestate = string.Join(",", Enumerable.Range(1, 33).Select(static i => $"k{i:D2}=v{i:D2}")); + + Assert.False(TraceContextPropagator.TryExtractTracestate([tracestate], out _)); + } + + [Fact] + public void TryExtractTracestate_SingleHeaderRejectsDuplicateLongKeys() + { + var key = new string('a', 33); + + Assert.False(TraceContextPropagator.TryExtractTracestate([$"{key}=1,{key}=2"], out _)); + } + + [Fact] + public async Task Extract_DoesNotHangWhenLaterKeyAppearsInsideEarlierValue() + { + // Regression test for GHSA-8785-wc3w-h8q6 + const string tracestate = "foo1=foo2,foo2=1"; + + var deadline = TimeSpan.FromSeconds(1); + + var extractionTask = Task.Run(() => CallTraceContextPropagator(tracestate)); + var completedTask = await Task.WhenAny(extractionTask, Task.Delay(deadline)); + + Assert.True(extractionTask.IsCompleted, $"The task did not complete within {deadline}."); + Assert.Same(extractionTask, completedTask); + Assert.Equal(tracestate, await extractionTask); + } + + [Fact] + public void TryExtractTracestate_NullCollectionReturnsEmpty() + { + Assert.True(TraceContextPropagator.TryExtractTracestate((IEnumerable?)null, out var actual)); + Assert.Empty(actual); + } + [Fact] public void Inject_NoTracestate() { @@ -334,4 +478,54 @@ private static string CallTraceContextPropagator(string[] tracestate) Assert.NotNull(traceState); return traceState; } + + private sealed class ReadOnlyCarrierValues(params string[] values) : IReadOnlyList + { + public int Count => values.Length; + + public string this[int index] => values[index]; + + public IEnumerator GetEnumerator() + { + foreach (var value in values) + { + yield return value; + } + } + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + } + + private sealed class EnumerableCarrierValues(params string[] values) : IEnumerable + { + public IEnumerator GetEnumerator() + { + foreach (var value in values) + { + yield return value; + } + } + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + } + + private sealed class SingleUseEnumerableCarrierValues(params string[] values) : IEnumerable + { + public int EnumerationCount { get; private set; } + + public IEnumerator GetEnumerator() + { + if (this.EnumerationCount++ > 0) + { + throw new InvalidOperationException("Sequence was enumerated multiple times."); + } + + foreach (var value in values) + { + yield return value; + } + } + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + } }