Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
extract_args_from_signature)
from ._validators import (
validate_options,
validate_datetime,
validate_duration,
validate_file_destination,
validate_client_parameters,
validate_required_parameter)
Expand Down Expand Up @@ -62,17 +64,86 @@ class BatchArgumentTree(object):
"""Dependency tree parser for arguments of complex objects"""

_class_name = re.compile(r"<(.*?)>") # Strip model name from class docstring
_underscore_case = re.compile('(?!^)([A-Z]+)') # Convert from CamelCase to underscore_case

def __init__(self):
def __init__(self, validator):
self._arg_tree = {}
self._request_param = {}
self._custom_validator = validator
self.done = False

def __iter__(self):
"""Iterate over arguments"""
for arg, details in self._arg_tree.items():
yield arg, details

def queue_argument(self, name=None, path=None, root=None, options=None, dependencies=None):
def _get_children(self, group):
"""Find all the arguments under to a specific complex argument group.
:param str group: The namespace of the complex parameter.
:returns: The names of the related arugments.
"""
return [arg for arg, value in self._arg_tree.items() if value['path'].startswith(group)]

def _get_siblings(self, group):
"""Find all the arguments at the same level of a specific complex argument group.
:param str group: The namespace of the complex parameter.
:returns: The names of the related arugments.
"""
return [arg for arg, value in self._arg_tree.items() if value['path'] == group]

def _parse(self, namespace, path, required):
"""Parse dependency tree to list all required command line arguments based on
current inputs.
:param namespace: The namespace container all current argument inputs
:param path: The current complex object path
:param required: Whether the args in this object path are required
"""
required_args = []
children = self._get_children(path)
if not required:
if not any([getattr(namespace, n) for n in children]):
return []
siblings = self._get_siblings(path)
if not siblings:
raise ValueError("Invalid argmuent dependency tree") # TODO
dependencies = self._arg_tree[siblings[0]]['dependencies']
for child_arg in children:
if child_arg in required_args:
continue
details = self._arg_tree[child_arg]
if '.'.join([details['path'], details['root']]) in dependencies:
required_args.append(child_arg)
elif details['path'] in dependencies:
required_args.extend(self._parse(namespace, details['path'], True))
elif details['path'] == path:
continue
else:
required_args.extend(self._parse(namespace, details['path'], False))
return set(required_args)

def set_request_param(self, name, model):
"""Set the name of the parameter that will be serialized for the
request body.
:param str name: The name of the parameter
:param str model: The name of the class
"""
self._request_param['name'] = name
self._request_param['model'] = model.split('.')[-1]

def deserialize_json(self, client, kwargs, json_obj):
"""Deserialize the contents of a JSON file into the request body
parameter.
:param client: An Azure Batch SDK client
:param dict kwargs: The request kwargs
:param dict json_obj: The loaded JSON content
"""
kwargs[self._request_param['name']] = client._deserialize( #pylint:disable=W0212
self._request_param['model'], json_obj)
if kwargs[self._request_param['name']] is None:
message = "Failed to deserialized JSON file into object {}"
raise ValueError(message.format(self._request_param['model']))

def queue_argument(self, name=None, path=None, root=None, options=None, type=None, dependencies=None):
"""Add pending command line argument
:param str name: The name of the command line argument.
:param str path: The complex object path to the parameter.
Expand All @@ -85,6 +156,7 @@ def queue_argument(self, name=None, path=None, root=None, options=None, dependen
'path': path,
'root': root,
'options': options,
'type': type,
'dependencies': [".".join([path, arg]) for arg in dependencies]
}

Expand All @@ -100,6 +172,10 @@ def compile_args(self):
objects.
"""
for name, details in self._arg_tree.items():
if details['type'] == 'bool':
details['options']['action'] = 'store_true'
elif details['type'].startswith('['):
details['options']['nargs'] = '+'
yield (name, CliCommandArgument(dest=name, **details['options']))

def existing(self, name):
Expand All @@ -117,6 +193,12 @@ def class_name(self, type_str):
"""
return self._class_name.findall(type_str)[0]

def operations_name(self, class_str):
"""Convert the operations class name into Python case.
:param str class_str: The class name.
"""
return self._underscore_case.sub(r'_\1', class_str[:-10]).lower()

def full_name(self, arg_details):
"""Create a full path to the complex object parameter of a
given argument.
Expand All @@ -130,8 +212,11 @@ def group_title(self, path):
:param str path: The complex object path of the argument.
:returns: str
"""
group_name = path.split('.')[-1]
return " ".join([n.title() for n in group_name.split('_')])
group_path = path.split('.')
group_title = ' : '.join(group_path)
for group in group_path:
group_title = group_title.replace(group, " ".join([n.title() for n in group.split('_')]), 1)
return group_title

def arg_name(self, name):
"""Convert snake case argument name to a command line name.
Expand Down Expand Up @@ -174,6 +259,37 @@ def find_return_type(self, model):
if return_type:
return re.sub(r"\n\s*", "", return_type.group(1))

def parse_mutually_exclusive(self, namespace, required, params):
"""Validate whether two or more mutually exclusive arguments or
argument groups have been set correctly.
:param bool required: Whether one of the parameters must be set.
:param list params: List of namespace paths for mutually exclusive
request properties.
"""
argtree = self._arg_tree.items()
ex_arg_names = [a for a, v in argtree if self.full_name(v) in params]
ex_args = [getattr(namespace, a) for a, v in argtree if a in ex_arg_names]
ex_args = list(filter(None, ex_args))
ex_group_names = []
ex_groups = []
for arg_group in params:
child_args = self._get_children(arg_group)
if child_args:
ex_group_names.append(self.group_title(arg_group))
if any([getattr(namespace, arg) for arg in child_args]):
ex_groups.append(ex_group_names[-1])

message = None
if not ex_groups and not ex_args and required:
message = "One of the following arguments, or argument groups are required: \n"
elif len(ex_groups) > 1 or len(ex_args) > 1 or (ex_groups and ex_args):
message = ("The follow arguments or argument groups are mutually "
"exclusive and cannot be combined: \n")
if message:
missing = [self.arg_name(n) for n in ex_arg_names] + ex_group_names
message += '\n'.join(missing)
raise ValueError(message)

def parse(self, namespace):
"""Parse all arguments in the namespace to validate whether all required
arguments have been set.
Expand All @@ -182,49 +298,38 @@ def parse(self, namespace):
"""
try:
if namespace.json_file:
if not os.path.isfile(namespace.json_file):
try:
with open(namespace.json_file) as file_handle:
namespace.json_file = json.load(file_handle)
except EnvironmentError:
raise ValueError("Cannot access JSON request file: " + namespace.json_file)
except ValueError as err:
raise ValueError("Invalid JSON file: {}".format(err))
for name in self._arg_tree:
if getattr(namespace, name):
raise ValueError("--json-file cannot be combined with " + self.arg_name(name))
return
except AttributeError:
pass

for name, details in self._arg_tree.items():
if not getattr(namespace, name):
continue
dependencies = details['dependencies']
siblings = [arg for arg, value in self._arg_tree.items() \
if self.full_name(value) in dependencies]
for arg in self.find_complex_dependencies(dependencies):
siblings.append(arg)
for arg in siblings:
if not getattr(namespace, arg):
arg_name = self.arg_name(name)
arg_group = self.group_title(self._arg_tree[name]['path'])
required_arg = self.arg_name(arg)
message = "When using {} of {}, the following is also required: {}".format(
arg_name, arg_group, required_arg)
raise ValueError(message)
if self._custom_validator:
try:
self._custom_validator(namespace, self)
except TypeError:
raise ValueError("Custom validator must be a function that takes two arguments.")

required_args = self._parse(namespace, self._request_param['name'], True)
missing_args = [n for n in required_args if not getattr(namespace, n)]
if missing_args:
message = "The following additional arguments are required:\n"
message += "\n".join([self.arg_name(m) for m in missing_args])
raise ValueError(message)
self.done = True

def find_complex_dependencies(self, dependencies):
"""Recursive generator to find required argments from dependent
complect objects.
:param list dependencies: A list of the dependencies of the current object.
:returns: The names of the required arguments.
"""
cmplx_args = [arg for arg, value in self._arg_tree.items() if value['path'] in dependencies]
for arg in cmplx_args:
for sub_arg in self.find_complex_dependencies(self._arg_tree[arg]['dependencies']):
yield sub_arg
yield arg

class AzureDataPlaneCommand(object):

def __init__(self, module_name, name, operation, factory, transform_result, #pylint:disable=too-many-arguments
table_transformer, flatten, ignore):
table_transformer, flatten, ignore, validator):

if not isinstance(operation, string_types):
raise ValueError("Operation must be a string. Got '{}'".format(operation))
Expand All @@ -233,7 +338,7 @@ def __init__(self, module_name, name, operation, factory, transform_result, #pyl
self.ignore = list(IGNORE_PARAMETERS) # Parameters to ignore
if ignore:
self.ignore.extend(ignore)
self.parser = BatchArgumentTree()
self.parser = BatchArgumentTree(validator)

# The name of the request options parameter
self._options_param = self._format_options_name(operation)
Expand All @@ -243,8 +348,6 @@ def __init__(self, module_name, name, operation, factory, transform_result, #pyl
self._options_attrs = []
# The loaded options model to populate for the request
self._options_model = None
# The parameter and model of the request body
self._request_body = None

def _execute_command(kwargs):
from msrest.paging import Paged
Expand All @@ -261,10 +364,7 @@ def _execute_command(kwargs):
if json_file:
with open(json_file) as file_handle:
json_obj = json.load(file_handle)
kwargs[self._request_body['name']] = client._deserialize( #pylint:disable=W0212
self._request_body['model'], json_obj)
if kwargs[self._request_body['name']] is None:
raise ValueError("JSON file '{}' is not in correct format.".format(json_file))
self.parser.deserialize_json(client, kwargs, json_obj)
for arg, _ in self.parser:
del kwargs[arg]
else:
Expand Down Expand Up @@ -389,7 +489,8 @@ def _format_options_name(self, operation):
"""
operation = operation.split('#')[-1]
op_class, op_function = operation.split('.')
return "{}_{}_options".format(op_class[:-10].lower(), op_function)
op_class = self.parser.operations_name(op_class)
return "{}_{}_options".format(op_class, op_function)

def _should_flatten(self, param):
"""Check whether the current parameter object should be flattened.
Expand Down Expand Up @@ -421,6 +522,8 @@ def _build_prefix(self, arg, param, path):
:param str path: Request parameter namespace.
"""
prefix_list = path.split('.')
if len(prefix_list) == 1:
return arg
resolved_name = prefix_list[0] + "_" + param
if arg == resolved_name:
return arg
Expand Down Expand Up @@ -452,13 +555,14 @@ def _process_options(self):
help=docstring,
arg_group=self._options_group))

def _resolve_conflict(self, arg, param, path, options, dependencies, conflicting):
def _resolve_conflict(self, arg, param, path, options, typestr, dependencies, conflicting):
"""Resolve conflicting command line arguments.
:param str arg: Name of the command line argument.
:param str param: Original request parameter name.
:param str path: Request parameter namespace.
:param dict options: The kwargs to be used to instantiate CliCommandArgument.
:param list dependencies: A list of complete paths to other parameters that.
:param list dependencies: A list of complete paths to other parameters that are required
if this parameter is set.
:param list conflicting: A list of the argument names that have already conflicted.
"""
if self.parser.existing(arg):
Expand All @@ -469,13 +573,16 @@ def _resolve_conflict(self, arg, param, path, options, dependencies, conflicting
self.parser.queue_argument(**existing)
new = self._build_prefix(arg, param, path)
options['options_list'] = [self.parser.arg_name(new)]
self._resolve_conflict(new, param, path, options, dependencies, conflicting)
self._resolve_conflict(new, param, path, options, typestr, dependencies, conflicting)
elif arg in conflicting:
new = self._build_prefix(arg, param, path)
options['options_list'] = [self.parser.arg_name(new)]
self._resolve_conflict(new, param, path, options, dependencies, conflicting)
if new in conflicting:
self.parser.queue_argument(arg, path, param, options, typestr, dependencies)
else:
options['options_list'] = [self.parser.arg_name(new)]
self._resolve_conflict(new, param, path, options, typestr, dependencies, conflicting)
else:
self.parser.queue_argument(arg, path, param, options, dependencies)
self.parser.queue_argument(arg, path, param, options, typestr, dependencies)

def _flatten_object(self, path, param_model, conflict_names=[]): #pylint: disable=W0102
"""Flatten a complex parameter object into command line arguments.
Expand All @@ -496,18 +603,14 @@ def _flatten_object(self, path, param_model, conflict_names=[]): #pylint: disabl
options['validator'] = lambda ns: validate_required_parameter(ns, self.parser)
options['default'] = None # Extract details from signature

if details['type'] == 'bool':
options['action'] = 'store_true'
self._resolve_conflict(param_attr, param_attr, path, options,
required_attrs, conflict_names)
elif details['type'] in BASIC_TYPES:
if details['type'] in BASIC_TYPES:
self._resolve_conflict(param_attr, param_attr, path, options,
required_attrs, conflict_names)
details['type'], required_attrs, conflict_names)
else:
attr_model = self._load_model(details['type'])
if not hasattr(attr_model, '_attribute_map'): # Must be an enum
self._resolve_conflict(param_attr, param_attr, path, options,
required_attrs, conflict_names)
details['type'], required_attrs, conflict_names)
else:
self._flatten_object('.'.join([path, param_attr]), attr_model)

Expand All @@ -524,7 +627,7 @@ def _load_transformed_arguments(self, operation, handler):
yield option_arg
elif arg_type.startswith(":class:"): # TODO: could add handling for enums
param_type = self.parser.class_name(arg_type)
self._request_body = {'name': arg[0], 'model': param_type.split('.')[-1]}
self.parser.set_request_param(arg[0], param_type)
param_model = self._load_model(param_type)
self._flatten_object(arg[0], param_model)
for flattened_arg in self.parser.compile_args():
Expand Down Expand Up @@ -555,12 +658,11 @@ def _load_transformed_arguments(self, operation, handler):


def cli_data_plane_command(name, operation, client_factory, transform=None,
table_transformer=None, flatten=FLATTEN, ignore=None):
table_transformer=None, flatten=FLATTEN, ignore=None, validator=None):
""" Registers an Azure CLI Batch Data Plane command. These commands must respond to a
challenge from the service when they make requests. """

command = AzureDataPlaneCommand(__name__, name, operation, client_factory,
transform, table_transformer, flatten, ignore)
transform, table_transformer, flatten, ignore, validator)

# add parameters required to create a batch client
group_name = 'Batch Account'
Expand Down
Loading