Skip to content

Commit 5b56460

Browse files
committed
Add support for functools.partial
Fixes python#1484 This is currently the most popular mypy issue that does not need a PEP. I'm sure there's stuff missing, but this should handle most cases.
1 parent 790e8a7 commit 5b56460

File tree

6 files changed

+251
-24
lines changed

6 files changed

+251
-24
lines changed

Diff for: mypy/checkexpr.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -1216,14 +1216,14 @@ def apply_function_plugin(
12161216
assert callback is not None # Assume that caller ensures this
12171217
return callback(
12181218
FunctionContext(
1219-
formal_arg_types,
1220-
formal_arg_kinds,
1221-
callee.arg_names,
1222-
formal_arg_names,
1223-
callee.ret_type,
1224-
formal_arg_exprs,
1225-
context,
1226-
self.chk,
1219+
arg_types=formal_arg_types,
1220+
arg_kinds=formal_arg_kinds,
1221+
callee_arg_names=callee.arg_names,
1222+
arg_names=formal_arg_names,
1223+
default_return_type=callee.ret_type,
1224+
args=formal_arg_exprs,
1225+
context=context,
1226+
api=self.chk,
12271227
)
12281228
)
12291229
else:
@@ -1233,15 +1233,15 @@ def apply_function_plugin(
12331233
object_type = get_proper_type(object_type)
12341234
return method_callback(
12351235
MethodContext(
1236-
object_type,
1237-
formal_arg_types,
1238-
formal_arg_kinds,
1239-
callee.arg_names,
1240-
formal_arg_names,
1241-
callee.ret_type,
1242-
formal_arg_exprs,
1243-
context,
1244-
self.chk,
1236+
type=object_type,
1237+
arg_types=formal_arg_types,
1238+
arg_kinds=formal_arg_kinds,
1239+
callee_arg_names=callee.arg_names,
1240+
arg_names=formal_arg_names,
1241+
default_return_type=callee.ret_type,
1242+
args=formal_arg_exprs,
1243+
context=context,
1244+
api=self.chk,
12451245
)
12461246
)
12471247

Diff for: mypy/plugins/default.py

+8
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
4747
return ctypes.array_constructor_callback
4848
elif fullname == "functools.singledispatch":
4949
return singledispatch.create_singledispatch_function_callback
50+
elif fullname == "functools.partial":
51+
import mypy.plugins.functools
52+
53+
return mypy.plugins.functools.partial_new_callback
5054

5155
return None
5256

@@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
118122
return singledispatch.singledispatch_register_callback
119123
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
120124
return singledispatch.call_singledispatch_function_after_register_argument
125+
elif fullname == "functools.partial.__call__":
126+
import mypy.plugins.functools
127+
128+
return mypy.plugins.functools.partial_call_callback
121129
return None
122130

123131
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:

Diff for: mypy/plugins/functools.py

+130-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,21 @@
44

55
from typing import Final, NamedTuple
66

7+
import mypy.checker
78
import mypy.plugin
8-
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
9+
from mypy.argmap import map_actuals_to_formals
10+
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
911
from mypy.plugins.common import add_method_to_class
10-
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
12+
from mypy.types import (
13+
AnyType,
14+
CallableType,
15+
Instance,
16+
Type,
17+
TypeOfAny,
18+
UnboundType,
19+
UninhabitedType,
20+
get_proper_type,
21+
)
1122

1223
functools_total_ordering_makers: Final = {"functools.total_ordering"}
1324

@@ -102,3 +113,120 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
102113
comparison_methods[name] = None
103114

104115
return comparison_methods
116+
117+
118+
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
119+
"""Infer a more precise return type for functools.partial"""
120+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
121+
return ctx.default_return_type
122+
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
123+
return ctx.default_return_type
124+
if len(ctx.arg_types[0]) != 1:
125+
return ctx.default_return_type
126+
127+
fn_type = get_proper_type(ctx.arg_types[0][0])
128+
if not isinstance(fn_type, CallableType):
129+
return ctx.default_return_type
130+
131+
defaulted = fn_type.copy_modified(
132+
arg_kinds=[
133+
(
134+
ArgKind.ARG_OPT
135+
if k == ArgKind.ARG_POS
136+
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
137+
)
138+
for k in fn_type.arg_kinds
139+
]
140+
)
141+
142+
actual_args = [a for param in ctx.args[1:] for a in param]
143+
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
144+
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
145+
actual_types = [a for param in ctx.arg_types[1:] for a in param]
146+
147+
_, bound = ctx.api.expr_checker.check_call(
148+
callee=defaulted,
149+
args=actual_args,
150+
arg_kinds=actual_arg_kinds,
151+
arg_names=actual_arg_names,
152+
context=ctx.context,
153+
)
154+
bound = get_proper_type(bound)
155+
if not isinstance(bound, CallableType):
156+
return ctx.default_return_type
157+
158+
formal_to_actual = map_actuals_to_formals(
159+
actual_kinds=actual_arg_kinds,
160+
actual_names=actual_arg_names,
161+
formal_kinds=fn_type.arg_kinds,
162+
formal_names=fn_type.arg_names,
163+
actual_arg_type=lambda i: actual_types[i],
164+
)
165+
166+
partial_kinds = []
167+
partial_types = []
168+
partial_names = []
169+
# We need to fully apply any positional arguments (they cannot be respecified)
170+
# However, keyword arguments can be respecified, so just give them a default
171+
for i, actuals in enumerate(formal_to_actual):
172+
arg_type = bound.arg_types[i]
173+
if isinstance(get_proper_type(arg_type), UninhabitedType):
174+
arg_type = fn_type.arg_types[i] # bit of a hack
175+
176+
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
177+
partial_kinds.append(fn_type.arg_kinds[i])
178+
partial_types.append(arg_type)
179+
partial_names.append(fn_type.arg_names[i])
180+
elif actuals:
181+
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
182+
continue
183+
kind = actual_arg_kinds[actuals[0]]
184+
if kind == ArgKind.ARG_NAMED:
185+
kind = ArgKind.ARG_NAMED_OPT
186+
partial_kinds.append(kind)
187+
partial_types.append(arg_type)
188+
partial_names.append(fn_type.arg_names[i])
189+
190+
ret_type = bound.ret_type
191+
if isinstance(get_proper_type(ret_type), UninhabitedType):
192+
ret_type = fn_type.ret_type # same kind of hack as above
193+
194+
partially_applied = fn_type.copy_modified(
195+
arg_types=partial_types,
196+
arg_kinds=partial_kinds,
197+
arg_names=partial_names,
198+
ret_type=ret_type,
199+
)
200+
201+
ret = ctx.api.named_generic_type("functools.partial", [ret_type])
202+
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
203+
return ret
204+
205+
206+
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
207+
"""Infer a more precise return type for functools.partial.__call__."""
208+
if (
209+
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
210+
or not isinstance(ctx.type, Instance)
211+
or ctx.type.type.fullname != "functools.partial"
212+
or not ctx.type.extra_attrs
213+
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
214+
):
215+
return ctx.default_return_type
216+
217+
partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
218+
if len(ctx.arg_types) != 2: # *args, **kwargs
219+
return ctx.default_return_type
220+
221+
args = [a for param in ctx.args for a in param]
222+
arg_kinds = [a for param in ctx.arg_kinds for a in param]
223+
arg_names = [a for param in ctx.arg_names for a in param]
224+
225+
result = ctx.api.expr_checker.check_call(
226+
callee=partial_type,
227+
args=args,
228+
arg_kinds=arg_kinds,
229+
arg_names=arg_names,
230+
context=ctx.context,
231+
)
232+
return result[0]

Diff for: mypy/types.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1495,13 +1495,14 @@ def copy_modified(
14951495
last_known_value: Bogus[LiteralType | None] = _dummy,
14961496
) -> Instance:
14971497
new = Instance(
1498-
self.type,
1499-
args if args is not _dummy else self.args,
1500-
self.line,
1501-
self.column,
1498+
typ=self.type,
1499+
args=args if args is not _dummy else self.args,
1500+
line=self.line,
1501+
column=self.column,
15021502
last_known_value=(
15031503
last_known_value if last_known_value is not _dummy else self.last_known_value
15041504
),
1505+
extra_attrs=self.extra_attrs,
15051506
)
15061507
# We intentionally don't copy the extra_attrs here, so they will be erased.
15071508
new.can_be_true = self.can_be_true

Diff for: test-data/unit/check-functools.test

+86
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,89 @@ def f(d: D[C]) -> None:
144144

145145
d: D[int] # E: Type argument "int" of "D" must be a subtype of "C"
146146
[builtins fixtures/dict.pyi]
147+
148+
[case testFunctoolsPartialBasic]
149+
from typing import Callable
150+
import functools
151+
152+
def foo(a: int, b: str, c: int = 5) -> int: ... # N: "foo" defined here
153+
154+
p1 = functools.partial(foo)
155+
p1(1, "a", 3) # OK
156+
p1(1, "a", c=3) # OK
157+
p1(1, b="a", c=3) # OK
158+
159+
def takes_callable_int(f: Callable[..., int]) -> None: ...
160+
def takes_callable_str(f: Callable[..., str]) -> None: ...
161+
takes_callable_int(p1)
162+
takes_callable_str(p1) # E: Argument 1 to "takes_callable_str" has incompatible type "partial[int]"; expected "Callable[..., str]" \
163+
# N: "partial[int].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], int]"
164+
165+
p2 = functools.partial(foo, 1)
166+
p2("a") # OK
167+
p2("a", 3) # OK
168+
p2("a", c=3) # OK
169+
p2(1, 3) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
170+
p2(1, "a", 3) # E: Too many arguments for "foo" \
171+
# E: Argument 1 to "foo" has incompatible type "int"; expected "str" \
172+
# E: Argument 2 to "foo" has incompatible type "str"; expected "int"
173+
p2(a=1, b="a", c=3) # E: Unexpected keyword argument "a" for "foo"
174+
175+
p3 = functools.partial(foo, b="a")
176+
p3(1) # OK
177+
p3(1, c=3) # OK
178+
p3(a=1) # OK
179+
p3(1, b="a", c=3) # OK, keywords can be clobbered
180+
p3(1, 3) # E: Too many positional arguments for "foo" \
181+
# E: Argument 2 to "foo" has incompatible type "int"; expected "str"
182+
183+
functools.partial(foo, "a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
184+
functools.partial(foo, b=1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
185+
functools.partial(1) # E: Argument 1 to "partial" has incompatible type "int"; expected "Callable[..., Never]"
186+
[builtins fixtures/dict.pyi]
187+
188+
[case testFunctoolsPartialStar]
189+
import functools
190+
191+
def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ...
192+
193+
p1 = functools.partial(foo, 1, d="a", x=9)
194+
p1("a", 2, 3, 4) # OK
195+
p1("a", 2, 3, 4, d="a") # OK
196+
p1("a", 2, 3, 4, "a") # E: Argument 5 to "foo" has incompatible type "str"; expected "int"
197+
p1("a", 2, 3, 4, x="a") # E: Argument "x" to "foo" has incompatible type "str"; expected "int"
198+
199+
p2 = functools.partial(foo, 1, "a")
200+
p2(2, 3, 4, d="a") # OK
201+
p2("a") # E: Missing named argument "d" for "foo" \
202+
# E: Argument 1 to "foo" has incompatible type "str"; expected "int"
203+
p2(2, 3, 4) # E: Missing named argument "d" for "foo"
204+
205+
functools.partial(foo, 1, "a", "b", "c", d="a") # E: Argument 3 to "foo" has incompatible type "str"; expected "int" \
206+
# E: Argument 4 to "foo" has incompatible type "str"; expected "int"
207+
208+
[builtins fixtures/dict.pyi]
209+
210+
[case testFunctoolsPartialGeneric]
211+
from typing import TypeVar
212+
import functools
213+
214+
T = TypeVar("T")
215+
U = TypeVar("U")
216+
217+
def foo(a: T, b: T) -> T: ...
218+
219+
p1 = functools.partial(foo, 1)
220+
reveal_type(p1(2)) # N: Revealed type is "builtins.int"
221+
p1("a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
222+
223+
p2 = functools.partial(foo, "a")
224+
p2(1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
225+
reveal_type(p2("a")) # N: Revealed type is "builtins.str"
226+
227+
def bar(a: T, b: U) -> U: ...
228+
229+
p3 = functools.partial(bar, 1)
230+
reveal_type(p3(2)) # N: Revealed type is "builtins.int"
231+
reveal_type(p3("a")) # N: Revealed type is "builtins.str"
232+
[builtins fixtures/dict.pyi]

Diff for: test-data/unit/lib-stub/functools.pyi

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generic, TypeVar, Callable, Any, Mapping, overload
1+
from typing import Generic, TypeVar, Callable, Any, Mapping, Self, overload
22

33
_T = TypeVar("_T")
44

@@ -33,3 +33,7 @@ class cached_property(Generic[_T]):
3333
def __get__(self, instance: object, owner: type[Any] | None = ...) -> _T: ...
3434
def __set_name__(self, owner: type[Any], name: str) -> None: ...
3535
def __class_getitem__(cls, item: Any) -> Any: ...
36+
37+
class partial(Generic[_T]):
38+
def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ...
39+
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...

0 commit comments

Comments
 (0)