Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix chunked payload decoding #29

Merged
merged 6 commits into from
Jul 30, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SHELL := /bin/bash
PROJECT = websocket_client

TEST_DEPS = cowboy recon
CT_SUITES = wc
TEST_DEPS = cowboy recon proper
CT_SUITES = wc wsc_lib

include erlang.mk
128 changes: 66 additions & 62 deletions src/websocket_client.erl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
-callback onconnect(websocket_req:req(), state()) ->
% Simple client: only server-initiated pings will be
% automatically responded to.
{ok, state()}
{ok, state()}
% Keepalive client: will automatically initiate a ping to the server
% every keepalive() ms.
| {ok, state(), keepalive()}
Expand Down Expand Up @@ -340,7 +340,7 @@ handle_info({TransError, _Socket, Reason},
}=Context) ->
ok = websocket_close(WSReq, Handler, HState0, {TransError, Reason}),
{stop, {socket_error, Reason}, Context};
handle_info({Trans, Socket, Data},
handle_info({Trans, _Socket, Data},
handshaking,
#context{
transport=#transport{ name=Trans },
Expand All @@ -367,28 +367,62 @@ handle_info({Trans, Socket, Data},
Result
end,
WSReq2 = websocket_req:keepalive(KeepAlive, WSReq1),
%% TODO This is nasty and hopefully there's a nicer way
case Remaining of
<<>> -> ok;
_ -> self() ! {Trans, Socket, Remaining}
end,
{next_state, connected, Context#context{
wsreq=WSReq2,
handler={Handler, HState2},
buffer = <<>>}}
handle_websocket_frame(Remaining, Context#context{
wsreq=WSReq2,
handler={Handler, HState2},
buffer= <<>>})
end;
handle_info({Trans, _Socket, Data},
connected,
#context{
transport=#transport{ name=Trans },
handler={Handler, HState0},
transport=#transport{ name=Trans }
}=Context) ->
handle_websocket_frame(Data, Context);
handle_info(Msg, State,
#context{
wsreq=WSReq,
handler={Handler, HState0},
buffer=Buffer
}=Context) ->
try Handler:websocket_info(Msg, WSReq, HState0) of
HandlerResponse ->
case handle_response(HandlerResponse, Handler, WSReq) of
{ok, WSReqN, HStateN} ->
{next_state, State, Context#context{
handler={Handler, HStateN},
wsreq=WSReqN,
buffer=Buffer}};
{close, Reason, WSReqN, Handler, HStateN} ->
{stop, Reason, Context#context{
wsreq=WSReqN,
handler={Handler, HStateN}}}
end
catch Class:Reason ->
%% TODO Maybe a function_clause catch here to allow
%% not having to have a catch-all clause in websocket_info CB?
error_logger:error_msg(
"** Websocket client ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
"** Last message was ~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_info, 3, Class, Reason, Msg, HState0,
erlang:get_stacktrace()]),
websocket_close(WSReq, Handler, HState0, Reason),
{stop, Reason, Context}
end.

% Recursively handle all frames that are in the buffer;
% If the last frame is incomplete, leave it in the buffer and wait for more.
handle_websocket_frame(Data, #context{}=Context) ->
#context{
handler={Handler, HState0},
wsreq=WSReq,
buffer=Buffer} = Context,
Result =
case websocket_req:remaining(WSReq) of
undefined ->
wsc_lib:decode_frame(WSReq, << Buffer/binary, Data/binary >>);
wsc_lib:decode_frame(WSReq, << Buffer/binary, Data/binary >>); %% TODO ??
Remaining ->
wsc_lib:decode_frame(WSReq, websocket_req:opcode(WSReq), Remaining, Data, Buffer)
end,
Expand All @@ -401,16 +435,18 @@ handle_info({Trans, _Socket, Data},
try
HandlerResponse = Handler:websocket_handle(Message, WSReqN, HState0),
WSReqN2 = websocket_req:remaining(undefined, WSReqN),
case handle_response(HandlerResponse, Handler, BufferN, WSReqN2) of
{ok, WSReqN2, HStateN2, BufferN2} ->
case BufferN2 of
<<>> -> ok;
_ -> self() ! {Trans, _Socket, BufferN2}
end,
{next_state, connected, Context#context{
handler={Handler, HStateN2},
wsreq=WSReqN2,
buffer= <<>>}};
case handle_response(HandlerResponse, Handler, WSReqN2) of
{ok, WSReqN2, HStateN2} ->
Context2 = Context#context{
handler = {Handler, HStateN2},
wsreq = WSReqN2,
buffer = <<>>},
case BufferN of
<<>> ->
{next_state, connected, Context2};
_ ->
handle_websocket_frame(BufferN, Context2)
end;
{close, Error, WSReqN2, Handler, HStateN2} ->
{stop, Error, Context#context{
wsreq=WSReqN2,
Expand All @@ -435,55 +471,23 @@ handle_info({Trans, _Socket, Data},
{close, _Reason, WSReqN} ->
{next_state, disconnected, Context#context{wsreq=WSReqN,
buffer= <<>>}}
end;
handle_info(Msg, State,
#context{
wsreq=WSReq,
handler={Handler, HState0},
buffer=Buffer
}=Context) ->
try Handler:websocket_info(Msg, WSReq, HState0) of
HandlerResponse ->
case handle_response(HandlerResponse, Handler, Buffer, WSReq) of
{ok, WSReqN, HStateN, BufferN} ->
{next_state, State, Context#context{
handler={Handler, HStateN},
wsreq=WSReqN,
buffer=BufferN}};
{close, Reason, WSReqN, Handler, HStateN} ->
{stop, Reason, Context#context{
wsreq=WSReqN,
handler={Handler, HStateN}}}
end
catch Class:Reason ->
%% TODO Maybe a function_clause catch here to allow
%% not having to have a catch-all clause in websocket_info CB?
error_logger:error_msg(
"** Websocket client ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
"** Last message was ~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_info, 3, Class, Reason, Msg, HState0,
erlang:get_stacktrace()]),
websocket_close(WSReq, Handler, HState0, Reason),
{stop, Reason, Context}
end.


-spec code_change(OldVsn :: term(), state_name(), #context{}, Extra :: any()) ->
{ok, state_name(), #context{}}.
code_change(_OldVsn, StateName, Context, _Extra) ->
{ok, StateName, Context}.

%% @doc Handles return values from the callback module
handle_response({ok, HandlerState}, _Handler, Buffer, WSReq) ->
{ok, WSReq, HandlerState, Buffer};
handle_response({reply, Frame, HandlerState}, Handler, Buffer, WSReq) ->
handle_response({ok, HandlerState}, _Handler, WSReq) ->
{ok, WSReq, HandlerState};
handle_response({reply, Frame, HandlerState}, Handler, WSReq) ->
case encode_and_send(Frame, WSReq) of
ok -> {ok, WSReq, HandlerState, Buffer};
ok -> {ok, WSReq, HandlerState};
Reason -> {close, Reason, WSReq, Handler, HandlerState}
end;
handle_response({close, Payload, HandlerState}, Handler, _, WSReq) ->
handle_response({close, Payload, HandlerState}, Handler, WSReq) ->
encode_and_send({close, Payload}, WSReq),
{close, normal, WSReq, Handler, HandlerState}.

Expand Down
65 changes: 29 additions & 36 deletions src/wsc_lib.erl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
%% Purely functional aspects of websocket client comms.
%%
%% Herein live all the functions for pure data processing.
-compile([export_all]).
-include("websocket_req.hrl").

-export([create_auth_header/3]).
Expand Down Expand Up @@ -75,43 +76,35 @@ consume_response(Status, Response, HeaderAcc) ->
{ok, Status, HeaderAcc, Body}
end.

%% @doc Start or continue continuation payload with length less than 126 bytes
decode_frame(WSReq, << 0:4, Opcode:4, 0:1, Len:7, Rest/bits >>)
when Len < 126 ->
WSReq1 = set_continuation_if_empty(WSReq, Opcode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, Opcode, Len, Rest, <<>>);
%% @doc Start or continue continuation payload with length a 2 byte int
decode_frame(WSReq, << 0:4, Opcode:4, 0:1, 126:7, Len:16, Rest/bits >>)
when Len > 125, Opcode < 8 ->
WSReq1 = set_continuation_if_empty(WSReq, Opcode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, Opcode, Len, Rest, <<>>);
%% @doc Start or continue continuation payload with length a 64 bit int
decode_frame(WSReq, << 0:4, Opcode:4, 0:1, 127:7, 0:1, Len:63, Rest/bits >>)
when Len > 16#ffff, Opcode < 8 ->
WSReq1 = set_continuation_if_empty(WSReq, Opcode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, Opcode, Len, Rest, <<>>);
%% @doc Length is less 126 bytes
decode_frame(WSReq, << 1:1, 0:3, Opcode:4, 0:1, Len:7, Rest/bits >>)
unpack_frame(<< Fin:1, RSV:3, OpCode:4, Mask:1, Len:7, Payload/bits >>)
when Len < 126 ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, Opcode, Len, Rest, <<>>);
%% @doc Length is a 2 byte integer
decode_frame(WSReq, << 1:1, 0:3, Opcode:4, 0:1, 126:7, Len:16, Rest/bits >>)
when Len > 125, Opcode < 8 ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, Opcode, Len, Rest, <<>>);
%% @doc Length is a 64 bit integer
decode_frame(WSReq, << 1:1, 0:3, Opcode:4, 0:1, 127:7, 0:1, Len:63, Rest/bits >>)
when Len > 16#ffff, Opcode < 8 ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, Opcode, Len, Rest, <<>>);
%% @doc Need more data to read length properly
decode_frame(WSReq, Data) ->
{recv, WSReq, Data}.
{ok, Fin, RSV, OpCode, Len, unmask_frame(Mask, Len, Payload)};
unpack_frame(<< Fin:1, RSV:3, OpCode:4, Mask:1, 126:7, Len:16, Payload/bits>>)
when Len > 125, OpCode < 8 ->
{ok, Fin, RSV, OpCode, Len, unmask_frame(Mask, Len, Payload)};
unpack_frame(<< Fin:1, RSV:3, OpCode:4, Mask:1, 127:7, 0:1, Len:63, Payload/bits>>)
when Len > 16#ffff, OpCode < 8 ->
{ok, Fin, RSV, OpCode, Len, unmask_frame(Mask, Len, Payload)};
unpack_frame(Data) ->
{incomplete, Data}.

%% @doc Start or continue continuation payload with length less than 126 bytes
decode_frame(WSReq, Frame) when is_binary(Frame) ->
case unpack_frame(Frame) of
{incomplete, Data} -> {recv, WSReq, Data};
{ok, 0, 0, OpCode, Len, Payload} ->
WSReq1 = set_continuation_if_empty(WSReq, OpCode),
WSReq2 = websocket_req:fin(0, WSReq1),
decode_frame(WSReq2, OpCode, Len, Payload, <<>>);
{ok, 1, 0, OpCode, Len, Payload} ->
WSReq1 = websocket_req:fin(1, WSReq),
decode_frame(WSReq1, OpCode, Len, Payload, <<>>)
end.

unmask_frame(0, _, Payload) -> Payload;
unmask_frame(1, Len, << Mask:32, Rest/bits >>) ->
<< Payload:Len/binary, NextFrame/bits >> = Rest,
<< (mask_payload(Mask, Payload))/bits, NextFrame/binary >>.

-spec decode_frame(websocket_req:req(),
Opcode :: websocket_req:opcode(),
Expand Down Expand Up @@ -185,7 +178,7 @@ encode_frame(Type) when is_atom(Type) ->

%% @doc The payload is masked using a masking key byte by byte.
%% Can do it in 4 byte chunks to save time until there is left than 4 bytes left
mask_payload(MaskingKey, Payload) ->
mask_payload(MaskingKey, Payload) when is_integer(MaskingKey), is_binary(Payload) ->
mask_payload(MaskingKey, Payload, <<>>).
mask_payload(_, <<>>, Acc) ->
Acc;
Expand Down
16 changes: 16 additions & 0 deletions test/ws_client.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
start_link/0,
start_link/1,
start_link/2,
socket/1,
socket/2,
send_text/2,
send_binary/2,
send_ping/2,
Expand Down Expand Up @@ -43,6 +45,16 @@ start_link(Url, KeepAlive) ->
stop(Pid) ->
Pid ! stop.

socket(Pid) ->
socket(Pid, 5000).
socket(Pid, Timeout) ->
Pid ! {gimme_socket, self()},
receive
{hav_socket, Pid, Sock} -> {ok, Sock}
after Timeout ->
{error, timeout}
end.

send_text(Pid, Msg) ->
websocket_client:cast(Pid, {text, Msg}).

Expand Down Expand Up @@ -96,6 +108,10 @@ websocket_info({recv, From}, _, State = #state{buffer = [Top|Rest]}) ->
ct:pal("Sending buffer hd to: ~p {Buffer: ~p}~n", [From, [Top|Rest]]),
From ! Top,
{ok, State#state{buffer = Rest}};
websocket_info({gimme_socket, Whom}, WSReq, State) ->
Sock = websocket_req:socket(WSReq),
Whom ! {hav_socket, self(), Sock},
{ok, State};
websocket_info(stop, _, State) ->
{close, <<>>, State}.

Expand Down
Loading