Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve license_header tool by only traversing files under revision c… #13803

Merged
merged 2 commits into from
Jan 11, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 98 additions & 53 deletions tools/license_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from itertools import chain
import logging
import sys
import subprocess

# the default apache license
_LICENSE = """Licensed to the Apache Software Foundation (ASF) under one
Expand Down Expand Up @@ -92,9 +93,23 @@
# Previous license header, which will be removed
_OLD_LICENSE = re.compile('.*Copyright.*by Contributors')

def _has_license(lines):

def get_mxnet_root():
curpath = os.path.abspath(os.path.dirname(__file__))
def is_mxnet_root(path: str) -> bool:
return os.path.exists(os.path.join(path, ".mxnet_root"))
while not is_mxnet_root(curpath):
parent = os.path.abspath(os.path.join(curpath, os.pardir))
if parent == curpath:
raise RuntimeError("Got to the root and couldn't find a parent folder with .mxnet_root")
curpath = parent
return curpath


def _lines_have_license(lines):
return any([any([p in l for p in _LICENSE_PATTERNS]) for l in lines])


def _get_license(comment_mark):
if comment_mark == '*':
body = '/*\n'
Expand All @@ -113,65 +128,88 @@ def _get_license(comment_mark):
body += '\n'
return body

def _valid_file(fname, verbose=False):

def should_have_license(fname):
if any([l in fname for l in _WHITE_LIST]):
if verbose:
logging.info('skip ' + fname + ', it matches the white list')
logging.debug('skip ' + fname + ', it matches the white list')
return False
_, ext = os.path.splitext(fname)
if ext not in _LANGS:
if verbose:
logging.info('skip ' + fname + ', unknown file extension')
logging.debug('skip ' + fname + ', unknown file extension')
return False
return True

def process_file(fname, action, verbose=True):
if not _valid_file(fname, verbose):

def file_has_license(fname):
if not should_have_license(fname):
return True
try:
with open(fname, 'r', encoding="utf-8") as f:
lines = f.readlines()
if not lines:
if not lines or _lines_have_license(lines):
return True
if _has_license(lines):
return True
elif action == 'check':
else:
logging.error("File %s doesn't have a license", fname)
return False
_, ext = os.path.splitext(fname)
with open(fname, 'w', encoding="utf-8") as f:
# shebang line
if lines[0].startswith('#!'):
f.write(lines[0].rstrip()+'\n\n')
del lines[0]
f.write(_get_license(_LANGS[ext]))
for l in lines:
f.write(l.rstrip()+'\n')
logging.info('added license header to ' + fname)
except UnicodeError:
return True
return True

def process_folder(root, action):
excepts = []
for root, _, files in os.walk(root):
for f in files:
fname = os.path.normpath(os.path.join(root, f))
if not process_file(fname, action):
excepts.append(fname)
if action == 'check' and excepts:
logging.warning('The following files do not contain a valid license, '+
'you can use `tools/license_header.py add [file]` to add'+
'them automatically: ')
for x in excepts:
logging.warning(x)
return False
return True

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(format='%(asctime)-15s %(message)s')
def file_add_license(fname):
if not should_have_license(fname):
return
with open(fname, 'r', encoding="utf-8") as f:
lines = f.readlines()
if _lines_have_license(lines):
return
_, ext = os.path.splitext(fname)
with open(fname, 'w', encoding="utf-8") as f:
# shebang line
if lines[0].startswith('#!'):
f.write(lines[0].rstrip()+'\n\n')
del lines[0]
f.write(_get_license(_LANGS[ext]))
for l in lines:
f.write(l.rstrip()+'\n')
logging.info('added license header to ' + fname)
return


def under_git():
return subprocess.run(['git', 'rev-parse', 'HEAD'],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0


def git_files():
return list(map(os.fsdecode,
subprocess.check_output('git ls-tree -r HEAD --name-only -z'.split()).split(b'\0')))


def file_generator(path: str):
for (dirpath, dirnames, files) in os.walk(path):
for file in files:
yield os.path.abspath(os.path.join(dirpath, file))


def foreach(fn, iterable):
for x in iterable:
fn(x)


def script_name():
""":returns: script name with leading paths removed"""
return os.path.split(sys.argv[0])[1]


def main():
logging.basicConfig(
format='{}: %(levelname)s %(message)s'.format(script_name()),
level=os.environ.get("LOGLEVEL", "INFO"))

parser = argparse.ArgumentParser(
description='Add or check source license header')

parser.add_argument(
'action', nargs=1, type=str,
choices=['add', 'check'], default='add',
Expand All @@ -182,19 +220,26 @@ def process_folder(root, action):
help='Files to add license header to')

args = parser.parse_args()
files = list(chain(*args.file))
action = args.action[0]
has_license = True
if len(files) > 0:
for file in files:
has_license = process_file(file, action)
if action == 'check' and not has_license:
logging.warn("{} doesn't have a license".format(file))
has_license = False
else:
has_license = process_folder(os.path.join(os.path.dirname(__file__), '..'), action)
if not has_license:
sys.exit(1)
files = list(chain(*args.file))
if not files and action =='check':
if under_git():
logging.info("Git detected: Using files under version control")
files = git_files()
else:
logging.info("Using files under mxnet sources root")
files = file_generator(get_mxnet_root())

if action == 'check':
if not all(map(file_has_license, files)):
return 1
else:
logging.info("All known and whitelisted files have license")
return 0
else:
sys.exit(0)
assert action == 'add'
foreach(file_add_license, files)
return 0

if __name__ == '__main__':
sys.exit(main())