Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On python<3.9, get rid of ast.Index before passing AST to PyiVisitor #113

Merged
merged 7 commits into from
Jan 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 28 additions & 43 deletions pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,21 @@ def run_check(self, plugin, **kwargs):
return super().run_check(plugin, **kwargs)


class LegacyNormalizer(ast.NodeTransformer):
"""Transform AST to be consistent across Python versions."""

if sys.version_info < (3, 9):

def visit_Index(self, node: ast.Index) -> ast.expr:
"""Index nodes no longer exist in Python 3.9.

For example, consider the AST representing Union[str, int].
Before 3.9: Subscript(value=Name(id='Union'), slice=Index(value=Tuple(...)))
3.9 and newer: Subscript(value=Name(id='Union'), slice=Tuple(...))
"""
return node.value


@dataclass
class PyiVisitor(ast.NodeVisitor):
filename: Path = Path("(none)")
Expand Down Expand Up @@ -449,15 +464,10 @@ def _check_for_multiple_literals(self, members: Sequence[ast.expr]) -> None:
new_literal_members: list[ast.expr] = []

for literal in literals_in_union:
if sys.version_info >= (3, 9):
contents = literal
if isinstance(literal, ast.Tuple):
new_literal_members.extend(literal.elts)
else:
contents = literal.value

if isinstance(contents, ast.Tuple):
new_literal_members.extend(contents.elts)
else:
new_literal_members.append(contents)
new_literal_members.append(literal)

new_literal_slice = unparse(ast.Tuple(new_literal_members)).strip("()")

Expand Down Expand Up @@ -501,21 +511,10 @@ def visit_Subscript(self, node: ast.Subscript) -> None:
self.visit(node.slice)
return

# Union[str, int] parses differently depending on python versions:
# Before 3.9: Subscript(value=Name(id='Union'), slice=Index(value=Tuple(...)))
# 3.9 and newer: Subscript(value=Name(id='Union'), slice=Tuple(...))
if sys.version_info >= (3, 9):
if isinstance(node.slice, ast.Tuple):
self._visit_slice_tuple(node.slice, value_id)
else:
self.visit(node.slice)
if isinstance(node.slice, ast.Tuple):
self._visit_slice_tuple(node.slice, value_id)
else:
if isinstance(node.slice, ast.Index) and isinstance(
node.slice.value, ast.Tuple
):
self._visit_slice_tuple(node.slice.value, value_id)
else:
self.visit(node.slice)
self.visit(node.slice)

def _visit_slice_tuple(self, node: ast.Tuple, parent: str | None) -> None:
if parent == "Union":
Expand Down Expand Up @@ -574,15 +573,10 @@ def _check_subscript_version_check(self, node: ast.Compare) -> None:
version_info = node.left
if isinstance(version_info, ast.Subscript):
slc = version_info.slice
if isinstance(slc, (ast.Index, ast.Num)):
# Python 3.9 flattens the AST and removes Index, so simulate that here
slice_num = slc if isinstance(slc, ast.Num) else slc.value
# TODO: ast.Num works, but is deprecated
if isinstance(slc, ast.Num):
# anything other than the integer 0 doesn't make much sense
if (
isinstance(slice_num, ast.Num)
and isinstance(slice_num.n, int)
and slice_num.n == 0
):
if isinstance(slc.n, int) and slc.n == 0:
must_be_single = True
else:
self.error(node, Y003)
Expand Down Expand Up @@ -760,19 +754,10 @@ def _check_class_method_for_bad_typevars(

cls_typevar: str

# see comment in visit_Subscript
if sys.version_info >= (3, 9):
if isinstance(first_arg_annotation.slice, ast.Name):
cls_typevar = first_arg_annotation.slice.id
else:
return
if isinstance(first_arg_annotation.slice, ast.Name):
cls_typevar = first_arg_annotation.slice.id
else:
if isinstance(first_arg_annotation.slice, ast.Index) and isinstance(
first_arg_annotation.slice.value, ast.Name
):
cls_typevar = first_arg_annotation.slice.value.id
else:
return
return

if not isinstance(first_arg_annotation.value, ast.Name):
return
Expand Down Expand Up @@ -878,7 +863,7 @@ def run(self):
path = Path(self.filename)
if path.suffix == ".pyi":
visitor = PyiVisitor(filename=path)
for error in visitor.run(self.tree):
for error in visitor.run(LegacyNormalizer().visit(self.tree)):
yield error

@classmethod
Expand Down