Skip to content

Commit

Permalink
getting closer
Browse files Browse the repository at this point in the history
  • Loading branch information
BrennanConroy committed May 13, 2021
1 parent cb6d83a commit 11e2816
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 175 deletions.
12 changes: 11 additions & 1 deletion src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ public class ExtendedWebSocketAcceptContext : WebSocketAcceptContext
/// <summary>
///
/// </summary>
public WebSocketCreationOptions? WebSocketOptions { get; set; }
public bool DangerousEnableCompression { get; set; }

/// <summary>
///
/// </summary>
public bool ServerContextTakeover { get; set; } = true;

/// <summary>
///
/// </summary>
public int ServerMaxWindowBits { get; set; } = 15;
}
}
73 changes: 60 additions & 13 deletions src/Middleware/WebSockets/src/HandshakeHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ public static string CreateResponseKey(string requestKey)
return Convert.ToBase64String(hashedBytes);
}

// https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
public static bool ParseDeflateOptions(ReadOnlySpan<char> extension, WebSocketDeflateOptions options, [NotNullWhen(true)] out string? response)
{
bool hasServerMaxWindowBits = false;
bool hasClientMaxWindowBits = false;
response = null;
var builder = new StringBuilder(WebSocketDeflateConstants.MaxExtensionLength);
builder.Append(WebSocketDeflateConstants.Extension);
Expand All @@ -91,40 +94,58 @@ public static bool ParseDeflateOptions(ReadOnlySpan<char> extension, WebSocketDe
{
if (value.SequenceEqual(WebSocketDeflateConstants.ClientNoContextTakeover))
{
// REVIEW: If someone specifies true for server options, do we allow client to override this?
options.ClientContextTakeover = false;
builder.Append("; ").Append(WebSocketDeflateConstants.ClientNoContextTakeover);
}
else if (value.SequenceEqual(WebSocketDeflateConstants.ServerNoContextTakeover))
{
options.ServerContextTakeover = false;
builder.Append("; ").Append(WebSocketDeflateConstants.ServerNoContextTakeover);
// REVIEW: Do we want to reject it?
// Client requests no context takeover but options passed in specified context takeover, so reject the negotiate offer
if (options.ServerContextTakeover)
{
return false;
}
}
else if (value.StartsWith(WebSocketDeflateConstants.ClientMaxWindowBits))
{
var clientMaxWindowBits = ParseWindowBits(value, WebSocketDeflateConstants.ClientMaxWindowBits);
if (clientMaxWindowBits > options.ClientMaxWindowBits)
// 8 is a valid value according to the spec, but our zlib implementation does not support it
if (clientMaxWindowBits == 8)
{
return false;
}
// if client didn't send a value for ClientMaxWindowBits use the value the server set
options.ClientMaxWindowBits = clientMaxWindowBits ?? options.ClientMaxWindowBits;

// https://tools.ietf.org/html/rfc7692#section-7.1.2.2
// the server may either ignore this
// value or use this value to avoid allocating an unnecessarily big LZ77
// sliding window by including the "client_max_window_bits" extension
// parameter in the corresponding extension negotiation response to the
// offer with a value equal to or smaller than the received value.
options.ClientMaxWindowBits = Math.Min(clientMaxWindowBits ?? 15, options.ClientMaxWindowBits);

// If a received extension negotiation offer doesn't have the
// "client_max_window_bits" extension parameter, the corresponding
// extension negotiation response to the offer MUST NOT include the
// "client_max_window_bits" extension parameter.
builder.Append("; ").Append(WebSocketDeflateConstants.ClientMaxWindowBits).Append('=')
.Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture));
}
else if (value.StartsWith(WebSocketDeflateConstants.ServerMaxWindowBits))
{
hasServerMaxWindowBits = true;
var serverMaxWindowBits = ParseWindowBits(value, WebSocketDeflateConstants.ServerMaxWindowBits);
if (serverMaxWindowBits > options.ServerMaxWindowBits)
// 8 is a valid value according to the spec, but our zlib implementation does not support it
if (serverMaxWindowBits == 8)
{
return false;
}
// if client didn't send a value for ServerMaxWindowBits use the value the server set
options.ServerMaxWindowBits = serverMaxWindowBits ?? options.ServerMaxWindowBits;

builder.Append("; ")
.Append(WebSocketDeflateConstants.ServerMaxWindowBits).Append('=')
.Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture));
// https://tools.ietf.org/html/rfc7692#section-7.1.2.1
// A server accepts an extension negotiation offer with this parameter
// by including the "server_max_window_bits" extension parameter in the
// extension negotiation response to send back to the client with the
// same or smaller value as the offer.
options.ServerMaxWindowBits = Math.Min(serverMaxWindowBits ?? 15, options.ServerMaxWindowBits);
}

static int? ParseWindowBits(ReadOnlySpan<char> value, string propertyName)
Expand All @@ -138,7 +159,7 @@ public static bool ParseDeflateOptions(ReadOnlySpan<char> extension, WebSocketDe
}

if (!int.TryParse(value[(startIndex + 1)..], NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) ||
windowBits < 9 ||
windowBits < 8 ||
windowBits > 15)
{
throw new WebSocketException(WebSocketError.HeaderError, $"invalid {propertyName} used: {value[(startIndex + 1)..].ToString()}");
Expand All @@ -155,6 +176,32 @@ public static bool ParseDeflateOptions(ReadOnlySpan<char> extension, WebSocketDe
extension = extension[(end + 1)..];
}

if (!options.ServerContextTakeover)
{
builder.Append("; ").Append(WebSocketDeflateConstants.ServerNoContextTakeover);
}

if (hasServerMaxWindowBits || options.ServerMaxWindowBits != 15)
{
builder.Append("; ")
.Append(WebSocketDeflateConstants.ServerMaxWindowBits).Append('=')
.Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture));
}

// https://tools.ietf.org/html/rfc7692#section-7.1.2.2
// If a received extension negotiation offer doesn't have the
// "client_max_window_bits" extension parameter, the corresponding
// extension negotiation response to the offer MUST NOT include the
// "client_max_window_bits" extension parameter.
//
// Absence of this extension parameter in an extension negotiation
// response indicates that the server can receive messages compressed
// using an LZ77 sliding window of up to 32,768 bytes.
if (!hasClientMaxWindowBits)
{
options.ClientMaxWindowBits = 15;
}

response = builder.ToString();

return true;
Expand Down
8 changes: 6 additions & 2 deletions src/Middleware/WebSockets/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#nullable enable
Microsoft.AspNetCore.Builder.WebSocketOptions.AllowedOrigins.get -> System.Collections.Generic.IList<string!>!
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.WebSocketOptions.get -> System.Net.WebSockets.WebSocketCreationOptions?
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.WebSocketOptions.set -> void
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.DangerousEnableCompression.get -> bool
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.DangerousEnableCompression.set -> void
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.ServerContextTakeover.get -> bool
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.ServerContextTakeover.set -> void
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.ServerMaxWindowBits.get -> int
Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.ServerMaxWindowBits.set -> void
Microsoft.AspNetCore.WebSockets.WebSocketMiddleware.Invoke(Microsoft.AspNetCore.Http.HttpContext! context) -> System.Threading.Tasks.Task!
~Microsoft.AspNetCore.WebSockets.WebSocketMiddleware.WebSocketMiddleware(Microsoft.AspNetCore.Http.RequestDelegate! next, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Builder.WebSocketOptions!>! options, Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory) -> void
override Microsoft.AspNetCore.WebSockets.ExtendedWebSocketAcceptContext.SubProtocol.get -> string?
Expand Down
96 changes: 28 additions & 68 deletions src/Middleware/WebSockets/src/WebSocketMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,88 +144,61 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
}

TimeSpan keepAliveInterval = _options.KeepAliveInterval;
WebSocketCreationOptions? creationOptions = null;
bool enableCompression = false;
bool serverContextTakeover = true;
int serverMaxWindowBits = 15;
if (acceptContext is ExtendedWebSocketAcceptContext advancedAcceptContext)
{
if (advancedAcceptContext.KeepAliveInterval.HasValue)
{
keepAliveInterval = advancedAcceptContext.KeepAliveInterval.Value;
}
if (advancedAcceptContext.WebSocketOptions is not null)
{
creationOptions = advancedAcceptContext.WebSocketOptions;
}
enableCompression = advancedAcceptContext.DangerousEnableCompression;
serverContextTakeover = advancedAcceptContext.ServerContextTakeover;
serverMaxWindowBits = advancedAcceptContext.ServerMaxWindowBits;
}

string key = _context.Request.Headers.SecWebSocketKey;

HandshakeHelpers.GenerateResponseHeaders(key, subProtocol, _context.Response.Headers);

WebSocketDeflateOptions? deflateOptions = null;
var ext = _context.Request.Headers.SecWebSocketExtensions;
if (ext.Count != 0)
if (enableCompression)
{
// loop over each extension offer, extensions can have multiple offers we can accept any
foreach (var extension in _context.Request.Headers.GetCommaSeparatedValues(HeaderNames.SecWebSocketExtensions))
var ext = _context.Request.Headers.SecWebSocketExtensions;
if (ext.Count != 0)
{
if (extension.TrimStart().StartsWith("permessage-deflate", StringComparison.Ordinal)
&& creationOptions?.DangerousDeflateOptions is not null)
// loop over each extension offer, extensions can have multiple offers we can accept any
foreach (var extension in _context.Request.Headers.GetCommaSeparatedValues(HeaderNames.SecWebSocketExtensions))
{
// We do not want to modify the users options
deflateOptions = CloneWebSocketDeflateOptions(creationOptions.DangerousDeflateOptions);
if (HandshakeHelpers.ParseDeflateOptions(extension, deflateOptions!, out var response))
if (extension.TrimStart().StartsWith("permessage-deflate", StringComparison.Ordinal))
{
if (CompareDeflateOptions(deflateOptions, creationOptions.DangerousDeflateOptions))
deflateOptions = new WebSocketDeflateOptions()
{
ServerContextTakeover = serverContextTakeover,
ServerMaxWindowBits = serverMaxWindowBits
};
if (HandshakeHelpers.ParseDeflateOptions(extension, deflateOptions, out var response))
{
// avoids allocating a new WebSocketCreationOptions below when checking if we have new deflate options to apply
deflateOptions = null;
// If more extension types are added, this would need to be a header append
// and we wouldn't want to break out of the loop
_context.Response.Headers.SecWebSocketExtensions = response;
break;
}
// If more extension types are added, this would need to be a header append
// and we wouldn't want to break out of the loop
_context.Response.Headers.SecWebSocketExtensions = response;
break;
}
}
}
}

Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101

WebSocketCreationOptions? options = creationOptions;
if (options is null)
{
options = new WebSocketCreationOptions()
{
IsServer = true,
KeepAliveInterval = keepAliveInterval,
SubProtocol = subProtocol,
DangerousDeflateOptions = deflateOptions,
};
}
else if (deflateOptions is not null)
{
// use a new options instance so we don't modify the users options
options = new WebSocketCreationOptions()
{
IsServer = true,
KeepAliveInterval = creationOptions!.KeepAliveInterval,
SubProtocol = creationOptions!.SubProtocol,
DangerousDeflateOptions = deflateOptions,
};
}
else if (options.IsServer == false)
return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
{
// use a new options instance so we don't modify the users options
options = new WebSocketCreationOptions()
{
IsServer = true,
KeepAliveInterval = creationOptions!.KeepAliveInterval,
SubProtocol = creationOptions!.SubProtocol,
DangerousDeflateOptions = deflateOptions,
};
}

return WebSocket.CreateFromStream(opaqueTransport, options);
IsServer = true,
KeepAliveInterval = keepAliveInterval,
SubProtocol = subProtocol,
DangerousDeflateOptions = deflateOptions
});
}

private static WebSocketDeflateOptions? CloneWebSocketDeflateOptions([NotNullIfNotNull("options")] WebSocketDeflateOptions? options)
Expand All @@ -244,19 +217,6 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
};
}

private static bool CompareDeflateOptions(WebSocketDeflateOptions? lhs, WebSocketDeflateOptions? rhs)
{
if (lhs is null || rhs is null)
{
return lhs == rhs;
}

return lhs.ClientContextTakeover == rhs.ClientContextTakeover &&
lhs.ClientMaxWindowBits == rhs.ClientMaxWindowBits &&
lhs.ServerContextTakeover == rhs.ServerContextTakeover &&
lhs.ServerMaxWindowBits == rhs.ServerMaxWindowBits;
}

public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
{
if (!HttpMethods.IsGet(method))
Expand Down
Loading

0 comments on commit 11e2816

Please sign in to comment.