Skip to content

Commit 8aa3b34

Browse files
CLI: replace min dependencies also in pyproject. toml (#414)
* replace min dependencies also in `pyproject. toml` * add tests * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6df43c2 commit 8aa3b34

File tree

7 files changed

+71
-9
lines changed

7 files changed

+71
-9
lines changed

.github/workflows/ci-scripts.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
timeout-minutes: 20
5757
run: |
5858
set -e
59-
pip install -e . -U -r requirements/_tests.txt -f $TORCH_URL
59+
pip install -e ".[cli]" -U -r requirements/_tests.txt -f $TORCH_URL
6060
pip --version
6161
pip list
6262

.github/workflows/ci-testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
timeout-minutes: 20
5151
run: |
5252
set -e
53-
pip install -e . -U -r requirements/_tests.txt -f $TORCH_URL
53+
pip install -e ".[cli]" -U -r requirements/_tests.txt -f $TORCH_URL
5454
pip --version
5555
pip list
5656

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- CI: add `force-check-all` input to MD link check workflow ([#408](https://github.com/Lightning-AI/utilities/pull/408))
1515

1616

17+
- CLI: replace min dependencies also in `pyproject.toml` ([#414](https://github.com/Lightning-AI/utilities/pull/414))
18+
19+
1720
### Changed
1821

1922
- CLI: switch from `fire` to `jsonargparse` ([#371](https://github.com/Lightning-AI/utilities/pull/371))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ lint.per-file-ignores."src/**" = [
7070
lint.per-file-ignores."tests/**" = [
7171
"ANN001", # Missing type annotation for function argument
7272
"ANN101", # Missing type annotation for `self` in method
73-
"ANN201", # Missing return type annotation for public function
73+
"ANN201", # Missing return type annotation for public function
7474
"ANN202", # Missing return type annotation for private function
7575
"ANN204", # Missing return type annotation for special method
7676
"ANN401", # Dynamically typed expressions (typing.Any)

requirements/cli.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
jsonargparse[signatures] >=4.38.0
2+
tomlkit

src/lightning_utilities/cli/dependencies.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import glob
66
import os.path
77
import re
8+
import warnings
89
from collections.abc import Sequence
910
from pprint import pprint
1011
from typing import Union
1112

1213
REQUIREMENT_ROOT = "requirements.txt"
1314
REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt"))
1415
REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True)
16+
REQUIREMENT_FILES_ALL += glob.glob(os.path.join("**", "pyproject.toml"))
1517
if os.path.isfile(REQUIREMENT_ROOT):
1618
REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT]
1719

@@ -43,23 +45,57 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
4345
fp.writelines(lines)
4446

4547

46-
def _replace_min(fname: str) -> None:
48+
def _replace_min_txt(fname: str) -> None:
4749
with open(fname) as fopen:
4850
req = fopen.read().replace(">=", "==")
4951
with open(fname, "w") as fw:
5052
fw.write(req)
5153

5254

55+
def _replace_min_pyproject_toml(fname: str) -> None:
56+
"""Replace all `>=` with `==` in the standard pyproject.toml file in [project.dependencies]."""
57+
import tomlkit
58+
59+
# Load and parse the existing pyproject.toml
60+
with open(fname, encoding="utf-8") as f:
61+
content = f.read()
62+
doc = tomlkit.parse(content)
63+
64+
# todo: consider also replace extras in [dependency-groups] -> extras = [...]
65+
deps = doc.get("project", {}).get("dependencies")
66+
if not deps:
67+
return
68+
69+
# Replace '>=version' with '==version' in each dependency
70+
for i, req in enumerate(deps):
71+
# Simple string value
72+
deps[i] = req.replace(">=", "==")
73+
74+
# Dump back out, preserving layout
75+
with open(fname, "w", encoding="utf-8") as f:
76+
f.write(tomlkit.dumps(doc))
77+
78+
5379
def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
5480
"""Replace the min package version by fixed one."""
5581
if isinstance(req_files, str):
5682
req_files = [req_files]
5783
for fname in req_files:
58-
_replace_min(fname)
84+
if fname.endswith(".txt"):
85+
_replace_min_txt(fname)
86+
elif os.path.basename(fname) == "pyproject.toml":
87+
_replace_min_pyproject_toml(fname)
88+
else:
89+
warnings.warn(
90+
"Only *.txt with plain list of requirements or standard pyproject.toml are supported."
91+
f" File '{fname}' is not supported.",
92+
UserWarning,
93+
stacklevel=2,
94+
)
5995

6096

6197
def _replace_package_name(requirements: list[str], old_package: str, new_package: str) -> list[str]:
62-
"""Replace one package by another with same version in given requirement file.
98+
"""Replace one package by another with the same version in a given requirement file.
6399
64100
>>> _replace_package_name(["torch>=1.0 # comment", "torchvision>=0.2", "torchtext <0.3"], "torch", "pytorch")
65101
['pytorch>=1.0 # comment', 'torchvision>=0.2', 'torchtext <0.3']

tests/unittests/cli/test_dependencies.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
_PATH_ROOT = Path(__file__).parent.parent.parent
1010

1111

12-
def test_prune_packages(tmpdir):
12+
def test_prune_packages_txt(tmpdir):
1313
req_file = tmpdir / "requirements.txt"
1414
with open(req_file, "w") as fp:
1515
fp.writelines(["fire\n", "abc>=0.1\n"])
@@ -19,7 +19,7 @@ def test_prune_packages(tmpdir):
1919
assert lines == ["fire\n"]
2020

2121

22-
def test_oldest_packages(tmpdir):
22+
def test_oldest_packages_txt(tmpdir):
2323
req_file = tmpdir / "requirements.txt"
2424
with open(req_file, "w") as fp:
2525
fp.writelines(["fire>0.2\n", "abc>=0.1\n"])
@@ -29,11 +29,33 @@ def test_oldest_packages(tmpdir):
2929
assert lines == ["fire>0.2\n", "abc==0.1\n"]
3030

3131

32-
def test_replace_packages(tmpdir):
32+
def test_replace_packages_txt(tmpdir):
3333
req_file = tmpdir / "requirements.txt"
3434
with open(req_file, "w") as fp:
3535
fp.writelines(["torchvision>=0.2\n", "torch>=1.0 # comment\n", "torchtext <0.3\n"])
3636
replace_package_in_requirements(old_package="torch", new_package="pytorch", req_files=[str(req_file)])
3737
with open(req_file) as fp:
3838
lines = fp.readlines()
3939
assert lines == ["torchvision>=0.2\n", "pytorch>=1.0 # comment\n", "torchtext <0.3\n"]
40+
41+
42+
def test_oldest_packages_pyproject_toml(tmpdir):
43+
req_file = tmpdir / "pyproject.toml"
44+
with open(req_file, "w") as fp:
45+
fp.writelines([
46+
"[project]\n",
47+
"dependencies = [\n",
48+
' "fire>0.2",\n',
49+
' "abc>=0.1",\n',
50+
"]\n",
51+
])
52+
replace_oldest_version(req_files=[str(req_file)])
53+
with open(req_file) as fp:
54+
lines = fp.readlines()
55+
assert lines == [
56+
"[project]\n",
57+
"dependencies = [\n",
58+
' "fire>0.2",\n',
59+
' "abc==0.1",\n',
60+
"]\n",
61+
]

0 commit comments

Comments
 (0)