Skip to content

Commit 2a1e89e

Browse files
authored
Merge pull request #221 from trailofbits/plurals
Handle plurals correctly in output
2 parents ebe7381 + cd5a7d2 commit 2a1e89e

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

pip_audit/_cli.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ def _fatal(msg: str) -> NoReturn:
136136
sys.exit(1)
137137

138138

139-
def audit() -> None:
140-
"""
141-
The primary entrypoint for `pip-audit`.
142-
"""
139+
def _parser() -> argparse.ArgumentParser:
143140
parser = argparse.ArgumentParser(
144141
prog="pip-audit",
145142
description="audit the Python environment for dependencies with known vulnerabilities",
@@ -241,8 +238,20 @@ def audit() -> None:
241238
action="store_true",
242239
help="automatically upgrade dependencies with known vulnerabilities",
243240
)
241+
return parser
242+
243+
244+
def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
245+
return parser.parse_args()
246+
247+
248+
def audit() -> None:
249+
"""
250+
The primary entrypoint for `pip-audit`.
251+
"""
252+
parser = _parser()
253+
args = _parse_args(parser)
244254

245-
args = parser.parse_args()
246255
if args.verbose:
247256
logging.root.setLevel("DEBUG")
248257

@@ -307,10 +316,17 @@ def audit() -> None:
307316
# TODO(ww): Refine this: we should always output if our output format is an SBOM
308317
# or other manifest format (like the default JSON format).
309318
if vuln_count > 0:
310-
summary_msg = f"Found {vuln_count} known vulnerabilities in {pkg_count} packages"
319+
summary_msg = (
320+
f"Found {vuln_count} known "
321+
f"{'vulnerability' if vuln_count == 1 else 'vulnerabilities'} "
322+
f"in {pkg_count} {'package' if pkg_count == 1 else 'packages'}"
323+
)
311324
if args.fix:
312325
summary_msg += (
313-
f" and fixed {fixed_vuln_count} vulnerabilities in {fixed_pkg_count} packages"
326+
f" and fixed {fixed_vuln_count} "
327+
f"{'vulnerability' if fixed_vuln_count == 1 else 'vulnerabilities'} "
328+
f"in {fixed_pkg_count} "
329+
f"{'package' if fixed_pkg_count == 1 else 'packages'}"
314330
)
315331
print(summary_msg, file=sys.stderr)
316332
print(formatter.format(result))

test/test_cli.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pretend # type: ignore
2+
import pytest
3+
4+
import pip_audit._cli
5+
6+
7+
@pytest.mark.parametrize(
8+
"args, vuln_count, pkg_count, expected",
9+
[
10+
([], 1, 1, "Found 1 known vulnerability in 1 package"),
11+
([], 2, 1, "Found 2 known vulnerabilities in 1 package"),
12+
([], 2, 2, "Found 2 known vulnerabilities in 2 packages"),
13+
(["--fix"], 1, 1, "fixed 1 vulnerability in 1 package"),
14+
(["--fix"], 2, 1, "fixed 2 vulnerabilities in 1 package"),
15+
(["--fix"], 2, 2, "fixed 2 vulnerabilities in 2 packages"),
16+
],
17+
)
18+
def test_plurals(capsys, monkeypatch, args, vuln_count, pkg_count, expected):
19+
dummysource = pretend.stub(fix=lambda a: None)
20+
monkeypatch.setattr(pip_audit._cli, "PipSource", lambda *a, **kw: dummysource)
21+
22+
parser = pip_audit._cli._parser()
23+
monkeypatch.setattr(pip_audit._cli, "_parse_args", lambda x: parser.parse_args(args))
24+
25+
result = [
26+
(
27+
pretend.stub(
28+
is_skipped=lambda: False,
29+
name="something" + str(i),
30+
canonical_name="something" + str(i),
31+
version=1,
32+
),
33+
[pretend.stub(fix_versions=[2], id="foo")] * (vuln_count // pkg_count),
34+
)
35+
for i in range(pkg_count)
36+
]
37+
38+
auditor = pretend.stub(audit=lambda a: result)
39+
monkeypatch.setattr(pip_audit._cli, "Auditor", lambda *a, **kw: auditor)
40+
41+
resolve_fix_versions = [pretend.stub(is_skipped=lambda: False, dep=spec) for spec, _ in result]
42+
monkeypatch.setattr(pip_audit._cli, "resolve_fix_versions", lambda *a: resolve_fix_versions)
43+
44+
try:
45+
pip_audit._cli.audit()
46+
except SystemExit:
47+
pass
48+
49+
captured = capsys.readouterr()
50+
assert expected in captured.err

0 commit comments

Comments
 (0)