diff --git a/src/tests/fixers_test.py b/src/tests/fixers_test.py index 13decf4..79f7dd8 100644 --- a/src/tests/fixers_test.py +++ b/src/tests/fixers_test.py @@ -2,14 +2,13 @@ from typing import List, Tuple from tryceratops.fixers import VerboseReraiseFixer -from tryceratops.violations import codes -from tryceratops.violations.violations import Violation +from tryceratops.violations import VerboseReraiseViolation, codes from .analyzer_helpers import read_sample_lines -def create_violation(code: Tuple[str, str], line: int): - return Violation(code[0], line, 0, "", "") +def create_verbose_reraise_violation(code: Tuple[str, str], line: int): + return VerboseReraiseViolation(code[0], line, 0, "", "", None, "ex") def assert_ast_is_valid(results: List[str]): @@ -29,10 +28,9 @@ def test_verbose_fixer(): lines = read_sample_lines("except_verbose_reraise") expected_modified_line = 20 expected_modified_offset = expected_modified_line - 1 - violation = create_violation(codes.VERBOSE_RERAISE, expected_modified_line) + violation = create_verbose_reraise_violation(codes.VERBOSE_RERAISE, expected_modified_line) results = fixer.perform_fix(lines, violation) - print(f"result: '{results[expected_modified_offset]}'") assert_ast_is_valid(results) assert_unmodified_lines(lines, results, expected_modified_offset) diff --git a/src/tryceratops/fixers/exception_block.py b/src/tryceratops/fixers/exception_block.py index 6375bf0..8848e66 100644 --- a/src/tryceratops/fixers/exception_block.py +++ b/src/tryceratops/fixers/exception_block.py @@ -1,11 +1,10 @@ import re from abc import ABC, abstractmethod from collections import defaultdict -from typing import Iterable, List, Tuple +from typing import Generic, Iterable, List, Tuple, TypeVar from tryceratops.violations import Violation, codes - -GroupedViolations = dict[str, List[Violation]] +from tryceratops.violations.violations import VerboseReraiseViolation class FileFixerHandler: @@ -28,7 +27,11 @@ def __exit__(self, exc_type, exc_value, exc_traceback): self.file.close() -class BaseFixer(ABC): +ViolationType = TypeVar("ViolationType", bound=Violation) +GroupedViolations = dict[str, List[ViolationType]] + + +class BaseFixer(Generic[ViolationType], ABC): violation_code: Tuple[str, str] fixes_made = 0 @@ -44,7 +47,7 @@ def _group_violations_by_filename(self, violations: List[Violation]) -> GroupedV return group - def _process_group(self, filename: str, violations: List[Violation]): + def _process_group(self, filename: str, violations: List[ViolationType]): with FileFixerHandler(filename) as file: for violation in violations: file_lines = file.read_lines() @@ -54,7 +57,7 @@ def _process_group(self, filename: str, violations: List[Violation]): self.fixes_made += 1 @abstractmethod - def perform_fix(self, lines: List[str], violation: Violation) -> List[str]: + def perform_fix(self, lines: List[str], violation: ViolationType) -> List[str]: pass def fix(self, violations: List[Violation]): @@ -65,14 +68,14 @@ def fix(self, violations: List[Violation]): self._process_group(filename, file_violations) -class VerboseReraiseFixer(BaseFixer): +class VerboseReraiseFixer(BaseFixer[VerboseReraiseViolation]): violation_code = codes.VERBOSE_RERAISE - def perform_fix(self, lines: List[str], violation: Violation) -> List[str]: + def perform_fix(self, lines: List[str], violation: VerboseReraiseViolation) -> List[str]: all_lines = lines[:] guilty_line = all_lines[violation.line - 1] - new_line = re.sub(r"raise.*", "raise", guilty_line) + new_line = re.sub(rf"raise {violation.exception_name}", "raise", guilty_line) all_lines[violation.line - 1] = new_line return all_lines