Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
support annotation in python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
Ning Shang committed Sep 13, 2020
1 parent 0a21a90 commit 1dde669
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 29 deletions.
31 changes: 16 additions & 15 deletions tools/nni_annotation/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import astor

from .utils import ast_Num, ast_Str

# pylint: disable=unidiomatic-typecheck

Expand Down Expand Up @@ -37,13 +38,13 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
for call in value.elts:
assert type(call) is ast.Call, 'Element in layer_choice should be function call'
call_name = astor.to_source(call).strip()
call_funcs_keys.append(ast.Str(s=call_name))
call_funcs_keys.append(ast_Str(s=call_name))
call_funcs_values.append(call.func)
assert not call.args, 'Number of args without keyword should be zero'
kw_args = []
kw_values = []
for kw in call.keywords:
kw_args.append(ast.Str(s=kw.arg))
kw_args.append(ast_Str(s=kw.arg))
kw_values.append(kw.value)
call_kwargs_values.append(ast.Dict(keys=kw_args, values=kw_values))
call_funcs = ast.Dict(keys=call_funcs_keys, values=call_funcs_values)
Expand All @@ -57,12 +58,12 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
elif k.id == 'optional_inputs':
assert not fields['optional_inputs'], 'Duplicated field: optional_inputs'
assert type(value) is ast.List, 'Value of optional_inputs should be a list'
var_names = [ast.Str(s=astor.to_source(var).strip()) for var in value.elts]
var_names = [ast_Str(s=astor.to_source(var).strip()) for var in value.elts]
optional_inputs = ast.Dict(keys=var_names, values=value.elts)
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, \
assert type(value) is ast_Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list'
optional_input_size = value
fields['optional_input_size'] = True
Expand All @@ -79,8 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt)
mutable_layer_cnt += 1
target_call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='mutable_layer', ctx=ast.Load())
target_call_args = [ast.Str(s=mutable_id),
ast.Str(s=mutable_layer_id),
target_call_args = [ast_Str(s=mutable_id),
ast_Str(s=mutable_layer_id),
call_funcs,
call_kwargs]
if fields['fixed_inputs']:
Expand All @@ -93,8 +94,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
target_call_args.append(optional_input_size)
else:
target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0))
target_call_args.append(ast.Str(s=nas_mode))
target_call_args.append(ast_Num(n=0))
target_call_args.append(ast_Str(s=nas_mode))
if nas_mode in ['enas_mode', 'oneshot_mode', 'darts_mode']:
target_call_args.append(ast.Name(id='tensorflow'))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
Expand Down Expand Up @@ -151,7 +152,7 @@ def parse_nni_variable(code):
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'

name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
keyword_arg = ast.keyword(arg='name', value=ast_Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
Expand All @@ -169,7 +170,7 @@ def parse_nni_function(code):
convert_args_to_dict(call, with_lambda=True)

name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
call.keywords[0].value = ast_Str(s=name_str)

return call, funcs

Expand All @@ -180,12 +181,12 @@ def convert_args_to_dict(call, with_lambda=False):
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
if type(arg) in [ast_Str, ast_Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg_value = ast_Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
Expand All @@ -209,7 +210,7 @@ def test_variable_equal(node1, node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'):
if k in ('lineno', 'col_offset', 'ctx', 'end_lineno', 'end_col_offset'):
continue
if not test_variable_equal(v, getattr(node2, k)):
return False
Expand Down Expand Up @@ -282,7 +283,7 @@ def visit(self, node):
annotation = self.stack[-1]

# this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str:
if type(node) is ast.Expr and type(node.value) is ast_Str:
# must not annotate an annotation string
assert annotation is None, 'Annotating an annotation'
return self._visit_string(node)
Expand All @@ -306,7 +307,7 @@ def _visit_string(self, node):
if string.startswith('@nni.training_update'):
expr = parse_annotation(string[1:])
call_node = expr.value
call_node.args.insert(0, ast.Str(s=self.nas_mode))
call_node.args.insert(0, ast_Str(s=self.nas_mode))
return expr

if string.startswith('@nni.report_intermediate_result') \
Expand Down
10 changes: 6 additions & 4 deletions tools/nni_annotation/search_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import astor

from .utils import ast_Num, ast_Str

# pylint: disable=unidiomatic-typecheck


Expand Down Expand Up @@ -44,7 +46,7 @@ def generate_mutable_layer_search_space(self, args):
self.search_space[key]['_value'][mutable_layer] = {
'layer_choice': [k.s for k in args[2].keys],
'optional_inputs': [k.s for k in args[5].keys],
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
'optional_input_size': args[6].n if isinstance(args[6], ast_Num) else [args[6].elts[0].n, args[6].elts[1].n]
}

def visit_Call(self, node): # pylint: disable=invalid-name
Expand Down Expand Up @@ -73,7 +75,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
# there is a `name` argument
assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"'
assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"'
assert type(node.keywords[0].value) is ast.Str, 'Smart parameter\'s name must be string literal'
assert type(node.keywords[0].value) is ast_Str, 'Smart parameter\'s name must be string literal'
name = node.keywords[0].value.s
specified_name = True
else:
Expand All @@ -86,7 +88,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
# check if it is a number or a string and get its value accordingly
args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys]
args = [key.n if type(key) is ast_Num else key.s for key in node.args[0].keys]
else:
# arguments of other functions must be literal number
assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
Expand All @@ -95,7 +97,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name

key = self.module_name + '/' + name + '/' + func
# store key in ast.Call
node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))
node.keywords.append(ast.keyword(arg='key', value=ast_Str(s=key)))

if func == 'function_choice':
func = 'choice'
Expand Down
22 changes: 12 additions & 10 deletions tools/nni_annotation/specific_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import astor
from nni_cmd.common_utils import print_warning

from .utils import ast_Num, ast_Str

# pylint: disable=unidiomatic-typecheck

para_cfg = None
Expand Down Expand Up @@ -134,7 +136,7 @@ def parse_nni_variable(code):
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'

name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
keyword_arg = ast.keyword(arg='name', value=ast_Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
Expand All @@ -152,7 +154,7 @@ def parse_nni_function(code):
convert_args_to_dict(call, with_lambda=True)

name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
call.keywords[0].value = ast_Str(s=name_str)

return call, funcs

Expand All @@ -163,12 +165,12 @@ def convert_args_to_dict(call, with_lambda=False):
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
if type(arg) in [ast_Str, ast_Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg_value = ast_Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
Expand All @@ -192,7 +194,7 @@ def test_variable_equal(node1, node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'):
if k in ('lineno', 'col_offset', 'ctx', 'end_lineno', 'end_col_offset'):
continue
if not test_variable_equal(v, getattr(node2, k)):
return False
Expand Down Expand Up @@ -264,7 +266,7 @@ def visit(self, node):
annotation = self.stack[-1]

# this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str:
if type(node) is ast.Expr and type(node.value) is ast_Str:
# must not annotate an annotation string
assert annotation is None, 'Annotating an annotation'
return self._visit_string(node)
Expand All @@ -290,23 +292,23 @@ def _visit_string(self, node):
"Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))
args=[ast_Str(s='Get next parameter here...')], keywords=[]))

if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))
args=[ast_Str(s='Training update here...')], keywords=[]))

if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
args=[ast_Str(s='nni.report_intermediate_result: '), arg], keywords=[]))

if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
args=[ast_Str(s='nni.report_final_result: '), arg], keywords=[]))

if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
Expand Down
15 changes: 15 additions & 0 deletions tools/nni_annotation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import ast
from sys import version_info


if version_info >= (3, 8):
ast_Num = ast_Str = ast_Bytes = ast_NameConstant = ast_Ellipsis = ast.Constant
else:
ast_Num = ast.Num
ast_Str = ast.Str
ast_Bytes = ast.Bytes
ast_NameConstant = ast.NameConstant
ast_Ellipsis = ast.Ellipsis

0 comments on commit 1dde669

Please sign in to comment.