Skip to content

Commit

Permalink
Merge pull request #514 from erszcz/fix-intersection-typed-function-c…
Browse files Browse the repository at this point in the history
…alls

Fix intersection-typed function calls with union-typed arguments
  • Loading branch information
erszcz committed Mar 1, 2023
2 parents 182be65 + 792d13f commit ba4476f
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 134 deletions.
33 changes: 0 additions & 33 deletions priv/prelude/erlang.specs.erl
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,6 @@

-spec erlang:'--'(list(), list()) -> list().

%% The original spec is:
%%
%% -spec erlang:'++'(list(), term()) -> term().
%%
%% Now, this is funny:
%%
%% > [] ++ b.
%% b
%% > [a] ++ b.
%% [a|b]
%% > [a, b] ++ c.
%% [a,b|c]
%% > [a|b] ++ c.
%% ** exception error: bad argument
%% in operator ++/2
%% called as [a|b] ++ c
%% > [] ++ [a].
%% [a]
%% > [a,b] ++ [c].
%% [a,b,c]
%% > [a|b] ++ [c].
%% ** exception error: bad argument
%% in operator ++/2
%% called as [a|b] ++ [c]
%%
-spec erlang:'++'([], T) -> T;
([T1, ...], [T2]) -> [T1 | T2, ...];
([T1], [T2]) -> [T1 | T2];
([T1, ...], nonempty_improper_list(T2, T3)) -> nonempty_improper_list(T1 | T2, T3);
([T1], nonempty_improper_list(T2, T3)) -> nonempty_improper_list(T1 | T2, T3);
([T1, ...], T2) -> nonempty_improper_list(T1, T2).


%% Prior to OTP 24.1 the spec does not list `none' as valid `Args',
%% but the function accepts it and works properly.
-spec erlang:error(Reason, Args) -> no_return() when
Expand Down
11 changes: 5 additions & 6 deletions priv/prelude/filename.specs.erl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
(string() | atom() | deep_list()) -> string().

-type name() :: string() | atom() | deep_list().
-type bname() :: string() | atom() | deep_list() | binary().

-spec join([name()]) -> string();
([bname()]) -> binary().
([binary()]) -> binary().

-spec join(name(), name()) -> string();
(bname(), name()) -> binary();
(name(), bname()) -> binary();
(bname(), bname()) -> binary().
-spec join(name(), name()) -> string();
(binary(), name()) -> binary();
(name(), binary()) -> binary();
(binary(), binary()) -> binary().
4 changes: 0 additions & 4 deletions priv/prelude/lists.specs.erl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
-spec foldl(fun((T, Acc) -> Acc), Acc, [T]) -> Acc.
-spec foldr(fun((T, Acc) -> Acc), Acc, [T]) -> Acc.

%% Preserve the (non)empty property of the input list.
-spec map(fun((A) -> B), [A, ...]) -> [B, ...];
(fun((A) -> B), [A]) -> [B].

%% -spec mapfoldl(Fun, Acc0, List1) -> {List2, Acc1} when
%% Fun :: fun((A, AccIn) -> {B, AccOut}),
%% Acc0 :: term(),
Expand Down
2 changes: 1 addition & 1 deletion src/absform.erl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ normalize_record_field({typed_record_field,
-spec normalize_function_type_list(FunTypeList) -> FunTypeList when
FunTypeList :: gradualizer_type:af_function_type_list().
normalize_function_type_list(FunTypeList) ->
lists:map(fun normalize_function_type/1, FunTypeList).
?assert_type(lists:map(fun normalize_function_type/1, FunTypeList), nonempty_list()).

-spec normalize_function_type(BoundedFun | Fun) -> BoundedFun when
BoundedFun :: gradualizer_type:af_constrained_function_type(),
Expand Down
42 changes: 23 additions & 19 deletions src/gradualizer_cache.erl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,19 @@

%% API
-export([start_link/1,
get_glb/3,
store_glb/4
]).
get/2,
store/3]).

%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).

-define(SERVER, ?MODULE).
-define(GLB_CACHE, gradualizer_glb_cache).
-define(SUB_CACHE, gradualizer_sub_cache).

-record(state, {}).

%% give shorter alias
-type type() :: gradualizer_type:abstract_type().

%%===================================================================
%% API
%%===================================================================
Expand All @@ -41,26 +38,32 @@ start_link(Opts) ->
%% GLB Cache
%%

-spec get_glb(module(), type(), type()) -> false | {type(), constraints:t()}.
get_glb(Module, T1, T2) ->
try ets:lookup(?GLB_CACHE, {Module, T1, T2}) of
-spec get(atom(), any()) -> none | {some, any()}.
get(glb, Key) -> get_(?GLB_CACHE, Key);
get(subtype, Key) -> get_(?SUB_CACHE, Key).

get_(Cache, Key) ->
try ets:lookup(Cache, Key) of
[] ->
false;
[{_, TyCs}] ->
TyCs
none;
[{_, Value}] ->
{some, Value}
catch error:badarg ->
%% cache not initialized
false
%% cache not initialized
none
end.

-spec store_glb(module(), type(), type(), {type(), constraints:t()}) -> ok.
store_glb(Module, T1, T2, TyCs) ->
-spec store(atom(), any(), any()) -> ok.
store(glb, Key, Value) -> store_(?GLB_CACHE, Key, Value);
store(subtype, Key, Value) -> store_(?SUB_CACHE, Key, Value).

store_(Cache, Key, Value) ->
try
ets:insert(?GLB_CACHE, {{Module, T1, T2}, TyCs}),
ets:insert(Cache, {Key, Value}),
ok
catch error:badarg ->
%% cache not initialized
ok
%% cache not initialized
ok
end.

%%===================================================================
Expand All @@ -69,6 +72,7 @@ store_glb(Module, T1, T2, TyCs) ->

init([_Opts]) ->
ets:new(?GLB_CACHE, [set, public, named_table]),
ets:new(?SUB_CACHE, [set, public, named_table]),
{ok, #state{}}.

handle_call(_Request, _From, State) ->
Expand Down
10 changes: 9 additions & 1 deletion src/gradualizer_lib.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
empty_tenv/0, create_tenv/3,
remove_pos_typed_record_field/1,
ensure_form_list/1,
zipn/1]).
zipn/1,
cartesian_product/1]).
-export_type([graph/1, tenv/0]).

-type type() :: gradualizer_type:abstract_type().
Expand Down Expand Up @@ -329,3 +330,10 @@ zipn([], Acc) ->
[ lists:reverse(Zipped) || Zipped <- Acc ];
zipn([L | Ls], Acc) ->
zipn(Ls, lists:zipwith(fun (Z, Zs) -> [Z | Zs] end, L, Acc)).

cartesian_product(ListOfLists) ->
lists:foldr(fun (L, []) ->
[ [E] || E <- L ];
(L2, Acc) ->
[ [E | L1] || L1 <- Acc, E <- L2 ]
end, [], ListOfLists).
122 changes: 84 additions & 38 deletions src/typechecker.erl
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,21 @@ compatible(Ty1, Ty2, Env) ->

-spec subtype(type(), type(), env()) -> compatible().
subtype(Ty1, Ty2, Env) ->
try compat(Ty1, Ty2, maps:new(), Env) of
{_Memoization, Constraints} ->
{true, Constraints}
catch
nomatch ->
false
Module = maps:get(module, Env#env.tenv),
case gradualizer_cache:get(?FUNCTION_NAME, {Module, Ty1, Ty2}) of
none ->
R = try compat(Ty1, Ty2, maps:new(), Env) of
{_Memoization, Constraints} ->
{true, Constraints}
catch
nomatch ->
false
end,
gradualizer_cache:store(?FUNCTION_NAME, {Module, Ty1, Ty2}, R),
R;
{some, R} ->
%% these two types have already been seen and calculated
R
end.

%% Check if at least one of the types in a list is a subtype of a type.
Expand Down Expand Up @@ -571,15 +580,15 @@ glb(T1, T2, A, Env) ->
true -> {type(none), constraints:empty()};
false ->
Module = maps:get(module, Env#env.tenv),
case gradualizer_cache:get_glb(Module, T1, T2) of
false ->
case gradualizer_cache:get(?FUNCTION_NAME, {Module, T1, T2}) of
none ->
Ty1 = normalize(T1, Env),
Ty2 = normalize(T2, Env),
{Ty, Cs} = glb_ty(Ty1, Ty2, A#{ {T1, T2} => 0 }, Env),
NormTy = normalize(Ty, Env),
gradualizer_cache:store_glb(Module, T1, T2, {NormTy, Cs}),
gradualizer_cache:store(?FUNCTION_NAME, {Module, T1, T2}, {NormTy, Cs}),
{NormTy, Cs};
TyCs ->
{some, TyCs} ->
%% these two types have already been seen and calculated
TyCs
end
Expand Down Expand Up @@ -2196,8 +2205,8 @@ type_check_call_ty(Env, {fun_ty, ArgsTy, ResTy, Cs}, Args, E) ->
P = element(2, E),
throw(argument_length_mismatch(P, arity(LenTy), arity(LenArgs)))
end;
type_check_call_ty(Env, {fun_ty_intersection, Tyss, Cs}, Args, E) ->
{ResTy, VarBinds, CsI} = type_check_call_ty_intersect(Env, Tyss, Args, E),
type_check_call_ty(Env, {fun_ty_intersection, ClauseTys, Cs}, Args, E) ->
{ResTy, VarBinds, CsI} = type_check_call_ty_intersect(Env, ClauseTys, Args, E),
{ResTy, VarBinds, constraints:combine(Cs, CsI)};
type_check_call_ty(Env, {fun_ty_union, Tyss, Cs}, Args, E) ->
{ResTy, VarBinds, CsI} = type_check_call_ty_union(Env, Tyss, Args, E),
Expand All @@ -2206,18 +2215,44 @@ type_check_call_ty(_Env, {type_error, _}, _Args, {Name, _P, FunTy}) ->
throw(type_error(Name, FunTy, type('fun'))).

-spec type_check_call_ty_intersect(env(), _, _, _) -> {type(), env(), constraints:t()}.
type_check_call_ty_intersect(Env, [], Args, {Name, P, FunTy}) ->
throw(type_error(call_intersect, P, Name, FunTy, infer_arg_types(Args, Env)));
type_check_call_ty_intersect(Env, [Ty | Tys], Args, E) ->
try
type_check_call_ty(Env, Ty, Args, E)
catch
Error when element(1,Error) == type_error ->
type_check_call_ty_intersect(Env, Tys, Args, E)
type_check_call_ty_intersect(Env, ClauseTys, Args, E = {Name, P, FunTy}) ->
check_call_arity(hd(ClauseTys), Args, E),
ArgTypes = infer_arg_types(Args, Env),
ArgExpandedUnions = lists:map(fun
(?type(union, Tys)) -> Tys;
(Ty) -> [Ty]
end, ArgTypes),
ArgTyCombinations = gradualizer_lib:cartesian_product(ArgExpandedUnions),
Matches = [ {Clause, Cs}
|| {fun_ty, ClauseParamTys, _, _} = Clause <- ClauseTys,
ArgTys <- ArgTyCombinations,
{true, Cs} <- [
lists:foldl(fun
(_, false) -> false;
({ArgTy, ParamTy}, {true, AccCs}) ->
case subtype(ArgTy, ParamTy, Env) of
false -> false;
{true, Cs} ->
{true, constraints:combine(Cs, AccCs)}
end
end,
{true, constraints:empty()},
lists:zip(ArgTys, ClauseParamTys))
] ],
NMatches = length(Matches),
NArgTyCombinations = length(ArgTyCombinations),
if
NMatches < NArgTyCombinations ->
throw(type_error(call_intersect, P, Name, FunTy, ArgTypes));
NMatches >= NArgTyCombinations ->
{MatchingClauses, Css} = lists:unzip(Matches),
{ResTys, Css1} = lists:unzip([ {ResTy, Cs} || {fun_ty, _, ResTy, Cs} <- MatchingClauses ]),
{lub(ResTys, Env), Env, constraints:combine(Css ++ Css1)}
end.

-spec infer_arg_types([expr()], env()) -> [type()].
infer_arg_types(Args, Env) ->
%% TODO: don't drop the constraints
lists:map(fun (Arg) ->
{ArgTy, _VB, _Cs} = type_check_expr(Env#env{infer = true}, Arg),
ArgTy
Expand Down Expand Up @@ -3335,20 +3370,23 @@ type_check_fun(Env, Expr, _Arity) ->
end.

-spec type_check_call_intersection(env(), type(), _, _, _, _) -> {env(), constraints:t()}.
type_check_call_intersection(Env, ResTy, OrigExpr, [Ty], Args, E) ->
type_check_call(Env, ResTy, OrigExpr, Ty, Args, E);
type_check_call_intersection(Env, ResTy, OrigExpr, Tys, Args, E) ->
type_check_call_intersection_(Env, ResTy, OrigExpr, Tys, Args, E).

-spec type_check_call_intersection_(env(), type(), _, _, _, _) -> {env(), constraints:t()}.
type_check_call_intersection_(Env, _ResTy, _, [], Args, {P, Name, FunTy}) ->
throw(type_error(call_intersect, P, Name, FunTy, infer_arg_types(Args, Env)));
type_check_call_intersection_(Env, ResTy, OrigExpr, [Ty | Tys], Args, E) ->
try
type_check_call(Env, ResTy, OrigExpr, Ty, Args, E)
catch
Error when element(1, Error) == type_error ->
type_check_call_intersection_(Env, ResTy, OrigExpr, Tys, Args, E)
type_check_call_intersection(Env, ResTy, OrigExpr, ClauseTys, Args, {P, Name, FunTy}) ->
{FunResTy, Env1, Cs} = type_check_call_ty_intersect(Env, ClauseTys, Args, {Name, P, FunTy}),
case subtype(FunResTy, ResTy, Env) of
{true, Cs1} ->
{union_var_binds([Env1], Env), constraints:combine([Cs, Cs1])};
false ->
throw(type_error(OrigExpr, FunResTy, ResTy))
end.

-spec check_call_arity(_, _, _) -> ok.
check_call_arity({fun_ty, ArgsTy, _FunResTy, _Cs}, Args, {P, Name, _}) ->
case length(ArgsTy) =:= length(Args) of
true -> ok;
false ->
LenTys = arity(length(ArgsTy)),
LenArgs = arity(length(Args)),
throw(type_error(call_arity, P, Name, LenTys, LenArgs))
end.

-spec type_check_call(env(), type(), _, _, _, _) -> {env(), constraints:t()}.
Expand Down Expand Up @@ -4838,6 +4876,11 @@ add_type_pat_union(Pat, ?type(union, UnionTys) = UnionTy, Env) ->
_SomeTysMatched ->
%% TODO: The constraints should be merged with *or* semantics
%% and var binds with intersection
%% TODO by erszcz: see tuple_union_arg:j/1 for a problem with this.
%% To solve this we might need to erase var binds gathered in the member patterns and
%% instead bind the vars to fresh type vars.
%% The type vars would have upper bounds of LUB(member var binds' types).
%% This is food for thought, it might or might not work.
{lub(PatTys, Env),
normalize(type(union, UBounds), Env),
union_var_binds(Envs, Env),
Expand Down Expand Up @@ -5180,10 +5223,13 @@ type_of_bin_element({bin_element, _P, Expr, _Size, Specifiers}, OccursAs) ->

%%% Helper functions

-spec type(map, any) -> type();
(tuple, any) -> type();
(atom(), [any()]) -> type().
type(Name, Args) ->
-spec type(atom(), any | [any()]) -> type().
type(map, any) -> type_(map, any);
type(tuple, any) -> type_(tuple, any);
type(Name, Args) -> type_(Name, Args).

-spec type_(_, _) -> type().
type_(Name, Args) ->
{type, erl_anno:new(0), Name, Args}.

%% Helper to create a type, typically a normalized type
Expand Down
7 changes: 7 additions & 0 deletions test/known_problems/should_fail/tuple_union_arg.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,12 @@
(d, e) -> {d, e}.
i(V, U) -> {V, U}.

%% This passes, though it shouldn't, because V is inferred to be d | a,
%% and U is inferred to be b | e.
%% If that was the case, then the call to i/2 could sometimes succeed.
%% Not always, though! So it should already be considered an error.
%%
%% However, due to how the union is structured we know that when V=d, then U=b,
%% and that combination is certain to fail. The same holds for V=a, U=e.
-spec j({d, b} | {a, e}) -> {a, b} | {d, e}.
j({V, U}) -> i(V, U).
16 changes: 0 additions & 16 deletions test/known_problems/should_pass/intersection_should_pass.erl

This file was deleted.

Loading

0 comments on commit ba4476f

Please sign in to comment.