diff --git a/spec/std/http/server/handlers/websocket_handler_spec.cr b/spec/std/http/server/handlers/websocket_handler_spec.cr index db226502ce36..dc7e852e936e 100644 --- a/spec/std/http/server/handlers/websocket_handler_spec.cr +++ b/spec/std/http/server/handlers/websocket_handler_spec.cr @@ -22,9 +22,10 @@ describe HTTP::WebSocketHandler do io = IO::Memory.new headers = HTTP::Headers{ - "Upgrade" => "WS", - "Connection" => "Upgrade", - "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Upgrade" => "WS", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version" => "13", } request = HTTP::Request.new("GET", "/", headers: headers) response = HTTP::Server::Response.new(io) @@ -47,6 +48,7 @@ describe HTTP::WebSocketHandler do "Upgrade" => "websocket", "Connection" => {{connection}}, "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version" => "13", } request = HTTP::Request.new("GET", "/", headers: headers) response = HTTP::Server::Response.new(io) @@ -63,16 +65,17 @@ describe HTTP::WebSocketHandler do response.close - io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n") + io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n") end {% end %} it "gives upgrade response for case-insensitive 'WebSocket' upgrade request" do io = IO::Memory.new headers = HTTP::Headers{ - "Upgrade" => "WebSocket", - "Connection" => "Upgrade", - "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Upgrade" => "WebSocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version" => "13", } request = HTTP::Request.new("GET", "/", headers: headers) response = HTTP::Server::Response.new(io) @@ -89,6 +92,67 @@ describe HTTP::WebSocketHandler do response.close - io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n") + io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n") + end + + it "returns bad request if Sec-WebSocket-Key is missing" do + io = IO::Memory.new + + headers = HTTP::Headers{ + "Upgrade" => "websocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Version" => "13", + } + request = HTTP::Request.new("GET", "/", headers: headers) + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + + handler = HTTP::WebSocketHandler.new { } + handler.call context + + response.close + + io.to_s.should eq("HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n") + end + + it "returns upgrade required if Sec-WebSocket-Version is missing" do + io = IO::Memory.new + + headers = HTTP::Headers{ + "Upgrade" => "websocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + } + request = HTTP::Request.new("GET", "/", headers: headers) + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + + handler = HTTP::WebSocketHandler.new { } + handler.call context + + response.close + + io.to_s.should eq("HTTP/1.1 426 Upgrade Required\r\nSec-WebSocket-Version: 13\r\nContent-Length: 0\r\n\r\n") + end + + it "returns upgrade required if Sec-WebSocket-Version is invalid" do + io = IO::Memory.new + + headers = HTTP::Headers{ + "Upgrade" => "websocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version" => "12", + } + request = HTTP::Request.new("GET", "/", headers: headers) + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + + handler = HTTP::WebSocketHandler.new { } + handler.call context + + response.close + + io.to_s.should eq("HTTP/1.1 426 Upgrade Required\r\nSec-WebSocket-Version: 13\r\nContent-Length: 0\r\n\r\n") end end diff --git a/spec/std/http/web_socket_spec.cr b/spec/std/http/web_socket_spec.cr index 81cd6a460970..06209567448c 100644 --- a/spec/std/http/web_socket_spec.cr +++ b/spec/std/http/web_socket_spec.cr @@ -374,6 +374,61 @@ describe HTTP::WebSocket do ws2.run end + it "handshake fails if server does not switch protocols" do + port_chan = Channel(Int32).new + spawn do + http_ref = nil + http_server = http_ref = HTTP::Server.new(0) do |context| + context.response.status_code = 200 + end + + http_server.bind + port_chan.send(http_server.port) + http_server.listen + + http_ref.not_nil!.close + end + + listen_port = port_chan.receive + + expect_raises(Socket::Error, "Handshake got denied. Status code was 200.") do + HTTP::WebSocket::Protocol.new("127.0.0.1", port: listen_port, path: "/") + end + end + + it "handshake fails if server does not verify Sec-WebSocket-Key" do + port_chan = Channel(Int32).new + spawn do + http_ref = nil + has_been_called = false + + http_server = http_ref = HTTP::Server.new(0) do |context| + response = context.response + response.status_code = 101 + response.headers["Upgrade"] = "websocket" + response.headers["Connection"] = "Upgrade" + if has_been_called + response.headers["Sec-WebSocket-Accept"] = "foobar" + http_ref.not_nil!.close + else + has_been_called = true + end + end + + http_server.bind + port_chan.send(http_server.port) + http_server.listen + end + + listen_port = port_chan.receive + + 2.times do + expect_raises(Socket::Error, "Handshake got denied. Server did not verify WebSocket challenge.") do + HTTP::WebSocket::Protocol.new("127.0.0.1", port: listen_port, path: "/") + end + end + end + typeof(HTTP::WebSocket.new(URI.parse("ws://localhost"))) typeof(HTTP::WebSocket.new("localhost", "/")) typeof(HTTP::WebSocket.new("ws://localhost")) diff --git a/src/http/server/handlers/websocket_handler.cr b/src/http/server/handlers/websocket_handler.cr index 3d7ebc58b948..0e25ddd6a516 100644 --- a/src/http/server/handlers/websocket_handler.cr +++ b/src/http/server/handlers/websocket_handler.cr @@ -15,20 +15,28 @@ class HTTP::WebSocketHandler def call(context) if websocket_upgrade_request? context.request - key = context.request.headers["Sec-Websocket-Key"] + response = context.response + + version = context.request.headers["Sec-WebSocket-Version"]? + unless version == WebSocket::Protocol::VERSION + response.status_code = 426 + response.headers["Sec-WebSocket-Version"] = WebSocket::Protocol::VERSION + return + end - accept_code = - {% if flag?(:without_openssl) %} - Digest::SHA1.base64digest("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - {% else %} - Base64.strict_encode(OpenSSL::SHA1.hash("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) - {% end %} + key = context.request.headers["Sec-WebSocket-Key"]? + + unless key + response.status_code = 400 + return + end + + accept_code = WebSocket::Protocol.key_challenge(key) - response = context.response response.status_code = 101 response.headers["Upgrade"] = "websocket" response.headers["Connection"] = "Upgrade" - response.headers["Sec-Websocket-Accept"] = accept_code + response.headers["Sec-WebSocket-Accept"] = accept_code response.upgrade do |io| ws_session = WebSocket.new(io) @proc.call(ws_session, context) diff --git a/src/http/web_socket/protocol.cr b/src/http/web_socket/protocol.cr index 4a4dfb9a7ebb..8ed2c2e2df47 100644 --- a/src/http/web_socket/protocol.cr +++ b/src/http/web_socket/protocol.cr @@ -26,7 +26,7 @@ class HTTP::WebSocket::Protocol end MASK_BIT = 128_u8 - VERSION = 13 + VERSION = "13" record PacketInfo, opcode : Opcode, @@ -257,11 +257,13 @@ class HTTP::WebSocket::Protocol end {% end %} + random_key = Base64.strict_encode(StaticArray(UInt8, 16).new { rand(256).to_u8 }) + headers["Host"] = "#{host}:#{port}" headers["Connection"] = "Upgrade" headers["Upgrade"] = "websocket" - headers["Sec-WebSocket-Version"] = VERSION.to_s - headers["Sec-WebSocket-Key"] = Base64.strict_encode(StaticArray(UInt8, 16).new { rand(256).to_u8 }) + headers["Sec-WebSocket-Version"] = VERSION + headers["Sec-WebSocket-Key"] = random_key path = "/" if path.empty? handshake = HTTP::Request.new("GET", path, headers) @@ -269,7 +271,12 @@ class HTTP::WebSocket::Protocol socket.flush handshake_response = HTTP::Client::Response.from_io(socket) unless handshake_response.status_code == 101 - raise Socket::Error.new("Handshake got denied. Status code was #{handshake_response.status_code}") + raise Socket::Error.new("Handshake got denied. Status code was #{handshake_response.status_code}.") + end + + challenge_response = Protocol.key_challenge(random_key) + unless handshake_response.headers["Sec-WebSocket-Accept"]? == challenge_response + raise Socket::Error.new("Handshake got denied. Server did not verify WebSocket challenge.") end new(socket, masked: true) @@ -285,4 +292,12 @@ class HTTP::WebSocket::Protocol raise ArgumentError.new("No host or path specified which are required.") end + + def self.key_challenge(key) + {% if flag?(:without_openssl) %} + Digest::SHA1.base64digest("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + {% else %} + Base64.strict_encode(OpenSSL::SHA1.hash("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + {% end %} + end end