Skip to content
80 changes: 72 additions & 8 deletions spec/std/http/server/handlers/websocket_handler_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ describe HTTP::WebSocketHandler do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "WS",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Upgrade" => "WS",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
Expand All @@ -47,6 +48,7 @@ describe HTTP::WebSocketHandler do
"Upgrade" => "websocket",
"Connection" => {{connection}},
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
Expand All @@ -63,16 +65,17 @@ describe HTTP::WebSocketHandler do

response.close

io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
end
{% end %}

it "gives upgrade response for case-insensitive 'WebSocket' upgrade request" do
io = IO::Memory.new
headers = HTTP::Headers{
"Upgrade" => "WebSocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Upgrade" => "WebSocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
Expand All @@ -89,6 +92,67 @@ describe HTTP::WebSocketHandler do

response.close

io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
end

it "returns bad request if Sec-WebSocket-Key is missing" do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "websocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
context = HTTP::Server::Context.new(request, response)

handler = HTTP::WebSocketHandler.new { }
handler.call context

response.close

io.to_s.should eq("HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
end

it "returns upgrade required if Sec-WebSocket-Version is missing" do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "websocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
context = HTTP::Server::Context.new(request, response)

handler = HTTP::WebSocketHandler.new { }
handler.call context

response.close

io.to_s.should eq("HTTP/1.1 426 Upgrade Required\r\nSec-WebSocket-Version: 13\r\nContent-Length: 0\r\n\r\n")
end

it "returns upgrade required if Sec-WebSocket-Version is invalid" do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "websocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "12",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
context = HTTP::Server::Context.new(request, response)

handler = HTTP::WebSocketHandler.new { }
handler.call context

response.close

io.to_s.should eq("HTTP/1.1 426 Upgrade Required\r\nSec-WebSocket-Version: 13\r\nContent-Length: 0\r\n\r\n")
end
end
54 changes: 54 additions & 0 deletions spec/std/http/web_socket_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,60 @@ describe HTTP::WebSocket do
ws2.run
end

it "handshake fails if server does not switch protocols" do
http_server = HTTP::Server.new do |context|
context.response.status_code = 200
end

address = http_server.bind_unused_port
spawn http_server.not_nil!.listen # TODO: Remove .not_nil! when #6037 is fixed

expect_raises(Socket::Error, "Handshake got denied. Status code was 200.") do
HTTP::WebSocket::Protocol.new(address.address, port: address.port, path: "/")
end
ensure
# http_server.try &.close # TODO: Uncomment when #5958 is fixed
end

describe "handshake fails if server does not verify Sec-WebSocket-Key" do
it "Sec-WebSocket-Accept missing" do
http_server = HTTP::Server.new do |context|
response = context.response
response.status_code = 101
response.headers["Upgrade"] = "websocket"
response.headers["Connection"] = "Upgrade"
end

address = http_server.bind_unused_port
spawn http_server.not_nil!.listen # TODO: Remove .not_nil! when #6037 is fixed

expect_raises(Socket::Error, "Handshake got denied. Server did not verify WebSocket challenge.") do
HTTP::WebSocket::Protocol.new(address.address, port: address.port, path: "/")
end
ensure
# http_server.try &.close # TODO: Uncomment when #5958 is fixed
end

it "Sec-WebSocket-Accept incorrect" do
http_server = HTTP::Server.new do |context|
response = context.response
response.status_code = 101
response.headers["Upgrade"] = "websocket"
response.headers["Connection"] = "Upgrade"
response.headers["Sec-WebSocket-Accept"] = "foobar"
end

address = http_server.bind_unused_port
spawn http_server.not_nil!.listen # TODO: Remove .not_nil! when #6037 is fixed

expect_raises(Socket::Error, "Handshake got denied. Server did not verify WebSocket challenge.") do
HTTP::WebSocket::Protocol.new(address.address, port: address.port, path: "/")
end
ensure
# http_server.try &.close # TODO: Uncomment when #5958 is fixed
end
end

typeof(HTTP::WebSocket.new(URI.parse("ws://localhost")))
typeof(HTTP::WebSocket.new("localhost", "/"))
typeof(HTTP::WebSocket.new("ws://localhost"))
Expand Down
26 changes: 17 additions & 9 deletions src/http/server/handlers/websocket_handler.cr
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,28 @@ class HTTP::WebSocketHandler

def call(context)
if websocket_upgrade_request? context.request
key = context.request.headers["Sec-Websocket-Key"]
response = context.response

version = context.request.headers["Sec-WebSocket-Version"]?
unless version == WebSocket::Protocol::VERSION
response.status_code = 426
response.headers["Sec-WebSocket-Version"] = WebSocket::Protocol::VERSION
return
end

accept_code =
{% if flag?(:without_openssl) %}
Digest::SHA1.base64digest("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
{% else %}
Base64.strict_encode(OpenSSL::SHA1.hash("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
{% end %}
key = context.request.headers["Sec-WebSocket-Key"]?

unless key
response.status_code = 400
return
end

accept_code = WebSocket::Protocol.key_challenge(key)

response = context.response
response.status_code = 101
response.headers["Upgrade"] = "websocket"
response.headers["Connection"] = "Upgrade"
response.headers["Sec-Websocket-Accept"] = accept_code
response.headers["Sec-WebSocket-Accept"] = accept_code
response.upgrade do |io|
ws_session = WebSocket.new(io)
@proc.call(ws_session, context)
Expand Down
65 changes: 42 additions & 23 deletions src/http/web_socket/protocol.cr
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class HTTP::WebSocket::Protocol
end

MASK_BIT = 128_u8
VERSION = 13
VERSION = "13"

record PacketInfo,
opcode : Opcode,
Expand Down Expand Up @@ -245,31 +245,42 @@ class HTTP::WebSocket::Protocol
port = port || (tls ? 443 : 80)

socket = TCPSocket.new(host, port)

{% if !flag?(:without_openssl) %}
if tls
if tls.is_a?(Bool) # true, but we want to get rid of the union
context = OpenSSL::SSL::Context::Client.new
else
context = tls
begin
{% if !flag?(:without_openssl) %}
if tls
if tls.is_a?(Bool) # true, but we want to get rid of the union
context = OpenSSL::SSL::Context::Client.new
else
context = tls
end
socket = OpenSSL::SSL::Socket::Client.new(socket, context: context, sync_close: true, hostname: host)
end
socket = OpenSSL::SSL::Socket::Client.new(socket, context: context, sync_close: true, hostname: host)
{% end %}

random_key = Base64.strict_encode(StaticArray(UInt8, 16).new { rand(256).to_u8 })

headers["Host"] = "#{host}:#{port}"
headers["Connection"] = "Upgrade"
headers["Upgrade"] = "websocket"
headers["Sec-WebSocket-Version"] = VERSION
headers["Sec-WebSocket-Key"] = random_key

path = "/" if path.empty?
handshake = HTTP::Request.new("GET", path, headers)
handshake.to_io(socket)
socket.flush
handshake_response = HTTP::Client::Response.from_io(socket)
unless handshake_response.status_code == 101
raise Socket::Error.new("Handshake got denied. Status code was #{handshake_response.status_code}.")
end
{% end %}

headers["Host"] = "#{host}:#{port}"
headers["Connection"] = "Upgrade"
headers["Upgrade"] = "websocket"
headers["Sec-WebSocket-Version"] = VERSION.to_s
headers["Sec-WebSocket-Key"] = Base64.strict_encode(StaticArray(UInt8, 16).new { rand(256).to_u8 })

path = "/" if path.empty?
handshake = HTTP::Request.new("GET", path, headers)
handshake.to_io(socket)
socket.flush
handshake_response = HTTP::Client::Response.from_io(socket)
unless handshake_response.status_code == 101
raise Socket::Error.new("Handshake got denied. Status code was #{handshake_response.status_code}")
challenge_response = Protocol.key_challenge(random_key)
unless handshake_response.headers["Sec-WebSocket-Accept"]? == challenge_response
raise Socket::Error.new("Handshake got denied. Server did not verify WebSocket challenge.")
end
rescue exc
socket.close
raise exc
end

new(socket, masked: true)
Expand All @@ -285,4 +296,12 @@ class HTTP::WebSocket::Protocol

raise ArgumentError.new("No host or path specified which are required.")
end

def self.key_challenge(key)
{% if flag?(:without_openssl) %}
Digest::SHA1.base64digest("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
{% else %}
Base64.strict_encode(OpenSSL::SHA1.hash("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
{% end %}
end
end