Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth #99

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Auth #99

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ encoding-test: all
(cd test; ct_run -suite latin_SUITE utf8_SUITE utf8_to_latindb_SUITE latin_to_utf8db_SUITE -pa ../ebin $(CRYPTO_PATH))

test: all
(cd test; ct_run -suite environment_SUITE basics_SUITE conn_mgr_SUITE -pa ../ebin $(CRYPTO_PATH))
(cd test; ct_run -suite emysql_auth_SUITE environment_SUITE basics_SUITE conn_mgr_SUITE -cover emysql.cover -pa ../ebin $(CRYPTO_PATH))

test20: all
(cd test; ct_run -suite pool_SUITE -pa ../ebin $(CRYPTO_PATH))
Expand Down
7 changes: 6 additions & 1 deletion include/emysql.hrl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

-record(pool, {pool_id, size, user, password, host, port, database, encoding, available=queue:new(), locked=gb_trees:empty(), waiting=queue:new(), start_cmds=[], conn_test_period=0}).
-record(emysql_connection, {id, pool_id, encoding, socket, version, thread_id, caps, language, prepared=gb_trees:empty(), locked_at, alive=true, test_period=0, last_test_time=0, monitor_ref}).
-record(greeting, {protocol_version, server_version, thread_id, salt1, salt2, caps, caps_high, language, status, seq_num, plugin}).
-record(greeting, {protocol_version, server_version, thread_id, salt1, salt2, caps, caps_high, language, status, seq_num, plugin, server_caps, salt}).
-record(field, {seq_num, catalog, db, table, org_table, name, org_name, type, default, charset_nr, length, flags, decimals, decoder}).
-record(packet, {size, seq_num, data}).
-record(ok_packet, {seq_num, affected_rows, insert_id, status, warning_count, msg}).
Expand All @@ -49,6 +49,11 @@
-define(SECURE_CONNECTION, 32768).
-define(CONNECT_WITH_DB, 8).
-define(CONN_TEST_PERIOD, 28000).
-define(CLIENT_PLUGIN_AUTH, 16#00080000).
-define(MAX_PACKET_SIZE, 16#01000000).
-define(CLIENT_COMPRESS, 16#00000020).
-define(CLIENT_SSL, 16#00000800).
-define(CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, 16#00200000).

%% MYSQL COMMANDS
-define(COM_SLEEP, 16#00).
Expand Down
214 changes: 112 additions & 102 deletions src/emysql_auth.erl
Original file line number Diff line number Diff line change
Expand Up @@ -30,108 +30,118 @@
-include("emysql.hrl").

do_handshake(Sock, User, Password) ->
%-% io:format("~p handshake: recv_greeting~n", [self()]),
Greeting = recv_greeting(Sock),
%-% io:format("~p handshake: auth~n", [self()]),
case auth(Sock, Greeting#greeting.seq_num+1, User, Password,
Greeting#greeting.salt1, Greeting#greeting.salt2, Greeting#greeting.plugin) of
OK when is_record(OK, ok_packet) ->
%-% io:format("~p handshake: ok~n", [self()]),
ok;
Err when is_record(Err, error_packet) ->
%-% io:format("~p handshake: FAIL ~p -> EXIT ~n~n", [self(), Err]),
exit({failed_to_authenticate, Err});
Other ->
%-% io:format("~p handshake: UNEXPECTED ~p -> EXIT ~n~n", [self(), Other]),
exit({unexpected_packet, Other})
end,
Greeting.

recv_greeting(Sock) ->
%-% io:format("~p recv_greeting~n", [self()]),
{GreetingPacket,Unparsed} = emysql_tcp:recv_packet(Sock, emysql_app:default_timeout(), <<>>),
%-% io:format("~p recv_greeting ... received ...~n", [self()]),
case GreetingPacket#packet.data of
<<255, _/binary>> ->
% io:format("error: ", []),
{{#error_packet{
code = Code,
msg = Msg
},_}, _Rest} = emysql_tcp:response(Sock, emysql_app:default_timeout(), GreetingPacket, Unparsed),
% io:format("exit: ~p~n-------------~p~n", [Code, Msg]),
exit({Code, Msg});
<<ProtocolVersion:8/integer, Rest1/binary>> ->
% io:format("prl v: ~p~n-------------~p~n", [ProtocolVersion, Rest1]),
{ServerVersion, Rest2} = emysql_util:asciz(Rest1),
% io:format("srv v: ~p~n-------------~p~n", [ServerVersion, Rest2]),
<<ThreadID:32/little, Rest3/binary>> = Rest2,
% io:format("tread id: ~p~n-------------~p~n", [ThreadID, Rest3]),
{Salt, Rest4} = emysql_util:asciz(Rest3),
% io:format("salt: ~p~n-------------~p~n", [Salt, Rest4]),
<<ServerCaps:16/little, Rest5/binary>> = Rest4,
% io:format("caps: ~p~n-------------~p~n", [ServerCaps, Rest5]),
<<ServerLanguage:8/little,
ServerStatus:16/little,
ServerCapsHigh:16/little,
ScrambleLength:8/little,
_:10/binary-unit:8,
Rest6/binary>> = Rest5,
% io:format("lang: ~p, status: ~p, caps hi: ~p, salt len: ~p~n-------------~p ~n", [ServerLanguage, ServerStatus, ServerCapsHigh, ScrambleLength, Rest6]),
Salt2Length = case ScrambleLength of 0 -> 13; _-> ScrambleLength - 8 end,
<<Salt2Bin:Salt2Length/binary-unit:8, Plugin/binary>> = Rest6,
{Salt2, <<>>} = emysql_util:asciz(Salt2Bin),
% io:format("salt 2: ~p~n", [Salt2]),
% io:format("plugin: ~p~n", [Plugin]),
#greeting{
protocol_version = ProtocolVersion,
server_version = ServerVersion,
thread_id = ThreadID,
salt1 = Salt,
salt2 = Salt2,
caps = ServerCaps,
caps_high = ServerCapsHigh,
language = ServerLanguage,
status = ServerStatus,
seq_num = GreetingPacket#packet.seq_num,
plugin = Plugin
};
What ->
%-% io:format("~p recv_greeting FAILED: ~p~n", [self(), What]),
exit({greeting_failed, What})
end.

parse_server_version(Version) ->
[A,B,C] = string:tokens(Version, "."),
{list_to_integer(A), list_to_integer(B), list_to_integer(C)}.

auth(Sock, SeqNum, User, Password, Salt1, Salt2, Plugin) ->
ScrambleBuff = if
is_list(Password) orelse is_binary(Password) ->
case Plugin of
?MYSQL_OLD_PASSWORD ->
password_old(Password, Salt1 ++ Salt2); % untested
_ ->
password_new(Password, Salt1 ++ Salt2)
end;
true ->
<<>>
end,
DBCaps = 0,
DatabaseB = <<>>,
Caps = ?LONG_PASSWORD bor ?CLIENT_LOCAL_FILE bor ?LONG_FLAG bor ?TRANSACTIONS bor
?CLIENT_MULTI_STATEMENTS bor ?CLIENT_MULTI_RESULTS bor
?PROTOCOL_41 bor ?SECURE_CONNECTION bor DBCaps,
Maxsize = ?MAXPACKETBYTES,
UserB = unicode:characters_to_binary(User),
PasswordL = size(ScrambleBuff),
Packet = <<Caps:32/little, Maxsize:32/little, 8:8, 0:23/integer-unit:8, UserB/binary, 0:8, PasswordL:8, ScrambleBuff/binary, DatabaseB/binary>>,
case emysql_tcp:send_and_recv_packet(Sock, Packet, SeqNum) of
#eof_packet{seq_num = SeqNum1} ->
AuthOld = password_old(Password, Salt1),
emysql_tcp:send_and_recv_packet(Sock, <<AuthOld/binary, 0:8>>, SeqNum1+1);
Result ->
Result
end.
{InitialPacket, _Unparsed} = emysql_tcp:recv_packet(Sock, emysql_app:default_timeout(), <<>>),

% plain handshake
{InitialParsed, ResponsePacket, ClientCaps} = handshake_response_packet(InitialPacket, User, Password),
ok = case emysql_tcp:send_and_recv_packet(Sock, ResponsePacket, InitialParsed#greeting.seq_num+1, ClientCaps) of
OK when is_record(OK, ok_packet) ->
ok;
Err when is_record(Err, error_packet) ->
exit({failed_to_authenticate, Err});
Other ->
exit({unexpected_packet, Other})
end,
% io:format("INITIAL: ~p~n", [InitialPacket]),
% io:format("RESPONSE: ~p~n", [ResponsePacket]),

InitialParsed.

initial_handshake_packet(Packet) ->
Data = initial_handshake_packet_data(Packet#packet.data),
Data#greeting{ seq_num = Packet#packet.seq_num }.

initial_handshake_packet_data(<<10:8/little, Rest/binary>>) ->
{ServerVersion, Rest1} = emysql_util:asciz(Rest),
<<ThreadID:32/little, Salt1Bin:8/binary-unit:8, 0:8, ServerCapsLow:16/little, Rest2/binary>> = Rest1,
Salt1 = binary_to_list(Salt1Bin),

P = #greeting{
protocol_version = 10,
server_version = ServerVersion,
thread_id = ThreadID,
salt1 = Salt1,
caps = ServerCapsLow,
server_caps = ServerCapsLow,
salt = Salt1
},

case Rest2 of
<<>> -> P;

<<ServerLanguage:8/little,ServerStatus:16/little,ServerCapsHigh:16/little,ScrambleLength:8/little,0:80,Rest3/binary>> ->
<<ServerCaps:32/little>> = <<ServerCapsLow:16/little, ServerCapsHigh:16/little>>,
P2 = P#greeting{ language=ServerLanguage, status=ServerStatus, caps_high=ServerCapsHigh, server_caps=ServerCaps },
case ServerCaps band ?SECURE_CONNECTION of
0 -> P2;
?SECURE_CONNECTION ->
Salt2Length = max(13, ScrambleLength-8) - 1,
<<Salt2Bin:Salt2Length/binary-unit:8, 0:8, Rest4/binary>> = Rest3,
Salt2 = binary_to_list(Salt2Bin),
{Plugin, _} = emysql_util:asciz(Rest4),
P2#greeting{ salt2=Salt2, plugin=Plugin, salt=Salt1 ++ Salt2 }
end
end.

client_caps(ServerCaps) ->
L = [?LONG_FLAG, ?TRANSACTIONS, ?SECURE_CONNECTION, ?CLIENT_MULTI_RESULTS, ?PROTOCOL_41, ?CLIENT_PLUGIN_AUTH, ?CLIENT_MULTI_STATEMENTS],
ClientCaps = lists:foldl(fun(A, B) -> A bor B end, 0, L),
Caps = ClientCaps band ServerCaps,

LongPassword = if ServerCaps band ?PROTOCOL_41 == 0 -> [?LONG_PASSWORD]; true -> [] end,
L2 = LongPassword,
lists:foldl(fun(A, B) -> A bor B end, Caps, L2).

handshake_response_packet(InitialPacket, User, Password) ->
InitialParsed = initial_handshake_packet(InitialPacket),
ClientCaps = client_caps(InitialParsed#greeting.server_caps),
ResponsePacket = handshake_response_packet_data(User, Password, InitialParsed#greeting.salt, ClientCaps, InitialParsed#greeting.plugin),
{InitialParsed, ResponsePacket, ClientCaps}.

handshake_response_packet_data(User, Password, Salt, ClientCaps, Plugin) ->
case ClientCaps band ?PROTOCOL_41 of
0 ->
handshake_response_320(User, Password, Salt, ClientCaps);
?PROTOCOL_41 ->
handshake_response_41(User, Password, Salt, ClientCaps, Plugin)
end.

handshake_response_320(User, Password, Salt, ClientCaps) ->
AuthResponse = password_old(Password, Salt),
UserB = list_to_binary(User),
<<ClientCaps:16/little, 0:24, UserB/binary, 0:8, AuthResponse/binary, 0:8>>.

handshake_response_41(User, Password, Salt, ClientCaps, Plugin) ->
AuthResponse = password_new(Password, Salt),
UserB = unicode:characters_to_binary(User),
Database = <<>>,

Common = <<ClientCaps:32/little, ?MAX_PACKET_SIZE:32/little, 8:8, 0:23/integer-unit:8, UserB/binary, 0:8>>,

AuthResponseB = if
ClientCaps band ?CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA > 0 ->
<<(emysql_util:lenenc_integer(size(AuthResponse)))/binary, AuthResponse/binary>>;
ClientCaps band ?SECURE_CONNECTION > 0 ->
<<(size(AuthResponse)):8, AuthResponse/binary>>;
true ->
<<AuthResponse/binary, 0:8>>
end,

DatabaseB = if
ClientCaps band ?CONNECT_WITH_DB > 0 ->
<<Database/binary, 0:8>>;
true ->
<<>>
end,

AuthPluginB = if
ClientCaps band ?CLIENT_PLUGIN_AUTH > 0 ->
<<(list_to_binary(Plugin))/binary, 0:8>>;
true ->
<<>>
end,

<<Common/binary, AuthResponseB/binary, DatabaseB/binary, AuthPluginB/binary>>.

password_new([], _Salt) ->
<<>>;
Expand Down
31 changes: 17 additions & 14 deletions src/emysql_conn.erl
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,31 @@ set_database(_, undefined) -> ok;
set_database(_, Empty) when Empty == ""; Empty == <<>> -> ok;
set_database(Connection, Database) ->
Packet = <<?COM_QUERY, "use `", (iolist_to_binary(Database))/binary, "`">>, % todo: utf8?
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0).
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps).

set_encoding(_, undefined) ->
ok;
set_encoding(Connection, Encoding) ->
Packet = <<?COM_QUERY, "set names '", (erlang:atom_to_binary(Encoding, utf8))/binary, "'">>,
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0).
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps).

execute(Connection, Query, []) when is_list(Query) ->
%-% io:format("~p execute list: ~p using connection: ~p~n", [self(), iolist_to_binary(Query), Connection#emysql_connection.id]),
Packet = <<?COM_QUERY, (emysql_util:to_binary(Query))/binary>>,
% Packet = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0);
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps);

execute(Connection, Query, []) when is_binary(Query) ->
%-% io:format("~p execute binary: ~p using connection: ~p~n", [self(), Query, Connection#emysql_connection.id]),
Packet = <<?COM_QUERY, Query/binary>>,
% Packet = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0);
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps);

execute(Connection, StmtName, []) when is_atom(StmtName) ->
prepare_statement(Connection, StmtName),
StmtNameBin = atom_to_binary(StmtName, utf8),
Packet = <<?COM_QUERY, "EXECUTE ", StmtNameBin/binary>>,
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0);
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps);

execute(Connection, Query, Args) when (is_list(Query) orelse is_binary(Query)) andalso is_list(Args) ->
StmtName = "stmt_"++integer_to_list(erlang:phash2(Query)),
Expand All @@ -74,7 +74,7 @@ execute(Connection, Query, Args) when (is_list(Query) orelse is_binary(Query)) a
OK when is_record(OK, ok_packet) ->
ParamNamesBin = list_to_binary(string:join([[$@ | integer_to_list(I)] || I <- lists:seq(1, length(Args))], ", ")), % todo: utf8?
Packet = <<?COM_QUERY, "EXECUTE ", (list_to_binary(StmtName))/binary, " USING ", ParamNamesBin/binary>>, % todo: utf8?
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0);
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps);
Error ->
Error
end,
Expand All @@ -88,7 +88,7 @@ execute(Connection, StmtName, Args) when is_atom(StmtName), is_list(Args) ->
ParamNamesBin = list_to_binary(string:join([[$@ | integer_to_list(I)] || I <- lists:seq(1, length(Args))], ", ")), % todo: utf8?
StmtNameBin = atom_to_binary(StmtName, utf8),
Packet = <<?COM_QUERY, "EXECUTE ", StmtNameBin/binary, " USING ", ParamNamesBin/binary>>,
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0);
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps);
Error ->
Error
end.
Expand All @@ -98,7 +98,7 @@ prepare(Connection, Name, Statement) when is_atom(Name) ->
prepare(Connection, Name, Statement) ->
StatementBin = emysql_util:encode(Statement, binary),
Packet = <<?COM_QUERY, "PREPARE ", (list_to_binary(Name))/binary, " FROM ", StatementBin/binary>>, % todo: utf8?
case emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0) of
case emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps) of
OK when is_record(OK, ok_packet) ->
ok;
Err when is_record(Err, error_packet) ->
Expand All @@ -109,7 +109,7 @@ unprepare(Connection, Name) when is_atom(Name)->
unprepare(Connection, atom_to_list(Name));
unprepare(Connection, Name) ->
Packet = <<?COM_QUERY, "DEALLOCATE PREPARE ", (list_to_binary(Name))/binary>>, % todo: utf8?
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0).
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps).

open_n_connections(PoolId, N) ->
case emysql_conn_mgr:find_pool(PoolId, emysql_conn_mgr:pools()) of
Expand Down Expand Up @@ -180,7 +180,10 @@ open_connection(#pool{pool_id=PoolId, host=Host, port=Port, user=User,
},
%%-% io:format("~p open connection: ... set db ...~n", [self()]),
ok = set_database_or_die(Connection, Database),
ok = set_encoding_or_die(Connection, Encoding),
ok = case Caps band ?PROTOCOL_41 of
?PROTOCOL_41 -> set_encoding_or_die(Connection, Encoding);
_ -> ok
end,
ok = run_startcmds_or_die(Connection, StartCmds),
ok = give_manager_control(Sock),
Connection;
Expand Down Expand Up @@ -216,11 +219,11 @@ set_database_or_die(#emysql_connection { socket = Socket } = Connection, Databas
exit({failed_to_set_database, Err1#error_packet.msg})
end.

run_startcmds_or_die(#emysql_connection{socket=Socket}, StartCmds) ->
run_startcmds_or_die(#emysql_connection{socket=Socket, caps=Caps}, StartCmds) ->
lists:foreach(
fun(Cmd) ->
Packet = <<?COM_QUERY, Cmd/binary>>,
case emysql_tcp:send_and_recv_packet(Socket, Packet, 0) of
case emysql_tcp:send_and_recv_packet(Socket, Packet, 0, Caps) of
OK when OK =:= ok orelse is_record(OK, ok_packet) ->
ok;
#error_packet{msg=Msg} ->
Expand Down Expand Up @@ -281,7 +284,7 @@ close_connection(Conn) ->
ok = gen_tcp:close(Conn#emysql_connection.socket).

test_connection(Conn, StayLocked) ->
case catch emysql_tcp:send_and_recv_packet(Conn#emysql_connection.socket, <<?COM_PING>>, 0) of
case catch emysql_tcp:send_and_recv_packet(Conn#emysql_connection.socket, <<?COM_PING>>, 0, Conn#emysql_connection.caps) of
{'EXIT', _} ->
case reset_connection(emysql_conn_mgr:pools(), Conn, StayLocked) of
NewConn when is_record(NewConn, emysql_connection) ->
Expand Down Expand Up @@ -313,7 +316,7 @@ now_seconds() ->
set_params(_, _, [], Result) -> Result;
set_params(Connection, Num, Values, _) ->
Packet = set_params_packet(Num, Values),
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0).
emysql_tcp:send_and_recv_packet(Connection#emysql_connection.socket, Packet, 0, Connection#emysql_connection.caps).

set_params_packet(NumStart, Values) ->
BinValues = [emysql_util:encode(Val, binary) || Val <- Values],
Expand Down
Loading