Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ repos:
name: Update Dockerfile dependency graph
entry: tools/update-dockerfile-graph.sh
language: script
- id: enforce-import-regex-instead-of-re
name: Enforce import regex as re
entry: python tools/enforce_regex_import.py
language: python
types: [python]
pass_filenames: false
# forbid directly import triton
- id: forbid-direct-triton-import
name: "Forbid direct 'import triton'"
Expand Down
81 changes: 81 additions & 0 deletions tools/enforce_regex_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import re
import subprocess
from pathlib import Path

FORBIDDEN_PATTERNS = re.compile(
r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)')
ALLOWED_PATTERNS = [
re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'),
re.compile(r'^\s*import\s+regex\s*$'),
]


def get_staged_python_files() -> list[str]:
try:
result = subprocess.run(
['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'],
capture_output=True,
text=True,
check=True)
files = result.stdout.strip().split(
'\n') if result.stdout.strip() else []
return [f for f in files if f.endswith('.py')]
except subprocess.CalledProcessError:
return []


def is_forbidden_import(line: str) -> bool:
line = line.strip()
return FORBIDDEN_PATTERNS.match(line) and not any(
pattern.match(line) for pattern in ALLOWED_PATTERNS)


def check_file(filepath: str) -> list[tuple[int, str]]:
violations = []
try:
with open(filepath, encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
if is_forbidden_import(line):
violations.append((line_num, line.strip()))
except (OSError, UnicodeDecodeError):
pass
return violations


def main() -> int:
files = get_staged_python_files()
if not files:
return 0

total_violations = 0

for filepath in files:
if not Path(filepath).exists():
continue

violations = check_file(filepath)
if violations:
print(f"\n❌ {filepath}:")
for line_num, line in violations:
print(f" Line {line_num}: {line}")
total_violations += 1

if total_violations > 0:
print(f"\n💡 Found {total_violations} violation(s).")
print("❌ Please replace 'import re' with 'import regex as re'")
print(
" Also replace 'from re import ...' with 'from regex import ...'"
) # noqa: E501
print("✅ Allowed imports:")
print(" - import regex as re")
print(" - import regex") # noqa: E501
return 1

return 0


if __name__ == "__main__":
raise SystemExit(main())
Loading