Skip to content

Commit 851e89a

Browse files
JelleZijlstraJukkaL
authored andcommitted
Add support for assert_type (#12584)
See python/cpython#30843. The implementation mostly follows that of cast(). It relies on `mypy.sametypes.is_same_type()`.
1 parent c4a9697 commit 851e89a

24 files changed

+175
-12
lines changed

mypy/checkexpr.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
get_proper_types, flatten_nested_unions, LITERAL_TYPE_NAMES,
2424
)
2525
from mypy.nodes import (
26-
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
26+
AssertTypeExpr, NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
2727
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
2828
OpExpr, UnaryExpr, IndexExpr, CastExpr, RevealExpr, TypeApplication, ListExpr,
2929
TupleExpr, DictExpr, LambdaExpr, SuperExpr, SliceExpr, Context, Expression,
@@ -3144,6 +3144,14 @@ def visit_cast_expr(self, expr: CastExpr) -> Type:
31443144
context=expr)
31453145
return target_type
31463146

3147+
def visit_assert_type_expr(self, expr: AssertTypeExpr) -> Type:
3148+
source_type = self.accept(expr.expr, type_context=AnyType(TypeOfAny.special_form),
3149+
allow_none_return=True, always_allow_any=True)
3150+
target_type = expr.type
3151+
if not is_same_type(source_type, target_type):
3152+
self.msg.assert_type_fail(source_type, target_type, expr)
3153+
return source_type
3154+
31473155
def visit_reveal_expr(self, expr: RevealExpr) -> Type:
31483156
"""Type check a reveal_type expression."""
31493157
if expr.kind == REVEAL_TYPE:

mypy/errorcodes.py

+3
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def __str__(self) -> str:
113113
REDUNDANT_CAST: Final = ErrorCode(
114114
"redundant-cast", "Check that cast changes type of expression", "General"
115115
)
116+
ASSERT_TYPE: Final = ErrorCode(
117+
"assert-type", "Check that assert_type() call succeeds", "General"
118+
)
116119
COMPARISON_OVERLAP: Final = ErrorCode(
117120
"comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General"
118121
)

mypy/literals.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr,
99
TypeApplication, LambdaExpr, ListComprehension, SetComprehension, DictionaryComprehension,
1010
GeneratorExpr, BackquoteExpr, TypeVarExpr, TypeAliasExpr, NamedTupleExpr, EnumCallExpr,
11-
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr
11+
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr,
12+
AssertTypeExpr,
1213
)
1314
from mypy.visitor import ExpressionVisitor
1415

@@ -175,6 +176,9 @@ def visit_slice_expr(self, e: SliceExpr) -> None:
175176
def visit_cast_expr(self, e: CastExpr) -> None:
176177
return None
177178

179+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
180+
return None
181+
178182
def visit_conditional_expr(self, e: ConditionalExpr) -> None:
179183
return None
180184

mypy/messages.py

+5
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,11 @@ def redundant_cast(self, typ: Type, context: Context) -> None:
12131213
self.fail('Redundant cast to {}'.format(format_type(typ)), context,
12141214
code=codes.REDUNDANT_CAST)
12151215

1216+
def assert_type_fail(self, source_type: Type, target_type: Type, context: Context) -> None:
1217+
self.fail(f"Expression is of type {format_type(source_type)}, "
1218+
f"not {format_type(target_type)}", context,
1219+
code=codes.ASSERT_TYPE)
1220+
12161221
def unimported_type_becomes_any(self, prefix: str, typ: Type, ctx: Context) -> None:
12171222
self.fail("{} becomes {} due to an unfollowed import".format(prefix, format_type(typ)),
12181223
ctx, code=codes.NO_ANY_UNIMPORTED)

mypy/mixedtraverser.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

33
from mypy.nodes import (
4-
Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
4+
AssertTypeExpr, Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
55
CastExpr, TypeApplication, TypeAliasExpr, TypeVarExpr, TypedDictExpr, NamedTupleExpr,
66
PromoteExpr, NewTypeExpr
77
)
@@ -79,6 +79,10 @@ def visit_cast_expr(self, o: CastExpr) -> None:
7979
super().visit_cast_expr(o)
8080
o.type.accept(self)
8181

82+
def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
83+
super().visit_assert_type_expr(o)
84+
o.type.accept(self)
85+
8286
def visit_type_application(self, o: TypeApplication) -> None:
8387
super().visit_type_application(o)
8488
for t in o.types:

mypy/nodes.py

+16
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,22 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
19381938
return visitor.visit_cast_expr(self)
19391939

19401940

1941+
class AssertTypeExpr(Expression):
1942+
"""Represents a typing.assert_type(expr, type) call."""
1943+
__slots__ = ('expr', 'type')
1944+
1945+
expr: Expression
1946+
type: "mypy.types.Type"
1947+
1948+
def __init__(self, expr: Expression, typ: 'mypy.types.Type') -> None:
1949+
super().__init__()
1950+
self.expr = expr
1951+
self.type = typ
1952+
1953+
def accept(self, visitor: ExpressionVisitor[T]) -> T:
1954+
return visitor.visit_assert_type_expr(self)
1955+
1956+
19411957
class RevealExpr(Expression):
19421958
"""Reveal type expression reveal_type(expr) or reveal_locals() expression."""
19431959

mypy/semanal.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from typing_extensions import Final, TypeAlias as _TypeAlias
5757

5858
from mypy.nodes import (
59-
MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
59+
AssertTypeExpr, MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
6060
ClassDef, Var, GDEF, FuncItem, Import, Expression, Lvalue,
6161
ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr,
6262
IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt,
@@ -99,7 +99,7 @@
9999
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
100100
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
101101
PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES,
102-
is_named_instance,
102+
ASSERT_TYPE_NAMES, is_named_instance,
103103
)
104104
from mypy.typeops import function_type, get_type_vars
105105
from mypy.type_visitor import TypeQuery
@@ -3898,6 +3898,19 @@ def visit_call_expr(self, expr: CallExpr) -> None:
38983898
expr.analyzed.line = expr.line
38993899
expr.analyzed.column = expr.column
39003900
expr.analyzed.accept(self)
3901+
elif refers_to_fullname(expr.callee, ASSERT_TYPE_NAMES):
3902+
if not self.check_fixed_args(expr, 2, 'assert_type'):
3903+
return
3904+
# Translate second argument to an unanalyzed type.
3905+
try:
3906+
target = self.expr_to_unanalyzed_type(expr.args[1])
3907+
except TypeTranslationError:
3908+
self.fail('assert_type() type is not a type', expr)
3909+
return
3910+
expr.analyzed = AssertTypeExpr(expr.args[0], target)
3911+
expr.analyzed.line = expr.line
3912+
expr.analyzed.column = expr.column
3913+
expr.analyzed.accept(self)
39013914
elif refers_to_fullname(expr.callee, REVEAL_TYPE_NAMES):
39023915
if not self.check_fixed_args(expr, 1, 'reveal_type'):
39033916
return
@@ -4201,6 +4214,12 @@ def visit_cast_expr(self, expr: CastExpr) -> None:
42014214
if analyzed is not None:
42024215
expr.type = analyzed
42034216

4217+
def visit_assert_type_expr(self, expr: AssertTypeExpr) -> None:
4218+
expr.expr.accept(self)
4219+
analyzed = self.anal_type(expr.type)
4220+
if analyzed is not None:
4221+
expr.type = analyzed
4222+
42044223
def visit_reveal_expr(self, expr: RevealExpr) -> None:
42054224
if expr.kind == REVEAL_TYPE:
42064225
if expr.expr is not None:

mypy/server/astmerge.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo,
5252
FuncDef, ClassDef, NamedTupleExpr, SymbolNode, Var, Statement, SuperExpr, NewTypeExpr,
5353
OverloadedFuncDef, LambdaExpr, TypedDictExpr, EnumCallExpr, FuncBase, TypeAliasExpr, CallExpr,
54-
CastExpr, TypeAlias,
54+
CastExpr, TypeAlias, AssertTypeExpr,
5555
MDEF
5656
)
5757
from mypy.traverser import TraverserVisitor
@@ -226,6 +226,10 @@ def visit_cast_expr(self, node: CastExpr) -> None:
226226
super().visit_cast_expr(node)
227227
self.fixup_type(node.type)
228228

229+
def visit_assert_type_expr(self, node: AssertTypeExpr) -> None:
230+
super().visit_assert_type_expr(node)
231+
self.fixup_type(node.type)
232+
229233
def visit_super_expr(self, node: SuperExpr) -> None:
230234
super().visit_super_expr(node)
231235
if node.info is not None:

mypy/server/deps.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
8989
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
9090
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
9191
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
92-
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr
92+
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
93+
AssertTypeExpr,
9394
)
9495
from mypy.operators import (
9596
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
@@ -686,6 +687,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
686687
super().visit_cast_expr(e)
687688
self.add_type_dependencies(e.type)
688689

690+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
691+
super().visit_assert_type_expr(e)
692+
self.add_type_dependencies(e.type)
693+
689694
def visit_type_application(self, e: TypeApplication) -> None:
690695
super().visit_type_application(e)
691696
for typ in e.types:

mypy/server/subexpr.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
SliceExpr, CastExpr, RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr,
88
IndexExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
99
ConditionalExpr, TypeApplication, LambdaExpr, StarExpr, BackquoteExpr, AwaitExpr,
10-
AssignmentExpr,
10+
AssignmentExpr, AssertTypeExpr,
1111
)
1212
from mypy.traverser import TraverserVisitor
1313

@@ -99,6 +99,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
9999
self.add(e)
100100
super().visit_cast_expr(e)
101101

102+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
103+
self.add(e)
104+
super().visit_assert_type_expr(e)
105+
102106
def visit_reveal_expr(self, e: RevealExpr) -> None:
103107
self.add(e)
104108
super().visit_reveal_expr(e)

mypy/strconv.py

+3
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,9 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> str:
431431
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> str:
432432
return self.dump([o.expr, o.type], o)
433433

434+
def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> str:
435+
return self.dump([o.expr, o.type], o)
436+
434437
def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> str:
435438
if o.kind == mypy.nodes.REVEAL_TYPE:
436439
return self.dump([o.expr], o)

mypy/traverser.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010
from mypy.visitor import NodeVisitor
1111
from mypy.nodes import (
12-
Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef,
12+
AssertTypeExpr, Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef,
1313
ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt,
1414
ForStmt, ReturnStmt, AssertStmt, DelStmt, IfStmt, RaiseStmt,
1515
TryStmt, WithStmt, MatchStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr,
@@ -205,6 +205,9 @@ def visit_slice_expr(self, o: SliceExpr) -> None:
205205
def visit_cast_expr(self, o: CastExpr) -> None:
206206
o.expr.accept(self)
207207

208+
def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
209+
o.expr.accept(self)
210+
208211
def visit_reveal_expr(self, o: RevealExpr) -> None:
209212
if o.kind == REVEAL_TYPE:
210213
assert o.expr is not None

mypy/treetransform.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import List, Dict, cast, Optional, Iterable
77

88
from mypy.nodes import (
9-
MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef,
9+
AssertTypeExpr, MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef,
1010
OverloadedFuncDef, ClassDef, Decorator, Block, Var,
1111
OperatorAssignmentStmt, ExpressionStmt, AssignmentStmt, ReturnStmt,
1212
RaiseStmt, AssertStmt, DelStmt, BreakStmt, ContinueStmt,
@@ -407,6 +407,9 @@ def visit_cast_expr(self, node: CastExpr) -> CastExpr:
407407
return CastExpr(self.expr(node.expr),
408408
self.type(node.type))
409409

410+
def visit_assert_type_expr(self, node: AssertTypeExpr) -> AssertTypeExpr:
411+
return AssertTypeExpr(self.expr(node.expr), self.type(node.type))
412+
410413
def visit_reveal_expr(self, node: RevealExpr) -> RevealExpr:
411414
if node.kind == REVEAL_TYPE:
412415
assert node.expr is not None

mypy/types.py

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@
132132
'typing_extensions.reveal_type',
133133
)
134134

135+
ASSERT_TYPE_NAMES: Final = (
136+
'typing.assert_type',
137+
'typing_extensions.assert_type',
138+
)
139+
135140
# Attributes that can optionally be defined in the body of a subclass of
136141
# enum.Enum but are removed from the class __dict__ by EnumMeta.
137142
ENUM_REMOVED_PROPS: Final = (

mypy/visitor.py

+7
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
8181
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
8282
pass
8383

84+
@abstractmethod
85+
def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T:
86+
pass
87+
8488
@abstractmethod
8589
def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
8690
pass
@@ -523,6 +527,9 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
523527
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
524528
pass
525529

530+
def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T:
531+
pass
532+
526533
def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
527534
pass
528535

mypyc/irbuild/expression.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ConditionalExpr, ComparisonExpr, IntExpr, FloatExpr, ComplexExpr, StrExpr,
1212
BytesExpr, EllipsisExpr, ListExpr, TupleExpr, DictExpr, SetExpr, ListComprehension,
1313
SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr,
14-
AssignmentExpr,
14+
AssignmentExpr, AssertTypeExpr,
1515
Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS
1616
)
1717
from mypy.types import TupleType, Instance, TypeType, ProperType, get_proper_type
@@ -203,6 +203,9 @@ def transform_super_expr(builder: IRBuilder, o: SuperExpr) -> Value:
203203
def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value:
204204
if isinstance(expr.analyzed, CastExpr):
205205
return translate_cast_expr(builder, expr.analyzed)
206+
elif isinstance(expr.analyzed, AssertTypeExpr):
207+
# Compile to a no-op.
208+
return builder.accept(expr.analyzed.expr)
206209

207210
callee = expr.callee
208211
if isinstance(callee, IndexExpr) and isinstance(callee.analyzed, TypeApplication):

mypyc/irbuild/visitor.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing_extensions import NoReturn
77

88
from mypy.nodes import (
9-
MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr,
9+
AssertTypeExpr, MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr,
1010
IntExpr, NameExpr, Var, IfStmt, UnaryExpr, ComparisonExpr, WhileStmt, CallExpr,
1111
IndexExpr, Block, ListExpr, ExpressionStmt, MemberExpr, ForStmt,
1212
BreakStmt, ContinueStmt, ConditionalExpr, OperatorAssignmentStmt, TupleExpr, ClassDef,
@@ -327,6 +327,9 @@ def visit_var(self, o: Var) -> None:
327327
def visit_cast_expr(self, o: CastExpr) -> Value:
328328
assert False, "CastExpr should have been handled in CallExpr"
329329

330+
def visit_assert_type_expr(self, o: AssertTypeExpr) -> Value:
331+
assert False, "AssertTypeExpr should have been handled in CallExpr"
332+
330333
def visit_star_expr(self, o: StarExpr) -> Value:
331334
assert False, "should have been handled in Tuple/List/Set/DictExpr or CallExpr"
332335

mypyc/test-data/irbuild-basic.test

+11
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,17 @@ L0:
876876
o = r3
877877
return 1
878878

879+
[case testAssertType]
880+
from typing import assert_type
881+
def f(x: int) -> None:
882+
y = assert_type(x, int)
883+
[out]
884+
def f(x):
885+
x, y :: int
886+
L0:
887+
y = x
888+
return 1
889+
879890
[case testDownCast]
880891
from typing import cast, List, Tuple
881892
class A: pass

test-data/unit/check-expressions.test

+12
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,18 @@ class B: pass
10361036
[out]
10371037
main:3: error: "A" not callable
10381038

1039+
-- assert_type()
1040+
1041+
[case testAssertType]
1042+
from typing import assert_type, Any
1043+
from typing_extensions import Literal
1044+
a: int = 1
1045+
returned = assert_type(a, int)
1046+
reveal_type(returned) # N: Revealed type is "builtins.int"
1047+
assert_type(a, str) # E: Expression is of type "int", not "str"
1048+
assert_type(a, Any) # E: Expression is of type "int", not "Any"
1049+
assert_type(a, Literal[1]) # E: Expression is of type "int", not "Literal[1]"
1050+
[builtins fixtures/tuple.pyi]
10391051

10401052
-- None return type
10411053
-- ----------------

test-data/unit/fixtures/typing-full.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from abc import abstractmethod, ABCMeta
1111
class GenericMeta(type): pass
1212

1313
def cast(t, o): ...
14+
def assert_type(o, t): ...
1415
overload = 0
1516
Any = 0
1617
Union = 0

test-data/unit/lib-stub/typing.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# the stubs under fixtures/.
1010

1111
cast = 0
12+
assert_type = 0
1213
overload = 0
1314
Any = 0
1415
Union = 0

0 commit comments

Comments
 (0)