diff --git a/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_command_type.py b/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_command_type.py index 1db27993901..5df34c2f459 100644 --- a/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_command_type.py +++ b/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_command_type.py @@ -26,6 +26,8 @@ extract_args_from_signature) from ._validators import ( validate_options, + validate_datetime, + validate_duration, validate_file_destination, validate_client_parameters, validate_required_parameter) @@ -62,9 +64,12 @@ class BatchArgumentTree(object): """Dependency tree parser for arguments of complex objects""" _class_name = re.compile(r"<(.*?)>") # Strip model name from class docstring + _underscore_case = re.compile('(?!^)([A-Z]+)') # Convert from CamelCase to underscore_case - def __init__(self): + def __init__(self, validator): self._arg_tree = {} + self._request_param = {} + self._custom_validator = validator self.done = False def __iter__(self): @@ -72,7 +77,73 @@ def __iter__(self): for arg, details in self._arg_tree.items(): yield arg, details - def queue_argument(self, name=None, path=None, root=None, options=None, dependencies=None): + def _get_children(self, group): + """Find all the arguments under to a specific complex argument group. + :param str group: The namespace of the complex parameter. + :returns: The names of the related arugments. + """ + return [arg for arg, value in self._arg_tree.items() if value['path'].startswith(group)] + + def _get_siblings(self, group): + """Find all the arguments at the same level of a specific complex argument group. + :param str group: The namespace of the complex parameter. + :returns: The names of the related arugments. + """ + return [arg for arg, value in self._arg_tree.items() if value['path'] == group] + + def _parse(self, namespace, path, required): + """Parse dependency tree to list all required command line arguments based on + current inputs. + :param namespace: The namespace container all current argument inputs + :param path: The current complex object path + :param required: Whether the args in this object path are required + """ + required_args = [] + children = self._get_children(path) + if not required: + if not any([getattr(namespace, n) for n in children]): + return [] + siblings = self._get_siblings(path) + if not siblings: + raise ValueError("Invalid argmuent dependency tree") # TODO + dependencies = self._arg_tree[siblings[0]]['dependencies'] + for child_arg in children: + if child_arg in required_args: + continue + details = self._arg_tree[child_arg] + if '.'.join([details['path'], details['root']]) in dependencies: + required_args.append(child_arg) + elif details['path'] in dependencies: + required_args.extend(self._parse(namespace, details['path'], True)) + elif details['path'] == path: + continue + else: + required_args.extend(self._parse(namespace, details['path'], False)) + return set(required_args) + + def set_request_param(self, name, model): + """Set the name of the parameter that will be serialized for the + request body. + :param str name: The name of the parameter + :param str model: The name of the class + """ + self._request_param['name'] = name + self._request_param['model'] = model.split('.')[-1] + + def deserialize_json(self, client, kwargs, json_obj): + """Deserialize the contents of a JSON file into the request body + parameter. + :param client: An Azure Batch SDK client + :param dict kwargs: The request kwargs + :param dict json_obj: The loaded JSON content + """ + kwargs[self._request_param['name']] = client._deserialize( #pylint:disable=W0212 + self._request_param['model'], json_obj) + if kwargs[self._request_param['name']] is None: + message = "Failed to deserialized JSON file into object {}" + raise ValueError(message.format(self._request_param['model'])) + + def queue_argument(self, name=None, path=None, root=None, options=None, type=None, dependencies=None): """Add pending command line argument :param str name: The name of the command line argument. :param str path: The complex object path to the parameter. @@ -85,6 +156,7 @@ def queue_argument(self, name=None, path=None, root=None, options=None, dependen 'path': path, 'root': root, 'options': options, + 'type': type, 'dependencies': [".".join([path, arg]) for arg in dependencies] } @@ -100,6 +172,10 @@ def compile_args(self): objects. """ for name, details in self._arg_tree.items(): + if details['type'] == 'bool': + details['options']['action'] = 'store_true' + elif details['type'].startswith('['): + details['options']['nargs'] = '+' yield (name, CliCommandArgument(dest=name, **details['options'])) def existing(self, name): @@ -117,6 +193,12 @@ def class_name(self, type_str): """ return self._class_name.findall(type_str)[0] + def operations_name(self, class_str): + """Convert the operations class name into Python case. + :param str class_str: The class name. + """ + return self._underscore_case.sub(r'_\1', class_str[:-10]).lower() + def full_name(self, arg_details): """Create a full path to the complex object parameter of a given argument. @@ -130,8 +212,11 @@ def group_title(self, path): :param str path: The complex object path of the argument. :returns: str """ - group_name = path.split('.')[-1] - return " ".join([n.title() for n in group_name.split('_')]) + group_path = path.split('.') + group_title = ' : '.join(group_path) + for group in group_path: + group_title = group_title.replace(group, " ".join([n.title() for n in group.split('_')]), 1) + return group_title def arg_name(self, name): """Convert snake case argument name to a command line name. @@ -174,6 +259,37 @@ def find_return_type(self, model): if return_type: return re.sub(r"\n\s*", "", return_type.group(1)) + def parse_mutually_exclusive(self, namespace, required, params): + """Validate whether two or more mutually exclusive arguments or + argument groups have been set correctly. + :param bool required: Whether one of the parameters must be set. + :param list params: List of namespace paths for mutually exclusive + request properties. + """ + argtree = self._arg_tree.items() + ex_arg_names = [a for a, v in argtree if self.full_name(v) in params] + ex_args = [getattr(namespace, a) for a, v in argtree if a in ex_arg_names] + ex_args = list(filter(None, ex_args)) + ex_group_names = [] + ex_groups = [] + for arg_group in params: + child_args = self._get_children(arg_group) + if child_args: + ex_group_names.append(self.group_title(arg_group)) + if any([getattr(namespace, arg) for arg in child_args]): + ex_groups.append(ex_group_names[-1]) + + message = None + if not ex_groups and not ex_args and required: + message = "One of the following arguments, or argument groups are required: \n" + elif len(ex_groups) > 1 or len(ex_args) > 1 or (ex_groups and ex_args): + message = ("The follow arguments or argument groups are mutually " + "exclusive and cannot be combined: \n") + if message: + missing = [self.arg_name(n) for n in ex_arg_names] + ex_group_names + message += '\n'.join(missing) + raise ValueError(message) + def parse(self, namespace): """Parse all arguments in the namespace to validate whether all required arguments have been set. @@ -182,49 +298,38 @@ def parse(self, namespace): """ try: if namespace.json_file: - if not os.path.isfile(namespace.json_file): + try: + with open(namespace.json_file) as file_handle: + namespace.json_file = json.load(file_handle) + except EnvironmentError: raise ValueError("Cannot access JSON request file: " + namespace.json_file) + except ValueError as err: + raise ValueError("Invalid JSON file: {}".format(err)) for name in self._arg_tree: if getattr(namespace, name): raise ValueError("--json-file cannot be combined with " + self.arg_name(name)) return except AttributeError: pass - - for name, details in self._arg_tree.items(): - if not getattr(namespace, name): - continue - dependencies = details['dependencies'] - siblings = [arg for arg, value in self._arg_tree.items() \ - if self.full_name(value) in dependencies] - for arg in self.find_complex_dependencies(dependencies): - siblings.append(arg) - for arg in siblings: - if not getattr(namespace, arg): - arg_name = self.arg_name(name) - arg_group = self.group_title(self._arg_tree[name]['path']) - required_arg = self.arg_name(arg) - message = "When using {} of {}, the following is also required: {}".format( - arg_name, arg_group, required_arg) - raise ValueError(message) + if self._custom_validator: + try: + self._custom_validator(namespace, self) + except TypeError: + raise ValueError("Custom validator must be a function that takes two arguments.") + + required_args = self._parse(namespace, self._request_param['name'], True) + missing_args = [n for n in required_args if not getattr(namespace, n)] + if missing_args: + message = "The following additional arguments are required:\n" + message += "\n".join([self.arg_name(m) for m in missing_args]) + raise ValueError(message) self.done = True - def find_complex_dependencies(self, dependencies): - """Recursive generator to find required argments from dependent - complect objects. - :param list dependencies: A list of the dependencies of the current object. - :returns: The names of the required arguments. - """ - cmplx_args = [arg for arg, value in self._arg_tree.items() if value['path'] in dependencies] - for arg in cmplx_args: - for sub_arg in self.find_complex_dependencies(self._arg_tree[arg]['dependencies']): - yield sub_arg - yield arg class AzureDataPlaneCommand(object): def __init__(self, module_name, name, operation, factory, transform_result, #pylint:disable=too-many-arguments - table_transformer, flatten, ignore): + table_transformer, flatten, ignore, validator): if not isinstance(operation, string_types): raise ValueError("Operation must be a string. Got '{}'".format(operation)) @@ -233,7 +338,7 @@ def __init__(self, module_name, name, operation, factory, transform_result, #pyl self.ignore = list(IGNORE_PARAMETERS) # Parameters to ignore if ignore: self.ignore.extend(ignore) - self.parser = BatchArgumentTree() + self.parser = BatchArgumentTree(validator) # The name of the request options parameter self._options_param = self._format_options_name(operation) @@ -243,8 +348,6 @@ def __init__(self, module_name, name, operation, factory, transform_result, #pyl self._options_attrs = [] # The loaded options model to populate for the request self._options_model = None - # The parameter and model of the request body - self._request_body = None def _execute_command(kwargs): from msrest.paging import Paged @@ -261,10 +364,7 @@ def _execute_command(kwargs): if json_file: with open(json_file) as file_handle: json_obj = json.load(file_handle) - kwargs[self._request_body['name']] = client._deserialize( #pylint:disable=W0212 - self._request_body['model'], json_obj) - if kwargs[self._request_body['name']] is None: - raise ValueError("JSON file '{}' is not in correct format.".format(json_file)) + self.parser.deserialize_json(client, kwargs, json_obj) for arg, _ in self.parser: del kwargs[arg] else: @@ -389,7 +489,8 @@ def _format_options_name(self, operation): """ operation = operation.split('#')[-1] op_class, op_function = operation.split('.') - return "{}_{}_options".format(op_class[:-10].lower(), op_function) + op_class = self.parser.operations_name(op_class) + return "{}_{}_options".format(op_class, op_function) def _should_flatten(self, param): """Check whether the current parameter object should be flattened. @@ -421,6 +522,8 @@ def _build_prefix(self, arg, param, path): :param str path: Request parameter namespace. """ prefix_list = path.split('.') + if len(prefix_list) == 1: + return arg resolved_name = prefix_list[0] + "_" + param if arg == resolved_name: return arg @@ -452,13 +555,14 @@ def _process_options(self): help=docstring, arg_group=self._options_group)) - def _resolve_conflict(self, arg, param, path, options, dependencies, conflicting): + def _resolve_conflict(self, arg, param, path, options, typestr, dependencies, conflicting): """Resolve conflicting command line arguments. :param str arg: Name of the command line argument. :param str param: Original request parameter name. :param str path: Request parameter namespace. :param dict options: The kwargs to be used to instantiate CliCommandArgument. - :param list dependencies: A list of complete paths to other parameters that. + :param list dependencies: A list of complete paths to other parameters that are required + if this parameter is set. :param list conflicting: A list of the argument names that have already conflicted. """ if self.parser.existing(arg): @@ -469,13 +573,16 @@ def _resolve_conflict(self, arg, param, path, options, dependencies, conflicting self.parser.queue_argument(**existing) new = self._build_prefix(arg, param, path) options['options_list'] = [self.parser.arg_name(new)] - self._resolve_conflict(new, param, path, options, dependencies, conflicting) + self._resolve_conflict(new, param, path, options, typestr, dependencies, conflicting) elif arg in conflicting: new = self._build_prefix(arg, param, path) - options['options_list'] = [self.parser.arg_name(new)] - self._resolve_conflict(new, param, path, options, dependencies, conflicting) + if new in conflicting: + self.parser.queue_argument(arg, path, param, options, typestr, dependencies) + else: + options['options_list'] = [self.parser.arg_name(new)] + self._resolve_conflict(new, param, path, options, typestr, dependencies, conflicting) else: - self.parser.queue_argument(arg, path, param, options, dependencies) + self.parser.queue_argument(arg, path, param, options, typestr, dependencies) def _flatten_object(self, path, param_model, conflict_names=[]): #pylint: disable=W0102 """Flatten a complex parameter object into command line arguments. @@ -496,18 +603,14 @@ def _flatten_object(self, path, param_model, conflict_names=[]): #pylint: disabl options['validator'] = lambda ns: validate_required_parameter(ns, self.parser) options['default'] = None # Extract details from signature - if details['type'] == 'bool': - options['action'] = 'store_true' - self._resolve_conflict(param_attr, param_attr, path, options, - required_attrs, conflict_names) - elif details['type'] in BASIC_TYPES: + if details['type'] in BASIC_TYPES: self._resolve_conflict(param_attr, param_attr, path, options, - required_attrs, conflict_names) + details['type'], required_attrs, conflict_names) else: attr_model = self._load_model(details['type']) if not hasattr(attr_model, '_attribute_map'): # Must be an enum self._resolve_conflict(param_attr, param_attr, path, options, - required_attrs, conflict_names) + details['type'], required_attrs, conflict_names) else: self._flatten_object('.'.join([path, param_attr]), attr_model) @@ -524,7 +627,7 @@ def _load_transformed_arguments(self, operation, handler): yield option_arg elif arg_type.startswith(":class:"): # TODO: could add handling for enums param_type = self.parser.class_name(arg_type) - self._request_body = {'name': arg[0], 'model': param_type.split('.')[-1]} + self.parser.set_request_param(arg[0], param_type) param_model = self._load_model(param_type) self._flatten_object(arg[0], param_model) for flattened_arg in self.parser.compile_args(): @@ -555,12 +658,11 @@ def _load_transformed_arguments(self, operation, handler): def cli_data_plane_command(name, operation, client_factory, transform=None, - table_transformer=None, flatten=FLATTEN, ignore=None): + table_transformer=None, flatten=FLATTEN, ignore=None, validator=None): """ Registers an Azure CLI Batch Data Plane command. These commands must respond to a challenge from the service when they make requests. """ - command = AzureDataPlaneCommand(__name__, name, operation, client_factory, - transform, table_transformer, flatten, ignore) + transform, table_transformer, flatten, ignore, validator) # add parameters required to create a batch client group_name = 'Batch Account' diff --git a/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_validators.py b/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_validators.py index c3cddd9635d..d67802df4e4 100644 --- a/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_validators.py +++ b/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/_validators.py @@ -11,6 +11,9 @@ from urlparse import urlsplit # pylint: disable=import-error from datetime import datetime +from msrest.serialization import Deserializer +from msrest.exceptions import DeserializationError + from azure.mgmt.batch import BatchManagementClient from azure.mgmt.storage import StorageManagementClient @@ -25,13 +28,38 @@ def datetime_type(string): date_format = '%Y-%m-%dT%H:%MZ' return datetime.strptime(string, date_format) + +def validate_datetime(namespace, name, parser): + """Validate the correct format of a datetime string and deserialize.""" + date_str = getattr(namespace, name) + if date_str and isinstance(date_str, str): + try: + date_obj = Deserializer.deserialize_iso(date_str) + except DeserializationError: + message = "Argument {} is not a valid ISO-8601 datetime format" + raise ValueError(message.format(name)) + else: + setattr(namespace, name, date_obj) + validate_required_parameter(namespace, parser) + + +def validate_duration(name, value): + """Validate the correct format of a timespan string and deserilize.""" + try: + value = Deserializer.deserialize_duration(value) + except DeserializationError: + message = "Argument {} is not in a valid ISO-8601 duration format" + raise ValueError(message.format(name)) + else: + return value + + def validate_metadata(namespace): if namespace.metadata: namespace.metadata = dict(x.split('=', 1) for x in namespace.metadata) # COMMAND NAMESPACE VALIDATORS - def validate_required_parameter(ns, parser): """Validates required parameters in Batch complex objects""" if not parser.done: @@ -71,7 +99,7 @@ def validate_options(namespace): start = namespace.start_range end = namespace.end_range except AttributeError: - return + pass else: namespace.ocp_range = None del namespace.start_range @@ -80,6 +108,18 @@ def validate_options(namespace): start = start if start else 0 end = end if end else "" namespace.ocp_range = "bytes={}-{}".format(start, end) + for date_arg in ['if_modified_since', 'if_unmodified_since']: # TODO: Should we also try RFC-1123? + try: + date_str = getattr(namespace, date_arg) + if date_str and isinstance(date_str, str): + date_obj = Deserializer.deserialize_iso(date_str) + except AttributeError: + pass + except DeserializationError: + message = "Argument {} is not a valid ISO-8601 datetime format" + raise ValueError(message.format(date_arg)) + else: + setattr(namespace, name, date_obj) def validate_file_destination(namespace): @@ -135,3 +175,12 @@ def validate_client_parameters(namespace): raise ValueError("Need specifiy batch account in command line or enviroment variable.") if not namespace.account_endpoint: raise ValueError("Need specifiy batch endpoint in command line or enviroment variable.") + +# CUSTOM REQUEST VALIDATORS + +def validate_pool_settings(namespace, parser): + """Custom parsing to enfore that either PaaS or IaaS instances are configured + in the add pool request body. + """ + groups = ['pool.cloud_service_configuration', 'pool.virtual_machine_configuration'] + parser.parse_mutually_exclusive(namespace, True, groups) \ No newline at end of file diff --git a/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/commands.py b/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/commands.py index f07d446c59c..426d036cf1c 100644 --- a/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/commands.py +++ b/src/command_modules/azure-cli-batch/azure/cli/command_modules/batch/commands.py @@ -7,6 +7,8 @@ from ._client_factory import batch_client_factory, batch_data_service_factory from ._command_type import cli_data_plane_command, cli_custom_data_plane_command +from ._validators import ( + validate_pool_settings) data_path = 'azure.batch.operations.{}_operations#{}' custom_path = 'azure.cli.command_modules.batch.custom#{}' @@ -62,8 +64,8 @@ factory = lambda args: batch_data_service_factory(**args).pool cli_data_plane_command('batch pool usage-metrics list', data_path.format('pool', 'PoolOperations.list_pool_usage_metrics'), factory) cli_data_plane_command('batch pool all-stats show', data_path.format('pool', 'PoolOperations.get_all_pools_lifetime_statistics'), factory) -#cli_data_plane_command('batch pool create', data_path.format('pool', 'PoolOperations.add'), factory) -cli_custom_data_plane_command('batch pool create', custom_path.format('create_pool'), factory) +cli_data_plane_command('batch pool create', data_path.format('pool', 'PoolOperations.add'), factory, validator=validate_pool_settings) +#cli_custom_data_plane_command('batch pool create', custom_path.format('create_pool'), factory) cli_data_plane_command('batch pool list', data_path.format('pool', 'PoolOperations.list'), factory) cli_data_plane_command('batch pool delete', data_path.format('pool', 'PoolOperations.delete'), factory) cli_data_plane_command('batch pool show', data_path.format('pool', 'PoolOperations.get'), factory) @@ -81,13 +83,13 @@ factory = lambda args: batch_data_service_factory(**args).job cli_data_plane_command('batch job all-stats show', data_path.format('job', 'JobOperations.get_all_jobs_lifetime_statistics'), factory) -#cli_data_plane_command('batch job create', data_path.format('job', 'JobOperations.add'), factory)#, ignore=["job.job_release_task", "job.job_preparation_task", "job.job_manager_task"]) +cli_data_plane_command('batch job create', data_path.format('job', 'JobOperations.add'), factory)#, ignore=["job.job_release_task", "job.job_preparation_task", "job.job_manager_task"]) #cli_data_plane_command('batch job list', data_path.format('job', 'JobOperations.list'), factory) cli_data_plane_command('batch job delete', data_path.format('job', 'JobOperations.delete'), factory) cli_data_plane_command('batch job show', data_path.format('job', 'JobOperations.get'), factory) #cli_data_plane_command('batch job set', data_path.format('job', 'JobOperations.patch'), factory) #cli_data_plane_command('batch job reset', data_path.format('job', 'JobOperations.update'), factory) -cli_custom_data_plane_command('batch job create', custom_path.format('create_job'), factory) +#cli_custom_data_plane_command('batch job create', custom_path.format('create_job'), factory) cli_custom_data_plane_command('batch job list', custom_path.format('list_job'), factory) cli_custom_data_plane_command('batch job set', custom_path.format('patch_job'), factory) cli_custom_data_plane_command('batch job reset', custom_path.format('update_job'), factory) @@ -97,8 +99,8 @@ cli_data_plane_command('batch job prep-release-status list', data_path.format('job', 'JobOperations.list_preparation_and_release_task_status'), factory) factory = lambda args: batch_data_service_factory(**args).job_schedule -#cli_data_plane_command('batch job-schedule create', data_path.format('job_schedule', 'JobScheduleOperations.add'), factory) -cli_custom_data_plane_command('batch job-schedule create', custom_path.format('create_job_schedule'), factory) +cli_data_plane_command('batch job-schedule create', data_path.format('job_schedule', 'JobScheduleOperations.add'), factory) +#cli_custom_data_plane_command('batch job-schedule create', custom_path.format('create_job_schedule'), factory) cli_data_plane_command('batch job-schedule delete', data_path.format('job_schedule', 'JobScheduleOperations.delete'), factory) cli_data_plane_command('batch job-schedule show', data_path.format('job_schedule', 'JobScheduleOperations.get'), factory) #cli_data_plane_command('batch job-schedule set', data_path.format('job_schedule', 'JobScheduleOperations.patch'), factory)