Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve license_header tool by only traversing files under revision c…
Browse files Browse the repository at this point in the history
…ontrol
  • Loading branch information
larroy committed Jan 8, 2019
1 parent 96439e6 commit bd2f3bd
Showing 1 changed file with 98 additions and 53 deletions.
151 changes: 98 additions & 53 deletions tools/license_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from itertools import chain
import logging
import sys
import subprocess

# the default apache license
_LICENSE = """Licensed to the Apache Software Foundation (ASF) under one
Expand Down Expand Up @@ -90,9 +91,23 @@
# Previous license header, which will be removed
_OLD_LICENSE = re.compile('.*Copyright.*by Contributors')

def _has_license(lines):

def get_mxnet_root():
curpath = os.path.abspath(os.path.dirname(__file__))
def is_mxnet_root(path: str) -> bool:
return os.path.exists(os.path.join(path, ".mxnet_root"))
while not is_mxnet_root(curpath):
parent = os.path.abspath(os.path.join(curpath, os.pardir))
if parent == curpath:
raise RuntimeError("Got to the root and couldn't find a parent folder with .mxnet_root")
curpath = parent
return curpath


def _lines_have_license(lines):
return any([any([p in l for p in _LICENSE_PATTERNS]) for l in lines])


def _get_license(comment_mark):
if comment_mark == '*':
body = '/*\n'
Expand All @@ -111,65 +126,88 @@ def _get_license(comment_mark):
body += '\n'
return body

def _valid_file(fname, verbose=False):

def should_have_license(fname):
if any([l in fname for l in _WHITE_LIST]):
if verbose:
logging.info('skip ' + fname + ', it matches the white list')
logging.debug('skip ' + fname + ', it matches the white list')
return False
_, ext = os.path.splitext(fname)
if ext not in _LANGS:
if verbose:
logging.info('skip ' + fname + ', unknown file extension')
logging.debug('skip ' + fname + ', unknown file extension')
return False
return True

def process_file(fname, action, verbose=True):
if not _valid_file(fname, verbose):

def file_has_license(fname):
if not should_have_license(fname):
return True
try:
with open(fname, 'r', encoding="utf-8") as f:
lines = f.readlines()
if not lines:
if not lines or _lines_have_license(lines):
return True
if _has_license(lines):
return True
elif action == 'check':
else:
logging.error("File %s doesn't have a license", fname)
return False
_, ext = os.path.splitext(fname)
with open(fname, 'w', encoding="utf-8") as f:
# shebang line
if lines[0].startswith('#!'):
f.write(lines[0].rstrip()+'\n\n')
del lines[0]
f.write(_get_license(_LANGS[ext]))
for l in lines:
f.write(l.rstrip()+'\n')
logging.info('added license header to ' + fname)
except UnicodeError:
return True
return True

def process_folder(root, action):
excepts = []
for root, _, files in os.walk(root):
for f in files:
fname = os.path.normpath(os.path.join(root, f))
if not process_file(fname, action):
excepts.append(fname)
if action == 'check' and excepts:
logging.warning('The following files do not contain a valid license, '+
'you can use `tools/license_header.py add [file]` to add'+
'them automatically: ')
for x in excepts:
logging.warning(x)
return False
return True

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(format='%(asctime)-15s %(message)s')
def file_add_license(fname):
if not should_have_license(fname):
return
with open(fname, 'r', encoding="utf-8") as f:
lines = f.readlines()
if _lines_have_license(lines):
return
_, ext = os.path.splitext(fname)
with open(fname, 'w', encoding="utf-8") as f:
# shebang line
if lines[0].startswith('#!'):
f.write(lines[0].rstrip()+'\n\n')
del lines[0]
f.write(_get_license(_LANGS[ext]))
for l in lines:
f.write(l.rstrip()+'\n')
logging.info('added license header to ' + fname)
return


def under_git():
return subprocess.run(['git', 'rev-parse', 'HEAD'],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0


def git_files():
return list(map(os.fsdecode,
subprocess.check_output('git ls-tree -r master --name-only -z'.split()).split(b'\0')))


def file_generator(path: str):
for (dirpath, dirnames, files) in os.walk(path):
for file in files:
yield os.path.abspath(os.path.join(dirpath, file))


def foreach(fn, iterable):
for x in iterable:
fn(x)


def script_name():
""":returns: script name with leading paths removed"""
return os.path.split(sys.argv[0])[1]


def main():
logging.basicConfig(
format='{}: %(levelname)s %(message)s'.format(script_name()),
level=os.environ.get("LOGLEVEL", "INFO"))

parser = argparse.ArgumentParser(
description='Add or check source license header')

parser.add_argument(
'action', nargs=1, type=str,
choices=['add', 'check'], default='add',
Expand All @@ -180,19 +218,26 @@ def process_folder(root, action):
help='Files to add license header to')

args = parser.parse_args()
files = list(chain(*args.file))
action = args.action[0]
has_license = True
if len(files) > 0:
for file in files:
has_license = process_file(file, action)
if action == 'check' and not has_license:
logging.warn("{} doesn't have a license".format(file))
has_license = False
else:
has_license = process_folder(os.path.join(os.path.dirname(__file__), '..'), action)
if not has_license:
sys.exit(1)
files = list(chain(*args.file))
if not files and action =='check':
if under_git():
logging.info("Git detected: Using files under version control")
files = git_files()
else:
logging.info("Using files under mxnet sources root")
files = file_generator(get_mxnet_root())

if action == 'check':
if not all(map(file_has_license, files)):
return 1
else:
logging.info("All known and whitelisted files have license")
return 0
else:
sys.exit(0)
assert action == 'add'
foreach(file_add_license, files)
return 0

if __name__ == '__main__':
sys.exit(main())

0 comments on commit bd2f3bd

Please sign in to comment.