Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Buffers;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
Expand Down Expand Up @@ -48,10 +49,18 @@ public HttpResponseAdapterFeature(IHttpResponseBodyFeature httpResponseBody)

void IHttpResponseBodyFeature.DisableBuffering()
{
if (_state == StreamState.NotStarted)
_responseBodyFeature.DisableBuffering();
_state = StreamState.NotBuffering;

// If anything is already buffered, we'll use a custom pipe that will
// clear out the buffer the next time flush is called since this method
// is not async
if (_bufferedStream is { })
{
_pipeWriter = new FlushingBufferedPipeWriter(this, _responseBodyFeature.Writer);
}
else
{
_state = StreamState.NotBuffering;
_responseBodyFeature.DisableBuffering();
_pipeWriter = _responseBodyFeature.Writer;
}
}
Expand Down Expand Up @@ -100,22 +109,35 @@ private async ValueTask FlushInternalAsync()
await _pipeWriter.FlushAsync();
}

if (_state is StreamState.Buffering && _bufferedStream is not null && !SuppressContent)
if (_state is StreamState.Buffering)
{
await DrainStreamAsync(default);
}
}

private async ValueTask DrainStreamAsync(CancellationToken token)
{
if (_bufferedStream is null)
{
return;
}

if (!SuppressContent)
{
if (_filter is { } filter)
{
await _bufferedStream.DrainBufferAsync(filter);
await _bufferedStream.DrainBufferAsync(filter, token);
await filter.DisposeAsync();
_filter = null;
}
else
{
await _bufferedStream.DrainBufferAsync(_responseBodyFeature.Stream);
await _bufferedStream.DrainBufferAsync(_responseBodyFeature.Stream, token);
}

await _bufferedStream.DisposeAsync();
_bufferedStream = null;
}

await _bufferedStream.DisposeAsync();
_bufferedStream = null;
}

Stream IHttpResponseBodyFeature.Stream => this;
Expand Down Expand Up @@ -286,4 +308,92 @@ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationTo

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
=> CurrentStream.WriteAsync(buffer, offset, count, cancellationToken);

/// <summary>
/// A <see cref="PipeWriter"/> that can flush any existing buffered items before writing next sequence of bytes
/// Intended to be used if <see cref="IHttpResponseBodyFeature.DisableBuffering"/> is called and data has been buffered
/// to ensure that the final output will be ordered correctly (since we can't asynchronously write the data in that call).
/// </summary>
/// <remarks>
/// Calls to <see cref="Advance(int)"/>, <see cref="GetSpan(int)"/>, <see cref="GetMemory(int)"/> must be called
/// in a group without calling <see cref="FlushAsync(CancellationToken)"/>. If not, then the call to <see cref="Advance(int)"/>
/// will potentially advance the inner pipe rather than the buffer.
/// </remarks>
private sealed class FlushingBufferedPipeWriter : PipeWriter
{
private readonly PipeWriter _other;

private HttpResponseAdapterFeature? _feature;
private ArrayBufferWriter<byte>? _buffer;

public FlushingBufferedPipeWriter(HttpResponseAdapterFeature feature, PipeWriter other)
{
_feature = feature;
_other = other;
}

public override void CancelPendingFlush() => _other.CancelPendingFlush();

public override void Complete(Exception? exception = null) => _other.Complete(exception);

public override async ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = default)
{
await FlushExistingDataAsync(cancellationToken);

return await _other.FlushAsync(cancellationToken);
}

private async ValueTask FlushExistingDataAsync(CancellationToken cancellationToken)
{
if (_feature is { })
{
await _feature.DrainStreamAsync(cancellationToken);
_feature = null;
}

if (_buffer is { })
{
await _other.WriteAsync(_buffer.WrittenMemory, cancellationToken);
_buffer = null;
}
}

public bool IsBuffered => _feature is { };

public override void Advance(int bytes)
{
if (_buffer is { })
{
_buffer.Advance(bytes);
}
else
{
_other.Advance(bytes);
}
}

public override Memory<byte> GetMemory(int sizeHint = 0)
{
if (IsBuffered)
{
return (_buffer ??= new()).GetMemory(sizeHint);
}
else
{
return _other.GetMemory(sizeHint);
}
}

public override Span<byte> GetSpan(int sizeHint = 0)
{
if (IsBuffered)
{
return (_buffer ??= new()).GetSpan(sizeHint);
}
else
{
return _other.GetSpan(sizeHint);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SystemWebAdapters.Features;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -282,14 +283,109 @@ public async Task BufferedOutputIsEnabled()
Assert.Equal("True", result);
}

private static Task<string> RunAsync(Action<HttpContext> action, Action<IEndpointConventionBuilder>? builder = null)
[Fact]
public async Task BufferingCanBeDisabled()
{
await RunAsync(middleware: (ctx, next) =>
{
ctx.Features.GetRequired<IHttpResponseBufferingFeature>().EnableBuffering(1024, default);
Assert.True(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);
ctx.Features.GetRequired<IHttpResponseBodyFeature>().DisableBuffering();
Assert.False(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);

return next(ctx);
});
}

[Fact]
public async Task BufferingCanBeDisabledWithSuppressContent()
{
var result = await RunAsync(middleware: async (ctx, next) =>
{
ctx.Features.GetRequired<IHttpResponseBufferingFeature>().EnableBuffering(1024, default);
Assert.True(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);

await ctx.Response.WriteAsync("before ");

ctx.Features.GetRequired<IHttpResponseContentFeature>().SuppressContent = true;

ctx.Features.GetRequired<IHttpResponseBodyFeature>().DisableBuffering();
Assert.False(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);

await ctx.Response.WriteAsync("after");

await next(ctx);
});

Assert.Equal("after", result);
}

[Fact]
public async Task BufferingCanBeDisabledAndFlushes()
{
var result = await RunAsync(middleware: async (ctx, next) =>
{
ctx.Features.GetRequired<IHttpResponseBufferingFeature>().EnableBuffering(1024, default);
Assert.True(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);

await ctx.Response.WriteAsync("before ");

ctx.Features.GetRequired<IHttpResponseBodyFeature>().DisableBuffering();
Assert.False(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);

await ctx.Response.WriteAsync("after");

await next(ctx);
});

Assert.Equal("before after", result);
}

[Fact]
public async Task BufferingCannotBeEnabledIfWritingHasBegun()
{
await RunAsync(middleware: async (ctx, next) =>
{
await ctx.Response.WriteAsync("start");

Assert.Throws<InvalidOperationException>(() =>
{
ctx.Features.GetRequired<IHttpResponseBufferingFeature>().EnableBuffering(1024, default);
});

await next(ctx);
});
}

[Fact]
public async Task BufferingCanBeDisabledAndFlushesUsingPipe()
{
var result = await RunAsync(middleware: async (ctx, next) =>
{
ctx.Features.GetRequired<IHttpResponseBufferingFeature>().EnableBuffering(1024, default);
Assert.True(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);

await ctx.Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes("before "));

ctx.Features.GetRequired<IHttpResponseBodyFeature>().DisableBuffering();
Assert.False(ctx.Features.GetRequired<IHttpResponseBufferingFeature>().IsEnabled);

await ctx.Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes("after"));

await next(ctx);
});

Assert.Equal("before after", result);
}

private static Task<string> RunAsync(Action<HttpContext> action, Action<IEndpointConventionBuilder>? builder = null, Func<Http.HttpContext, RequestDelegate, Task>? middleware = null)
=> RunAsync(ctx =>
{
action(ctx);
return Task.CompletedTask;
}, builder);
}, builder, middleware);

private static async Task<string> RunAsync(Func<HttpContext, Task> action, Action<IEndpointConventionBuilder>? builder = null)
private static async Task<string> RunAsync(Func<HttpContext, Task>? endpointAction = null, Action<IEndpointConventionBuilder>? builder = null, Func<Http.HttpContext, RequestDelegate, Task>? middleware = null)
{
builder ??= _ => { };

Expand All @@ -310,9 +406,21 @@ private static async Task<string> RunAsync(Func<HttpContext, Task> action, Actio
{
app.UseRouting();
app.UseSystemWebAdapters();

if (middleware is { })
{
app.Use(middleware);
}

app.UseEndpoints(endpoints =>
{
builder(endpoints.Map("/", (HttpContextCore context) => action(context)));
builder(endpoints.Map("/", (HttpContextCore context) =>
{
if (endpointAction is { })
{
endpointAction(context);
}
}));
});
});
})
Expand Down