Skip to content

Commit d8d6812

Browse files
nicoddemusasottile
authored andcommitted
Merge pull request #8540 from hauntsaninja/assert310
(cherry picked from commit af31c60)
1 parent a506148 commit d8d6812

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ Sankt Petersbug
273273
Segev Finer
274274
Serhii Mozghovyi
275275
Seth Junot
276+
Shantanu Jain
276277
Shubham Adep
277278
Simon Gomizelj
278279
Simon Kerr

changelog/8539.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed assertion rewriting on Python 3.10.

src/_pytest/assertion/rewrite.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -673,12 +673,9 @@ def run(self, mod: ast.Module) -> None:
673673
if not mod.body:
674674
# Nothing to do.
675675
return
676-
# Insert some special imports at the top of the module but after any
677-
# docstrings and __future__ imports.
678-
aliases = [
679-
ast.alias("builtins", "@py_builtins"),
680-
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
681-
]
676+
677+
# We'll insert some special imports at the top of the module, but after any
678+
# docstrings and __future__ imports, so first figure out where that is.
682679
doc = getattr(mod, "docstring", None)
683680
expect_docstring = doc is None
684681
if doc is not None and self.is_rewrite_disabled(doc):
@@ -710,10 +707,27 @@ def run(self, mod: ast.Module) -> None:
710707
lineno = item.decorator_list[0].lineno
711708
else:
712709
lineno = item.lineno
710+
# Now actually insert the special imports.
711+
if sys.version_info >= (3, 10):
712+
aliases = [
713+
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
714+
ast.alias(
715+
"_pytest.assertion.rewrite",
716+
"@pytest_ar",
717+
lineno=lineno,
718+
col_offset=0,
719+
),
720+
]
721+
else:
722+
aliases = [
723+
ast.alias("builtins", "@py_builtins"),
724+
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
725+
]
713726
imports = [
714727
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
715728
]
716729
mod.body[pos:pos] = imports
730+
717731
# Collect asserts.
718732
nodes: List[ast.AST] = [mod]
719733
while nodes:

0 commit comments

Comments
 (0)