Skip to content

Commit

Permalink
Merge pull request #29 from sanmiguel/mc-fix-chunked-payload-decoding
Browse files Browse the repository at this point in the history
fix chunked payload decoding
  • Loading branch information
sanmiguel committed Jul 30, 2015
2 parents e47e2fa + 1b443ba commit de2ba03
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 100 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SHELL := /bin/bash
PROJECT = websocket_client

TEST_DEPS = cowboy recon
CT_SUITES = wc
TEST_DEPS = cowboy recon proper
CT_SUITES = wc wsc_lib

include erlang.mk
128 changes: 66 additions & 62 deletions src/websocket_client.erl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
-callback onconnect(websocket_req:req(), state()) ->
% Simple client: only server-initiated pings will be
% automatically responded to.
{ok, state()}
{ok, state()}
% Keepalive client: will automatically initiate a ping to the server
% every keepalive() ms.
| {ok, state(), keepalive()}
Expand Down Expand Up @@ -340,7 +340,7 @@ handle_info({TransError, _Socket, Reason},
}=Context) ->
ok = websocket_close(WSReq, Handler, HState0, {TransError, Reason}),
{stop, {socket_error, Reason}, Context};
handle_info({Trans, Socket, Data},
handle_info({Trans, _Socket, Data},
handshaking,
#context{
transport=#transport{ name=Trans },
Expand All @@ -367,28 +367,62 @@ handle_info({Trans, Socket, Data},
Result
end,
WSReq2 = websocket_req:keepalive(KeepAlive, WSReq1),
%% TODO This is nasty and hopefully there's a nicer way
case Remaining of
<<>> -> ok;
_ -> self() ! {Trans, Socket, Remaining}
end,
{next_state, connected, Context#context{
wsreq=WSReq2,
handler={Handler, HState2},
buffer = <<>>}}
handle_websocket_frame(Remaining, Context#context{
wsreq=WSReq2,
handler={Handler, HState2},
buffer= <<>>})
end;
handle_info({Trans, _Socket, Data},
connected,
#context{
transport=#transport{ name=Trans },
handler={Handler, HState0},
transport=#transport{ name=Trans }
}=Context) ->
handle_websocket_frame(Data, Context);
handle_info(Msg, State,
#context{
wsreq=WSReq,
handler={Handler, HState0},
buffer=Buffer
}=Context) ->
try Handler:websocket_info(Msg, WSReq, HState0) of
HandlerResponse ->
case handle_response(HandlerResponse, Handler, WSReq) of
{ok, WSReqN, HStateN} ->
{next_state, State, Context#context{
handler={Handler, HStateN},
wsreq=WSReqN,
buffer=Buffer}};
{close, Reason, WSReqN, Handler, HStateN} ->
{stop, Reason, Context#context{
wsreq=WSReqN,
handler={Handler, HStateN}}}
end
catch Class:Reason ->
%% TODO Maybe a function_clause catch here to allow
%% not having to have a catch-all clause in websocket_info CB?
error_logger:error_msg(
"** Websocket client ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
"** Last message was ~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_info, 3, Class, Reason, Msg, HState0,
erlang:get_stacktrace()]),
websocket_close(WSReq, Handler, HState0, Reason),
{stop, Reason, Context}
end.

% Recursively handle all frames that are in the buffer;
% If the last frame is incomplete, leave it in the buffer and wait for more.
handle_websocket_frame(Data, #context{}=Context) ->
#context{
handler={Handler, HState0},
wsreq=WSReq,
buffer=Buffer} = Context,
Result =
case websocket_req:remaining(WSReq) of
undefined ->
wsc_lib:decode_frame(WSReq, << Buffer/binary, Data/binary >>);
wsc_lib:decode_frame(WSReq, << Buffer/binary, Data/binary >>); %% TODO ??
Remaining ->
wsc_lib:decode_frame(WSReq, websocket_req:opcode(WSReq), Remaining, Data, Buffer)
end,
Expand All @@ -401,16 +435,18 @@ handle_info({Trans, _Socket, Data},
try
HandlerResponse = Handler:websocket_handle(Message, WSReqN, HState0),
WSReqN2 = websocket_req:remaining(undefined, WSReqN),
case handle_response(HandlerResponse, Handler, BufferN, WSReqN2) of
{ok, WSReqN2, HStateN2, BufferN2} ->
case BufferN2 of
<<>> -> ok;
_ -> self() ! {Trans, _Socket, BufferN2}
end,
{next_state, connected, Context#context{
handler={Handler, HStateN2},
wsreq=WSReqN2,
buffer= <<>>}};
case handle_response(HandlerResponse, Handler, WSReqN2) of
{ok, WSReqN2, HStateN2} ->
Context2 = Context#context{
handler = {Handler, HStateN2},
wsreq = WSReqN2,
buffer = <<>>},
case BufferN of
<<>> ->
{next_state, connected, Context2};
_ ->
handle_websocket_frame(BufferN, Context2)
end;
{close, Error, WSReqN2, Handler, HStateN2} ->
{stop, Error, Context#context{
wsreq=WSReqN2,
Expand All @@ -435,55 +471,23 @@ handle_info({Trans, _Socket, Data},
{close, _Reason, WSReqN} ->
{next_state, disconnected, Context#context{wsreq=WSReqN,
buffer= <<>>}}
end;
handle_info(Msg, State,
#context{
wsreq=WSReq,
handler={Handler, HState0},
buffer=Buffer
}=Context) ->
try Handler:websocket_info(Msg, WSReq, HState0) of
HandlerResponse ->
case handle_response(HandlerResponse, Handler, Buffer, WSReq) of
{ok, WSReqN, HStateN, BufferN} ->
{next_state, State, Context#context{
handler={Handler, HStateN},
wsreq=WSReqN,
buffer=BufferN}};
{close, Reason, WSReqN, Handler, HStateN} ->
{stop, Reason, Context#context{
wsreq=WSReqN,
handler={Handler, HStateN}}}
end
catch Class:Reason ->
%% TODO Maybe a function_clause catch here to allow
%% not having to have a catch-all clause in websocket_info CB?
error_logger:error_msg(
"** Websocket client ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
"** Last message was ~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_info, 3, Class, Reason, Msg, HState0,
erlang:get_stacktrace()]),
websocket_close(WSReq, Handler, HState0, Reason),
{stop, Reason, Context}
end.


-spec code_change(OldVsn :: term(), state_name(), #context{}, Extra :: any()) ->
{ok, state_name(), #context{}}.
code_change(_OldVsn, StateName, Context, _Extra) ->
{ok, StateName, Context}.

%% @doc Handles return values from the callback module
handle_response({ok, HandlerState}, _Handler, Buffer, WSReq) ->
{ok, WSReq, HandlerState, Buffer};
handle_response({reply, Frame, HandlerState}, Handler, Buffer, WSReq) ->
handle_response({ok, HandlerState}, _Handler, WSReq) ->
{ok, WSReq, HandlerState};
handle_response({reply, Frame, HandlerState}, Handler, WSReq) ->
case encode_and_send(Frame, WSReq) of
ok -> {ok, WSReq, HandlerState, Buffer};
ok -> {ok, WSReq, HandlerState};
Reason -> {close, Reason, WSReq, Handler, HandlerState}
end;
handle_response({close, Payload, HandlerState}, Handler, _, WSReq) ->
handle_response({close, Payload, HandlerState}, Handler, WSReq) ->
encode_and_send({close, Payload}, WSReq),
{close, normal, WSReq, Handler, HandlerState}.

Expand Down
65 changes: 29 additions & 36 deletions src/wsc_lib.erl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
%% Purely functional aspects of websocket client comms.
%%
%% Herein live all the functions for pure data processing.
-compile([export_all]).
-include("websocket_req.hrl").

-export([create_auth_header/3]).
Expand Down Expand Up @@ -75,43 +76,35 @@ consume_response(Status, Response, HeaderAcc) ->
{ok, Status, HeaderAcc, Body}
end.

%% @doc Start or continue continuation payload with length less than 126 bytes
decode_frame(WSReq, << 0:4, Opcode:4, 0:1, Len:7, Rest/bits >>)
when Len < 126 ->
WSReq1 = set_continuation_if_empty(WSReq, Opcode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, Opcode, Len, Rest, <<>>);
%% @doc Start or continue continuation payload with length a 2 byte int
decode_frame(WSReq, << 0:4, Opcode:4, 0:1, 126:7, Len:16, Rest/bits >>)
when Len > 125, Opcode < 8 ->
WSReq1 = set_continuation_if_empty(WSReq, Opcode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, Opcode, Len, Rest, <<>>);
%% @doc Start or continue continuation payload with length a 64 bit int
decode_frame(WSReq, << 0:4, Opcode:4, 0:1, 127:7, 0:1, Len:63, Rest/bits >>)
when Len > 16#ffff, Opcode < 8 ->
WSReq1 = set_continuation_if_empty(WSReq, Opcode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, Opcode, Len, Rest, <<>>);
%% @doc Length is less 126 bytes
decode_frame(WSReq, << 1:1, 0:3, Opcode:4, 0:1, Len:7, Rest/bits >>)
unpack_frame(<< Fin:1, RSV:3, OpCode:4, Mask:1, Len:7, Payload/bits >>)
when Len < 126 ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, Opcode, Len, Rest, <<>>);
%% @doc Length is a 2 byte integer
decode_frame(WSReq, << 1:1, 0:3, Opcode:4, 0:1, 126:7, Len:16, Rest/bits >>)
when Len > 125, Opcode < 8 ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, Opcode, Len, Rest, <<>>);
%% @doc Length is a 64 bit integer
decode_frame(WSReq, << 1:1, 0:3, Opcode:4, 0:1, 127:7, 0:1, Len:63, Rest/bits >>)
when Len > 16#ffff, Opcode < 8 ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, Opcode, Len, Rest, <<>>);
%% @doc Need more data to read length properly
decode_frame(WSReq, Data) ->
{recv, WSReq, Data}.
{ok, Fin, RSV, OpCode, Len, unmask_frame(Mask, Len, Payload)};
unpack_frame(<< Fin:1, RSV:3, OpCode:4, Mask:1, 126:7, Len:16, Payload/bits>>)
when Len > 125, OpCode < 8 ->
{ok, Fin, RSV, OpCode, Len, unmask_frame(Mask, Len, Payload)};
unpack_frame(<< Fin:1, RSV:3, OpCode:4, Mask:1, 127:7, 0:1, Len:63, Payload/bits>>)
when Len > 16#ffff, OpCode < 8 ->
{ok, Fin, RSV, OpCode, Len, unmask_frame(Mask, Len, Payload)};
unpack_frame(Data) ->
{incomplete, Data}.

%% @doc Start or continue continuation payload with length less than 126 bytes
decode_frame(WSReq, Frame) when is_binary(Frame) ->
case unpack_frame(Frame) of
{incomplete, Data} -> {recv, WSReq, Data};
{ok, 0, 0, OpCode, Len, Payload} ->
WSReq1 = set_continuation_if_empty(WSReq, OpCode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, OpCode, Len, Payload, <<>>);
{ok, 1, 0, OpCode, Len, Payload} ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, OpCode, Len, Payload, <<>>)
end.

unmask_frame(0, _, Payload) -> Payload;
unmask_frame(1, Len, << Mask:32, Rest/bits >>) ->
<< Payload:Len/binary, NextFrame/bits >> = Rest,
<< (mask_payload(Mask, Payload))/bits, NextFrame/binary >>.

-spec decode_frame(websocket_req:req(),
Opcode :: websocket_req:opcode(),
Expand Down Expand Up @@ -185,7 +178,7 @@ encode_frame(Type) when is_atom(Type) ->

%% @doc The payload is masked using a masking key byte by byte.
%% Can do it in 4 byte chunks to save time until there is left than 4 bytes left
mask_payload(MaskingKey, Payload) ->
mask_payload(MaskingKey, Payload) when is_integer(MaskingKey), is_binary(Payload) ->
mask_payload(MaskingKey, Payload, <<>>).
mask_payload(_, <<>>, Acc) ->
Acc;
Expand Down
16 changes: 16 additions & 0 deletions test/ws_client.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
start_link/0,
start_link/1,
start_link/2,
socket/1,
socket/2,
send_text/2,
send_binary/2,
send_ping/2,
Expand Down Expand Up @@ -43,6 +45,16 @@ start_link(Url, KeepAlive) ->
stop(Pid) ->
Pid ! stop.

socket(Pid) ->
socket(Pid, 5000).
socket(Pid, Timeout) ->
Pid ! {gimme_socket, self()},
receive
{hav_socket, Pid, Sock} -> {ok, Sock}
after Timeout ->
{error, timeout}
end.

send_text(Pid, Msg) ->
websocket_client:cast(Pid, {text, Msg}).

Expand Down Expand Up @@ -96,6 +108,10 @@ websocket_info({recv, From}, _, State = #state{buffer = [Top|Rest]}) ->
ct:pal("Sending buffer hd to: ~p {Buffer: ~p}~n", [From, [Top|Rest]]),
From ! Top,
{ok, State#state{buffer = Rest}};
websocket_info({gimme_socket, Whom}, WSReq, State) ->
Sock = websocket_req:socket(WSReq),
Whom ! {hav_socket, self(), Sock},
{ok, State};
websocket_info(stop, _, State) ->
{close, <<>>, State}.

Expand Down
Loading

0 comments on commit de2ba03

Please sign in to comment.