diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py index 875c0fba..6852bc0f 100644 --- a/pre_commit_hooks/check_docstring_first.py +++ b/pre_commit_hooks/check_docstring_first.py @@ -11,7 +11,7 @@ )) -def check_docstring_first(src: bytes, filename: str = '') -> int: +def check_docstring_first(src: bytes, filename: str) -> int: """Returns nonzero if the source has what looks like a docstring that is not at the beginning of the source. @@ -20,9 +20,14 @@ def check_docstring_first(src: bytes, filename: str = '') -> int: """ found_docstring_line = None found_code_line = None + assignment_lines = set() tok_gen = tokenize_tokenize(io.BytesIO(src).readline) - for tok_type, _, (sline, scol), _, _ in tok_gen: + for tok_type, string, (sline, scol), _, _ in tok_gen: + # Save all lines with top-level attribute assignments + if scol == 2 and tok_type == tokenize.OP and string == '=': + assignment_lines.add(sline) + # Looks like a docstring! if tok_type == tokenize.STRING and scol == 0: if found_docstring_line is not None: @@ -31,7 +36,10 @@ def check_docstring_first(src: bytes, filename: str = '') -> int: f'(first docstring on line {found_docstring_line}).', ) return 1 - elif found_code_line is not None: + elif ( + found_code_line is not None and + sline > 0 and sline - 1 not in assignment_lines + ): print( f'{filename}:{sline} Module docstring appears after code ' f'(code seen on line {found_code_line}).', @@ -55,6 +63,6 @@ def main(argv: Optional[Sequence[str]] = None) -> int: for filename in args.filenames: with open(filename, 'rb') as f: contents = f.read() - retv |= check_docstring_first(contents, filename=filename) + retv |= check_docstring_first(contents, filename) return retv diff --git a/tests/check_docstring_first_test.py b/tests/check_docstring_first_test.py index ed5c08ef..2da82b6d 100644 --- a/tests/check_docstring_first_test.py +++ b/tests/check_docstring_first_test.py @@ -38,6 +38,13 @@ ), # String literals in expressions are ok. (b'x = "foo"\n', 0, ''), + # Attribute docstrings are ok. + ( + b'x = "foo"\n' + b'"""x holds the foo"""', + 0, + '', + ), ) @@ -48,7 +55,7 @@ @all_tests def test_unit(capsys, contents, expected, expected_out): - assert check_docstring_first(contents) == expected + assert check_docstring_first(contents, '') == expected assert capsys.readouterr()[0] == expected_out.format(filename='')