diff --git a/src/OpenTelemetry.Sampler.AWS/AWSXRaySamplerClient.cs b/src/OpenTelemetry.Sampler.AWS/AWSXRaySamplerClient.cs index a99b090261..7c94c844a6 100644 --- a/src/OpenTelemetry.Sampler.AWS/AWSXRaySamplerClient.cs +++ b/src/OpenTelemetry.Sampler.AWS/AWSXRaySamplerClient.cs @@ -124,17 +124,26 @@ private void Dispose(bool disposing) private async Task DoRequestAsync(string endpoint, HttpRequestMessage request) { + // 1 MB is well above any legitimate X-Ray sampling rules/targets + // response while still protecting against unbounded reads. + const int maxResponseSizeInBytes = 1024 * 1024; + try { - var response = await this.httpClient.SendAsync(request).ConfigureAwait(false); + // Use ResponseHeadersRead so the response body is streamed rather + // than buffered entirely in memory, allowing LimitedStream to + // enforce the cap during download. + using var response = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { AWSSamplerEventSource.Log.FailedToGetSuccessResponse(endpoint, response.StatusCode.ToString()); return string.Empty; } - var responseString = await response.Content.ReadAsStringAsync().ConfigureAwait(false); - return responseString; + var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); + using var limitedStream = new LimitedStream(stream, maxResponseSizeInBytes); + using var reader = new StreamReader(limitedStream); + return await reader.ReadToEndAsync().ConfigureAwait(false); } catch (Exception ex) { diff --git a/src/OpenTelemetry.Sampler.AWS/CHANGELOG.md b/src/OpenTelemetry.Sampler.AWS/CHANGELOG.md index dad2c485bf..492a8de2d7 100644 --- a/src/OpenTelemetry.Sampler.AWS/CHANGELOG.md +++ b/src/OpenTelemetry.Sampler.AWS/CHANGELOG.md @@ -4,6 +4,8 @@ * Updated OpenTelemetry core component version(s) to `1.15.2`. ([#4080](https://github.com/open-telemetry/opentelemetry-dotnet-contrib/pull/4080)) +* Limit the max size read for response body getting the sampling rules to 1MB. + ([#4100](https://github.com/open-telemetry/opentelemetry-dotnet-contrib/pull/4100)) ## 0.1.0-alpha.7 diff --git a/src/OpenTelemetry.Sampler.AWS/LimitedStream.cs b/src/OpenTelemetry.Sampler.AWS/LimitedStream.cs new file mode 100644 index 0000000000..92337dfe89 --- /dev/null +++ b/src/OpenTelemetry.Sampler.AWS/LimitedStream.cs @@ -0,0 +1,109 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +namespace OpenTelemetry.Sampler.AWS; + +/// +/// A read-only stream wrapper that throws +/// if the underlying stream exceeds a configured maximum number of bytes. +/// This protects against denial-of-service when reading from untrusted HTTP responses. +/// +internal sealed class LimitedStream : Stream +{ + private readonly Stream innerStream; + private readonly long maxBytes; + private long totalBytesRead; + + public LimitedStream(Stream innerStream, long maxBytes) + { + this.innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream)); + + if (maxBytes <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxBytes), maxBytes, "Value must be greater than zero."); + } + + this.maxBytes = maxBytes; + } + + public override bool CanRead => this.innerStream.CanRead; + + public override bool CanSeek => false; + + public override bool CanWrite => false; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + var remaining = this.maxBytes - this.totalBytesRead; + if (remaining <= 0) + { + // Allowance exhausted - signal EOF so callers stop reading. + return 0; + } + + var clampedCount = (int)Math.Min(count, remaining); + var bytesRead = this.innerStream.Read(buffer, offset, clampedCount); + this.totalBytesRead += bytesRead; + return bytesRead; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { +#if NET + return await this.ReadAsync(buffer.AsMemory(offset, count), cancellationToken).ConfigureAwait(false); +#else + var remaining = this.maxBytes - this.totalBytesRead; + if (remaining <= 0) + { + return 0; + } + + var clampedCount = (int)Math.Min(count, remaining); + var bytesRead = await this.innerStream.ReadAsync(buffer, offset, clampedCount, cancellationToken).ConfigureAwait(false); + this.totalBytesRead += bytesRead; + return bytesRead; +#endif + } + +#if NET + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var remaining = this.maxBytes - this.totalBytesRead; + if (remaining <= 0) + { + return 0; + } + + var clampedLength = (int)Math.Min(buffer.Length, remaining); + var bytesRead = await this.innerStream.ReadAsync(buffer[..clampedLength], cancellationToken).ConfigureAwait(false); + this.totalBytesRead += bytesRead; + return bytesRead; + } +#endif + + public override void Flush() => this.innerStream.Flush(); + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (disposing) + { + this.innerStream.Dispose(); + } + + base.Dispose(disposing); + } +} diff --git a/test/OpenTelemetry.Sampler.AWS.Tests/TestAWSXRaySamplerClient.cs b/test/OpenTelemetry.Sampler.AWS.Tests/TestAWSXRaySamplerClient.cs index 4523149e28..88a4f15638 100644 --- a/test/OpenTelemetry.Sampler.AWS.Tests/TestAWSXRaySamplerClient.cs +++ b/test/OpenTelemetry.Sampler.AWS.Tests/TestAWSXRaySamplerClient.cs @@ -163,6 +163,18 @@ public async Task TestGetSamplingTargetsWithMalformed() Assert.Null(targetsResponse); } + [Fact] + public async Task TestGetSamplingRulesWithOversizedResponse() + { + // Create a response larger than the 1 MB limit enforced by DoRequestAsync. + var oversizedPayload = new string('x', (1024 * 1024) + 1); + this.requestHandler.SetResponse("/GetSamplingRules", oversizedPayload); + + var rules = await this.client.GetSamplingRules(); + + Assert.Empty(rules); + } + private void CreateResponse(string endpoint, string filePath) { var mockResponse = File.ReadAllText(filePath); diff --git a/test/OpenTelemetry.Sampler.AWS.Tests/TestLimitedStreamReader.cs b/test/OpenTelemetry.Sampler.AWS.Tests/TestLimitedStreamReader.cs new file mode 100644 index 0000000000..42a0d4084a --- /dev/null +++ b/test/OpenTelemetry.Sampler.AWS.Tests/TestLimitedStreamReader.cs @@ -0,0 +1,118 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +using System.Text; +using Xunit; + +namespace OpenTelemetry.Sampler.AWS.Tests; + +public class TestLimitedStreamReader +{ + [Fact] + public async Task ReadWithinLimitSucceeds() + { + var data = Encoding.UTF8.GetBytes("hello"); + using var inner = new MemoryStream(data); + using var limited = new LimitedStream(inner, maxBytes: 1024); + using var reader = new StreamReader(limited); + + var result = await reader.ReadToEndAsync(); + + Assert.Equal("hello", result); + } + + [Fact] + public async Task ReadExactlyAtLimitSucceeds() + { + var data = Encoding.UTF8.GetBytes("12345"); + using var inner = new MemoryStream(data); + using var limited = new LimitedStream(inner, maxBytes: 5); + using var reader = new StreamReader(limited); + + var result = await reader.ReadToEndAsync(); + + Assert.Equal("12345", result); + } + + [Fact] + public async Task ReadExceedingLimitTruncates() + { + var data = Encoding.UTF8.GetBytes(new string('x', 2048)); + using var inner = new MemoryStream(data); + using var limited = new LimitedStream(inner, maxBytes: 1024); + using var reader = new StreamReader(limited); + + var result = await reader.ReadToEndAsync(); + + Assert.Equal(1024, result.Length); + } + + [Fact] + public void SyncReadClampsThenReturnsZero() + { + var data = Encoding.UTF8.GetBytes(new string('x', 2048)); + using var inner = new MemoryStream(data); + using var limited = new LimitedStream(inner, maxBytes: 1024); + + var buffer = new byte[2048]; + + // First read is clamped to 1024 bytes. + var bytesRead = limited.Read(buffer, 0, buffer.Length); + Assert.Equal(1024, bytesRead); + + // Second read returns 0 (EOF) because the allowance is exhausted. + bytesRead = limited.Read(buffer, 0, buffer.Length); + Assert.Equal(0, bytesRead); + } + + [Fact] + public void SyncReadClampsToRemainingAllowance() + { + var data = Encoding.UTF8.GetBytes(new string('x', 200)); + using var inner = new MemoryStream(data); + using var limited = new LimitedStream(inner, maxBytes: 100); + + var buffer = new byte[200]; + var bytesRead = limited.Read(buffer, 0, buffer.Length); + + Assert.Equal(100, bytesRead); + } + + [Fact] + public void CannotWrite() + { + using var inner = new MemoryStream(); + using var limited = new LimitedStream(inner, maxBytes: 1024); + + Assert.False(limited.CanWrite); + Assert.Throws( + () => limited.Write(new byte[1], 0, 1)); + } + + [Fact] + public void CannotSeek() + { + using var inner = new MemoryStream(); + using var limited = new LimitedStream(inner, maxBytes: 1024); + + Assert.False(limited.CanSeek); + Assert.Throws( + () => limited.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void ThrowsOnNullInnerStream() + { + Assert.Throws(() => new LimitedStream(null!, maxBytes: 1024)); + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + [InlineData(-100)] + public void ThrowsOnInvalidMaxBytes(long maxBytes) + { + using var inner = new MemoryStream(); + Assert.Throws(() => new LimitedStream(inner, maxBytes)); + } +}