Skip to content

Commit

Permalink
Type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
kynan committed Jan 19, 2025
1 parent 9e4873d commit f58c9e0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 42 deletions.
67 changes: 45 additions & 22 deletions nbstripout/_nbstripout.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
*.ipynb diff=ipynb
"""

from argparse import ArgumentParser, RawDescriptionHelpFormatter
from argparse import ArgumentParser, RawDescriptionHelpFormatter, Namespace
import collections
import copy
import io
Expand All @@ -118,6 +118,7 @@
from pathlib import PureWindowsPath
import re
from subprocess import call, check_call, check_output, CalledProcessError, STDOUT
from typing import List, Optional
import sys
import warnings

Expand All @@ -134,7 +135,7 @@
INSTALL_LOCATION_SYSTEM = 'system'


def _get_system_gitconfig_folder():
def _get_system_gitconfig_folder() -> str:
try:
git_config_output = check_output(
['git', 'config', '--system', '--list', '--show-origin'], universal_newlines=True, stderr=STDOUT
Expand All @@ -160,7 +161,9 @@ def _get_system_gitconfig_folder():
return path.abspath(path.dirname(system_gitconfig_file_path))


def _get_attrfile(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None):
def _get_attrfile(
git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, attrfile: Optional[str] = None
) -> str:
if not attrfile:
if install_location == INSTALL_LOCATION_SYSTEM:
try:
Expand All @@ -185,7 +188,7 @@ def _get_attrfile(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=
return attrfile


def _parse_size(num_str):
def _parse_size(num_str: str) -> int:
num_str = num_str.upper()
if num_str[-1].isdigit():
return int(num_str)
Expand All @@ -195,11 +198,15 @@ def _parse_size(num_str):
return int(num_str[:-1]) * (10**6)
elif num_str[-1] == 'G':
return int(num_str[:-1]) * (10**9)
else:
raise ValueError(f'Unknown size identifier {num_str[-1]}')
raise ValueError(f'Unknown size identifier {num_str[-1]}')


def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, attrfile=None):
def install(
git_config: str,
install_location: str = INSTALL_LOCATION_LOCAL,
python: Optional[str] = None,
attrfile: Optional[str] = None,
) -> int:
"""Install the git filter and set the git attributes."""
try:
filepath = f'"{PureWindowsPath(python or sys.executable).as_posix()}" -m nbstripout'
Expand Down Expand Up @@ -229,7 +236,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
diff_exists = '*.ipynb diff' in attrs

if filt_exists and diff_exists:
return
return 0

try:
with open(attrfile, 'a', newline='') as f:
Expand All @@ -242,6 +249,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
print('*.zpln filter=nbstripout', file=f)
if not diff_exists:
print('*.ipynb diff=ipynb', file=f)
return 0
except PermissionError:
print(f'Installation failed: could not write to {attrfile}', file=sys.stderr)

Expand All @@ -251,7 +259,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
return 1


def uninstall(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None):
def uninstall(git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, attrfile: Optional[str] = None) -> int:
"""Uninstall the git filter and unset the git attributes."""
try:
call(git_config + ['--unset', 'filter.nbstripout.clean'], stdout=open(devnull, 'w'), stderr=STDOUT)
Expand All @@ -274,9 +282,10 @@ def uninstall(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None
f.seek(0)
f.write(''.join(lines))
f.truncate()
return 0


def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False):
def status(git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, verbose: bool = False) -> int:
"""Return 0 if nbstripout is installed in the current repo, 1 otherwise"""
try:
if install_location == INSTALL_LOCATION_SYSTEM:
Expand Down Expand Up @@ -342,22 +351,28 @@ def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False):
return 1


def process_jupyter_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'):
def process_jupyter_notebook(
input_stream: io.IOBase,
output_stream: io.IOBase,
args: Namespace,
extra_keys: List[str],
filename: str = 'input from stdin',
) -> bool:
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning)
nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT)

nb_orig = copy.deepcopy(nb)
nb_stripped = strip_output(
nb,
args.keep_output,
args.keep_count,
args.keep_id,
extra_keys,
args.drop_empty_cells,
args.drop_tagged_cells.split(),
args.strip_init_cells,
_parse_size(args.max_size),
nb=nb,
keep_output=args.keep_output,
keep_count=args.keep_count,
keep_id=args.keep_id,
extra_keys=extra_keys,
drop_empty_cells=args.drop_empty_cells,
drop_tagged_cells=args.drop_tagged_cells.split(),
strip_init_cells=args.strip_init_cells,
max_size=_parse_size(args.max_size),
)

any_change = nb_orig != nb_stripped
Expand All @@ -377,7 +392,13 @@ def process_jupyter_notebook(input_stream, output_stream, args, extra_keys, file
return any_change


def process_zeppelin_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'):
def process_zeppelin_notebook(
input_stream: io.IOBase,
output_stream: io.IOBase,
args: Namespace,
extra_keys: List[str],
filename: str = 'input from stdin',
):
nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict)
nb_orig = copy.deepcopy(nb)
nb_stripped = strip_zeppelin_output(nb)
Expand Down Expand Up @@ -569,7 +590,9 @@ def main():
try:
with io.open(filename, 'r+', encoding='utf8', newline='') as f:
out = output_stream if args.textconv or args.dry_run else f
if process_notebook(f, out, args, extra_keys, filename):
if process_notebook(
input_stream=f, output_stream=out, args=args, extra_keys=extra_keys, filename=filename
):
any_change = True

except nbformat.reader.NotJSONError:
Expand Down
43 changes: 23 additions & 20 deletions nbstripout/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections import defaultdict
import sys
from typing import Any, Callable, Iterator, List, Optional

from nbformat import NotebookNode

__all__ = ['pop_recursive', 'strip_output', 'strip_zeppelin_output', 'MetadataError']

Expand All @@ -8,7 +11,7 @@ class MetadataError(Exception):
pass


def pop_recursive(d, key, default=None):
def pop_recursive(d: dict, key: str, default: Optional[NotebookNode] = None) -> NotebookNode:
"""dict.pop(key) where `key` is a `.`-delimited list of nested keys.
>>> d = {'a': {'b': 1, 'c': 2}}
Expand All @@ -25,11 +28,11 @@ def pop_recursive(d, key, default=None):
return default
key_head, key_tail = key.split('.', maxsplit=1)
if key_head in d:
return pop_recursive(d[key_head], key_tail, default)
return pop_recursive(d[key_head], key=key_tail, default=default)
return default


def _cells(nb, conditionals):
def _cells(nb: NotebookNode, conditionals: Callable[[NotebookNode], bool]) -> Iterator[NotebookNode]:
"""Remove cells not satisfying any conditional in conditionals and yield all other cells."""
if hasattr(nb, 'nbformat') and nb.nbformat < 4:
for ws in nb.worksheets:
Expand All @@ -44,7 +47,7 @@ def _cells(nb, conditionals):
yield cell


def get_size(item):
def get_size(item: Any) -> int:
"""Recursively sums length of all strings in `item`"""
if isinstance(item, str):
return len(item)
Expand All @@ -56,7 +59,7 @@ def get_size(item):
return len(str(item))


def determine_keep_output(cell, default, strip_init_cells=False):
def determine_keep_output(cell: NotebookNode, default: bool, strip_init_cells: bool = False):
"""Given a cell, determine whether output should be kept
Based on whether the metadata has "init_cell": true,
Expand All @@ -80,29 +83,29 @@ def determine_keep_output(cell, default, strip_init_cells=False):
return default


def _zeppelin_cells(nb):
def _zeppelin_cells(nb: dict) -> Iterator[dict]:
for pg in nb['paragraphs']:
yield pg


def strip_zeppelin_output(nb):
def strip_zeppelin_output(nb: dict) -> dict:
for cell in _zeppelin_cells(nb):
if 'results' in cell:
cell['results'] = {}
return nb


def strip_output(
nb,
keep_output,
keep_count,
keep_id,
extra_keys=[],
drop_empty_cells=False,
drop_tagged_cells=[],
strip_init_cells=False,
max_size=0,
):
nb: NotebookNode,
keep_output: bool,
keep_count: bool,
keep_id: bool,
extra_keys: List[str] = [],
drop_empty_cells: bool = False,
drop_tagged_cells: List[str] = [],
strip_init_cells: bool = False,
max_size: int = 0,
) -> NotebookNode:
"""
Strip the outputs, execution count/prompt number and miscellaneous
metadata from a notebook object, unless specified to keep either the outputs
Expand All @@ -122,7 +125,7 @@ def strip_output(
keys[namespace].append(subkey)

for field in keys['metadata']:
pop_recursive(nb.metadata, field)
pop_recursive(nb.metadata, key=field)

conditionals = []
# Keep cells if they have any `source` line that contains non-whitespace
Expand All @@ -132,7 +135,7 @@ def strip_output(
conditionals.append(lambda c: tag_to_drop not in c.get('metadata', {}).get('tags', []))

for i, cell in enumerate(_cells(nb, conditionals)):
keep_output_this_cell = determine_keep_output(cell, keep_output, strip_init_cells)
keep_output_this_cell = determine_keep_output(cell=cell, default=keep_output, strip_init_cells=strip_init_cells)

# Remove the outputs, unless directed otherwise
if 'outputs' in cell:
Expand All @@ -157,5 +160,5 @@ def strip_output(
if 'id' in cell and not keep_id:
cell['id'] = str(i)
for field in keys['cell']:
pop_recursive(cell, field)
pop_recursive(cell, key=field)
return nb

0 comments on commit f58c9e0

Please sign in to comment.