Skip to content

Commit

Permalink
Fixed checking of variable assignments involving tuple unpacking
Browse files Browse the repository at this point in the history
This also unified all variable checking across different assignment types (annotation assignment, augmented assignment and any other kind of assignment)

Fixes #486.
  • Loading branch information
agronholm committed Nov 2, 2024
1 parent 9a73eb0 commit 889ad53
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 139 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This library adheres to
- Dropped Python 3.8 support
- Changed the signature of ``typeguard_ignore()`` to be compatible with
``typing.no_type_check()`` (PR by @jolaf)
- Fixed checking of variable assignments involving tuple unpacking
(`#486 <https://github.com/agronholm/typeguard/pull/486>`_)

**4.4.0** (2024-10-27)

Expand Down
87 changes: 41 additions & 46 deletions src/typeguard/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
import warnings
from collections.abc import Sequence
from typing import Any, Callable, NoReturn, TypeVar, Union, overload

from . import _suppression
Expand Down Expand Up @@ -242,59 +243,53 @@ def check_yield_type(


def check_variable_assignment(
value: object, varname: str, annotation: Any, memo: TypeCheckMemo
value: Any, targets: Sequence[list[tuple[str, Any]]], memo: TypeCheckMemo
) -> Any:
if _suppression.type_checks_suppressed:
return value

try:
check_type_internal(value, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(value, add_class_prefix=True)
exc.append_path_element(f"value assigned to {varname} ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise

return value

value_to_return = value
for target in targets:
star_variable_index = next(
(i for i, (varname, _) in enumerate(target) if varname.startswith("*")),
None,
)
if star_variable_index is not None:
value_to_return = list(value)
remaining_vars = len(target) - 1 - star_variable_index
end_index = len(value_to_return) - remaining_vars
values_to_check = (
value_to_return[:star_variable_index]
+ [value_to_return[star_variable_index:end_index]]
+ value_to_return[end_index:]
)
elif len(target) > 1:
values_to_check = value_to_return = []
iterator = iter(value)
for _ in target:
try:
values_to_check.append(next(iterator))
except StopIteration:
raise ValueError(
f"not enough values to unpack (expected {len(target)}, got "
f"{len(values_to_check)})"
) from None

def check_multi_variable_assignment(
value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo
) -> Any:
if max(len(target) for target in targets) == 1:
iterated_values = [value]
else:
iterated_values = list(value)

if not _suppression.type_checks_suppressed:
for expected_types in targets:
value_index = 0
for ann_index, (varname, expected_type) in enumerate(
expected_types.items()
):
if varname.startswith("*"):
varname = varname[1:]
keys_left = len(expected_types) - 1 - ann_index
next_value_index = len(iterated_values) - keys_left
obj: object = iterated_values[value_index:next_value_index]
value_index = next_value_index
else:
values_to_check = [value]

for val, (varname, annotation) in zip(values_to_check, target):
try:
check_type_internal(val, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(val, add_class_prefix=True)
exc.append_path_element(f"value assigned to {varname} ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
obj = iterated_values[value_index]
value_index += 1
raise

try:
check_type_internal(obj, expected_type, memo)
except TypeCheckError as exc:
qualname = qualified_name(obj, add_class_prefix=True)
exc.append_path_element(f"value assigned to {varname} ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise

return iterated_values[0] if len(iterated_values) == 1 else iterated_values
return value_to_return


def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None:
Expand Down
146 changes: 75 additions & 71 deletions src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
If,
Import,
ImportFrom,
Index,
List,
Load,
LShift,
Expand Down Expand Up @@ -389,9 +388,7 @@ def visit_BinOp(self, node: BinOp) -> Any:
union_name = self.transformer._get_import("typing", "Union")
return Subscript(
value=union_name,
slice=Index(
Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
),
slice=Tuple(elts=[node.left, node.right], ctx=Load()),
ctx=Load(),
)

Expand All @@ -410,24 +407,18 @@ def visit_Subscript(self, node: Subscript) -> Any:
# The subscript of typing(_extensions).Literal can be any arbitrary string, so
# don't try to evaluate it as code
if node.slice:
if isinstance(node.slice, Index):
# Python 3.8
slice_value = node.slice.value # type: ignore[attr-defined]
else:
slice_value = node.slice

if isinstance(slice_value, Tuple):
if isinstance(node.slice, Tuple):
if self._memo.name_matches(node.value, *annotated_names):
# Only treat the first argument to typing.Annotated as a potential
# forward reference
items = cast(
typing.List[expr],
[self.visit(slice_value.elts[0])] + slice_value.elts[1:],
[self.visit(node.slice.elts[0])] + node.slice.elts[1:],
)
else:
items = cast(
typing.List[expr],
[self.visit(item) for item in slice_value.elts],
[self.visit(item) for item in node.slice.elts],
)

# If this is a Union and any of the items is Any, erase the entire
Expand All @@ -450,7 +441,7 @@ def visit_Subscript(self, node: Subscript) -> Any:
if item is None:
items[index] = self.transformer._get_import("typing", "Any")

slice_value.elts = items
node.slice.elts = items
else:
self.generic_visit(node)

Expand Down Expand Up @@ -542,18 +533,10 @@ def _use_memo(
return_annotation, *generator_names
):
if isinstance(return_annotation, Subscript):
annotation_slice = return_annotation.slice

# Python < 3.9
if isinstance(annotation_slice, Index):
annotation_slice = (
annotation_slice.value # type: ignore[attr-defined]
)

if isinstance(annotation_slice, Tuple):
items = annotation_slice.elts
if isinstance(return_annotation.slice, Tuple):
items = return_annotation.slice.elts
else:
items = [annotation_slice]
items = [return_annotation.slice]

if len(items) > 0:
new_memo.yield_annotation = self._convert_annotation(
Expand Down Expand Up @@ -743,7 +726,7 @@ def visit_FunctionDef(
annotation_ = self._convert_annotation(node.args.vararg.annotation)
if annotation_:
container = Name("tuple", ctx=Load())
subscript_slice: Tuple | Index = Tuple(
subscript_slice = Tuple(
[
annotation_,
Constant(Ellipsis),
Expand Down Expand Up @@ -1024,12 +1007,25 @@ def visit_AnnAssign(self, node: AnnAssign) -> Any:
func_name = self._get_import(
"typeguard._functions", "check_variable_assignment"
)
targets_arg = List(
[
List(
[
Tuple(
[Constant(node.target.id), annotation],
ctx=Load(),
)
],
ctx=Load(),
)
],
ctx=Load(),
)
node.value = Call(
func_name,
[
node.value,
Constant(node.target.id),
annotation,
targets_arg,
self._memo.get_memo_name(),
],
[],
Expand All @@ -1047,7 +1043,7 @@ def visit_Assign(self, node: Assign) -> Any:

# Only instrument function-local assignments
if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)):
targets: list[dict[Constant, expr | None]] = []
preliminary_targets: list[list[tuple[Constant, expr | None]]] = []
check_required = False
for target in node.targets:
elts: Sequence[expr]
Expand All @@ -1058,63 +1054,63 @@ def visit_Assign(self, node: Assign) -> Any:
else:
continue

annotations_: dict[Constant, expr | None] = {}
annotations_: list[tuple[Constant, expr | None]] = []
for exp in elts:
prefix = ""
if isinstance(exp, Starred):
exp = exp.value
prefix = "*"

path: list[str] = []
while isinstance(exp, Attribute):
path.insert(0, exp.attr)
exp = exp.value

if isinstance(exp, Name):
self._memo.ignored_names.add(exp.id)
name = prefix + exp.id
if not path:
self._memo.ignored_names.add(exp.id)

path.insert(0, exp.id)
name = prefix + ".".join(path)
annotation = self._memo.variable_annotations.get(exp.id)
if annotation:
annotations_[Constant(name)] = annotation
annotations_.append((Constant(name), annotation))
check_required = True
else:
annotations_[Constant(name)] = None
annotations_.append((Constant(name), None))

targets.append(annotations_)
preliminary_targets.append(annotations_)

if check_required:
# Replace missing annotations with typing.Any
for item in targets:
for key, expression in item.items():
targets: list[list[tuple[Constant, expr]]] = []
for items in preliminary_targets:
target_list: list[tuple[Constant, expr]] = []
targets.append(target_list)
for key, expression in items:
if expression is None:
item[key] = self._get_import("typing", "Any")
target_list.append((key, self._get_import("typing", "Any")))
else:
target_list.append((key, expression))

if len(targets) == 1 and len(targets[0]) == 1:
func_name = self._get_import(
"typeguard._functions", "check_variable_assignment"
)
target_varname = next(iter(targets[0]))
node.value = Call(
func_name,
[
node.value,
target_varname,
targets[0][target_varname],
self._memo.get_memo_name(),
],
[],
)
elif targets:
func_name = self._get_import(
"typeguard._functions", "check_multi_variable_assignment"
)
targets_arg = List(
[
Dict(keys=list(target), values=list(target.values()))
for target in targets
],
ctx=Load(),
)
node.value = Call(
func_name,
[node.value, targets_arg, self._memo.get_memo_name()],
[],
)
func_name = self._get_import(
"typeguard._functions", "check_variable_assignment"
)
targets_arg = List(
[
List(
[Tuple([name, ann], ctx=Load()) for name, ann in target],
ctx=Load(),
)
for target in targets
],
ctx=Load(),
)
node.value = Call(
func_name,
[node.value, targets_arg, self._memo.get_memo_name()],
[],
)

return node

Expand Down Expand Up @@ -1175,12 +1171,20 @@ def visit_AugAssign(self, node: AugAssign) -> Any:
operator_call = Call(
operator_func, [Name(node.target.id, ctx=Load()), node.value], []
)
targets_arg = List(
[
List(
[Tuple([Constant(node.target.id), annotation], ctx=Load())],
ctx=Load(),
)
],
ctx=Load(),
)
check_call = Call(
self._get_import("typeguard._functions", "check_variable_assignment"),
[
operator_call,
Constant(node.target.id),
annotation,
targets_arg,
self._memo.get_memo_name(),
],
[],
Expand Down
7 changes: 2 additions & 5 deletions src/typeguard/_union_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from ast import (
BinOp,
BitOr,
Index,
Load,
Name,
NodeTransformer,
Subscript,
Tuple,
fix_missing_locations,
parse,
)
from ast import Tuple as ASTTuple
from types import CodeType
from typing import Any

Expand All @@ -30,9 +29,7 @@ def visit_BinOp(self, node: BinOp) -> Any:
if isinstance(node.op, BitOr):
return Subscript(
value=self.union_name,
slice=Index(
ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
),
slice=Tuple(elts=[node.left, node.right], ctx=Load()),
ctx=Load(),
)

Expand Down
Loading

0 comments on commit 889ad53

Please sign in to comment.