|
5 | 5 | import glob |
6 | 6 | import os.path |
7 | 7 | import re |
| 8 | +import warnings |
8 | 9 | from collections.abc import Sequence |
9 | 10 | from pprint import pprint |
10 | 11 | from typing import Union |
11 | 12 |
|
12 | 13 | REQUIREMENT_ROOT = "requirements.txt" |
13 | 14 | REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt")) |
14 | 15 | REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True) |
| 16 | +REQUIREMENT_FILES_ALL += glob.glob(os.path.join("**", "pyproject.toml")) |
15 | 17 | if os.path.isfile(REQUIREMENT_ROOT): |
16 | 18 | REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT] |
17 | 19 |
|
@@ -43,23 +45,57 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None: |
43 | 45 | fp.writelines(lines) |
44 | 46 |
|
45 | 47 |
|
46 | | -def _replace_min(fname: str) -> None: |
| 48 | +def _replace_min_txt(fname: str) -> None: |
47 | 49 | with open(fname) as fopen: |
48 | 50 | req = fopen.read().replace(">=", "==") |
49 | 51 | with open(fname, "w") as fw: |
50 | 52 | fw.write(req) |
51 | 53 |
|
52 | 54 |
|
| 55 | +def _replace_min_pyproject_toml(fname: str) -> None: |
| 56 | + """Replace all `>=` with `==` in the standard pyproject.toml file in [project.dependencies].""" |
| 57 | + import tomlkit |
| 58 | + |
| 59 | + # Load and parse the existing pyproject.toml |
| 60 | + with open(fname, encoding="utf-8") as f: |
| 61 | + content = f.read() |
| 62 | + doc = tomlkit.parse(content) |
| 63 | + |
| 64 | + # todo: consider also replace extras in [dependency-groups] -> extras = [...] |
| 65 | + deps = doc.get("project", {}).get("dependencies") |
| 66 | + if not deps: |
| 67 | + return |
| 68 | + |
| 69 | + # Replace '>=version' with '==version' in each dependency |
| 70 | + for i, req in enumerate(deps): |
| 71 | + # Simple string value |
| 72 | + deps[i] = req.replace(">=", "==") |
| 73 | + |
| 74 | + # Dump back out, preserving layout |
| 75 | + with open(fname, "w", encoding="utf-8") as f: |
| 76 | + f.write(tomlkit.dumps(doc)) |
| 77 | + |
| 78 | + |
53 | 79 | def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None: |
54 | 80 | """Replace the min package version by fixed one.""" |
55 | 81 | if isinstance(req_files, str): |
56 | 82 | req_files = [req_files] |
57 | 83 | for fname in req_files: |
58 | | - _replace_min(fname) |
| 84 | + if fname.endswith(".txt"): |
| 85 | + _replace_min_txt(fname) |
| 86 | + elif os.path.basename(fname) == "pyproject.toml": |
| 87 | + _replace_min_pyproject_toml(fname) |
| 88 | + else: |
| 89 | + warnings.warn( |
| 90 | + "Only *.txt with plain list of requirements or standard pyproject.toml are supported." |
| 91 | + f" File '{fname}' is not supported.", |
| 92 | + UserWarning, |
| 93 | + stacklevel=2, |
| 94 | + ) |
59 | 95 |
|
60 | 96 |
|
61 | 97 | def _replace_package_name(requirements: list[str], old_package: str, new_package: str) -> list[str]: |
62 | | - """Replace one package by another with same version in given requirement file. |
| 98 | + """Replace one package by another with the same version in a given requirement file. |
63 | 99 |
|
64 | 100 | >>> _replace_package_name(["torch>=1.0 # comment", "torchvision>=0.2", "torchtext <0.3"], "torch", "pytorch") |
65 | 101 | ['pytorch>=1.0 # comment', 'torchvision>=0.2', 'torchtext <0.3'] |
|
0 commit comments