Skip to content

Commit 57848ea

Browse files
committed
Add support for functools.partial
Fixes python#1484 This is currently the most popular mypy issue that does not need a PEP. I'm sure there's stuff missing, but this should handle most cases.
1 parent 790e8a7 commit 57848ea

File tree

5 files changed

+235
-20
lines changed

5 files changed

+235
-20
lines changed

Diff for: mypy/checkexpr.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -1216,14 +1216,14 @@ def apply_function_plugin(
12161216
assert callback is not None # Assume that caller ensures this
12171217
return callback(
12181218
FunctionContext(
1219-
formal_arg_types,
1220-
formal_arg_kinds,
1221-
callee.arg_names,
1222-
formal_arg_names,
1223-
callee.ret_type,
1224-
formal_arg_exprs,
1225-
context,
1226-
self.chk,
1219+
arg_types=formal_arg_types,
1220+
arg_kinds=formal_arg_kinds,
1221+
callee_arg_names=callee.arg_names,
1222+
arg_names=formal_arg_names,
1223+
default_return_type=callee.ret_type,
1224+
args=formal_arg_exprs,
1225+
context=context,
1226+
api=self.chk,
12271227
)
12281228
)
12291229
else:
@@ -1233,15 +1233,15 @@ def apply_function_plugin(
12331233
object_type = get_proper_type(object_type)
12341234
return method_callback(
12351235
MethodContext(
1236-
object_type,
1237-
formal_arg_types,
1238-
formal_arg_kinds,
1239-
callee.arg_names,
1240-
formal_arg_names,
1241-
callee.ret_type,
1242-
formal_arg_exprs,
1243-
context,
1244-
self.chk,
1236+
type=object_type,
1237+
arg_types=formal_arg_types,
1238+
arg_kinds=formal_arg_kinds,
1239+
callee_arg_names=callee.arg_names,
1240+
arg_names=formal_arg_names,
1241+
default_return_type=callee.ret_type,
1242+
args=formal_arg_exprs,
1243+
context=context,
1244+
api=self.chk,
12451245
)
12461246
)
12471247

Diff for: mypy/plugins/default.py

+8
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
4747
return ctypes.array_constructor_callback
4848
elif fullname == "functools.singledispatch":
4949
return singledispatch.create_singledispatch_function_callback
50+
elif fullname == "functools.partial":
51+
import mypy.plugins.functools
52+
53+
return mypy.plugins.functools.partial_new_callback
5054

5155
return None
5256

@@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
118122
return singledispatch.singledispatch_register_callback
119123
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
120124
return singledispatch.call_singledispatch_function_after_register_argument
125+
elif fullname == "functools.partial.__call__":
126+
import mypy.plugins.functools
127+
128+
return mypy.plugins.functools.partial_call_callback
121129
return None
122130

123131
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:

Diff for: mypy/plugins/functools.py

+126-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,22 @@
33
from __future__ import annotations
44

55
from typing import Final, NamedTuple
6+
from mypy.argmap import map_actuals_to_formals
67

8+
import mypy.checker
79
import mypy.plugin
8-
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
10+
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
911
from mypy.plugins.common import add_method_to_class
10-
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
12+
from mypy.types import (
13+
AnyType,
14+
CallableType,
15+
Instance,
16+
Type,
17+
TypeOfAny,
18+
UnboundType,
19+
UninhabitedType,
20+
get_proper_type,
21+
)
1122

1223
functools_total_ordering_makers: Final = {"functools.total_ordering"}
1324

@@ -102,3 +113,116 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
102113
comparison_methods[name] = None
103114

104115
return comparison_methods
116+
117+
118+
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
119+
"""Infer a more precise return type for functools.partial"""
120+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
121+
return ctx.default_return_type
122+
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
123+
return ctx.default_return_type
124+
if len(ctx.arg_types[0]) != 1:
125+
return ctx.default_return_type
126+
127+
fn_type = ctx.arg_types[0][0]
128+
if not isinstance(fn_type, CallableType):
129+
return ctx.default_return_type
130+
131+
defaulted = fn_type.copy_modified(
132+
arg_kinds=[
133+
(
134+
ArgKind.ARG_OPT
135+
if k == ArgKind.ARG_POS
136+
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
137+
)
138+
for k in fn_type.arg_kinds
139+
]
140+
)
141+
142+
actual_args = [a for param in ctx.args[1:] for a in param]
143+
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
144+
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
145+
actual_types = [a for param in ctx.arg_types[1:] for a in param]
146+
147+
_, bound = ctx.api.expr_checker.check_call(
148+
callee=defaulted,
149+
args=actual_args,
150+
arg_kinds=actual_arg_kinds,
151+
arg_names=actual_arg_names,
152+
context=ctx.context,
153+
)
154+
if not isinstance(bound, CallableType):
155+
return ctx.default_return_type
156+
157+
formal_to_actual = map_actuals_to_formals(
158+
actual_kinds=actual_arg_kinds,
159+
actual_names=actual_arg_names,
160+
formal_kinds=fn_type.arg_kinds,
161+
formal_names=fn_type.arg_names,
162+
actual_arg_type=lambda i: actual_types[i],
163+
)
164+
165+
partial_kinds = []
166+
partial_types = []
167+
partial_names = []
168+
# We need to fully apply any positional arguments (they cannot be respecified)
169+
# However, keyword arguments can be respecified, so just give them a default
170+
for i, actuals in enumerate(formal_to_actual):
171+
arg_type = bound.arg_types[i]
172+
if isinstance(arg_type, UninhabitedType):
173+
arg_type = fn_type.arg_types[i] # bit of a hack
174+
175+
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
176+
partial_kinds.append(fn_type.arg_kinds[i])
177+
partial_types.append(arg_type)
178+
partial_names.append(fn_type.arg_names[i])
179+
elif actuals:
180+
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
181+
continue
182+
kind = actual_arg_kinds[actuals[0]]
183+
if kind == ArgKind.ARG_NAMED:
184+
kind = ArgKind.ARG_NAMED_OPT
185+
partial_kinds.append(kind)
186+
partial_types.append(arg_type)
187+
partial_names.append(fn_type.arg_names[i])
188+
189+
ret_type = bound.ret_type
190+
if isinstance(ret_type, UninhabitedType):
191+
ret_type = fn_type.ret_type # same kind of hack as above
192+
193+
partially_applied = fn_type.copy_modified(
194+
arg_types=partial_types, arg_kinds=partial_kinds, arg_names=partial_names, ret_type=ret_type
195+
)
196+
197+
ret = ctx.api.named_generic_type("functools.partial", [])
198+
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
199+
return ret
200+
201+
202+
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
203+
"""Infer a more precise return type for functools.partial.__call__."""
204+
if (
205+
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
206+
or not isinstance(ctx.type, Instance)
207+
or ctx.type.type.fullname != "functools.partial"
208+
or not ctx.type.extra_attrs
209+
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
210+
):
211+
return ctx.default_return_type
212+
213+
partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
214+
if len(ctx.arg_types) != 2: # *args, **kwargs
215+
return ctx.default_return_type
216+
217+
args = [a for param in ctx.args for a in param]
218+
arg_kinds = [a for param in ctx.arg_kinds for a in param]
219+
arg_names = [a for param in ctx.arg_names for a in param]
220+
221+
result = ctx.api.expr_checker.check_call(
222+
callee=partial_type,
223+
args=args,
224+
arg_kinds=arg_kinds,
225+
arg_names=arg_names,
226+
context=ctx.context,
227+
)
228+
return result[0]

Diff for: test-data/unit/check-functools.test

+79
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,82 @@ def f(d: D[C]) -> None:
144144

145145
d: D[int] # E: Type argument "int" of "D" must be a subtype of "C"
146146
[builtins fixtures/dict.pyi]
147+
148+
[case testFunctoolsPartialBasic]
149+
import functools
150+
151+
def foo(a: int, b: str, c: int = 5) -> int: ... # N: "foo" defined here
152+
153+
p1 = functools.partial(foo)
154+
p1(1, "a", 3) # OK
155+
p1(1, "a", c=3) # OK
156+
p1(1, b="a", c=3) # OK
157+
158+
p2 = functools.partial(foo, 1)
159+
p2("a") # OK
160+
p2("a", 3) # OK
161+
p2("a", c=3) # OK
162+
p2(1, 3) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
163+
p2(1, "a", 3) # E: Too many arguments for "foo" \
164+
# E: Argument 1 to "foo" has incompatible type "int"; expected "str" \
165+
# E: Argument 2 to "foo" has incompatible type "str"; expected "int"
166+
p2(a=1, b="a", c=3) # E: Unexpected keyword argument "a" for "foo"
167+
168+
p3 = functools.partial(foo, b="a")
169+
p3(1) # OK
170+
p3(1, c=3) # OK
171+
p3(a=1) # OK
172+
p3(1, b="a", c=3) # OK, keywords can be clobbered
173+
p3(1, 3) # E: Too many positional arguments for "foo" \
174+
# E: Argument 2 to "foo" has incompatible type "int"; expected "str"
175+
176+
functools.partial(foo, "a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
177+
functools.partial(foo, b=1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
178+
functools.partial(1) # E: Argument 1 to "partial" has incompatible type "int"; expected "Callable[..., Never]"
179+
[builtins fixtures/dict.pyi]
180+
181+
[case testFunctoolsPartialStar]
182+
import functools
183+
184+
def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ...
185+
186+
p1 = functools.partial(foo, 1, d="a", x=9)
187+
p1("a", 2, 3, 4) # OK
188+
p1("a", 2, 3, 4, d="a") # OK
189+
p1("a", 2, 3, 4, "a") # E: Argument 5 to "foo" has incompatible type "str"; expected "int"
190+
p1("a", 2, 3, 4, x="a") # E: Argument "x" to "foo" has incompatible type "str"; expected "int"
191+
192+
p2 = functools.partial(foo, 1, "a")
193+
p2(2, 3, 4, d="a") # OK
194+
p2("a") # E: Missing named argument "d" for "foo" \
195+
# E: Argument 1 to "foo" has incompatible type "str"; expected "int"
196+
p2(2, 3, 4) # E: Missing named argument "d" for "foo"
197+
198+
functools.partial(foo, 1, "a", "b", "c", d="a") # E: Argument 3 to "foo" has incompatible type "str"; expected "int" \
199+
# E: Argument 4 to "foo" has incompatible type "str"; expected "int"
200+
201+
[builtins fixtures/dict.pyi]
202+
203+
[case testFunctoolsPartialGeneric]
204+
from typing import TypeVar
205+
import functools
206+
207+
T = TypeVar("T")
208+
U = TypeVar("U")
209+
210+
def foo(a: T, b: T) -> T: ...
211+
212+
p1 = functools.partial(foo, 1)
213+
reveal_type(p1(2)) # N: Revealed type is "builtins.int"
214+
p1("a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
215+
216+
p2 = functools.partial(foo, "a")
217+
p2(1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
218+
reveal_type(p2("a")) # N: Revealed type is "builtins.str"
219+
220+
def bar(a: T, b: U) -> U: ...
221+
222+
p3 = functools.partial(bar, 1)
223+
reveal_type(p3(2)) # N: Revealed type is "builtins.int"
224+
reveal_type(p3("a")) # N: Revealed type is "builtins.str"
225+
[builtins fixtures/dict.pyi]

Diff for: test-data/unit/lib-stub/functools.pyi

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generic, TypeVar, Callable, Any, Mapping, overload
1+
from typing import Generic, TypeVar, Callable, Any, Mapping, Self, overload
22

33
_T = TypeVar("_T")
44

@@ -33,3 +33,7 @@ class cached_property(Generic[_T]):
3333
def __get__(self, instance: object, owner: type[Any] | None = ...) -> _T: ...
3434
def __set_name__(self, owner: type[Any], name: str) -> None: ...
3535
def __class_getitem__(cls, item: Any) -> Any: ...
36+
37+
class partial(Generic[_T]):
38+
def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ...
39+
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...

0 commit comments

Comments
 (0)