Skip to content

Adds --verify #195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 3, 2024
34 changes: 29 additions & 5 deletions nbstripout/_nbstripout.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,27 +331,39 @@ def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False):
return 1

def process_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'):
any_change = False
if args.mode == 'zeppelin':
nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict)
nb_str_orig = json.dumps(nb, indent=2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for writing to a JSON formatted string here and below? Could you also just compare the dicts directly? Not necessarily objecting, just trying to understand why you've implemented it this way :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was afraid of instances where the dict is different but the serialization is the same; for instance if for some reason the strip zepellin converts a list to a tuple. Both are equivalent in json but not as python dicts. (It might also prevent some edge cases where copying vs deep cloning might be an issue)

Example

import json

d1 = {"a": 1, "b": [1,]}
d2 = {"a": 1, "b": (1,)}

print(json.dumps(d1) == json.dumps(d2))
# True
print(d2 == d1)
# False
print(json.dumps(d2))
# {"a": 1, "b": [1]}

(I am not 100% sure if my concerns are totally grounded, but felt better).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are indeed correct that both strip_output and strip_zeppelin_output mutate the existing dict, so we'd probably need to create a deep copy for comparison purposes. I'd prefer that over the JSON serialization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

nb_stripped = strip_zeppelin_output(nb)

nb_str_stripped = json.dumps(nb_stripped, indent=2)
if nb_str_orig != nb_str_stripped:
any_change = True

if args.dry_run:
output_stream.write(f'Dry run: would have stripped {filename}\n')
return
return any_change
if output_stream.seekable():
output_stream.seek(0)
output_stream.truncate()
json.dump(nb_stripped, output_stream, indent=2)
output_stream.write('\n')
output_stream.flush()
return
return any_change

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT)

nb_start_str = json.dumps(nb, indent=2)
nb = strip_output(nb, args.keep_output, args.keep_count, args.keep_id,
extra_keys, args.drop_empty_cells,
args.drop_tagged_cells.split(), args.strip_init_cells,
_parse_size(args.max_size))
nb_end_str = json.dumps(nb, indent=2)
if nb_start_str != nb_end_str:
any_change = True

if args.dry_run:
output_stream.write(f'Dry run: would have stripped {filename}\n')
Expand All @@ -363,7 +375,7 @@ def process_notebook(input_stream, output_stream, args, extra_keys, filename='in
warnings.simplefilter("ignore", category=UserWarning)
nbformat.write(nb, output_stream)
output_stream.flush()

return any_change

def main():
parser = ArgumentParser(epilog=__doc__, formatter_class=RawDescriptionHelpFormatter)
Expand All @@ -383,6 +395,8 @@ def main():
'repository and configuration summary if installed')
task.add_argument('--version', action='store_true',
help='Print version')
parser.add_argument("--verify", action="store_true",
help="Return a non-zero exit code if any files were changed, Implies --dry-run")
parser.add_argument('--keep-count', action='store_true',
help='Do not strip the execution count/prompt number')
parser.add_argument('--keep-output', action='store_true',
Expand Down Expand Up @@ -428,6 +442,9 @@ def main():
args = parser.parse_args()
git_config = ['git', 'config']

if args.verify and not args.dry_run:
args.dry_run = True

if args._system:
git_config.append('--system')
install_location = INSTALL_LOCATION_SYSTEM
Expand Down Expand Up @@ -483,14 +500,17 @@ def main():
input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') if sys.stdin else None
output_stream = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', newline='')

any_change = False
for filename in args.files:
if not (args.force or filename.endswith('.ipynb') or filename.endswith('.zpln')):
continue

try:
with io.open(filename, 'r+', encoding='utf8', newline='') as f:
out = output_stream if args.textconv or args.dry_run else f
process_notebook(f, out, args, extra_keys, filename)
if process_notebook(f, out, args, extra_keys, filename):
any_change = True

except nbformat.reader.NotJSONError:
print(f"No valid notebook detected in '{filename}'", file=sys.stderr)
raise SystemExit(1)
Expand All @@ -504,7 +524,11 @@ def main():

if not args.files and input_stream:
try:
process_notebook(input_stream, output_stream, args, extra_keys)
any_local_change = process_notebook(input_stream, output_stream, args, extra_keys)
any_change = any_change or any_local_change
except nbformat.reader.NotJSONError:
print('No valid notebook detected on stdin', file=sys.stderr)
raise SystemExit(1)

if args.verify and any_change:
raise SystemExit(1)
74 changes: 63 additions & 11 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from pathlib import Path
import re
import json
from subprocess import run, PIPE
# Note: typing.Pattern is deprecated, for removal in 3.13 in favour of re.Pattern introduced in 3.8
from typing import List, Union, Pattern
Expand Down Expand Up @@ -53,49 +54,100 @@ def nbstripout_exe():


@pytest.mark.parametrize("input_file, expected_file, args", TEST_CASES)
def test_end_to_end_stdin(input_file: str, expected_file: str, args: List[str]):
@pytest.mark.parametrize("verify", (True, False))
def test_end_to_end_stdin(input_file: str, expected_file: str, args: List[str], verify: bool):
with open(NOTEBOOKS_FOLDER / expected_file, mode="r") as f:
expected = f.read()
expected_str = json.dumps(json.loads(expected), indent=2)

with open(NOTEBOOKS_FOLDER / input_file, mode="r") as f:
pc = run([nbstripout_exe()] + args, stdin=f, stdout=PIPE, universal_newlines=True)
input_str = json.dumps(json.loads(f.read()), indent=2)

with open(NOTEBOOKS_FOLDER / input_file, mode="r") as f:
args = [nbstripout_exe()] + args
if verify:
args.append("--verify")
pc = run(args, stdin=f, stdout=PIPE, universal_newlines=True)
output = pc.stdout

assert output == expected
if verify:
# When using stin, the dry flag is disregarded.
if input_str != expected_str:
assert pc.returncode == 1
else:
assert pc.returncode == 0
else:
assert output == expected
assert pc.returncode == 0


@pytest.mark.parametrize("input_file, expected_file, args", TEST_CASES)
def test_end_to_end_file(input_file: str, expected_file: str, args: List[str], tmp_path):
@pytest.mark.parametrize("verify", (True, False))
def test_end_to_end_file(input_file: str, expected_file: str, args: List[str], tmp_path, verify: bool):
with open(NOTEBOOKS_FOLDER / expected_file, mode="r") as f:
expected = f.read()
expected_str = json.dumps(json.loads(expected), indent=2)

p = tmp_path / input_file
with open(NOTEBOOKS_FOLDER / input_file, mode="r") as f:
p.write_text(f.read())
pc = run([nbstripout_exe(), p] + args, stdout=PIPE, universal_newlines=True)

assert not pc.stdout and p.read_text() == expected
with open(NOTEBOOKS_FOLDER / input_file, mode="r") as f:
input_str = json.dumps(json.loads(f.read()), indent=2)

args = [nbstripout_exe(), p] + args
if verify:
args.append("--verify")
pc = run(args, stdout=PIPE, universal_newlines=True)

output = pc.stdout.strip()
if verify:
if expected_str != input_str.strip():
assert pc.returncode == 1

# Since verify implies --dry-run, we make sure the file is not modified
# In other words, that the output == input, INSTEAD of output == expected
output.strip() == input_str.strip()
else:
output_file_str = json.dumps(json.loads(p.read_text()), indent=2)
assert pc.returncode == 0
assert output_file_str == expected_str


@pytest.mark.parametrize("input_file, extra_args", DRY_RUN_CASES)
def test_dry_run_stdin(input_file: str, extra_args: List[str]):
@pytest.mark.parametrize("verify", (True, False))
def test_dry_run_stdin(input_file: str, extra_args: List[str], verify: bool):
expected = "Dry run: would have stripped input from stdin\n"

with open(NOTEBOOKS_FOLDER / input_file, mode="r") as f:
pc = run([nbstripout_exe(), "--dry-run"] + extra_args, stdin=f, stdout=PIPE, universal_newlines=True)
args = [nbstripout_exe(), "--dry-run"] + extra_args
if verify:
args.append("--verify")
pc = run(args, stdin=f, stdout=PIPE, universal_newlines=True)
output = pc.stdout
exit_code = pc.returncode

assert output == expected
if verify:
assert exit_code == 1
else:
assert exit_code == 0


@pytest.mark.parametrize("input_file, extra_args", DRY_RUN_CASES)
def test_dry_run_args(input_file: str, extra_args: List[str]):
@pytest.mark.parametrize("verify", (True, False))
def test_dry_run_args(input_file: str, extra_args: List[str], verify: bool):
expected_regex = re.compile(f"Dry run: would have stripped .*[/\\\\]{input_file}\n")

pc = run([nbstripout_exe(), str(NOTEBOOKS_FOLDER / input_file), "--dry-run", ] + extra_args, stdout=PIPE, universal_newlines=True)
args = [nbstripout_exe(), str(NOTEBOOKS_FOLDER / input_file), "--dry-run", ] + extra_args
if verify:
args.append("--verify")
pc = run(args, stdout=PIPE, universal_newlines=True)
output = pc.stdout
exit_code = pc.returncode

assert expected_regex.match(output)
if verify:
assert exit_code == 1


@pytest.mark.parametrize("input_file, expected_errs, extra_args", ERR_OUTPUT_CASES)
Expand Down