Skip to content

Commit

Permalink
permessage-deflate support
Browse files Browse the repository at this point in the history
  • Loading branch information
Mixfair committed Jun 27, 2022
1 parent 9fc906c commit 503e4f8
Showing 1 changed file with 43 additions and 14 deletions.
57 changes: 43 additions & 14 deletions src/WebSockets.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module WebSockets

using Base64, LoggingExtras, UUIDs, Sockets, Random
using Base64, LoggingExtras, UUIDs, Sockets, Random, CodecZlib
using MbedTLS: digest, MD_SHA1, SSLContext
using ..IOExtras, ..Streams, ..ConnectionPool, ..Messages, ..Conditions, ..Servers
import ..open
Expand Down Expand Up @@ -55,7 +55,7 @@ FrameFlags(final::Bool, opcode::OpCode, masked::Bool, len::Integer; rsv1::Bool=f
)

Base.show(io::IO, x::FrameFlags) =
print(io, "FrameFlags(", "final=", x.final, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")")
print(io, "FrameFlags(", "final=", x.final, ", deflate=", x.rsv1, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")")

primitive type Mask 32 end
Base.UInt32(x::Mask) = Base.bitcast(UInt32, x)
Expand Down Expand Up @@ -91,6 +91,27 @@ function mask!(bytes::Vector{UInt8}, mask)
return
end

function compress(data::T) where T <: AbstractVector{UInt8}
compressed = transcode(DeflateCompressor, data)
return vcat(compressed, 0x00)
end

function compress(data::String)
compressed = transcode(DeflateCompressor, data)
return String(vcat(compressed, 0x00))
end

function decompress(data::T) where T <: AbstractVector{UInt8}
decompressed = transcode(DeflateDecompressor, vcat(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return decompressed
end

function decompress(data::String)
decompressed = transcode(DeflateDecompressor, vcat(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return String(decompressed)
end


# send method Frame constructor
function Frame(final::Bool, opcode::OpCode, client::Bool, payload::AbstractVector{UInt8}; rsv1::Bool=false, rsv2::Bool=false, rsv3::Bool=false)
len, extlen = wslength(length(payload))
Expand Down Expand Up @@ -293,12 +314,13 @@ mutable struct WebSocket
writebuffer::Vector{UInt8}
readclosed::Bool
writeclosed::Bool
deflate::Bool
end

const DEFAULT_MAX_FRAG = 1024

WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) =
WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false)
WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, deflate::Bool=false) =
WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, deflate)

"""
WebSockets.isclosed(ws) -> Bool
Expand Down Expand Up @@ -347,7 +369,7 @@ WebSockets.open(url) do ws
end
```
"""
function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...)
function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, deflate=false, kw...)
key = base64encode(rand(Random.RandomDevice(), UInt8, 16))
headers = [
"Upgrade" => "websocket",
Expand All @@ -363,13 +385,14 @@ function open(f::Function, url; suppress_close_error::Bool=false, verbose=false,
if header(http, "Sec-WebSocket-Accept") != hashedkey(key)
throw(WebSocketError("Invalid Sec-WebSocket-Accept\n" * "$(http.message)"))
end
deflate = occursin("permessage-deflate", header(http, "Sec-Websocket-Extensions"))
# later stream logic checks to see if the HTTP message is "complete"
# by seeing if ntoread is 0, which is typemax(Int) for websockets by default
# so set it to 0 so it's correctly viewed as "complete" once we're done
# doing websocket things
http.ntoread = 0
io = http.stream
ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation)
ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation, deflate)
@debugv 2 "$(ws.id): WebSocket opened"
try
f(ws)
Expand Down Expand Up @@ -416,7 +439,8 @@ function listen end
listen(f, args...; kw...) = Servers.listen(http -> upgrade(f, http; kw...), args...; kw...)
listen!(f, args...; kw...) = Servers.listen!(http -> upgrade(f, http; kw...), args...; kw...)

function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...)
function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int),
maxfragmentation::Integer=DEFAULT_MAX_FRAG, deflate=false, kw...)
@debugv 2 "Server websocket upgrade requested"
isupgrade(http.message) || handshakeerror()
if !hasheader(http, "Sec-WebSocket-Version", "13")
Expand All @@ -430,10 +454,11 @@ function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=f
setheader(http, "Connection" => "Upgrade")
key = header(http, "Sec-WebSocket-Key")
setheader(http, "Sec-WebSocket-Accept" => hashedkey(key))
deflate && setheader(http, "Sec-Websocket-Extensions" => "permessage-deflate; client_no_context_takeover")
startwrite(http)
io = http.stream
req = http.message
ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation)
ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation, deflate)
@debugv 2 "$(ws.id): WebSocket upgraded; connection established"
try
f(ws)
Expand Down Expand Up @@ -507,7 +532,7 @@ function Sockets.send(ws::WebSocket, x)
# so we can appropriately set the FIN bit for the last fragmented frame
nextstate = iterate(x, st)
while true
n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item)))
n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? ws.deflate : false))
first = false
nextstate === nothing && break
item, st = nextstate
Expand All @@ -516,7 +541,8 @@ function Sockets.send(ws::WebSocket, x)
else
# single binary or text frame for message
@label write_single_frame
return writeframe(ws.io, Frame(true, opcode(x), ws.client, payload(ws, x)))
pl = ws.deflate ? compress(payload(ws, x)) : payload(ws, x)
return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=ws.deflate))
end
end

Expand Down Expand Up @@ -603,7 +629,7 @@ end
@noinline utf8check(x) = isvalid(x) || throw(WebSocketError(CloseFrameBody(1007, "Invalid UTF-8")))

function checkreadframe!(ws::WebSocket, frame::Frame)
if frame.flags.rsv1 || frame.flags.rsv2 || frame.flags.rsv3
if frame.flags.rsv2 || frame.flags.rsv3
throw(WebSocketError(CloseFrameBody(1002, "Reserved bits set in control frame")))
end
opcode = frame.flags.opcode
Expand All @@ -624,8 +650,6 @@ function checkreadframe!(ws::WebSocket, frame::Frame)
elseif opcode == PONG
control_len_check(frame.flags.len)
return false
elseif frame.flags.final && frame.flags.opcode == TEXT && frame.payload isa String
utf8check(frame.payload)
end
return frame.flags.final
end
Expand Down Expand Up @@ -659,7 +683,11 @@ function receive(ws::WebSocket)
@debugv 2 "$(ws.id): Received frame: $frame"
done = checkreadframe!(ws, frame)
# common case of reading single non-control frame
done && return frame.payload
if done
payload = ws.deflate ? decompress(frame.payload) : frame.payload
payload isa String && utf8check(payload)
return payload
end
opcode = frame.flags.opcode
iscontrol(opcode) && return receive(ws)
# if we're here, we're reading a fragmented message
Expand All @@ -674,6 +702,7 @@ function receive(ws::WebSocket)
end
done && break
end
payload = ws.deflate ? decompress(payload) : payload
payload isa String && utf8check(payload)
@debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])"
return payload
Expand Down

0 comments on commit 503e4f8

Please sign in to comment.