diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee186be21094..5ee909c3b8ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -128,6 +128,13 @@ repos: name: Update Dockerfile dependency graph entry: tools/update-dockerfile-graph.sh language: script + - id: enforce-import-regex-instead-of-re + name: Enforce import regex as re + entry: python tools/enforce_regex_import.py + language: python + types: [python] + pass_filenames: false + additional_dependencies: [regex] # forbid directly import triton - id: forbid-direct-triton-import name: "Forbid direct 'import triton'" diff --git a/tools/enforce_regex_import.py b/tools/enforce_regex_import.py new file mode 100644 index 000000000000..b55c4a94eac8 --- /dev/null +++ b/tools/enforce_regex_import.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import subprocess +from pathlib import Path + +import regex as re + +FORBIDDEN_PATTERNS = re.compile( + r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)') +ALLOWED_PATTERNS = [ + re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'), + re.compile(r'^\s*import\s+regex\s*$'), +] + + +def get_staged_python_files() -> list[str]: + try: + result = subprocess.run( + ['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'], + capture_output=True, + text=True, + check=True) + files = result.stdout.strip().split( + '\n') if result.stdout.strip() else [] + return [f for f in files if f.endswith('.py')] + except subprocess.CalledProcessError: + return [] + + +def is_forbidden_import(line: str) -> bool: + line = line.strip() + return bool( + FORBIDDEN_PATTERNS.match(line) + and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)) + + +def check_file(filepath: str) -> list[tuple[int, str]]: + violations = [] + try: + with open(filepath, encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + if is_forbidden_import(line): + violations.append((line_num, line.strip())) + except (OSError, UnicodeDecodeError): + pass + return violations + + +def main() -> int: + files = get_staged_python_files() + if not files: + return 0 + + total_violations = 0 + + for filepath in files: + if not Path(filepath).exists(): + continue + + violations = check_file(filepath) + if violations: + print(f"\nāŒ {filepath}:") + for line_num, line in violations: + print(f" Line {line_num}: {line}") + total_violations += 1 + + if total_violations > 0: + print(f"\nšŸ’” Found {total_violations} violation(s).") + print("āŒ Please replace 'import re' with 'import regex as re'") + print( + " Also replace 'from re import ...' with 'from regex import ...'" + ) # noqa: E501 + print("āœ… Allowed imports:") + print(" - import regex as re") + print(" - import regex") # noqa: E501 + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())