Skip to content

Commit

Permalink
Merge pull request #347 from malthejorgensen/groupby-multiple-usages
Browse files Browse the repository at this point in the history
Add B031: Warn when using `groupby()` result multiple times
  • Loading branch information
Zac-HD authored Feb 9, 2023
2 parents 4087f49 + 05a26c7 commit a7c7ac9
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ It is therefore recommended to use a stacklevel of 2 or greater to provide more

**B030**: Except handlers should only be exception classes or tuples of exception classes.

**B031**: Using the generator returned from `itertools.groupby()` more than once will do nothing on the
second usage. Save the result to a list if the result is needed multiple times.

Opinionated warnings
~~~~~~~~~~~~~~~~~~~~

Expand Down
65 changes: 65 additions & 0 deletions bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def children_in_scope(node):
yield from children_in_scope(child)


def walk_list(nodes):
for node in nodes:
yield from ast.walk(node)


def _typesafe_issubclass(cls, class_or_tuple):
try:
return issubclass(cls, class_or_tuple)
Expand Down Expand Up @@ -401,6 +406,7 @@ def visit_For(self, node):
self.check_for_b007(node)
self.check_for_b020(node)
self.check_for_b023(node)
self.check_for_b031(node)
self.generic_visit(node)

def visit_AsyncFor(self, node):
Expand Down Expand Up @@ -793,6 +799,56 @@ def check_for_b026(self, call: ast.Call):
):
self.errors.append(B026(starred.lineno, starred.col_offset))

def check_for_b031(self, loop_node): # noqa: C901
"""Check that `itertools.groupby` isn't iterated over more than once.
We emit a warning when the generator returned by `groupby()` is used
more than once inside a loop body or when it's used in a nested loop.
"""
# for <loop_node.target> in <loop_node.iter>: ...
if isinstance(loop_node.iter, ast.Call):
node = loop_node.iter
if (isinstance(node.func, ast.Name) and node.func.id in ("groupby",)) or (
isinstance(node.func, ast.Attribute)
and node.func.attr == "groupby"
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "itertools"
):
# We have an invocation of groupby which is a simple unpacking
if isinstance(loop_node.target, ast.Tuple) and isinstance(
loop_node.target.elts[1], ast.Name
):
group_name = loop_node.target.elts[1].id
else:
# Ignore any `groupby()` invocation that isn't unpacked
return

num_usages = 0
for node in walk_list(loop_node.body):
# Handled nested loops
if isinstance(node, ast.For):
for nested_node in walk_list(node.body):
assert nested_node != node
if (
isinstance(nested_node, ast.Name)
and nested_node.id == group_name
):
self.errors.append(
B031(
nested_node.lineno,
nested_node.col_offset,
vars=(nested_node.id,),
)
)

# Handle multiple uses
if isinstance(node, ast.Name) and node.id == group_name:
num_usages += 1
if num_usages > 1:
self.errors.append(
B031(node.lineno, node.col_offset, vars=(node.id,))
)

def _get_assigned_names(self, loop_node):
loop_targets = (ast.For, ast.AsyncFor, ast.comprehension)
for node in children_in_scope(loop_node):
Expand Down Expand Up @@ -1558,8 +1614,17 @@ def visit_Lambda(self, node):
"anything. Add exceptions to handle."
)
)

B030 = Error(message="B030 Except handlers should only be names of exception classes")

B031 = Error(
message=(
"B031 Using the generator returned from `itertools.groupby()` more than once"
" will do nothing on the second usage. Save the result to a list, if the"
" result is needed multiple times."
)
)

# Warnings disabled by default.
B901 = Error(
message=(
Expand Down
64 changes: 64 additions & 0 deletions tests/b031.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Should emit:
B030 - on lines 29, 33, 43
"""
import itertools
from itertools import groupby

shoppers = ["Jane", "Joe", "Sarah"]
items = [
("lettuce", "greens"),
("tomatoes", "greens"),
("cucumber", "greens"),
("chicken breast", "meats & fish"),
("salmon", "meats & fish"),
("ice cream", "frozen items"),
]

carts = {shopper: [] for shopper in shoppers}


def collect_shop_items(shopper, items):
# Imagine this an expensive database query or calculation that is
# advantageous to batch.
carts[shopper] += items


# Group by shopping section
for _section, section_items in groupby(items, key=lambda p: p[1]):
for shopper in shoppers:
collect_shop_items(shopper, section_items)

for _section, section_items in groupby(items, key=lambda p: p[1]):
collect_shop_items("Jane", section_items)
collect_shop_items("Joe", section_items)


for _section, section_items in groupby(items, key=lambda p: p[1]):
# This is ok
collect_shop_items("Jane", section_items)

for _section, section_items in itertools.groupby(items, key=lambda p: p[1]):
for shopper in shoppers:
collect_shop_items(shopper, section_items)

for group in groupby(items, key=lambda p: p[1]):
# This is bad, but not detected currently
collect_shop_items("Jane", group[1])
collect_shop_items("Joe", group[1])


# Make sure we ignore - but don't fail on more complicated invocations
for _key, (_value1, _value2) in groupby(
[("a", (1, 2)), ("b", (3, 4)), ("a", (5, 6))], key=lambda p: p[1]
):
collect_shop_items("Jane", group[1])
collect_shop_items("Joe", group[1])

# Make sure we ignore - but don't fail on more complicated invocations
for (_key1, _key2), (_value1, _value2) in groupby(
[(("a", "a"), (1, 2)), (("b", "b"), (3, 4)), (("a", "a"), (5, 6))],
key=lambda p: p[1],
):
collect_shop_items("Jane", group[1])
collect_shop_items("Joe", group[1])
12 changes: 12 additions & 0 deletions tests/test_bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
B028,
B029,
B030,
B031,
B901,
B902,
B903,
Expand Down Expand Up @@ -459,6 +460,17 @@ def test_b030(self):
)
self.assertEqual(errors, expected)

def test_b031(self):
filename = Path(__file__).absolute().parent / "b031.py"
bbc = BugBearChecker(filename=str(filename))
errors = list(bbc.run())
expected = self.errors(
B031(30, 36, vars=("section_items",)),
B031(34, 30, vars=("section_items",)),
B031(43, 36, vars=("section_items",)),
)
self.assertEqual(errors, expected)

@unittest.skipIf(sys.version_info < (3, 8), "not implemented for <3.8")
def test_b907(self):
filename = Path(__file__).absolute().parent / "b907.py"
Expand Down

0 comments on commit a7c7ac9

Please sign in to comment.