Skip to content

Commit

Permalink
fix: empty allowed_class skips check
Browse files Browse the repository at this point in the history
  • Loading branch information
guilatrova committed May 20, 2023
1 parent bed0fcc commit e3efc35
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
21 changes: 18 additions & 3 deletions src/tests/analyzers_classdefs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,28 @@ def test_inherit_from_allowed_exceptions():
)
)

assert_non_pickable = partial(assert_violation, codes.ALLOWED_BASE_EXCEPTION[0])
asset_non_inherit = partial(assert_violation, codes.ALLOWED_BASE_EXCEPTION[0])
msg_base = codes.ALLOWED_BASE_EXCEPTION[1]
allowed_msg = ", ".join(allowed_base_exceptions)

violations = analyzer.check(tree, "filename")

assert len(violations) == 2

assert_non_pickable(msg_base.format("InvalidBase", allowed_msg), 19, 0, violations[0])
assert_non_pickable(msg_base.format("MultiInvalidBase", allowed_msg), 23, 0, violations[1])
asset_non_inherit(msg_base.format("InvalidBase", allowed_msg), 19, 0, violations[0])
asset_non_inherit(msg_base.format("MultiInvalidBase", allowed_msg), 23, 0, violations[1])


def test_inherit_from_allowed_exceptions_undefined():
tree = read_sample("class_base_allowed")
analyzer = analyzers.classdefs.InheritFromBaseAnalyzer(
GlobalSettings(
include_experimental=False,
exclude_dirs=[],
ignore_violations=[],
allowed_base_exceptions=set(),
)
)

violations = analyzer.check(tree, "filename")
assert len(violations) == 0
5 changes: 4 additions & 1 deletion src/tryceratops/analyzers/classdefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ class InheritFromBaseAnalyzer(BaseAnalyzer):

@visit_error_handler
def visit_ClassDef(self, node: ast.ClassDef) -> t.Any:
settings = t.cast(GlobalSettings, self._settings)
if not settings.allowed_base_exceptions:
return self.generic_visit(node)

is_exc = any([base for base in node.bases if getattr(base, "id", None) == "Exception"])
if is_exc is False:
return self.generic_visit(node)

settings = t.cast(GlobalSettings, self._settings)
if node.name not in settings.allowed_base_exceptions:
self._mark_violation(
node,
Expand Down

0 comments on commit e3efc35

Please sign in to comment.