From cbb7e845791131f1b6f54b887e6eace84c38734d Mon Sep 17 00:00:00 2001 From: Florian Rathgeber Date: Sat, 18 Jan 2025 21:03:14 +0100 Subject: [PATCH] Type annotations --- nbstripout/_nbstripout.py | 67 ++++++++++++++++++++++++++------------- nbstripout/_utils.py | 43 +++++++++++++------------ 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/nbstripout/_nbstripout.py b/nbstripout/_nbstripout.py index 7c329c2..823179a 100644 --- a/nbstripout/_nbstripout.py +++ b/nbstripout/_nbstripout.py @@ -109,7 +109,7 @@ *.ipynb diff=ipynb """ -from argparse import ArgumentParser, RawDescriptionHelpFormatter +from argparse import ArgumentParser, RawDescriptionHelpFormatter, Namespace import collections import copy import io @@ -118,6 +118,7 @@ from pathlib import PureWindowsPath import re from subprocess import call, check_call, check_output, CalledProcessError, STDOUT +from typing import Optional import sys import warnings @@ -134,7 +135,7 @@ INSTALL_LOCATION_SYSTEM = 'system' -def _get_system_gitconfig_folder(): +def _get_system_gitconfig_folder() -> str: try: git_config_output = check_output( ['git', 'config', '--system', '--list', '--show-origin'], universal_newlines=True, stderr=STDOUT @@ -160,7 +161,9 @@ def _get_system_gitconfig_folder(): return path.abspath(path.dirname(system_gitconfig_file_path)) -def _get_attrfile(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None): +def _get_attrfile( + git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, attrfile: Optional[str] = None +) -> str: if not attrfile: if install_location == INSTALL_LOCATION_SYSTEM: try: @@ -185,7 +188,7 @@ def _get_attrfile(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile= return attrfile -def _parse_size(num_str): +def _parse_size(num_str: str) -> int: num_str = num_str.upper() if num_str[-1].isdigit(): return int(num_str) @@ -195,11 +198,15 @@ def _parse_size(num_str): return int(num_str[:-1]) * (10**6) elif num_str[-1] == 'G': return int(num_str[:-1]) * (10**9) - else: - raise ValueError(f'Unknown size identifier {num_str[-1]}') + raise ValueError(f'Unknown size identifier {num_str[-1]}') -def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, attrfile=None): +def install( + git_config: str, + install_location: str = INSTALL_LOCATION_LOCAL, + python: Optional[str] = None, + attrfile: Optional[str] = None, +) -> int: """Install the git filter and set the git attributes.""" try: filepath = f'"{PureWindowsPath(python or sys.executable).as_posix()}" -m nbstripout' @@ -229,7 +236,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at diff_exists = '*.ipynb diff' in attrs if filt_exists and diff_exists: - return + return 0 try: with open(attrfile, 'a', newline='') as f: @@ -242,6 +249,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at print('*.zpln filter=nbstripout', file=f) if not diff_exists: print('*.ipynb diff=ipynb', file=f) + return 0 except PermissionError: print(f'Installation failed: could not write to {attrfile}', file=sys.stderr) @@ -251,7 +259,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at return 1 -def uninstall(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None): +def uninstall(git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, attrfile: Optional[str] = None) -> int: """Uninstall the git filter and unset the git attributes.""" try: call(git_config + ['--unset', 'filter.nbstripout.clean'], stdout=open(devnull, 'w'), stderr=STDOUT) @@ -274,9 +282,10 @@ def uninstall(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None f.seek(0) f.write(''.join(lines)) f.truncate() + return 0 -def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False): +def status(git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, verbose: bool = False) -> int: """Return 0 if nbstripout is installed in the current repo, 1 otherwise""" try: if install_location == INSTALL_LOCATION_SYSTEM: @@ -342,22 +351,28 @@ def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False): return 1 -def process_jupyter_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'): +def process_jupyter_notebook( + input_stream: io.IOBase, + output_stream: io.IOBase, + args: Namespace, + extra_keys: list[str], + filename: str = 'input from stdin', +) -> bool: with warnings.catch_warnings(): warnings.simplefilter('ignore', category=UserWarning) nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT) nb_orig = copy.deepcopy(nb) nb_stripped = strip_output( - nb, - args.keep_output, - args.keep_count, - args.keep_id, - extra_keys, - args.drop_empty_cells, - args.drop_tagged_cells.split(), - args.strip_init_cells, - _parse_size(args.max_size), + nb=nb, + keep_output=args.keep_output, + keep_count=args.keep_count, + keep_id=args.keep_id, + extra_keys=extra_keys, + drop_empty_cells=args.drop_empty_cells, + drop_tagged_cells=args.drop_tagged_cells.split(), + strip_init_cells=args.strip_init_cells, + max_size=_parse_size(args.max_size), ) any_change = nb_orig != nb_stripped @@ -377,7 +392,13 @@ def process_jupyter_notebook(input_stream, output_stream, args, extra_keys, file return any_change -def process_zeppelin_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'): +def process_zeppelin_notebook( + input_stream: io.IOBase, + output_stream: io.IOBase, + args: Namespace, + extra_keys: list[str], + filename: str = 'input from stdin', +): nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict) nb_orig = copy.deepcopy(nb) nb_stripped = strip_zeppelin_output(nb) @@ -569,7 +590,9 @@ def main(): try: with io.open(filename, 'r+', encoding='utf8', newline='') as f: out = output_stream if args.textconv or args.dry_run else f - if process_notebook(f, out, args, extra_keys, filename): + if process_notebook( + input_stream=f, output_stream=out, args=args, extra_keys=extra_keys, filename=filename + ): any_change = True except nbformat.reader.NotJSONError: diff --git a/nbstripout/_utils.py b/nbstripout/_utils.py index 3b6e14e..db7f252 100644 --- a/nbstripout/_utils.py +++ b/nbstripout/_utils.py @@ -1,5 +1,8 @@ from collections import defaultdict import sys +from typing import Any, Callable, Iterator, Optional + +from nbformat import NotebookNode __all__ = ['pop_recursive', 'strip_output', 'strip_zeppelin_output', 'MetadataError'] @@ -8,7 +11,7 @@ class MetadataError(Exception): pass -def pop_recursive(d, key, default=None): +def pop_recursive(d: dict, key: str, default: Optional[NotebookNode] = None) -> NotebookNode: """dict.pop(key) where `key` is a `.`-delimited list of nested keys. >>> d = {'a': {'b': 1, 'c': 2}} @@ -25,11 +28,11 @@ def pop_recursive(d, key, default=None): return default key_head, key_tail = key.split('.', maxsplit=1) if key_head in d: - return pop_recursive(d[key_head], key_tail, default) + return pop_recursive(d[key_head], key=key_tail, default=default) return default -def _cells(nb, conditionals): +def _cells(nb: NotebookNode, conditionals: Callable[[NotebookNode], bool]) -> Iterator[NotebookNode]: """Remove cells not satisfying any conditional in conditionals and yield all other cells.""" if hasattr(nb, 'nbformat') and nb.nbformat < 4: for ws in nb.worksheets: @@ -44,7 +47,7 @@ def _cells(nb, conditionals): yield cell -def get_size(item): +def get_size(item: Any) -> int: """Recursively sums length of all strings in `item`""" if isinstance(item, str): return len(item) @@ -56,7 +59,7 @@ def get_size(item): return len(str(item)) -def determine_keep_output(cell, default, strip_init_cells=False): +def determine_keep_output(cell: NotebookNode, default: bool, strip_init_cells: bool = False): """Given a cell, determine whether output should be kept Based on whether the metadata has "init_cell": true, @@ -80,12 +83,12 @@ def determine_keep_output(cell, default, strip_init_cells=False): return default -def _zeppelin_cells(nb): +def _zeppelin_cells(nb: dict) -> Iterator[dict]: for pg in nb['paragraphs']: yield pg -def strip_zeppelin_output(nb): +def strip_zeppelin_output(nb: dict) -> dict: for cell in _zeppelin_cells(nb): if 'results' in cell: cell['results'] = {} @@ -93,16 +96,16 @@ def strip_zeppelin_output(nb): def strip_output( - nb, - keep_output, - keep_count, - keep_id, - extra_keys=[], - drop_empty_cells=False, - drop_tagged_cells=[], - strip_init_cells=False, - max_size=0, -): + nb: NotebookNode, + keep_output: bool, + keep_count: bool, + keep_id: bool, + extra_keys: list[str] = [], + drop_empty_cells: bool = False, + drop_tagged_cells: list[str] = [], + strip_init_cells: bool = False, + max_size: int = 0, +) -> NotebookNode: """ Strip the outputs, execution count/prompt number and miscellaneous metadata from a notebook object, unless specified to keep either the outputs @@ -122,7 +125,7 @@ def strip_output( keys[namespace].append(subkey) for field in keys['metadata']: - pop_recursive(nb.metadata, field) + pop_recursive(nb.metadata, key=field) conditionals = [] # Keep cells if they have any `source` line that contains non-whitespace @@ -132,7 +135,7 @@ def strip_output( conditionals.append(lambda c: tag_to_drop not in c.get('metadata', {}).get('tags', [])) for i, cell in enumerate(_cells(nb, conditionals)): - keep_output_this_cell = determine_keep_output(cell, keep_output, strip_init_cells) + keep_output_this_cell = determine_keep_output(cell=cell, default=keep_output, strip_init_cells=strip_init_cells) # Remove the outputs, unless directed otherwise if 'outputs' in cell: @@ -157,5 +160,5 @@ def strip_output( if 'id' in cell and not keep_id: cell['id'] = str(i) for field in keys['cell']: - pop_recursive(cell, field) + pop_recursive(cell, key=field) return nb