Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions spec/websocket_spec.cr
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
require "http/web_socket"
require "./spec_helper"

private def connect_amqp(websocket : HTTP::WebSocket)
AMQP::Client::Connection.start(
AMQP::Client::WebSocketIO.new(websocket),
"guest",
"guest",
"/",
10u16, # channel max
4096u32, # frame max
0u16, # heartbeat
AMQP::Client::ConnectionInformation.new )
end

describe "Websocket support" do
it "should connect over websocket" do
with_http_server do |http, _|
Expand All @@ -23,4 +36,116 @@ describe "Websocket support" do
conn.close
end
end

describe "when 'Sec-WebSocket-Protocol'" do
{"amqp", "amqpish"}.each do |header|
describe "is set to '#{header}'" do
it "should accept amqp client" do
with_http_server do |http, _|
headers = ::HTTP::Headers{
"Sec-WebSocket-Protocol" => "amqp",
}
websocket = ::HTTP::WebSocket.new(http.addr.address, path: "", port: http.addr.port, headers: headers)
connect_amqp(websocket)
websocket.close
end
end

it "should not accept mqtt client" do
with_http_server do |http, s|
s.@config.default_user_only_loopback = false
headers = ::HTTP::Headers{
"Sec-WebSocket-Protocol" => header,
}
websocket = ::HTTP::WebSocket.new(http.addr.address, path: "", port: http.addr.port, headers: headers)

connect = MQTT::Protocol::Connect.new(
client_id: "client_id",
clean_session: false,
keepalive: 30u16,
username: "guest",
password: "guest".to_slice,
will: nil,
)

ch = Channel(Nil).new
websocket.on_binary do |bytes|
pkt = MQTT::Protocol::Packet.from_io(IO::Memory.new(bytes))
fail("received unexpected #{pkt}")
rescue
ch.close # close to signal "failure"
end
websocket.on_close do
ch.close
end
spawn { websocket.run }

websocket.stream { |io| connect.to_io(MQTT::Protocol::IO.new(io)) }

expect_raises(Channel::ClosedError) do
select
when ch.receive # this is closed = error
fail("received data?")
when timeout(1.second)
fail("Socket not closed?")
end
end
end
end
end
end

{"mqtt", "mqttish"}.each do |header|
describe "is set to '#{header}'" do
it "should accept mqtt client" do
with_http_server do |http, s|
s.@config.default_user_only_loopback = false
headers = ::HTTP::Headers{
"Sec-WebSocket-Protocol" => header,
}
websocket = ::HTTP::WebSocket.new(http.addr.address, path: "", port: http.addr.port, headers: headers)

connect = MQTT::Protocol::Connect.new(
client_id: "client_id",
clean_session: false,
keepalive: 30u16,
username: "guest",
password: "guest".to_slice,
will: nil,
)

ch = Channel(MQTT::Protocol::Packet).new
websocket.on_binary do |bytes|
ch.send MQTT::Protocol::Packet.from_io(IO::Memory.new(bytes))
websocket.close
end
spawn { websocket.run }

websocket.stream { |io| connect.to_io(MQTT::Protocol::IO.new(io)) }

select
when pkt = ch.receive
pkt.should be_a MQTT::Protocol::Connack
when timeout(1.second)
websocket.close
fail("no response?")
end
end
end

it "should not accept amqp client" do
with_http_server do |http, _|
headers = ::HTTP::Headers{
"Sec-WebSocket-Protocol" => "mqtt",
}
websocket = ::HTTP::WebSocket.new(http.addr.address, path: "", port: http.addr.port, headers: headers)
expect_raises(IO::Error) do
connect_amqp(websocket)
end
websocket.close
end
end
end
end
end
end
48 changes: 45 additions & 3 deletions src/lavinmq/http/handler/websocket.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,65 @@ require "../../server"
module LavinMQ
# Acts as a proxy between websocket clients and the normal TCP servers
class WebsocketProxy
enum Protocol
MQTT
AMQP
end

def self.new(server : Server)
::HTTP::WebSocketHandler.new do |ws, ctx|
req = ctx.request
protocol, header_value = pick_protocol(req)

# Respond with the header value we used to decide protocol
if value = header_value
ctx.response.headers["Sec-WebSocket-Protocol"] = value
end

local_address = req.local_address.as?(Socket::IPAddress) ||
Socket::IPAddress.new("127.0.0.1", 0) # Fake when UNIXAddress
remote_address = req.remote_address.as?(Socket::IPAddress) ||
Socket::IPAddress.new("127.0.0.1", 0) # Fake when UNIXAddress
connection_info = ConnectionInfo.new(remote_address, local_address)
io = WebSocketIO.new(ws)
case req.path
when "/mqtt", "/ws/mqtt"

case protocol
in .mqtt?
Log.debug { "Protocol: mqtt" }
spawn server.handle_connection(io, connection_info, Server::Protocol::MQTT), name: "HandleWSconnection MQTT #{remote_address}"
else
in .amqp?
Log.debug { "Protocol: amqp" }
spawn server.handle_connection(io, connection_info, Server::Protocol::AMQP), name: "HandleWSconnection AMQP #{remote_address}"
end
end
end

# Returns Tuple(Protocol, String?) where the string value is the header value
# used if a header was used to decide protocol
# It accepts any Sec-WebSocket-Protocol starting with amqp or mqtt and fallbacks
# to request path then to AMQP.
private def self.pick_protocol(request : ::HTTP::Request) : {Protocol, String?}
if protocols = request.headers.get?("Sec-WebSocket-Protocol")
protocols.each do |protocol|
case value = protocol
# "amqp" is registered as amqp 1.0, but we accept any amqp value
# see https://www.iana.org/assignments/websocket/websocket.xml#subprotocol-name
when /^amqp/i then return {Protocol::AMQP, value}
# "mqtt" is registered as mqtt 5.0
# see https://www.iana.org/assignments/websocket/websocket.xml#subprotocol-name
when /^mqtt/i then return {Protocol::MQTT, value}
end
end
end

# Fallback to use path
case request.path
when "/mqtt", "/ws/mqtt"
return {Protocol::MQTT, nil}
end
# Default to AMQP
return {Protocol::AMQP, nil}
end
end

class WebSocketIO < IO
Expand Down
Loading