diff --git a/.gitignore b/.gitignore index af1e99aeed..a00d15ddba 100644 --- a/.gitignore +++ b/.gitignore @@ -307,6 +307,11 @@ env.bak/ venv.bak/ venv-update-reproducible-requirements/ +env.*/ +venv.*/ +.env.*/ +.venv.*/ + # Spyder project settings .spyderproject .spyproject diff --git a/appveyor.yml b/appveyor.yml index 212e43fc32..ec23fb2502 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -189,9 +189,17 @@ for: - "pytest -n 4 tests/functional" # Runs only in Linux, logging Public ECR when running canary and cred is available + - sh: " + if [[ -n $BY_CANARY ]] && [[ -n $DOCKER_USER ]] && [[ -n $DOCKER_PASS ]]; + then echo Logging in Docker Hub; echo $DOCKER_PASS | docker login --username $DOCKER_USER --password-stdin registry-1.docker.io; + fi" - sh: " if [[ -n $BY_CANARY ]]; - then echo Logging in Public ECR; aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws; + then echo Logging in Public ECR; aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws; + fi" + - sh: " + if [[ -n $BY_CANARY ]] || [[ -n $PRIVATE ]]; + then echo Logging in Public ECR; aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws; fi" - sh: "pytest -vv tests/integration" diff --git a/installer/pyinstaller/hook-samcli.py b/installer/pyinstaller/hook-samcli.py index 1855250020..a8e0ace3e8 100644 --- a/installer/pyinstaller/hook-samcli.py +++ b/installer/pyinstaller/hook-samcli.py @@ -3,6 +3,25 @@ hiddenimports = SAM_CLI_HIDDEN_IMPORTS +hiddenimports = [ + "cookiecutter.extensions", + "jinja2_time", + "text_unidecode", + "samtranslator", + "samcli.commands.init", + "samcli.commands.validate.validate", + "samcli.commands.build", + "samcli.commands.local.local", + "samcli.commands.package", + "samcli.commands.deploy", + "samcli.commands.logs", + "samcli.commands.publish", + # default hidden import 'pkg_resources.py2_warn' is added + # since pyInstaller 4.0. + "pkg_resources.py2_warn", + "aws_lambda_builders.workflows", + "configparser", +] datas = ( hooks.collect_data_files("samcli") + hooks.collect_data_files("samtranslator") diff --git a/requirements/base.txt b/requirements/base.txt index 1fb4e5965f..38614f9bfa 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -15,3 +15,8 @@ serverlessrepo==0.1.10 aws_lambda_builders==1.7.0 tomlkit==0.7.2 watchdog==2.1.2 + +# Needed for supporting Protocol in Python 3.6 +typing_extensions==3.10.0.0 +# Needed for supporting dataclasses decorator in Python3.6 +dataclasses==0.8; python_version < '3.7' diff --git a/samcli/cli/command.py b/samcli/cli/command.py index 0741cda84f..c0465db511 100644 --- a/samcli/cli/command.py +++ b/samcli/cli/command.py @@ -22,6 +22,8 @@ "samcli.commands.delete", "samcli.commands.logs", "samcli.commands.publish", + "samcli.commands.traces", + "samcli.commands.sync", "samcli.commands.pipeline.pipeline", # We intentionally do not expose the `bootstrap` command for now. We might open it up later # "samcli.commands.bootstrap", diff --git a/samcli/commands/_utils/options.py b/samcli/commands/_utils/options.py index 70ca59657c..67f6364662 100644 --- a/samcli/commands/_utils/options.py +++ b/samcli/commands/_utils/options.py @@ -5,21 +5,69 @@ import os import logging from functools import partial +import types import click from click.types import FuncParamType from samcli.commands._utils.template import get_template_data, TemplateNotFoundException -from samcli.cli.types import CfnParameterOverridesType, CfnMetadataType, CfnTags, SigningProfilesOptionType +from samcli.cli.types import ( + CfnParameterOverridesType, + CfnMetadataType, + CfnTags, + SigningProfilesOptionType, + ImageRepositoryType, + ImageRepositoriesType, +) from samcli.commands._utils.custom_options.option_nargs import OptionNargs from samcli.commands._utils.template import get_template_artifacts_format +from samcli.lib.utils.packagetype import ZIP, IMAGE _TEMPLATE_OPTION_DEFAULT_VALUE = "template.[yaml|yml|json]" DEFAULT_STACK_NAME = "sam-app" +DEFAULT_BUILD_DIR = os.path.join(".aws-sam", "build") +DEFAULT_CACHE_DIR = os.path.join(".aws-sam", "cache") LOG = logging.getLogger(__name__) +def parameterized_option(option): + """Meta decorator for option decorators. + This adds the ability to specify optional parameters for option decorators. + + Usage: + @parameterized_option + def some_option(f, required=False) + ... + + @some_option + def command(...) + + or + + @some_option(required=True) + def command(...) + """ + + def parameter_wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], types.FunctionType): + # Case when option decorator does not have parameter + # @stack_name_option + # def command(...) + return option(args[0]) + + # Case when option decorator does have parameter + # @stack_name_option("a", "b") + # def command(...) + + def option_wrapper(f): + return option(f, *args, **kwargs) + + return option_wrapper + + return parameter_wrapper + + def get_or_default_template_file_name(ctx, param, provided_value, include_build): """ Default value for the template file name option is more complex than what Click can handle. @@ -296,6 +344,46 @@ def signing_profiles_option(f): return signing_profiles_click_option()(f) +def common_observability_click_options(): + return [ + click.option( + "--start-time", + "-s", + default="10m ago", + help="Fetch events starting at this time. Time can be relative values like '5mins ago', 'yesterday' or " + "formatted timestamp like '2018-01-01 10:10:10'. Defaults to '10mins ago'.", + ), + click.option( + "--end-time", + "-e", + default=None, + help="Fetch events up to this time. Time can be relative values like '5mins ago', 'tomorrow' or " + "formatted timestamp like '2018-01-01 10:10:10'", + ), + click.option( + "--tail", + "-t", + is_flag=True, + help="Tail events. This will ignore the end time argument and continue to fetch events as they " + "become available.", + ), + click.option( + "--unformatted", + "-u", + is_flag=True, + help="Print events without any text formatting in JSON. This option might be useful if you are reading " + "output into another tool.", + ), + ] + + +def common_observability_options(f): + for option in common_observability_click_options(): + option(f) + + return f + + def metadata_click_option(): return click.option( "--metadata", @@ -304,31 +392,33 @@ def metadata_click_option(): ) -def metadata_override_option(f): +def metadata_option(f): return metadata_click_option()(f) -def capabilities_click_option(): +def capabilities_click_option(default): return click.option( "--capabilities", cls=OptionNargs, required=False, + default=default, type=FuncParamType(func=_space_separated_list_func_type), - help="A list of capabilities that you must specify" - "before AWS Cloudformation can create certain stacks. Some stack tem-" - "plates might include resources that can affect permissions in your AWS" - "account, for example, by creating new AWS Identity and Access Manage-" - "ment (IAM) users. For those stacks, you must explicitly acknowledge" - "their capabilities by specifying this parameter. The only valid values" - "are CAPABILITY_IAM and CAPABILITY_NAMED_IAM. If you have IAM resources," - "you can specify either capability. If you have IAM resources with cus-" - "tom names, you must specify CAPABILITY_NAMED_IAM. If you don't specify" - "this parameter, this action returns an InsufficientCapabilities error.", + help="A list of capabilities that you must specify " + "before AWS Cloudformation can create certain stacks. Some stack templates " + "might include resources that can affect permissions in your AWS " + "account, for example, by creating new AWS Identity and Access Management " + "(IAM) users. For those stacks, you must explicitly acknowledge " + "their capabilities by specifying this parameter. The only valid values" + "are CAPABILITY_IAM and CAPABILITY_NAMED_IAM. If you have IAM resources, " + "you can specify either capability. If you have IAM resources with custom " + "names, you must specify CAPABILITY_NAMED_IAM. If you don't specify " + "this parameter, this action returns an InsufficientCapabilities error.", ) -def capabilities_override_option(f): - return capabilities_click_option()(f) +@parameterized_option +def capabilities_option(f, default=None): + return capabilities_click_option(default)(f) def tags_click_option(): @@ -343,7 +433,7 @@ def tags_click_option(): ) -def tags_override_option(f): +def tags_option(f): return tags_click_option()(f) @@ -359,10 +449,255 @@ def notification_arns_click_option(): ) -def notification_arns_override_option(f): +def notification_arns_option(f): return notification_arns_click_option()(f) +def stack_name_click_option(required): + return click.option( + "--stack-name", + required=required, + help="The name of the AWS CloudFormation stack you're deploying to. " + "If you specify an existing stack, the command updates the stack. " + "If you specify a new stack, the command creates it.", + ) + + +@parameterized_option +def stack_name_option(f, required=False): + return stack_name_click_option(required)(f) + + +def s3_bucket_click_option(): + return click.option( + "--s3-bucket", + required=False, + callback=partial(artifact_callback, artifact=ZIP), + help="The name of the S3 bucket where this command uploads the artifacts that are referenced in your template.", + ) + + +def s3_bucket_option(f): + return s3_bucket_click_option()(f) + + +def build_dir_click_option(): + return click.option( + "--build-dir", + "-b", + default=DEFAULT_BUILD_DIR, + type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory + help="Path to a folder where the built artifacts will be stored. " + "This directory will be first removed before starting a build.", + ) + + +def build_dir_option(f): + return build_dir_click_option()(f) + + +def cache_dir_click_option(): + return click.option( + "--cache-dir", + "-cd", + default=DEFAULT_CACHE_DIR, + type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory + help="The folder where the cache artifacts will be stored when --cached is specified. " + "The default cache directory is .aws-sam/cache", + ) + + +def cache_dir_option(f): + return cache_dir_click_option()(f) + + +def base_dir_click_option(): + return click.option( + "--base-dir", + "-s", + default=None, + type=click.Path(dir_okay=True, file_okay=False), # Must be a directory + help="Resolve relative paths to function's source code with respect to this folder. Use this if " + "SAM template and your source code are not in same enclosing folder. By default, relative paths " + "are resolved with respect to the SAM template's location", + ) + + +def base_dir_option(f): + return base_dir_click_option()(f) + + +def manifest_click_option(): + return click.option( + "--manifest", + "-m", + default=None, + type=click.Path(), + help="Path to a custom dependency manifest (e.g., package.json) to use instead of the default one", + ) + + +def manifest_option(f): + return manifest_click_option()(f) + + +def cached_click_option(): + return click.option( + "--cached", + "-c", + is_flag=True, + help="Enable cached builds. Use this flag to reuse build artifacts that have not changed from previous builds. " + "AWS SAM evaluates whether you have made any changes to files in your project directory. \n\n" + "Note: AWS SAM does not evaluate whether changes have been made to third party modules " + "that your project depends on, where you have not provided a specific version. " + "For example, if your Python function includes a requirements.txt file with the following entry " + "requests=1.x and the latest request module version changes from 1.1 to 1.2, " + "SAM will not pull the latest version until you run a non-cached build.", + ) + + +def cached_option(f): + return cached_click_option()(f) + + +def image_repository_click_option(): + return click.option( + "--image-repository", + callback=partial(artifact_callback, artifact=IMAGE), + type=ImageRepositoryType(), + required=False, + help="ECR repo uri where this command uploads the image artifacts that are referenced in your template.", + ) + + +def image_repository_option(f): + return image_repository_click_option()(f) + + +def image_repositories_click_option(): + return click.option( + "--image-repositories", + multiple=True, + callback=image_repositories_callback, + type=ImageRepositoriesType(), + required=False, + help="Specify mapping of Function Logical ID to ECR Repo uri, of the form Function_Logical_ID=ECR_Repo_Uri." + "This option can be specified multiple times.", + ) + + +def image_repositories_option(f): + return image_repositories_click_option()(f) + + +def s3_prefix_click_option(): + return click.option( + "--s3-prefix", + required=False, + help="A prefix name that the command adds to the artifacts " + "name when it uploads them to the S3 bucket. The prefix name is a " + "path name (folder name) for the S3 bucket.", + ) + + +def s3_prefix_option(f): + return s3_prefix_click_option()(f) + + +def kms_key_id_click_option(): + return click.option( + "--kms-key-id", + required=False, + help="The ID of an AWS KMS key that the command uses to encrypt artifacts that are at rest in the S3 bucket.", + ) + + +def kms_key_id_option(f): + return kms_key_id_click_option()(f) + + +def use_json_click_option(): + return click.option( + "--use-json", + required=False, + is_flag=True, + help="Indicates whether to use JSON as the format for " + "the output AWS CloudFormation template. YAML is used by default.", + ) + + +def use_json_option(f): + return use_json_click_option()(f) + + +def force_upload_click_option(): + return click.option( + "--force-upload", + required=False, + is_flag=True, + help="Indicates whether to override existing files " + "in the S3 bucket. Specify this flag to upload artifacts even if they " + "match existing artifacts in the S3 bucket.", + ) + + +def force_upload_option(f): + return force_upload_click_option()(f) + + +def resolve_s3_click_option(): + from samcli.commands.package.exceptions import PackageResolveS3AndS3SetError, PackageResolveS3AndS3NotSetError + + return click.option( + "--resolve-s3", + required=False, + is_flag=True, + callback=partial( + resolve_s3_callback, + artifact=ZIP, + exc_set=PackageResolveS3AndS3SetError, + exc_not_set=PackageResolveS3AndS3NotSetError, + ), + help="Automatically resolve s3 bucket for non-guided deployments. " + "Enabling this option will also create a managed default s3 bucket for you. " + "If you do not provide a --s3-bucket value, the managed bucket will be used. " + "Do not use --s3-guided parameter with this option.", + ) + + +def resolve_s3_option(f): + return resolve_s3_click_option()(f) + + +def role_arn_click_option(): + return click.option( + "--role-arn", + required=False, + help="The Amazon Resource Name (ARN) of an AWS Identity " + "and Access Management (IAM) role that AWS CloudFormation assumes when " + "executing the change set.", + ) + + +def role_arn_option(f): + return role_arn_click_option()(f) + + +def resolve_image_repos_click_option(): + return click.option( + "--resolve-image-repos", + required=False, + is_flag=True, + help="Automatically create and delete ECR repositories for image-based functions in non-guided deployments. " + "A companion stack containing ECR repos for each function will be deployed along with the template stack. " + "Automatically created image repositories will be deleted if the corresponding functions are removed.", + ) + + +def resolve_image_repos_option(f): + return resolve_image_repos_click_option()(f) + + def _space_separated_list_func_type(value): if isinstance(value, str): return value.split(" ") diff --git a/samcli/commands/_utils/template.py b/samcli/commands/_utils/template.py index 44cb80ba28..0eed074a03 100644 --- a/samcli/commands/_utils/template.py +++ b/samcli/commands/_utils/template.py @@ -10,16 +10,16 @@ import yaml from botocore.utils import set_value_from_jmespath -from samcli.commands._utils.resources import ( +from samcli.commands.exceptions import UserException +from samcli.lib.utils.packagetype import ZIP +from samcli.yamlhelper import yaml_parse, yaml_dump +from samcli.lib.utils.resources import ( METADATA_WITH_LOCAL_PATHS, RESOURCES_WITH_LOCAL_PATHS, AWS_SERVERLESS_FUNCTION, AWS_LAMBDA_FUNCTION, get_packageable_resource_paths, ) -from samcli.commands.exceptions import UserException -from samcli.lib.utils.packagetype import ZIP -from samcli.yamlhelper import yaml_parse, yaml_dump class TemplateNotFoundException(UserException): diff --git a/samcli/commands/build/build_context.py b/samcli/commands/build/build_context.py index fbaa02ca35..232994fc7b 100644 --- a/samcli/commands/build/build_context.py +++ b/samcli/commands/build/build_context.py @@ -8,6 +8,8 @@ import shutil from typing import Dict, Optional, List +import click + from samcli.commands.build.exceptions import InvalidBuildDirException, MissingBuildMethodException from samcli.lib.intrinsic_resolver.intrinsics_symbol_table import IntrinsicsSymbolTable from samcli.lib.providers.provider import ResourcesToBuildCollector, Stack, Function, LayerVersion @@ -16,6 +18,21 @@ from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider from samcli.local.docker.manager import ContainerManager from samcli.local.lambdafn.exceptions import ResourceNotFound +from samcli.lib.build.exceptions import BuildInsideContainerError + +from samcli.commands.exceptions import UserException + +from samcli.lib.build.app_builder import ( + ApplicationBuilder, + BuildError, + UnsupportedBuilderLibraryVersionError, + ContainerBuildNotSupported, +) +from samcli.commands._utils.options import DEFAULT_BUILD_DIR +from samcli.lib.build.workflow_config import UnsupportedRuntimeException +from samcli.local.lambdafn.exceptions import FunctionNotFound +from samcli.commands._utils.template import move_template +from samcli.lib.build.exceptions import InvalidBuildGraphException LOG = logging.getLogger(__name__) @@ -33,6 +50,7 @@ def __init__( build_dir: str, cache_dir: str, cached: bool, + parallel: bool, mode: Optional[str], manifest_path: Optional[str] = None, clean: bool = False, @@ -58,6 +76,7 @@ def __init__( self._build_dir = build_dir self._cache_dir = cache_dir + self._parallel = parallel self._manifest_path = manifest_path self._clean = clean self._use_container = use_container @@ -80,7 +99,12 @@ def __init__( self._stacks: List[Stack] = [] def __enter__(self) -> "BuildContext": + self.set_up() + return self + def set_up(self) -> None: + """Set up class members used for building + This should be called each time before run() if stacks are changed.""" self._stacks, remote_stack_full_paths = SamLocalStackProvider.get_stacks( self._template_file, parameter_overrides=self._parameter_overrides, @@ -116,11 +140,110 @@ def __enter__(self) -> "BuildContext": docker_network_id=self._docker_network, skip_pull_image=self._skip_pull_image ) - return self - def __exit__(self, *args): pass + def get_resources_to_build(self): + return self.resources_to_build + + def run(self): + """Runs the building process by creating an ApplicationBuilder.""" + try: + builder = ApplicationBuilder( + self.get_resources_to_build(), + self.build_dir, + self.base_dir, + self.cache_dir, + self.cached, + self.is_building_specific_resource, + manifest_path_override=self.manifest_path_override, + container_manager=self.container_manager, + mode=self.mode, + parallel=self._parallel, + container_env_var=self._container_env_var, + container_env_var_file=self._container_env_var_file, + build_images=self._build_images, + ) + except FunctionNotFound as ex: + raise UserException(str(ex), wrapped_from=ex.__class__.__name__) from ex + + try: + artifacts = builder.build() + + stack_output_template_path_by_stack_path = { + stack.stack_path: stack.get_output_template_path(self.build_dir) for stack in self.stacks + } + for stack in self.stacks: + modified_template = builder.update_template( + stack, + artifacts, + stack_output_template_path_by_stack_path, + ) + move_template(stack.location, stack.get_output_template_path(self.build_dir), modified_template) + + click.secho("\nBuild Succeeded", fg="green") + + # try to use relpath so the command is easier to understand, however, + # under Windows, when SAM and (build_dir or output_template_path) are + # on different drive, relpath() fails. + root_stack = SamLocalStackProvider.find_root_stack(self.stacks) + out_template_path = root_stack.get_output_template_path(self.build_dir) + try: + build_dir_in_success_message = os.path.relpath(self.build_dir) + output_template_path_in_success_message = os.path.relpath(out_template_path) + except ValueError: + LOG.debug("Failed to retrieve relpath - using the specified path as-is instead") + build_dir_in_success_message = self.build_dir + output_template_path_in_success_message = out_template_path + + msg = self.gen_success_msg( + build_dir_in_success_message, + output_template_path_in_success_message, + os.path.abspath(self.build_dir) == os.path.abspath(DEFAULT_BUILD_DIR), + ) + + click.secho(msg, fg="yellow") + + except ( + UnsupportedRuntimeException, + BuildError, + BuildInsideContainerError, + UnsupportedBuilderLibraryVersionError, + ContainerBuildNotSupported, + InvalidBuildGraphException, + ) as ex: + click.secho("\nBuild Failed", fg="red") + + # Some Exceptions have a deeper wrapped exception that needs to be surfaced + # from deeper than just one level down. + deep_wrap = getattr(ex, "wrapped_from", None) + wrapped_from = deep_wrap if deep_wrap else ex.__class__.__name__ + raise UserException(str(ex), wrapped_from=wrapped_from) from ex + + @staticmethod + def gen_success_msg(artifacts_dir: str, output_template_path: str, is_default_build_dir: bool) -> str: + + invoke_cmd = "sam local invoke" + if not is_default_build_dir: + invoke_cmd += " -t {}".format(output_template_path) + + deploy_cmd = "sam deploy --guided" + if not is_default_build_dir: + deploy_cmd += " --template-file {}".format(output_template_path) + + msg = """\nBuilt Artifacts : {artifacts_dir} +Built Template : {template} + +Commands you can use next +========================= +[*] Invoke Function: {invokecmd} +[*] Deploy: {deploycmd} + """.format( + invokecmd=invoke_cmd, deploycmd=deploy_cmd, artifacts_dir=artifacts_dir, template=output_template_path + ) + + return msg + @staticmethod def _setup_build_dir(build_dir: str, clean: bool) -> str: build_path = pathlib.Path(build_dir) @@ -205,21 +328,57 @@ def resources_to_build(self) -> ResourcesToBuildCollector: ------- ResourcesToBuildCollector """ + return ( + self.collect_build_resources(self._resource_identifier) + if self._resource_identifier + else self.collect_all_build_resources() + ) + + def collect_build_resources(self, resource_identifier: str) -> ResourcesToBuildCollector: + """Collect a single buildable resource and its dependencies. + For a Lambda function, its layers will be included. + + Parameters + ---------- + resource_identifier : str + Resource identifier for the resource to be built + + Returns + ------- + ResourcesToBuildCollector + ResourcesToBuildCollector containing the buildable resource and its dependencies + + Raises + ------ + ResourceNotFound + raises ResourceNotFound is the specified resource cannot be found. + """ result = ResourcesToBuildCollector() - if self._resource_identifier: - self._collect_single_function_and_dependent_layers(self._resource_identifier, result) - self._collect_single_buildable_layer(self._resource_identifier, result) + # Get the functions and its layer. Skips if it's inline. + self._collect_single_function_and_dependent_layers(resource_identifier, result) + self._collect_single_buildable_layer(resource_identifier, result) - if not result.functions and not result.layers: - all_resources = [f.name for f in self.function_provider.get_all() if not f.inlinecode] - all_resources.extend([l.name for l in self.layer_provider.get_all()]) + if not result.functions and not result.layers: + # Collect all functions and layers that are not inline + all_resources = [f.name for f in self.function_provider.get_all() if not f.inlinecode] + all_resources.extend([l.name for l in self.layer_provider.get_all()]) - available_resource_message = ( - f"{self._resource_identifier} not found. Possible options in your " f"template: {all_resources}" - ) - LOG.info(available_resource_message) - raise ResourceNotFound(f"Unable to find a function or layer with name '{self._resource_identifier}'") - return result + available_resource_message = ( + f"{resource_identifier} not found. Possible options in your " f"template: {all_resources}" + ) + LOG.info(available_resource_message) + raise ResourceNotFound(f"Unable to find a function or layer with name '{resource_identifier}'") + return result + + def collect_all_build_resources(self) -> ResourcesToBuildCollector: + """Collect all buildable resources. Including Lambda functions and layers. + + Returns + ------- + ResourcesToBuildCollector + ResourcesToBuildCollector that contains all the buildable resources. + """ + result = ResourcesToBuildCollector() result.add_functions([f for f in self.function_provider.get_all() if BuildContext._is_function_buildable(f)]) result.add_layers([l for l in self.layer_provider.get_all() if BuildContext._is_layer_buildable(l)]) return result diff --git a/samcli/commands/build/command.py b/samcli/commands/build/command.py index 5c450d6589..6205c37969 100644 --- a/samcli/commands/build/command.py +++ b/samcli/commands/build/command.py @@ -12,10 +12,13 @@ template_option_without_build, docker_common_options, parameter_override_option, + build_dir_option, + cache_dir_option, + base_dir_option, + manifest_option, + cached_option, ) from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args -from samcli.lib.build.exceptions import BuildInsideContainerError -from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider from samcli.lib.telemetry.metric import track_command from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.lib.utils.version_checker import check_newer_version @@ -24,8 +27,6 @@ LOG = logging.getLogger(__name__) -DEFAULT_BUILD_DIR = os.path.join(".aws-sam", "build") -DEFAULT_CACHE_DIR = os.path.join(".aws-sam", "cache") HELP_TEXT = """ Use this command to build your AWS Lambda Functions source code to generate artifacts that target AWS Lambda's @@ -77,31 +78,6 @@ @click.command("build", help=HELP_TEXT, short_help="Build your Lambda function code") @configuration_option(provider=TomlProvider(section="parameters")) -@click.option( - "--build-dir", - "-b", - default=DEFAULT_BUILD_DIR, - type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory - help="Path to a folder where the built artifacts will be stored. " - "This directory will be first removed before starting a build.", -) -@click.option( - "--cache-dir", - "-cd", - default=DEFAULT_CACHE_DIR, - type=click.Path(file_okay=False, dir_okay=True, writable=True), # Must be a directory - help="The folder where the cache artifacts will be stored when --cached is specified. " - "The default cache directory is .aws-sam/cache", -) -@click.option( - "--base-dir", - "-s", - default=None, - type=click.Path(dir_okay=True, file_okay=False), # Must be a directory - help="Resolve relative paths to function's source code with respect to this folder. Use this if " - "SAM template and your source code are not in same enclosing folder. By default, relative paths " - "are resolved with respect to the SAM template's location", -) @click.option( "--use-container", "-u", @@ -150,25 +126,11 @@ help="Enabled parallel builds. Use this flag to build your AWS SAM template's functions and layers in parallel. " "By default the functions and layers are built in sequence", ) -@click.option( - "--manifest", - "-m", - default=None, - type=click.Path(), - help="Path to a custom dependency manifest (e.g., package.json) to use instead of the default one", -) -@click.option( - "--cached", - "-c", - is_flag=True, - help="Enable cached builds. Use this flag to reuse build artifacts that have not changed from previous builds. " - "AWS SAM evaluates whether you have made any changes to files in your project directory. \n\n" - "Note: AWS SAM does not evaluate whether changes have been made to third party modules " - "that your project depends on, where you have not provided a specific version. " - "For example, if your Python function includes a requirements.txt file with the following entry " - "requests=1.x and the latest request module version changes from 1.1 to 1.2, " - "SAM will not pull the latest version until you run a non-cached build.", -) +@build_dir_option +@cache_dir_option +@base_dir_option +@manifest_option +@cached_option @template_option_without_build @parameter_override_option @docker_common_options @@ -253,19 +215,7 @@ def do_cli( # pylint: disable=too-many-locals, too-many-statements Implementation of the ``cli`` method """ - from samcli.commands.exceptions import UserException - from samcli.commands.build.build_context import BuildContext - from samcli.lib.build.app_builder import ( - ApplicationBuilder, - BuildError, - UnsupportedBuilderLibraryVersionError, - ContainerBuildNotSupported, - ) - from samcli.lib.build.workflow_config import UnsupportedRuntimeException - from samcli.local.lambdafn.exceptions import FunctionNotFound - from samcli.commands._utils.template import move_template - from samcli.lib.build.build_graph import InvalidBuildGraphException LOG.debug("'build' command is called") if cached: @@ -283,6 +233,7 @@ def do_cli( # pylint: disable=too-many-locals, too-many-statements build_dir, cache_dir, cached, + parallel=parallel, clean=clean, manifest_path=manifest_path, use_container=use_container, @@ -295,100 +246,7 @@ def do_cli( # pylint: disable=too-many-locals, too-many-statements build_images=processed_build_images, aws_region=click_ctx.region, ) as ctx: - try: - builder = ApplicationBuilder( - ctx.resources_to_build, - ctx.build_dir, - ctx.base_dir, - ctx.cache_dir, - ctx.cached, - ctx.is_building_specific_resource, - manifest_path_override=ctx.manifest_path_override, - container_manager=ctx.container_manager, - mode=ctx.mode, - parallel=parallel, - container_env_var=processed_env_vars, - container_env_var_file=container_env_var_file, - build_images=processed_build_images, - ) - except FunctionNotFound as ex: - raise UserException(str(ex), wrapped_from=ex.__class__.__name__) from ex - - try: - artifacts = builder.build() - stack_output_template_path_by_stack_path = { - stack.stack_path: stack.get_output_template_path(ctx.build_dir) for stack in ctx.stacks - } - for stack in ctx.stacks: - modified_template = builder.update_template( - stack, - artifacts, - stack_output_template_path_by_stack_path, - ) - move_template(stack.location, stack.get_output_template_path(ctx.build_dir), modified_template) - - click.secho("\nBuild Succeeded", fg="green") - - # try to use relpath so the command is easier to understand, however, - # under Windows, when SAM and (build_dir or output_template_path) are - # on different drive, relpath() fails. - root_stack = SamLocalStackProvider.find_root_stack(ctx.stacks) - out_template_path = root_stack.get_output_template_path(ctx.build_dir) - try: - build_dir_in_success_message = os.path.relpath(ctx.build_dir) - output_template_path_in_success_message = os.path.relpath(out_template_path) - except ValueError: - LOG.debug("Failed to retrieve relpath - using the specified path as-is instead") - build_dir_in_success_message = ctx.build_dir - output_template_path_in_success_message = out_template_path - - msg = gen_success_msg( - build_dir_in_success_message, - output_template_path_in_success_message, - os.path.abspath(ctx.build_dir) == os.path.abspath(DEFAULT_BUILD_DIR), - ) - - click.secho(msg, fg="yellow") - - except ( - UnsupportedRuntimeException, - BuildError, - BuildInsideContainerError, - UnsupportedBuilderLibraryVersionError, - ContainerBuildNotSupported, - InvalidBuildGraphException, - ) as ex: - click.secho("\nBuild Failed", fg="red") - - # Some Exceptions have a deeper wrapped exception that needs to be surfaced - # from deeper than just one level down. - deep_wrap = getattr(ex, "wrapped_from", None) - wrapped_from = deep_wrap if deep_wrap else ex.__class__.__name__ - raise UserException(str(ex), wrapped_from=wrapped_from) from ex - - -def gen_success_msg(artifacts_dir: str, output_template_path: str, is_default_build_dir: bool) -> str: - - invoke_cmd = "sam local invoke" - if not is_default_build_dir: - invoke_cmd += " -t {}".format(output_template_path) - - deploy_cmd = "sam deploy --guided" - if not is_default_build_dir: - deploy_cmd += " --template-file {}".format(output_template_path) - - msg = """\nBuilt Artifacts : {artifacts_dir} -Built Template : {template} - -Commands you can use next -========================= -[*] Invoke Function: {invokecmd} -[*] Deploy: {deploycmd} - """.format( - invokecmd=invoke_cmd, deploycmd=deploy_cmd, artifacts_dir=artifacts_dir, template=output_template_path - ) - - return msg + ctx.run() def _get_mode_value_from_envvar(name: str, choices: List[str]) -> Optional[str]: diff --git a/samcli/commands/delete/delete_context.py b/samcli/commands/delete/delete_context.py index ad29ce9c04..f228580fd1 100644 --- a/samcli/commands/delete/delete_context.py +++ b/samcli/commands/delete/delete_context.py @@ -12,7 +12,7 @@ from click import prompt from samcli.cli.cli_config_file import TomlProvider -from samcli.lib.utils.botoconfig import get_boto_config_with_user_agent +from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent from samcli.lib.delete.cfn_utils import CfnUtils from samcli.lib.package.s3_uploader import S3Uploader diff --git a/samcli/commands/deploy/command.py b/samcli/commands/deploy/command.py index e1e3ae0452..1f7f82a7d9 100644 --- a/samcli/commands/deploy/command.py +++ b/samcli/commands/deploy/command.py @@ -7,18 +7,26 @@ from samcli.cli.cli_config_file import TomlProvider, configuration_option from samcli.cli.main import aws_creds_options, common_options, pass_context, print_cmdline_args -from samcli.cli.types import ImageRepositoryType, ImageRepositoriesType from samcli.commands._utils.options import ( - capabilities_override_option, - guided_deploy_stack_name, - metadata_override_option, - notification_arns_override_option, + capabilities_option, + metadata_option, + notification_arns_option, parameter_override_option, no_progressbar_option, - tags_override_option, + tags_option, template_click_option, signing_profiles_option, - image_repositories_callback, + stack_name_option, + s3_bucket_option, + image_repository_option, + image_repositories_option, + s3_prefix_option, + kms_key_id_option, + use_json_option, + force_upload_option, + resolve_s3_option, + role_arn_option, + resolve_image_repos_option, ) from samcli.commands.deploy.utils import sanitize_parameter_overrides from samcli.lib.telemetry.metric import track_command @@ -59,56 +67,6 @@ help="Specify this flag to allow SAM CLI to guide you through the deployment using guided prompts.", ) @template_click_option(include_build=True) -@click.option( - "--stack-name", - required=False, - callback=guided_deploy_stack_name, - help="The name of the AWS CloudFormation stack you're deploying to. " - "If you specify an existing stack, the command updates the stack. " - "If you specify a new stack, the command creates it.", -) -@click.option( - "--s3-bucket", - required=False, - help="The name of the S3 bucket where this command uploads your " - "CloudFormation template. This is required the deployments of " - "templates sized greater than 51,200 bytes", -) -@click.option( - "--image-repository", - type=ImageRepositoryType(), - required=False, - help="ECR repo uri where this command uploads the image artifacts that are referenced in your template.", -) -@click.option( - "--image-repositories", - multiple=True, - callback=image_repositories_callback, - type=ImageRepositoriesType(), - required=False, - help="Specify mapping of Function Logical ID to ECR Repo uri, of the form Function_Logical_ID=ECR_Repo_Uri." - "This option can be specified multiple times.", -) -@click.option( - "--force-upload", - required=False, - is_flag=True, - help="Indicates whether to override existing files in the S3 bucket. " - "Specify this flag to upload artifacts even if they " - "match existing artifacts in the S3 bucket.", -) -@click.option( - "--s3-prefix", - required=False, - help="A prefix name that the command adds to the " - "artifacts' name when it uploads them to the S3 bucket. " - "The prefix name is a path name (folder name) for the S3 bucket.", -) -@click.option( - "--kms-key-id", - required=False, - help="The ID of an AWS KMS key that the command uses to encrypt artifacts that are at rest in the S3 bucket.", -) @click.option( "--no-execute-changeset", required=False, @@ -120,13 +78,6 @@ "the changeset looks satisfactory, the stack changes can be made by " "running the same command without specifying `--no-execute-changeset`", ) -@click.option( - "--role-arn", - required=False, - help="The Amazon Resource Name (ARN) of an AWS Identity " - "and Access Management (IAM) role that AWS CloudFormation assumes when " - "executing the change set.", -) @click.option( "--fail-on-empty-changeset/--no-fail-on-empty-changeset", default=True, @@ -143,37 +94,24 @@ is_flag=True, help="Prompt to confirm if the computed changeset is to be deployed by SAM CLI.", ) -@click.option( - "--use-json", - required=False, - is_flag=True, - help="Indicates whether to use JSON as the format for " - "the output AWS CloudFormation template. YAML is used by default.", -) -@click.option( - "--resolve-s3", - required=False, - is_flag=True, - help="Automatically resolve s3 bucket for non-guided deployments. " - "Enabling this option will also create a managed default s3 bucket for you. " - "If you do not provide a --s3-bucket value, the managed bucket will be used. " - "Do not use --s3-guided parameter with this option.", -) -@click.option( - "--resolve-image-repos", - required=False, - is_flag=True, - help="Automatically create and delete ECR repositories for image-based functions in non-guided deployments. " - "A companion stack containing ECR repos for each function will be deployed along with the template stack. " - "Automatically created image repositories will be deleted if the corresponding functions are removed.", -) -@metadata_override_option -@notification_arns_override_option -@tags_override_option +@stack_name_option +@s3_bucket_option +@image_repository_option +@image_repositories_option +@force_upload_option +@s3_prefix_option +@kms_key_id_option +@role_arn_option +@use_json_option +@resolve_s3_option +@resolve_image_repos_option +@metadata_option +@notification_arns_option +@tags_option @parameter_override_option @signing_profiles_option @no_progressbar_option -@capabilities_override_option +@capabilities_option @aws_creds_options @common_options @image_repository_validation @@ -361,5 +299,6 @@ def do_cli( profile=profile, confirm_changeset=guided_context.confirm_changeset if guided else confirm_changeset, signing_profiles=guided_context.signing_profiles if guided else signing_profiles, + use_changeset=True, ) as deploy_context: deploy_context.run() diff --git a/samcli/commands/deploy/deploy_context.py b/samcli/commands/deploy/deploy_context.py index 3d4e7bc16a..88c571b47a 100644 --- a/samcli/commands/deploy/deploy_context.py +++ b/samcli/commands/deploy/deploy_context.py @@ -33,7 +33,7 @@ from samcli.lib.intrinsic_resolver.intrinsics_symbol_table import IntrinsicsSymbolTable from samcli.lib.package.s3_uploader import S3Uploader from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider -from samcli.lib.utils.botoconfig import get_boto_config_with_user_agent +from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent from samcli.yamlhelper import yaml_parse LOG = logging.getLogger(__name__) @@ -70,6 +70,7 @@ def __init__( profile, confirm_changeset, signing_profiles, + use_changeset, ): self.template_file = template_file self.stack_name = stack_name @@ -97,6 +98,7 @@ def __init__( self.deployer = None self.confirm_changeset = confirm_changeset self.signing_profiles = signing_profiles + self.use_changeset = use_changeset def __enter__(self): return self @@ -151,6 +153,7 @@ def run(self): display_parameter_overrides, self.confirm_changeset, self.signing_profiles, + self.use_changeset, ) return self.deploy( self.stack_name, @@ -165,6 +168,7 @@ def run(self): region, self.fail_on_empty_changeset, self.confirm_changeset, + self.use_changeset, ) def deploy( @@ -181,6 +185,7 @@ def deploy( region, fail_on_empty_changeset=True, confirm_changeset=False, + use_changeset=True, ): """ Deploy the stack to cloudformation. @@ -213,6 +218,8 @@ def deploy( Should fail when changeset is empty confirm_changeset : bool Should wait for customer's confirm before executing the changeset + use_changeset : bool + Involve creation of changesets, false when using sam sync """ stacks, _ = SamLocalStackProvider.get_stacks( self.template_file, @@ -225,36 +232,55 @@ def deploy( if not authorization_required: click.secho(f"{resource} may not have authorization defined.", fg="yellow") - try: - result, changeset_type = self.deployer.create_and_wait_for_changeset( - stack_name=stack_name, - cfn_template=template_str, - parameter_values=parameters, - capabilities=capabilities, - role_arn=role_arn, - notification_arns=notification_arns, - s3_uploader=s3_uploader, - tags=tags, - ) - click.echo(self.MSG_SHOWCASE_CHANGESET.format(changeset_id=result["Id"])) - - if no_execute_changeset: - return - - if confirm_changeset: - click.secho(self.MSG_CONFIRM_CHANGESET_HEADER, fg="yellow") - click.secho("=" * len(self.MSG_CONFIRM_CHANGESET_HEADER), fg="yellow") - if not click.confirm(f"{self.MSG_CONFIRM_CHANGESET}", default=False): + if use_changeset: + try: + result, changeset_type = self.deployer.create_and_wait_for_changeset( + stack_name=stack_name, + cfn_template=template_str, + parameter_values=parameters, + capabilities=capabilities, + role_arn=role_arn, + notification_arns=notification_arns, + s3_uploader=s3_uploader, + tags=tags, + ) + click.echo(self.MSG_SHOWCASE_CHANGESET.format(changeset_id=result["Id"])) + + if no_execute_changeset: return - self.deployer.execute_changeset(result["Id"], stack_name) - self.deployer.wait_for_execute(stack_name, changeset_type) - click.echo(self.MSG_EXECUTE_SUCCESS.format(stack_name=stack_name, region=region)) - - except deploy_exceptions.ChangeEmptyError as ex: - if fail_on_empty_changeset: + if confirm_changeset: + click.secho(self.MSG_CONFIRM_CHANGESET_HEADER, fg="yellow") + click.secho("=" * len(self.MSG_CONFIRM_CHANGESET_HEADER), fg="yellow") + if not click.confirm(f"{self.MSG_CONFIRM_CHANGESET}", default=False): + return + + self.deployer.execute_changeset(result["Id"], stack_name) + self.deployer.wait_for_execute(stack_name, changeset_type) + click.echo(self.MSG_EXECUTE_SUCCESS.format(stack_name=stack_name, region=region)) + + except deploy_exceptions.ChangeEmptyError as ex: + if fail_on_empty_changeset: + raise + LOG.error(str(ex)) + + else: + try: + result = self.deployer.sync( + stack_name=stack_name, + cfn_template=template_str, + parameter_values=parameters, + capabilities=capabilities, + role_arn=role_arn, + notification_arns=notification_arns, + s3_uploader=s3_uploader, + tags=tags, + ) + LOG.info(result) + + except deploy_exceptions.DeployFailedError as ex: + LOG.error(str(ex)) raise - click.echo(str(ex)) @staticmethod def merge_parameters(template_dict: Dict, parameter_overrides: Dict) -> List[Dict]: diff --git a/samcli/commands/deploy/utils.py b/samcli/commands/deploy/utils.py index c961eb710f..b9cfdc0baa 100644 --- a/samcli/commands/deploy/utils.py +++ b/samcli/commands/deploy/utils.py @@ -18,6 +18,7 @@ def print_deploy_args( parameter_overrides, confirm_changeset, signing_profiles, + use_changeset, ): """ Print a table of the values that are used during a sam deploy. @@ -43,6 +44,7 @@ def print_deploy_args( :param parameter_overrides: Cloudformation parameter overrides to be supplied based on the stack's template :param confirm_changeset: Prompt for changeset to be confirmed before going ahead with the deploy. :param signing_profiles: Signing profile details which will be used to sign functions/layers + :param use_changeset: Flag to use or skip the usage of changesets """ _parameters = parameter_overrides.copy() @@ -62,7 +64,8 @@ def print_deploy_args( click.secho("\n\tDeploying with following values\n\t===============================", fg="yellow") click.echo(f"\tStack name : {stack_name}") click.echo(f"\tRegion : {region}") - click.echo(f"\tConfirm changeset : {confirm_changeset}") + if use_changeset: + click.echo(f"\tConfirm changeset : {confirm_changeset}") if image_repository: msg = "Deployment image repository : " # NOTE(sriram-mv): tab length is 8 spaces. diff --git a/samcli/commands/logs/command.py b/samcli/commands/logs/command.py index 7042970a3a..81be09b4cf 100644 --- a/samcli/commands/logs/command.py +++ b/samcli/commands/logs/command.py @@ -3,11 +3,13 @@ """ import logging + import click +from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args +from samcli.commands._utils.options import common_observability_options from samcli.lib.telemetry.metric import track_command -from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.lib.utils.version_checker import check_newer_version LOG = logging.getLogger(__name__) @@ -38,9 +40,11 @@ @click.option( "--name", "-n", - required=True, - help="Name of your AWS Lambda function. If this function is a part of a CloudFormation stack, " - "this can be the LogicalID of function resource in the CloudFormation/SAM template.", + multiple=True, + help="Name(s) of your AWS Lambda function. If this function is a part of a CloudFormation stack, " + "this can be the LogicalID of function resource in the CloudFormation/SAM template. Multiple names can be provided" + "by repeating the parameter again. If it is not provided and no --cw-log-group have been given, it will scan" + "given stack and find all possible resources, and start pulling log information from them", ) @click.option("--stack-name", default=None, help="Name of the AWS CloudFormation stack that the function is a part of.") @click.option( @@ -52,26 +56,19 @@ "https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html", ) @click.option( - "--start-time", - "-s", - default="10m ago", - help="Fetch logs starting at this time. Time can be relative values like '5mins ago', 'yesterday' or " - "formatted timestamp like '2018-01-01 10:10:10'. Defaults to '10mins ago'.", -) -@click.option( - "--end-time", - "-e", - default=None, - help="Fetch logs up to this time. Time can be relative values like '5mins ago', 'tomorrow' or " - "formatted timestamp like '2018-01-01 10:10:10'", + "--include-traces", + "-i", + is_flag=True, + help="Include the XRay traces in the log output.", ) @click.option( - "--tail", - "-t", - is_flag=True, - help="Tail the log output. This will ignore the end time argument and continue to fetch logs as they " - "become available.", + "--cw-log-group", + multiple=True, + help="Additional CloudWatch Log group names that are not auto-discovered based upon --name parameter. " + "When provided, it will only tail the given CloudWatch Log groups. If you want to tail log groups related " + "to resources, please also provide their names as well", ) +@common_observability_options @cli_framework_options @aws_creds_options @pass_context @@ -84,8 +81,11 @@ def cli( stack_name, filter, tail, + include_traces, start_time, end_time, + unformatted, + cw_log_group, config_file, config_env, ): # pylint: disable=redefined-builtin @@ -94,30 +94,52 @@ def cli( """ # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing - do_cli(name, stack_name, filter, tail, start_time, end_time) # pragma: no cover + do_cli( + name, stack_name, filter, tail, include_traces, start_time, end_time, cw_log_group, unformatted, ctx.region + ) # pragma: no cover -def do_cli(function_name, stack_name, filter_pattern, tailing, start_time, end_time): +def do_cli( + names, + stack_name, + filter_pattern, + tailing, + include_tracing, + start_time, + end_time, + cw_log_groups, + unformatted, + region, +): """ Implementation of the ``cli`` method """ - from .logs_context import LogsCommandContext - - LOG.debug("'logs' command is called") - - with LogsCommandContext( - function_name, - stack_name=stack_name, - filter_pattern=filter_pattern, - start_time=start_time, - end_time=end_time, - ) as context: - - if tailing: - context.fetcher.tail(start_time=context.start_time, filter_pattern=context.filter_pattern) - else: - context.fetcher.load_time_period( - start_time=context.start_time, - end_time=context.end_time, - filter_pattern=context.filter_pattern, - ) + + from datetime import datetime + + from samcli.commands.logs.logs_context import parse_time, ResourcePhysicalIdResolver + from samcli.commands.logs.puller_factory import generate_puller + from samcli.lib.utils.boto_utils import get_boto_client_provider_with_config, get_boto_resource_provider_with_config + + sanitized_start_time = parse_time(start_time, "start-time") + sanitized_end_time = parse_time(end_time, "end-time") or datetime.utcnow() + + boto_client_provider = get_boto_client_provider_with_config(region_name=region) + boto_resource_provider = get_boto_resource_provider_with_config(region_name=region) + resource_logical_id_resolver = ResourcePhysicalIdResolver(boto_resource_provider, stack_name, names) + + # only fetch all resources when no CloudWatch log group defined + fetch_all_when_no_resource_name_given = not cw_log_groups + puller = generate_puller( + boto_client_provider, + resource_logical_id_resolver.get_resource_information(fetch_all_when_no_resource_name_given), + filter_pattern, + cw_log_groups, + unformatted, + include_tracing, + ) + + if tailing: + puller.tail(sanitized_start_time, filter_pattern) + else: + puller.load_time_period(sanitized_start_time, sanitized_end_time, filter_pattern) diff --git a/samcli/commands/logs/console_consumers.py b/samcli/commands/logs/console_consumers.py index 2f77e34ab0..9881e11725 100644 --- a/samcli/commands/logs/console_consumers.py +++ b/samcli/commands/logs/console_consumers.py @@ -13,6 +13,16 @@ class CWConsoleEventConsumer(ObservabilityEventConsumer[CWLogEvent]): Consumer implementation that will consume given event as outputting into console """ - # pylint: disable=R0201 + def __init__(self, add_newline: bool = False): + """ + + Parameters + ---------- + add_newline : bool + If it is True, it will add a new line at the end of each echo operation. Otherwise it will always print + into same line when echo is called. + """ + self._add_newline = add_newline + def consume(self, event: CWLogEvent): - click.echo(event.message, nl=False) + click.echo(event.message, nl=self._add_newline) diff --git a/samcli/commands/logs/logs_context.py b/samcli/commands/logs/logs_context.py index 5504895a70..777a7791a2 100644 --- a/samcli/commands/logs/logs_context.py +++ b/samcli/commands/logs/logs_context.py @@ -3,274 +3,138 @@ """ import logging +from typing import List, Optional, Set, Any -import boto3 -import botocore - -from samcli.commands.exceptions import UserException -from samcli.commands.logs.console_consumers import CWConsoleEventConsumer -from samcli.lib.observability.cw_logs.cw_log_formatters import ( - CWColorizeErrorsFormatter, - CWJsonFormatter, - CWKeywordHighlighterFormatter, - CWPrettyPrintFormatter, +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_STEPFUNCTIONS_STATEMACHINE, ) -from samcli.lib.observability.cw_logs.cw_log_group_provider import LogGroupProvider -from samcli.lib.observability.cw_logs.cw_log_puller import CWLogPuller -from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumerDecorator -from samcli.lib.utils.colors import Colored +from samcli.commands.exceptions import UserException +from samcli.lib.utils.boto_utils import BotoProviderType +from samcli.lib.utils.cloudformation import get_resource_summaries from samcli.lib.utils.time import to_utc, parse_date LOG = logging.getLogger(__name__) class InvalidTimestampError(UserException): - pass - - -class LogsCommandContext: """ - Sets up a context to run the Logs command by parsing the CLI arguments and creating necessary objects to be able - to fetch and display logs - - This class **must** be used inside a ``with`` statement as follows: - - with LogsCommandContext(**kwargs) as context: - context.fetcher.fetch(...) + Used to indicate that given date time string is an invalid timestamp """ - def __init__( - self, function_name, stack_name=None, filter_pattern=None, start_time=None, end_time=None, output_file=None - ): - """ - Initializes the context - - Parameters - ---------- - function_name : str - Name of the function to fetch logs for - - stack_name : str - Name of the stack where the function is available - - filter_pattern : str - Optional pattern to filter the logs by - - start_time : str - Fetch logs starting at this time - - end_time : str - Fetch logs up to this time - - output_file : str - Write logs to this file instead of Terminal - """ - - self._function_name = function_name - self._stack_name = stack_name - self._filter_pattern = filter_pattern - self._start_time = start_time - self._end_time = end_time - self._output_file = output_file - self._output_file_handle = None - - # No colors when we write to a file. Otherwise use colors - self._must_print_colors = not self._output_file - - self._logs_client = boto3.client("logs") - self._cfn_client = boto3.client("cloudformation") - - def __enter__(self): - """ - Performs some basic checks and returns itself when everything is ready to invoke a Lambda function. - - Returns - ------- - LogsCommandContext - Returns this object - """ - - self._output_file_handle = self._setup_output_file(self._output_file) - return self - - def __exit__(self, *args): - """ - Cleanup any necessary opened files - """ - - if self._output_file_handle: - self._output_file_handle.close() - self._output_file_handle = None - - @property - def fetcher(self): - return CWLogPuller( - logs_client=self._logs_client, - consumer=ObservabilityEventConsumerDecorator( - mappers=[ - CWColorizeErrorsFormatter(self.colored), - CWJsonFormatter(), - CWKeywordHighlighterFormatter(self.colored, self._filter_pattern), - CWPrettyPrintFormatter(self.colored), - ], - consumer=CWConsoleEventConsumer(), - ), - cw_log_group=self.log_group_name, - resource_name=self._function_name, - ) - - @property - def start_time(self): - return self._parse_time(self._start_time, "start-time") - - @property - def end_time(self): - return self._parse_time(self._end_time, "end-time") - - @property - def log_group_name(self): - """ - Name of the AWS CloudWatch Log Group that we will be querying. It generates the name based on the - Lambda Function name and stack name provided. - - Returns - ------- - str - Name of the CloudWatch Log Group - """ +def parse_time(time_str: str, property_name: str): + """ + Parse the time from the given string, convert to UTC, and return the datetime object - function_id = self._function_name - if self._stack_name: - function_id = self._get_resource_id_from_stack(self._cfn_client, self._stack_name, self._function_name) - LOG.debug( - "Function with LogicalId '%s' in stack '%s' resolves to actual physical ID '%s'", - self._function_name, - self._stack_name, - function_id, - ) + Parameters + ---------- + time_str : str + The time to parse - return LogGroupProvider.for_lambda_function(function_id) + property_name : str + Name of the property where this time came from. Used in the exception raised if time is not parseable - @property - def colored(self): - """ - Instance of Colored object to colorize strings + Returns + ------- + datetime.datetime + Parsed datetime object - Returns - ------- - samcli.commands.utils.colors.Colored - """ - # No colors if we are writing output to a file - return Colored(colorize=self._must_print_colors) + Raises + ------ + InvalidTimestampError + If the string cannot be parsed as a timestamp + """ + if not time_str: + return None - @property - def filter_pattern(self): - return self._filter_pattern + parsed = parse_date(time_str) + if not parsed: + raise InvalidTimestampError("Unable to parse the time provided by '{}'".format(property_name)) - @property - def output_file_handle(self): - return self._output_file_handle + return to_utc(parsed) - @staticmethod - def _setup_output_file(output_file): - """ - Open a log file if necessary and return the file handle. This will create a file if it does not exist - Parameters - ---------- - output_file : str - Path to a file where the logs should be written to +class ResourcePhysicalIdResolver: + """ + Wrapper class that is used to extract information about resources which we can tail their logs for given stack + """ - Returns - ------- - Handle to the opened log file, if necessary. None otherwise - """ - if not output_file: - return None + # list of resource types that is supported right now for pulling their logs + DEFAULT_SUPPORTED_RESOURCES: Set[str] = { + AWS_LAMBDA_FUNCTION, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_STEPFUNCTIONS_STATEMACHINE, + } - return open(output_file, "wb") + def __init__( + self, + boto_resource_provider: BotoProviderType, + stack_name: str, + resource_names: Optional[List[str]] = None, + supported_resource_types: Optional[Set[str]] = None, + ): + self._boto_resource_provider = boto_resource_provider + self._stack_name = stack_name + if resource_names is None: + resource_names = [] + if supported_resource_types is None: + supported_resource_types = ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES + self._supported_resource_types: Set[str] = supported_resource_types + self._resource_names = set(resource_names) - @staticmethod - def _parse_time(time_str, property_name): + def get_resource_information(self, fetch_all_when_no_resource_name_given: bool = True) -> List[Any]: """ - Parse the time from the given string, convert to UTC, and return the datetime object + Returns the list of resource information for the given stack. Parameters ---------- - time_str : str - The time to parse - - property_name : str - Name of the property where this time came from. Used in the exception raised if time is not parseable + fetch_all_when_no_resource_name_given : bool + When given, it will fetch all resources if no specific resource name is provided, default value is True Returns ------- - datetime.datetime - Parsed datetime object - - Raises - ------ - samcli.commands.exceptions.UserException - If the string cannot be parsed as a timestamp + List[StackResourceSummary] + List of resource information, which will be used to fetch the logs """ - if not time_str: - return None + if self._resource_names: + return self._fetch_resources_from_stack(self._resource_names) + if fetch_all_when_no_resource_name_given: + return self._fetch_resources_from_stack() + return [] - parsed = parse_date(time_str) - if not parsed: - raise InvalidTimestampError("Unable to parse the time provided by '{}'".format(property_name)) - - return to_utc(parsed) - - @staticmethod - def _get_resource_id_from_stack(cfn_client, stack_name, logical_id): + def _fetch_resources_from_stack(self, selected_resource_names: Optional[Set[str]] = None) -> List[Any]: """ - Given the LogicalID of a resource, call AWS CloudFormation to get physical ID of the resource within - the specified stack. + Returns list of all resources from given stack name + If any resource is not supported, it will discard them Parameters ---------- - cfn_client : boto3.session.Session.client - CloudFormation client provided by AWS SDK - - stack_name : str - Name of the stack to query - - logical_id : str - LogicalId of the resource + selected_resource_names : Optional[Set[str]] + An optional set of string parameter, which will filter resource names. If none is given, it will be + equal to all resource names in stack, which means there won't be any filtering by resource name. Returns ------- - str - Physical ID of the resource - - Raises - ------ - samcli.commands.exceptions.UserException - If the stack or resource does not exist + List[StackResourceSummary] + List of resource information, which will be used to fetch the logs """ - - LOG.debug( - "Getting resource's PhysicalId from AWS CloudFormation stack. StackName=%s, LogicalId=%s", - stack_name, - logical_id, + results = [] + LOG.debug("Getting logical id of the all resources for stack '%s'", self._stack_name) + stack_resources = get_resource_summaries( + self._boto_resource_provider, self._stack_name, ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES ) - try: - response = cfn_client.describe_stack_resource(StackName=stack_name, LogicalResourceId=logical_id) - - LOG.debug("Response from AWS CloudFormation %s", response) - return response["StackResourceDetail"]["PhysicalResourceId"] - - except botocore.exceptions.ClientError as ex: - LOG.debug( - "Unable to fetch resource name from CloudFormation Stack: " - "StackName=%s, ResourceLogicalId=%s, Response=%s", - stack_name, - logical_id, - ex.response, - ) + if selected_resource_names is None: + selected_resource_names = {stack_resource.logical_resource_id for stack_resource in stack_resources} - # The exception message already has a well formatted error message that we can surface to user - raise UserException(str(ex), wrapped_from=ex.response["Error"]["Code"]) from ex + for resource in stack_resources: + # if resource name is not selected, continue + if resource.logical_resource_id not in selected_resource_names: + LOG.debug("Resource (%s) is not selected with given input", resource.logical_resource_id) + continue + results.append(resource) + return results diff --git a/samcli/commands/logs/puller_factory.py b/samcli/commands/logs/puller_factory.py new file mode 100644 index 0000000000..a735b6cf7e --- /dev/null +++ b/samcli/commands/logs/puller_factory.py @@ -0,0 +1,180 @@ +""" +File keeps Factory method to prepare required puller information +with its producers and consumers +""" +import logging +from typing import List, Optional + +from samcli.commands.exceptions import UserException +from samcli.commands.logs.console_consumers import CWConsoleEventConsumer +from samcli.commands.traces.traces_puller_factory import generate_trace_puller +from samcli.lib.observability.cw_logs.cw_log_formatters import ( + CWColorizeErrorsFormatter, + CWJsonFormatter, + CWKeywordHighlighterFormatter, + CWPrettyPrintFormatter, + CWAddNewLineIfItDoesntExist, + CWLogEventJSONMapper, +) +from samcli.lib.observability.cw_logs.cw_log_group_provider import LogGroupProvider +from samcli.lib.observability.cw_logs.cw_log_puller import CWLogPuller +from samcli.lib.observability.observability_info_puller import ( + ObservabilityPuller, + ObservabilityEventConsumerDecorator, + ObservabilityEventConsumer, + ObservabilityCombinedPuller, +) +from samcli.lib.utils.boto_utils import BotoProviderType +from samcli.lib.utils.cloudformation import CloudFormationResourceSummary +from samcli.lib.utils.colors import Colored + +LOG = logging.getLogger(__name__) + + +class NoPullerGeneratedException(UserException): + """ + Used to indicate that no puller information have been generated + therefore there is no observability information (logs, xray) to pull + """ + + +def generate_puller( + boto_client_provider: BotoProviderType, + resource_information_list: List[CloudFormationResourceSummary], + filter_pattern: Optional[str] = None, + additional_cw_log_groups: Optional[List[str]] = None, + unformatted: bool = False, + include_tracing: bool = False, +) -> ObservabilityPuller: + """ + This function will generate generic puller which can be used to + pull information from various observability resources. + + Parameters + ---------- + boto_client_provider: BotoProviderType + Boto3 client generator, which will create a new instance of the client with a new session that could be + used within different threads/coroutines + resource_information_list : List[CloudFormationResourceSummary] + List of resource information, which keeps logical id, physical id and type of the resources + filter_pattern : Optional[str] + Optional filter pattern which will be used to filter incoming events + additional_cw_log_groups : Optional[str] + Optional list of additional CloudWatch log groups which will be used to fetch + log events from. + unformatted : bool + By default, logs and traces are printed with a format for terminal. If this option is provided, the events + will be printed unformatted in JSON. + include_tracing: bool + A flag to include the xray traces log or not + + Returns + ------- + Puller instance that can be used to pull information. + """ + if additional_cw_log_groups is None: + additional_cw_log_groups = [] + pullers: List[ObservabilityPuller] = [] + + # populate all puller instances for given resources + for resource_information in resource_information_list: + cw_log_group_name = LogGroupProvider.for_resource( + boto_client_provider, + resource_information.resource_type, + resource_information.physical_resource_id, + ) + if not cw_log_group_name: + LOG.warning( + "Can't find CloudWatch LogGroup name for resource (%s)", resource_information.logical_resource_id + ) + continue + + consumer = generate_consumer(filter_pattern, unformatted, resource_information.logical_resource_id) + pullers.append( + CWLogPuller( + boto_client_provider("logs"), + consumer, + cw_log_group_name, + resource_information.logical_resource_id, + ) + ) + + # populate puller instances for the additional CloudWatch log groups + for cw_log_group in additional_cw_log_groups: + consumer = generate_consumer(filter_pattern, unformatted) + pullers.append( + CWLogPuller( + boto_client_provider("logs"), + consumer, + cw_log_group, + ) + ) + + # if tracing flag is set, add the xray traces puller to fetch debug traces + if include_tracing: + trace_puller = generate_trace_puller(boto_client_provider("xray"), unformatted) + pullers.append(trace_puller) + + # if no puller have been collected, raise an exception since there is nothing to pull + if not pullers: + raise NoPullerGeneratedException("No valid resources find to pull information") + + # return the combined puller instance, which will pull from all pullers collected + return ObservabilityCombinedPuller(pullers) + + +def generate_consumer( + filter_pattern: Optional[str] = None, unformatted: bool = False, resource_name: Optional[str] = None +): + """ + Generates consumer instance with the given variables. + If unformatted is True, then it will return consumer with formatters for just JSON. + If not, it will return console consumer + """ + if unformatted: + return generate_unformatted_consumer() + + return generate_console_consumer(filter_pattern) + + +def generate_unformatted_consumer() -> ObservabilityEventConsumer: + """ + Creates event consumer, which prints CW Log Events unformatted as JSON into terminal + + Returns + ------- + ObservabilityEventConsumer which will store events into a file + """ + return ObservabilityEventConsumerDecorator( + [ + CWLogEventJSONMapper(), + ], + CWConsoleEventConsumer(True), + ) + + +def generate_console_consumer(filter_pattern: Optional[str]) -> ObservabilityEventConsumer: + """ + Creates a console event consumer, which is used to display events in the user's console + + Parameters + ---------- + filter_pattern : str + Filter pattern is used to display certain words in a different pattern then + the rest of the messages. + + Returns + ------- + A consumer which will display events into console + """ + colored = Colored() + return ObservabilityEventConsumerDecorator( + [ + CWColorizeErrorsFormatter(colored), + CWJsonFormatter(), + CWKeywordHighlighterFormatter(colored, filter_pattern), + CWPrettyPrintFormatter(colored), + CWAddNewLineIfItDoesntExist(), + ], + CWConsoleEventConsumer(), + ) diff --git a/samcli/commands/package/command.py b/samcli/commands/package/command.py index cc0dc35c5d..ab5a00c063 100644 --- a/samcli/commands/package/command.py +++ b/samcli/commands/package/command.py @@ -1,24 +1,24 @@ """ CLI command for "package" command """ -from functools import partial - import click from samcli.cli.cli_config_file import configuration_option, TomlProvider from samcli.cli.main import pass_context, common_options, aws_creds_options, print_cmdline_args -from samcli.cli.types import ImageRepositoryType, ImageRepositoriesType -from samcli.commands.package.exceptions import PackageResolveS3AndS3SetError, PackageResolveS3AndS3NotSetError from samcli.lib.cli_validation.image_repository_validation import image_repository_validation -from samcli.lib.utils.packagetype import ZIP, IMAGE from samcli.commands._utils.options import ( - artifact_callback, - resolve_s3_callback, signing_profiles_option, - image_repositories_callback, + s3_bucket_option, + image_repository_option, + image_repositories_option, + s3_prefix_option, + kms_key_id_option, + use_json_option, + force_upload_option, + resolve_s3_option, ) -from samcli.commands._utils.options import metadata_override_option, template_click_option, no_progressbar_option -from samcli.commands._utils.resources import resources_generator +from samcli.commands._utils.options import metadata_option, template_click_option, no_progressbar_option +from samcli.lib.utils.resources import resources_generator from samcli.lib.bootstrap.bootstrap import manage_stack from samcli.lib.telemetry.metric import track_command, track_template_warnings from samcli.lib.utils.version_checker import check_newer_version @@ -54,40 +54,6 @@ def resources_and_properties_help_string(): @click.command("package", short_help=SHORT_HELP, help=HELP_TEXT, context_settings=dict(max_content_width=120)) @configuration_option(provider=TomlProvider(section="parameters")) @template_click_option(include_build=True) -@click.option( - "--s3-bucket", - required=False, - callback=partial(artifact_callback, artifact=ZIP), - help="The name of the S3 bucket where this command uploads the artifacts that are referenced in your template.", -) -@click.option( - "--image-repository", - callback=partial(artifact_callback, artifact=IMAGE), - type=ImageRepositoryType(), - required=False, - help="ECR repo uri where this command uploads the image artifacts that are referenced in your template.", -) -@click.option( - "--image-repositories", - multiple=True, - callback=image_repositories_callback, - type=ImageRepositoriesType(), - required=False, - help="Specify mapping of Function Logical ID to ECR Repo uri, of the form Function_Logical_ID=ECR_Repo_Uri." - "This option can be specified multiple times.", -) -@click.option( - "--s3-prefix", - required=False, - help="A prefix name that the command adds to the artifacts " - "name when it uploads them to the S3 bucket. The prefix name is a " - "path name (folder name) for the S3 bucket.", -) -@click.option( - "--kms-key-id", - required=False, - help="The ID of an AWS KMS key that the command uses to encrypt artifacts that are at rest in the S3 bucket.", -) @click.option( "--output-template-file", required=False, @@ -96,37 +62,15 @@ def resources_and_properties_help_string(): "writes the output AWS CloudFormation template. If you don't specify a " "path, the command writes the template to the standard output.", ) -@click.option( - "--use-json", - required=False, - is_flag=True, - help="Indicates whether to use JSON as the format for " - "the output AWS CloudFormation template. YAML is used by default.", -) -@click.option( - "--force-upload", - required=False, - is_flag=True, - help="Indicates whether to override existing files " - "in the S3 bucket. Specify this flag to upload artifacts even if they " - "match existing artifacts in the S3 bucket.", -) -@click.option( - "--resolve-s3", - required=False, - is_flag=True, - callback=partial( - resolve_s3_callback, - artifact=ZIP, - exc_set=PackageResolveS3AndS3SetError, - exc_not_set=PackageResolveS3AndS3NotSetError, - ), - help="Automatically resolve s3 bucket for non-guided deployments. " - "Enabling this option will also create a managed default s3 bucket for you. " - "If you do not provide a --s3-bucket value, the managed bucket will be used. " - "Do not use --s3-guided parameter with this option.", -) -@metadata_override_option +@s3_bucket_option +@image_repository_option +@image_repositories_option +@s3_prefix_option +@kms_key_id_option +@use_json_option +@force_upload_option +@resolve_s3_option +@metadata_option @signing_profiles_option @no_progressbar_option @common_options diff --git a/samcli/commands/package/package_context.py b/samcli/commands/package/package_context.py index 0a26577333..8aaca8241c 100644 --- a/samcli/commands/package/package_context.py +++ b/samcli/commands/package/package_context.py @@ -30,7 +30,7 @@ from samcli.lib.package.code_signer import CodeSigner from samcli.lib.package.s3_uploader import S3Uploader from samcli.lib.package.uploaders import Uploaders -from samcli.lib.utils.botoconfig import get_boto_config_with_user_agent +from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent from samcli.yamlhelper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/samcli/commands/sync/__init__.py b/samcli/commands/sync/__init__.py new file mode 100644 index 0000000000..e849905d58 --- /dev/null +++ b/samcli/commands/sync/__init__.py @@ -0,0 +1,4 @@ +"""`sam sync` command.""" + +# Expose the cli object here +from .command import cli # noqa diff --git a/samcli/commands/sync/command.py b/samcli/commands/sync/command.py new file mode 100644 index 0000000000..9d90706a5f --- /dev/null +++ b/samcli/commands/sync/command.py @@ -0,0 +1,349 @@ +"""CLI command for "sync" command.""" +import os +import logging +from typing import List, Set, TYPE_CHECKING, Optional, Tuple + +import click + +from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args +from samcli.commands._utils.options import ( + template_option_without_build, + parameter_override_option, + capabilities_option, + metadata_option, + notification_arns_option, + tags_option, + stack_name_option, + base_dir_option, + image_repository_option, + image_repositories_option, + s3_prefix_option, + kms_key_id_option, + role_arn_option, +) +from samcli.cli.cli_config_file import configuration_option, TomlProvider +from samcli.lib.utils.version_checker import check_newer_version +from samcli.lib.bootstrap.bootstrap import manage_stack +from samcli.lib.cli_validation.image_repository_validation import image_repository_validation +from samcli.lib.telemetry.metric import track_command, track_template_warnings +from samcli.lib.warnings.sam_cli_warning import CodeDeployWarning, CodeDeployConditionWarning +from samcli.commands.build.command import _get_mode_value_from_envvar +from samcli.lib.sync.sync_flow_factory import SyncFlowFactory +from samcli.lib.sync.sync_flow_executor import SyncFlowExecutor +from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider +from samcli.lib.providers.provider import ( + ResourceIdentifier, + get_all_resource_ids, + get_unique_resource_ids, +) +from samcli.commands._utils.options import DEFAULT_BUILD_DIR, DEFAULT_CACHE_DIR +from samcli.cli.context import Context +from samcli.lib.sync.watch_manager import WatchManager + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.package.package_context import PackageContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + +HELP_TEXT = """ +Update/sync local artifacts to AWS +""" +SHORT_HELP = "Sync a project to AWS" + +DEFAULT_TEMPLATE_NAME = "template.yaml" + + +@click.command("sync", help=HELP_TEXT, short_help=SHORT_HELP) +@configuration_option(provider=TomlProvider(section="parameters")) +@template_option_without_build +@click.option( + "--infra", + is_flag=True, + help="Sync infrastructure", +) +@click.option( + "--code", + is_flag=True, + help="Sync code resources. This includes Lambda Functions, API Gateway, and Step Functions.", +) +@click.option( + "--watch", + is_flag=True, + help="Watch local files and automatically sync with remote.", +) +@click.option( + "--resource-id", + multiple=True, + help="Sync code for all the resources with the ID.", +) +@click.option( + "--resource", + multiple=True, + help="Sync code for all types of the resource.", +) +@stack_name_option(required=True) # pylint: disable=E1120 +@base_dir_option +@image_repository_option +@image_repositories_option +@s3_prefix_option +@kms_key_id_option +@role_arn_option +@parameter_override_option +@cli_framework_options +@aws_creds_options +@metadata_option +@notification_arns_option +@tags_option +@capabilities_option(default=("CAPABILITY_NAMED_IAM", "CAPABILITY_AUTO_EXPAND")) # pylint: disable=E1120 +@pass_context +@track_command +@image_repository_validation +@track_template_warnings([CodeDeployWarning.__name__, CodeDeployConditionWarning.__name__]) +@check_newer_version +@print_cmdline_args +def cli( + ctx: Context, + template_file: str, + infra: bool, + code: bool, + watch: bool, + resource_id: Optional[Tuple[str]], + resource: Optional[Tuple[str]], + stack_name: str, + base_dir: Optional[str], + parameter_overrides: dict, + image_repository: str, + image_repositories: Optional[Tuple[str]], + s3_prefix: str, + kms_key_id: str, + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + tags: dict, + metadata: dict, + config_file: str, + config_env: str, +) -> None: + """ + `sam sync` command entry point + """ + mode = _get_mode_value_from_envvar("SAM_BUILD_MODE", choices=["debug"]) + # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing + + do_cli( + template_file, + infra, + code, + watch, + resource_id, + resource, + stack_name, + ctx.region, + ctx.profile, + base_dir, + parameter_overrides, + mode, + image_repository, + image_repositories, + s3_prefix, + kms_key_id, + capabilities, + role_arn, + notification_arns, + tags, + metadata, + config_file, + config_env, + ) # pragma: no cover + + +def do_cli( + template_file: str, + infra: bool, + code: bool, + watch: bool, + resource_id: Optional[Tuple[str]], + resource: Optional[Tuple[str]], + stack_name: str, + region: str, + profile: str, + base_dir: Optional[str], + parameter_overrides: dict, + mode: Optional[str], + image_repository: str, + image_repositories: Optional[Tuple[str]], + s3_prefix: str, + kms_key_id: str, + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + tags: dict, + metadata: dict, + config_file: str, + config_env: str, +) -> None: + """ + Implementation of the ``cli`` method + """ + from samcli.lib.utils import osutils + from samcli.commands.build.build_context import BuildContext + from samcli.commands.package.package_context import PackageContext + from samcli.commands.deploy.deploy_context import DeployContext + + s3_bucket = manage_stack(profile=profile, region=region) + click.echo(f"\n\t\tManaged S3 bucket: {s3_bucket}") + click.echo("\t\tA different default S3 bucket can be set in samconfig.toml") + click.echo("\t\tOr by specifying --s3-bucket explicitly.") + + with BuildContext( + resource_identifier=None, + template_file=template_file, + base_dir=base_dir, + build_dir=DEFAULT_BUILD_DIR, + cache_dir=DEFAULT_CACHE_DIR, + clean=True, + use_container=False, + cached=True, + parallel=True, + parameter_overrides=parameter_overrides, + mode=mode, + ) as build_context: + built_template = os.path.join(".aws-sam", "build", DEFAULT_TEMPLATE_NAME) + + with osutils.tempfile_platform_independent() as output_template_file: + with PackageContext( + template_file=built_template, + s3_bucket=s3_bucket, + image_repository=image_repository, + image_repositories=image_repositories, + s3_prefix=s3_prefix, + kms_key_id=kms_key_id, + output_template_file=output_template_file.name, + no_progressbar=True, + metadata=metadata, + region=region, + profile=profile, + use_json=False, + force_upload=True, + ) as package_context: + + with DeployContext( + template_file=output_template_file.name, + stack_name=stack_name, + s3_bucket=s3_bucket, + image_repository=image_repository, + image_repositories=image_repositories, + no_progressbar=True, + s3_prefix=s3_prefix, + kms_key_id=kms_key_id, + parameter_overrides=parameter_overrides, + capabilities=capabilities, + role_arn=role_arn, + notification_arns=notification_arns, + tags=tags, + region=region, + profile=profile, + no_execute_changeset=True, + fail_on_empty_changeset=True, + confirm_changeset=False, + use_changeset=False, + force_upload=True, + signing_profiles=None, + ) as deploy_context: + if watch: + execute_watch(template_file, build_context, package_context, deploy_context) + elif code: + execute_code_sync(template_file, build_context, deploy_context, resource_id, resource) + else: + execute_infra_contexts(build_context, package_context, deploy_context) + + +def execute_infra_contexts( + build_context: "BuildContext", + package_context: "PackageContext", + deploy_context: "DeployContext", +) -> None: + """Executes the sync for infra. + + Parameters + ---------- + build_context : BuildContext + BuildContext + package_context : PackageContext + PackageContext + deploy_context : DeployContext + DeployContext + """ + LOG.debug("Executing the build using build context.") + build_context.run() + LOG.debug("Executing the packaging using package context.") + package_context.run() + LOG.debug("Executing the deployment using deploy context.") + deploy_context.run() + + +def execute_code_sync( + template: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + resource_ids: Optional[Tuple[str]], + resource_types: Optional[Tuple[str]], +) -> None: + """Executes the sync flow for code. + + Parameters + ---------- + template : str + Template file name + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + resource_ids : List[str] + List of resource IDs to be synced. + resource_types : List[str] + List of resource types to be synced. + """ + stacks = SamLocalStackProvider.get_stacks(template)[0] + factory = SyncFlowFactory(build_context, deploy_context, stacks) + factory.load_physical_id_mapping() + executor = SyncFlowExecutor() + + sync_flow_resource_ids: Set[ResourceIdentifier] = ( + get_unique_resource_ids(stacks, resource_ids, resource_types) + if resource_ids or resource_types + else set(get_all_resource_ids(stacks)) + ) + + for resource_id in sync_flow_resource_ids: + sync_flow = factory.create_sync_flow(resource_id) + if sync_flow: + executor.add_sync_flow(sync_flow) + else: + LOG.warning("Cannot create SyncFlow for %s. Skipping.", resource_id) + executor.execute() + + +def execute_watch( + template: str, + build_context: "BuildContext", + package_context: "PackageContext", + deploy_context: "DeployContext", +): + """Start sync watch execution + + Parameters + ---------- + template : str + Template file path + build_context : BuildContext + BuildContext + package_context : PackageContext + PackageContext + deploy_context : DeployContext + DeployContext + """ + watch_manager = WatchManager(template, build_context, package_context, deploy_context) + watch_manager.start() diff --git a/samcli/commands/traces/__init__.py b/samcli/commands/traces/__init__.py new file mode 100644 index 0000000000..596c05f79b --- /dev/null +++ b/samcli/commands/traces/__init__.py @@ -0,0 +1,6 @@ +""" +`sam traces` command +""" + +# Expose the cli object here +from samcli.commands.traces.command import cli diff --git a/samcli/commands/traces/command.py b/samcli/commands/traces/command.py new file mode 100644 index 0000000000..3699365b98 --- /dev/null +++ b/samcli/commands/traces/command.py @@ -0,0 +1,76 @@ +""" +CLI command for "traces" command +""" +import logging + +import click + +from samcli.cli.cli_config_file import configuration_option, TomlProvider +from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options, print_cmdline_args +from samcli.commands._utils.options import common_observability_options +from samcli.lib.telemetry.metric import track_command +from samcli.lib.utils.version_checker import check_newer_version + +LOG = logging.getLogger(__name__) + +HELP_TEXT = """ +Use this command to fetch AWS X-Ray traces generated by your stack.\n +""" + + +@click.command("traces", help=HELP_TEXT, short_help="Fetch AWS X-Ray traces") +@configuration_option(provider=TomlProvider(section="parameters")) +@click.option( + "--trace-id", + "-ti", + multiple=True, + help="Fetch specific trace by providing its id", +) +@common_observability_options +@cli_framework_options +@aws_creds_options +@pass_context +@track_command +@check_newer_version +@print_cmdline_args +def cli( + ctx, + trace_id, + start_time, + end_time, + tail, + unformatted, + config_file, + config_env, +): + """ + `sam traces` command entry point + """ + do_cli(trace_id, start_time, end_time, tail, unformatted, ctx.region) + + +def do_cli(trace_ids, start_time, end_time, tailing, unformatted, region): + """ + Implementation of the ``cli`` method + """ + from datetime import datetime + import boto3 + from samcli.commands.logs.logs_context import parse_time + from samcli.commands.traces.traces_puller_factory import generate_trace_puller + from samcli.lib.utils.boto_utils import get_boto_config_with_user_agent + + sanitized_start_time = parse_time(start_time, "start-time") + sanitized_end_time = parse_time(end_time, "end-time") or datetime.utcnow() + + boto_config = get_boto_config_with_user_agent(region_name=region) + xray_client = boto3.client("xray", config=boto_config) + + # generate puller depending on the parameters + puller = generate_trace_puller(xray_client, unformatted) + + if trace_ids: + puller.load_events(trace_ids) + elif tailing: + puller.tail(sanitized_start_time) + else: + puller.load_time_period(sanitized_start_time, sanitized_end_time) diff --git a/samcli/commands/traces/trace_console_consumers.py b/samcli/commands/traces/trace_console_consumers.py new file mode 100644 index 0000000000..9d84383a35 --- /dev/null +++ b/samcli/commands/traces/trace_console_consumers.py @@ -0,0 +1,18 @@ +""" +Contains console consumers for outputting XRay information back to console/terminal +""" + +import click + +from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumer +from samcli.lib.observability.xray_traces.xray_events import XRayTraceEvent + + +class XRayTraceConsoleConsumer(ObservabilityEventConsumer[XRayTraceEvent]): + """ + An XRayTraceEvent consumer which will output incoming XRayTraceEvent and print it back to console + """ + + # pylint: disable=R0201 + def consume(self, event: XRayTraceEvent): + click.echo(event.message) diff --git a/samcli/commands/traces/traces_puller_factory.py b/samcli/commands/traces/traces_puller_factory.py new file mode 100644 index 0000000000..7c3f5a4860 --- /dev/null +++ b/samcli/commands/traces/traces_puller_factory.py @@ -0,0 +1,112 @@ +""" +Factory methods which generates puller and consumer instances for XRay events +""" +from typing import Any, List + +from samcli.commands.traces.trace_console_consumers import XRayTraceConsoleConsumer +from samcli.lib.observability.observability_info_puller import ( + ObservabilityPuller, + ObservabilityEventConsumer, + ObservabilityEventConsumerDecorator, + ObservabilityCombinedPuller, +) +from samcli.lib.observability.xray_traces.xray_event_mappers import ( + XRayTraceConsoleMapper, + XRayServiceGraphConsoleMapper, + XRayServiceGraphJSONMapper, + XRayTraceJSONMapper, +) +from samcli.lib.observability.xray_traces.xray_event_puller import XRayTracePuller +from samcli.lib.observability.xray_traces.xray_service_graph_event_puller import XRayServiceGraphPuller + + +def generate_trace_puller( + xray_client: Any, + unformatted: bool = False, +) -> ObservabilityPuller: + """ + Generates puller instance with correct consumer and/or mapper configuration + + Parameters + ---------- + xray_client : Any + boto3 xray client to be used in XRayTracePuller instance + unformatted : bool + By default, logs and traces are printed with a format for terminal. If this option is provided, the events + will be printed unformatted in JSON. + + Returns + ------- + Puller instance with desired configuration + """ + pullers: List[ObservabilityPuller] = [] + pullers.append(XRayTracePuller(xray_client, generate_xray_event_consumer(unformatted))) + pullers.append(XRayServiceGraphPuller(xray_client, generate_xray_service_graph_consumer(unformatted))) + + return ObservabilityCombinedPuller(pullers) + + +def generate_unformatted_xray_event_consumer() -> ObservabilityEventConsumer: + """ + Generates unformatted consumer, which will print XRay events unformatted JSON into terminal + + Returns + ------- + File consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayTraceJSONMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_event_console_consumer() -> ObservabilityEventConsumer: + """ + Generates an instance of event consumer which will print events into console + + Returns + ------- + Console consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayTraceConsoleMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_event_consumer(unformatted: bool = False) -> ObservabilityEventConsumer: + """ + Generates consumer instance with the given variables. + If unformatted is True, then it will return consumer with formatters for just JSON. + If not, it will return console consumer + """ + if unformatted: + return generate_unformatted_xray_event_consumer() + return generate_xray_event_console_consumer() + + +def generate_unformatted_xray_service_graph_consumer() -> ObservabilityEventConsumer: + """ + Generates unformatted consumer, which will print XRay events unformatted JSON into terminal + + Returns + ------- + File consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayServiceGraphJSONMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_service_graph_console_consumer() -> ObservabilityEventConsumer: + """ + Generates an instance of event consumer which will print events into console + + Returns + ------- + Console consumer instance with desired mapper configuration + """ + return ObservabilityEventConsumerDecorator([XRayServiceGraphConsoleMapper()], XRayTraceConsoleConsumer()) + + +def generate_xray_service_graph_consumer(unformatted: bool = False) -> ObservabilityEventConsumer: + """ + Generates consumer instance with the given variables. + If unformatted is True, then it will return consumer with formatters for just JSON. + If not, it will return console consumer + """ + if unformatted: + return generate_unformatted_xray_service_graph_consumer() + return generate_xray_service_graph_console_consumer() diff --git a/samcli/commands/validate/lib/sam_template_validator.py b/samcli/commands/validate/lib/sam_template_validator.py index d9de756674..ca27ac8c56 100644 --- a/samcli/commands/validate/lib/sam_template_validator.py +++ b/samcli/commands/validate/lib/sam_template_validator.py @@ -10,7 +10,7 @@ from boto3.session import Session from samcli.lib.utils.packagetype import ZIP, IMAGE -from samcli.commands._utils.resources import AWS_SERVERLESS_FUNCTION +from samcli.lib.utils.resources import AWS_SERVERLESS_FUNCTION from samcli.yamlhelper import yaml_dump from .exceptions import InvalidSamDocumentException diff --git a/samcli/lib/build/app_builder.py b/samcli/lib/build/app_builder.py index 9823db0b4e..80c5f8eeea 100644 --- a/samcli/lib/build/app_builder.py +++ b/samcli/lib/build/app_builder.py @@ -10,20 +10,30 @@ import docker import docker.errors -from aws_lambda_builders import RPC_PROTOCOL_VERSION as lambda_builders_protocol_version +from aws_lambda_builders import ( + RPC_PROTOCOL_VERSION as lambda_builders_protocol_version, + __version__ as lambda_builders_version, +) from aws_lambda_builders.builder import LambdaBuilder from aws_lambda_builders.exceptions import LambdaBuilderError from samcli.commands.local.lib.exceptions import OverridesNotWellDefinedError from samcli.lib.build.build_graph import FunctionBuildDefinition, LayerBuildDefinition, BuildGraph from samcli.lib.build.build_strategy import ( DefaultBuildStrategy, - CachedBuildStrategy, + CachedOrIncrementalBuildStrategyWrapper, ParallelBuildStrategy, BuildStrategy, ) +from samcli.lib.utils.resources import ( + AWS_CLOUDFORMATION_STACK, + AWS_LAMBDA_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_APPLICATION, + AWS_SERVERLESS_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, +) from samcli.lib.docker.log_streamer import LogStreamer, LogStreamError from samcli.lib.providers.provider import ResourcesToBuildCollector, Function, get_full_path, Stack, LayerVersion -from samcli.lib.providers.sam_base_provider import SamBaseProvider from samcli.lib.utils.colors import Colored from samcli.lib.utils import osutils from samcli.lib.utils.packagetype import IMAGE, ZIP @@ -146,24 +156,26 @@ def build(self) -> Dict[str, str]: if self._cached: build_strategy = ParallelBuildStrategy( build_graph, - CachedBuildStrategy( + CachedOrIncrementalBuildStrategyWrapper( build_graph, build_strategy, self._base_dir, self._build_dir, self._cache_dir, + self._manifest_path_override, self._is_building_specific_resource, ), ) else: build_strategy = ParallelBuildStrategy(build_graph, build_strategy) elif self._cached: - build_strategy = CachedBuildStrategy( + build_strategy = CachedOrIncrementalBuildStrategyWrapper( build_graph, build_strategy, self._base_dir, self._build_dir, self._cache_dir, + self._manifest_path_override, self._is_building_specific_resource, ) @@ -266,26 +278,26 @@ def update_template( store_path = os.path.relpath(absolute_output_path, original_dir) if has_build_artifact: - if resource_type == SamBaseProvider.SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == ZIP: + if resource_type == AWS_SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == ZIP: properties["CodeUri"] = store_path - if resource_type == SamBaseProvider.LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == ZIP: + if resource_type == AWS_LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == ZIP: properties["Code"] = store_path - if resource_type in [SamBaseProvider.SERVERLESS_LAYER, SamBaseProvider.LAMBDA_LAYER]: + if resource_type in [AWS_SERVERLESS_LAYERVERSION, AWS_LAMBDA_LAYERVERSION]: properties["ContentUri"] = store_path - if resource_type == SamBaseProvider.LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: + if resource_type == AWS_LAMBDA_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: properties["Code"] = built_artifacts[full_path] - if resource_type == SamBaseProvider.SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: + if resource_type == AWS_SERVERLESS_FUNCTION and properties.get("PackageType", ZIP) == IMAGE: properties["ImageUri"] = built_artifacts[full_path] if is_stack: - if resource_type == SamBaseProvider.SERVERLESS_APPLICATION: + if resource_type == AWS_SERVERLESS_APPLICATION: properties["Location"] = store_path - if resource_type == SamBaseProvider.CLOUDFORMATION_STACK: + if resource_type == AWS_CLOUDFORMATION_STACK: properties["TemplateURL"] = store_path return template_dict @@ -381,6 +393,8 @@ def _build_layer( compatible_runtimes: List[str], artifact_dir: str, container_env_vars: Optional[Dict] = None, + dependencies_dir: Optional[str] = None, + download_dependencies: bool = True, ) -> str: """ Given the layer information, this method will build the Lambda layer. Depending on the configuration @@ -390,22 +404,23 @@ def _build_layer( ---------- layer_name : str Name or LogicalId of the function - codeuri : str Path to where the code lives - specified_workflow : str The specified workflow - compatible_runtimes : List[str] List of runtimes the layer build is compatible with - artifact_dir : str Path to where layer will be build into. A subfolder will be created in this directory depending on the specified workflow. - container_env_vars : Optional[Dict] An optional dictionary of environment variables to pass to the container. + dependencies_dir: Optional[str] + An optional string parameter which will be used in lambda builders for downloading dependencies into + separate folder + download_dependencies: bool + An optional boolean parameter to inform lambda builders whether download dependencies or use previously + downloaded ones. Default value is True. Returns ------- @@ -444,7 +459,15 @@ def _build_layer( ) else: self._build_function_in_process( - config, code_dir, artifact_subdir, scratch_dir, manifest_path, build_runtime, options + config, + code_dir, + artifact_subdir, + scratch_dir, + manifest_path, + build_runtime, + options, + dependencies_dir, + download_dependencies, ) # Not including subfolder in return so that we copy subfolder, instead of copying artifacts inside it. @@ -460,6 +483,8 @@ def _build_function( # pylint: disable=R1710 artifact_dir: str, metadata: Optional[Dict] = None, container_env_vars: Optional[Dict] = None, + dependencies_dir: Optional[str] = None, + download_dependencies: bool = True, ) -> str: """ Given the function information, this method will build the Lambda function. Depending on the configuration @@ -483,6 +508,12 @@ def _build_function( # pylint: disable=R1710 AWS Lambda function metadata container_env_vars : Optional[Dict] An optional dictionary of environment variables to pass to the container. + dependencies_dir: Optional[str] + An optional string parameter which will be used in lambda builders for downloading dependencies into + separate folder + download_dependencies: bool + An optional boolean parameter to inform lambda builders whether download dependencies or use previously + downloaded ones. Default value is True. Returns ------- @@ -534,7 +565,15 @@ def _build_function( # pylint: disable=R1710 ) return self._build_function_in_process( - config, code_dir, artifact_dir, scratch_dir, manifest_path, runtime, options + config, + code_dir, + artifact_dir, + scratch_dir, + manifest_path, + runtime, + options, + dependencies_dir, + download_dependencies, ) # pylint: disable=fixme @@ -573,6 +612,8 @@ def _build_function_in_process( manifest_path: str, runtime: str, options: Optional[Dict], + dependencies_dir: Optional[str], + download_dependencies: bool, ) -> str: builder = LambdaBuilder( @@ -583,17 +624,19 @@ def _build_function_in_process( runtime = runtime.replace(".al2", "") + kwargs = { + "runtime": runtime, + "executable_search_paths": config.executable_search_paths, + "mode": self._mode, + "options": options, + } + # todo: remove this check once the lambda builder release is finished + if lambda_builders_version == "1.8.0": + kwargs["dependencies_dir"] = dependencies_dir + kwargs["download_dependencies"] = download_dependencies + try: - builder.build( - source_dir, - artifacts_dir, - scratch_dir, - manifest_path, - runtime=runtime, - executable_search_paths=config.executable_search_paths, - mode=self._mode, - options=options, - ) + builder.build(source_dir, artifacts_dir, scratch_dir, manifest_path, **kwargs) except LambdaBuilderError as ex: raise BuildError(wrapped_from=ex.__class__.__name__, msg=str(ex)) from ex diff --git a/samcli/lib/build/build_graph.py b/samcli/lib/build/build_graph.py index f2d328be02..f3b1b1a837 100644 --- a/samcli/lib/build/build_graph.py +++ b/samcli/lib/build/build_graph.py @@ -2,10 +2,13 @@ Holds classes and utility methods related to build graph """ +import copy import logging -from copy import deepcopy +import os +import threading from pathlib import Path -from typing import Tuple, List, Any, Optional, Dict, cast +from typing import Sequence, Tuple, List, Any, Optional, Dict, cast, NamedTuple +from copy import deepcopy from uuid import uuid4 import tomlkit @@ -18,6 +21,8 @@ DEFAULT_BUILD_GRAPH_FILE_NAME = "build.toml" +DEFAULT_DEPENDENCIES_DIR = os.path.join(".aws-sam", "deps") + # filed names for the toml table PACKAGETYPE_FIELD = "packagetype" CODE_URI_FIELD = "codeuri" @@ -25,6 +30,7 @@ METADATA_FIELD = "metadata" FUNCTIONS_FIELD = "functions" SOURCE_MD5_FIELD = "source_md5" +MANIFEST_MD5_FIELD = "manifest_md5" ENV_VARS_FIELD = "env_vars" LAYER_NAME_FIELD = "layer_name" BUILD_METHOD_FIELD = "build_method" @@ -52,7 +58,9 @@ def _function_build_definition_to_toml_table( if function_build_definition.packagetype == ZIP: toml_table[CODE_URI_FIELD] = function_build_definition.codeuri toml_table[RUNTIME_FIELD] = function_build_definition.runtime - toml_table[SOURCE_MD5_FIELD] = function_build_definition.source_md5 + if function_build_definition.source_md5: + toml_table[SOURCE_MD5_FIELD] = function_build_definition.source_md5 + toml_table[MANIFEST_MD5_FIELD] = function_build_definition.manifest_md5 toml_table[PACKAGETYPE_FIELD] = function_build_definition.packagetype toml_table[FUNCTIONS_FIELD] = [f.full_path for f in function_build_definition.functions] @@ -86,6 +94,7 @@ def _toml_table_to_function_build_definition(uuid: str, toml_table: tomlkit.api. toml_table.get(PACKAGETYPE_FIELD, ZIP), dict(toml_table.get(METADATA_FIELD, {})), toml_table.get(SOURCE_MD5_FIELD, ""), + toml_table.get(MANIFEST_MD5_FIELD, ""), dict(toml_table.get(ENV_VARS_FIELD, {})), ) function_build_definition.uuid = uuid @@ -111,8 +120,9 @@ def _layer_build_definition_to_toml_table(layer_build_definition: "LayerBuildDef toml_table[CODE_URI_FIELD] = layer_build_definition.codeuri toml_table[BUILD_METHOD_FIELD] = layer_build_definition.build_method toml_table[COMPATIBLE_RUNTIMES_FIELD] = layer_build_definition.compatible_runtimes - toml_table[SOURCE_MD5_FIELD] = layer_build_definition.source_md5 - toml_table[LAYER_FIELD] = layer_build_definition.layer.name + if layer_build_definition.source_md5: + toml_table[SOURCE_MD5_FIELD] = layer_build_definition.source_md5 + toml_table[MANIFEST_MD5_FIELD] = layer_build_definition.manifest_md5 if layer_build_definition.env_vars: toml_table[ENV_VARS_FIELD] = layer_build_definition.env_vars toml_table[LAYER_FIELD] = layer_build_definition.layer.full_path @@ -142,17 +152,30 @@ def _toml_table_to_layer_build_definition(uuid: str, toml_table: tomlkit.api.Tab toml_table.get(BUILD_METHOD_FIELD), toml_table.get(COMPATIBLE_RUNTIMES_FIELD), toml_table.get(SOURCE_MD5_FIELD, ""), + toml_table.get(MANIFEST_MD5_FIELD, ""), dict(toml_table.get(ENV_VARS_FIELD, {})), ) layer_build_definition.uuid = uuid return layer_build_definition +class BuildHashingInformation(NamedTuple): + """ + Holds hashing information for the source folder and the manifest file + """ + + source_md5: str + manifest_md5: str + + class BuildGraph: """ Contains list of build definitions, with ability to read and write them into build.toml file """ + # private lock for build.toml reads and writes + __toml_lock = threading.Lock() + # global table build definitions key FUNCTION_BUILD_DEFINITIONS = "function_build_definitions" LAYER_BUILD_DEFINITIONS = "layer_build_definitions" @@ -162,7 +185,7 @@ def __init__(self, build_dir: str) -> None: self._filepath = Path(build_dir).parent.joinpath(DEFAULT_BUILD_GRAPH_FILE_NAME) self._function_build_definitions: List["FunctionBuildDefinition"] = [] self._layer_build_definitions: List["LayerBuildDefinition"] = [] - self._read() + self._atomic_read() def get_function_build_definitions(self) -> Tuple["FunctionBuildDefinition", ...]: return tuple(self._function_build_definitions) @@ -254,7 +277,81 @@ def clean_redundant_definitions_and_update(self, persist: bool) -> None: ] self._layer_build_definitions[:] = [bd for bd in self._layer_build_definitions if bd.layer] if persist: - self._write() + self._atomic_write() + + def update_definition_md5(self) -> None: + """ + Updates the build.toml file with the newest source_md5 values of the partial build's definitions + + This operation is atomic, that no other thread accesses build.toml + during the process of reading and modifying the md5 value + """ + with BuildGraph.__toml_lock: + stored_definitions = copy.deepcopy(self._function_build_definitions) + stored_layers = copy.deepcopy(self._layer_build_definitions) + self._read() + + function_content = BuildGraph._compare_md5_changes(stored_definitions, self._function_build_definitions) + layer_content = BuildGraph._compare_md5_changes(stored_layers, self._layer_build_definitions) + + if function_content or layer_content: + self._write_source_md5(function_content, layer_content) + + @staticmethod + def _compare_md5_changes( + input_list: Sequence["AbstractBuildDefinition"], compared_list: Sequence["AbstractBuildDefinition"] + ) -> Dict[str, BuildHashingInformation]: + """ + Helper to compare the function and layer definition changes in md5 value + + Returns a dictionary that has uuid as key, updated md5 value as value + """ + content = {} + for compared_def in compared_list: + for stored_def in input_list: + if stored_def == compared_def: + old_md5 = compared_def.source_md5 + updated_md5 = stored_def.source_md5 + old_manifest_md5 = compared_def.manifest_md5 + updated_manifest_md5 = stored_def.manifest_md5 + uuid = stored_def.uuid + if old_md5 != updated_md5 or old_manifest_md5 != updated_manifest_md5: + content[uuid] = BuildHashingInformation(updated_md5, updated_manifest_md5) + return content + + def _write_source_md5( + self, function_content: Dict[str, BuildHashingInformation], layer_content: Dict[str, BuildHashingInformation] + ) -> None: + """ + Helper to write source_md5 values to build.toml file + """ + document = {} + if not self._filepath.exists(): + open(self._filepath, "a+").close() + + txt = self._filepath.read_text() + # .loads() returns a TOMLDocument, + # and it behaves like a standard dictionary according to https://github.com/sdispater/tomlkit. + # in tomlkit 0.7.2, the types are broken (tomlkit#128, #130, #134) so here we convert it to Dict. + document = cast(Dict, tomlkit.loads(txt)) + + for function_uuid, hashing_info in function_content.items(): + if function_uuid in document.get(BuildGraph.FUNCTION_BUILD_DEFINITIONS, {}): + function_build_definition = document[BuildGraph.FUNCTION_BUILD_DEFINITIONS][function_uuid] + function_build_definition[SOURCE_MD5_FIELD] = hashing_info.source_md5 + function_build_definition[MANIFEST_MD5_FIELD] = hashing_info.manifest_md5 + LOG.info( + "Updated source_md5 and manifest_md5 field in build.toml for function with UUID %s", function_uuid + ) + + for layer_uuid, hashing_info in layer_content.items(): + if layer_uuid in document.get(BuildGraph.LAYER_BUILD_DEFINITIONS, {}): + layer_build_definition = document[BuildGraph.LAYER_BUILD_DEFINITIONS][layer_uuid] + layer_build_definition[SOURCE_MD5_FIELD] = hashing_info.source_md5 + layer_build_definition[MANIFEST_MD5_FIELD] = hashing_info.manifest_md5 + LOG.info("Updated source_md5 and manifest_md5 field in build.toml for layer with UUID %s", layer_uuid) + + self._filepath.write_text(tomlkit.dumps(document)) # type: ignore def _read(self) -> None: """ @@ -273,20 +370,29 @@ def _read(self) -> None: document = cast(Dict, tomlkit.loads(txt)) except OSError: LOG.debug("No previous build graph found, generating new one") - function_build_definitions_table = document.get(BuildGraph.FUNCTION_BUILD_DEFINITIONS, []) + function_build_definitions_table = document.get(BuildGraph.FUNCTION_BUILD_DEFINITIONS, {}) for function_build_definition_key in function_build_definitions_table: function_build_definition = _toml_table_to_function_build_definition( function_build_definition_key, function_build_definitions_table[function_build_definition_key] ) self._function_build_definitions.append(function_build_definition) - layer_build_definitions_table = document.get(BuildGraph.LAYER_BUILD_DEFINITIONS, []) + layer_build_definitions_table = document.get(BuildGraph.LAYER_BUILD_DEFINITIONS, {}) for layer_build_definition_key in layer_build_definitions_table: layer_build_definition = _toml_table_to_layer_build_definition( layer_build_definition_key, layer_build_definitions_table[layer_build_definition_key] ) self._layer_build_definitions.append(layer_build_definition) + def _atomic_read(self) -> None: + """ + Performs the _read() method with a global lock acquired + It makes sure no other thread accesses build.toml when a read is happening + """ + + with BuildGraph.__toml_lock: + self._read() + def _write(self) -> None: """ Writes build definition details into build.toml file, which would be used by the next build. @@ -317,6 +423,15 @@ def _write(self) -> None: self._filepath.write_text(tomlkit.dumps(document)) + def _atomic_write(self) -> None: + """ + Performs the _write() method with a global lock acquired + It makes sure no other thread accesses build.toml when a write is happening + """ + + with BuildGraph.__toml_lock: + self._write() + class AbstractBuildDefinition: """ @@ -324,11 +439,18 @@ class AbstractBuildDefinition: Build definition holds information about each unique build """ - def __init__(self, source_md5: str, env_vars: Optional[Dict] = None) -> None: + def __init__(self, source_md5: str, manifest_md5: str, env_vars: Optional[Dict] = None) -> None: self.uuid = str(uuid4()) self.source_md5 = source_md5 + self.manifest_md5 = manifest_md5 + # following properties are used during build time and they don't serialize into build.toml file + self.download_dependencies: bool = True self._env_vars = env_vars if env_vars else {} + @property + def dependencies_dir(self) -> str: + return str(os.path.join(DEFAULT_DEPENDENCIES_DIR, self.uuid)) + @property def env_vars(self) -> Dict: return deepcopy(self._env_vars) @@ -346,9 +468,10 @@ def __init__( build_method: Optional[str], compatible_runtimes: Optional[List[str]], source_md5: str = "", + manifest_md5: str = "", env_vars: Optional[Dict] = None, ): - super().__init__(source_md5, env_vars) + super().__init__(source_md5, manifest_md5, env_vars) self.name = name self.codeuri = codeuri self.build_method = build_method @@ -401,9 +524,10 @@ def __init__( packagetype: str, metadata: Optional[Dict], source_md5: str = "", + manifest_md5: str = "", env_vars: Optional[Dict] = None, ) -> None: - super().__init__(source_md5, env_vars) + super().__init__(source_md5, manifest_md5, env_vars) self.runtime = runtime self.codeuri = codeuri self.packagetype = packagetype diff --git a/samcli/lib/build/build_strategy.py b/samcli/lib/build/build_strategy.py index ecded3a743..dc6ac9f24b 100644 --- a/samcli/lib/build/build_strategy.py +++ b/samcli/lib/build/build_strategy.py @@ -5,18 +5,47 @@ import pathlib import shutil from abc import abstractmethod, ABC -from typing import Callable, Dict, List, Any, Optional, cast +from copy import deepcopy +from typing import Callable, Dict, List, Any, Optional, cast, Set -from samcli.commands.build.exceptions import MissingBuildMethodException from samcli.lib.utils import osutils from samcli.lib.utils.async_utils import AsyncContext from samcli.lib.utils.hash import dir_checksum from samcli.lib.utils.packagetype import ZIP, IMAGE -from samcli.lib.build.build_graph import BuildGraph, FunctionBuildDefinition, LayerBuildDefinition +from samcli.lib.build.dependency_hash_generator import DependencyHashGenerator +from samcli.lib.build.build_graph import ( + BuildGraph, + FunctionBuildDefinition, + LayerBuildDefinition, + AbstractBuildDefinition, + DEFAULT_DEPENDENCIES_DIR, +) +from samcli.lib.build.exceptions import MissingBuildMethodException LOG = logging.getLogger(__name__) +def clean_redundant_folders(base_dir: str, uuids: Set[str]) -> None: + """ + Compares existing folders inside base_dir and removes the ones which is not in the uuids set. + + Parameters + ---------- + base_dir : str + Base directory that it will be operating + uuids : Set[str] + Expected folder names. If any folder name in the base_dir is not present in this Set, it will be deleted. + """ + base_dir_path = pathlib.Path(base_dir) + + if not base_dir_path.exists(): + return + + for full_dir_path in pathlib.Path(base_dir).iterdir(): + if full_dir_path.name not in uuids: + shutil.rmtree(pathlib.Path(base_dir, full_dir_path.name)) + + class BuildStrategy(ABC): """ Base class for BuildStrategy @@ -87,8 +116,8 @@ def __init__( self, build_graph: BuildGraph, build_dir: str, - build_function: Callable[[str, str, str, str, Optional[str], str, dict, dict], str], - build_layer: Callable[[str, str, str, List[str], str, dict], str], + build_function: Callable[[str, str, str, str, Optional[str], str, dict, dict, Optional[str], bool], str], + build_layer: Callable[[str, str, str, List[str], str, dict, Optional[str], bool], str], ) -> None: super().__init__(build_graph) self._build_dir = build_dir @@ -114,6 +143,10 @@ def build_single_function_definition(self, build_definition: FunctionBuildDefini LOG.debug("Building to following folder %s", single_build_dir) + # we should create a copy and pass it down, otherwise additional env vars like LAMBDA_BUILDERS_LOG_LEVEL + # will make cache invalid all the time + container_env_vars = deepcopy(build_definition.env_vars) + # when a function is passed here, it is ZIP function, codeuri and runtime are not None result = self._build_function( build_definition.get_function_name(), @@ -123,7 +156,9 @@ def build_single_function_definition(self, build_definition: FunctionBuildDefini build_definition.get_handler_name(), single_build_dir, build_definition.metadata, - build_definition.env_vars, + container_env_vars, + build_definition.dependencies_dir, + build_definition.download_dependencies, ) function_build_results[single_full_path] = result @@ -168,6 +203,8 @@ def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) layer.compatible_runtimes, # type: ignore single_build_dir, layer_definition.env_vars, + layer_definition.dependencies_dir, + layer_definition.download_dependencies, ) } @@ -188,21 +225,16 @@ def __init__( base_dir: str, build_dir: str, cache_dir: str, - is_building_specific_resource: bool, ) -> None: super().__init__(build_graph) self._delegate_build_strategy = delegate_build_strategy self._base_dir = base_dir self._build_dir = build_dir self._cache_dir = cache_dir - self._is_building_specific_resource = is_building_specific_resource - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self._clean_redundant_cached() def build(self) -> Dict[str, str]: result = {} - with self, self._delegate_build_strategy: + with self._delegate_build_strategy: result.update(super().build()) return result @@ -290,12 +322,9 @@ def _clean_redundant_cached(self) -> None: """ clean the redundant cached folder """ - self._build_graph.clean_redundant_definitions_and_update(not self._is_building_specific_resource) uuids = {bd.uuid for bd in self._build_graph.get_function_build_definitions()} uuids.update({ld.uuid for ld in self._build_graph.get_layer_build_definitions()}) - for cache_dir in pathlib.Path(self._cache_dir).iterdir(): - if cache_dir.name not in uuids: - shutil.rmtree(pathlib.Path(self._cache_dir, cache_dir.name)) + clean_redundant_folders(self._cache_dir, uuids) class ParallelBuildStrategy(BuildStrategy): @@ -309,18 +338,18 @@ def __init__( self, build_graph: BuildGraph, delegate_build_strategy: BuildStrategy, - async_context: AsyncContext = AsyncContext(), + async_context: Optional[AsyncContext] = None, ) -> None: super().__init__(build_graph) self._delegate_build_strategy = delegate_build_strategy - self._async_context = async_context + self._async_context = async_context if async_context else AsyncContext() def build(self) -> Dict[str, str]: """ Runs all build and collects results from async context """ result = {} - with self, self._delegate_build_strategy: + with self._delegate_build_strategy: # ignore result super().build() # wait for other executions to complete @@ -348,3 +377,181 @@ def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) self._delegate_build_strategy.build_single_layer_definition, layer_definition ) return {} + + +class IncrementalBuildStrategy(BuildStrategy): + """ + Incremental build is supported for certain runtimes in aws-lambda-builders, with dependencies_dir (str) + and download_dependencies (bool) options. + + This build strategy sets whether we need to download dependencies again (download_dependencies option) by comparing + the md5 of the manifest file of the given runtime as well as the dependencies directory location + (dependencies_dir option). + """ + + def __init__( + self, + build_graph: BuildGraph, + delegate_build_strategy: BuildStrategy, + base_dir: str, + manifest_path_override: Optional[str], + ): + super().__init__(build_graph) + self._delegate_build_strategy = delegate_build_strategy + self._base_dir = base_dir + self._manifest_path_override = manifest_path_override + + def build(self) -> Dict[str, str]: + result = {} + with self, self._delegate_build_strategy: + result.update(super().build()) + return result + + def build_single_function_definition(self, build_definition: FunctionBuildDefinition) -> Dict[str, str]: + self._check_whether_manifest_is_changed(build_definition, build_definition.codeuri, build_definition.runtime) + return self._delegate_build_strategy.build_single_function_definition(build_definition) + + def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) -> Dict[str, str]: + self._check_whether_manifest_is_changed( + layer_definition, layer_definition.codeuri, layer_definition.build_method + ) + return self._delegate_build_strategy.build_single_layer_definition(layer_definition) + + def _check_whether_manifest_is_changed( + self, + build_definition: AbstractBuildDefinition, + codeuri: Optional[str], + runtime: Optional[str], + ) -> None: + """ + Checks whether the manifest file have been changed by comparing its md5 with previously stored one and updates + download_dependencies property of build definition to True, if it is changed + """ + manifest_hash = DependencyHashGenerator( + cast(str, codeuri), self._base_dir, cast(str, runtime), self._manifest_path_override + ).hash + + is_manifest_changed = True + if manifest_hash: + is_manifest_changed = manifest_hash != build_definition.manifest_md5 + if is_manifest_changed: + build_definition.manifest_md5 = manifest_hash + LOG.info( + "Manifest is changed for %s, downloading dependencies and copying/building source", + build_definition.uuid, + ) + else: + LOG.info("Manifest is not changed for %s, running incremental build", build_definition.uuid) + + build_definition.download_dependencies = is_manifest_changed + + def _clean_redundant_dependencies(self) -> None: + """ + Update build definitions with possible new manifest md5 information and clean the redundant dependencies folder + """ + uuids = {bd.uuid for bd in self._build_graph.get_function_build_definitions()} + uuids.update({ld.uuid for ld in self._build_graph.get_layer_build_definitions()}) + clean_redundant_folders(DEFAULT_DEPENDENCIES_DIR, uuids) + + +class CachedOrIncrementalBuildStrategyWrapper(BuildStrategy): + """ + A wrapper class which holds instance of CachedBuildStrategy and IncrementalBuildStrategy + to select one of them during function or layer build, depending on the runtime that they are using + """ + + SUPPORTED_RUNTIME_PREFIXES: Set[str] = { + "python", + "ruby", + "nodejs", + } + + def __init__( + self, + build_graph: BuildGraph, + delegate_build_strategy: BuildStrategy, + base_dir: str, + build_dir: str, + cache_dir: str, + manifest_path_override: Optional[str], + is_building_specific_resource: bool, + ): + super().__init__(build_graph) + self._incremental_build_strategy = IncrementalBuildStrategy( + build_graph, + delegate_build_strategy, + base_dir, + manifest_path_override, + ) + self._cached_build_strategy = CachedBuildStrategy( + build_graph, + delegate_build_strategy, + base_dir, + build_dir, + cache_dir, + ) + self._is_building_specific_resource = is_building_specific_resource + + def build(self) -> Dict[str, str]: + result = {} + with self._cached_build_strategy, self._incremental_build_strategy: + result.update(super().build()) + return result + + def build_single_function_definition(self, build_definition: FunctionBuildDefinition) -> Dict[str, str]: + if self._is_incremental_build_supported(build_definition.runtime): + LOG.debug( + "Running incremental build for runtime %s for build definition %s", + build_definition.runtime, + build_definition.uuid, + ) + return self._incremental_build_strategy.build_single_function_definition(build_definition) + + LOG.debug( + "Running incremental build for runtime %s for build definition %s", + build_definition.runtime, + build_definition.uuid, + ) + return self._cached_build_strategy.build_single_function_definition(build_definition) + + def build_single_layer_definition(self, layer_definition: LayerBuildDefinition) -> Dict[str, str]: + if self._is_incremental_build_supported(layer_definition.build_method): + LOG.debug( + "Running incremental build for runtime %s for build definition %s", + layer_definition.build_method, + layer_definition.uuid, + ) + return self._incremental_build_strategy.build_single_layer_definition(layer_definition) + + LOG.debug( + "Running cached build for runtime %s for build definition %s", + layer_definition.build_method, + layer_definition.uuid, + ) + return self._cached_build_strategy.build_single_layer_definition(layer_definition) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """ + After build is complete, this method cleans up redundant folders in cached directory as well as in dependencies + directory. This also updates hashes of the functions and layers, if only single function or layer is been built. + + If SAM CLI switched to use only IncrementalBuildStrategy, contents of this method should be moved inside + IncrementalBuildStrategy so that it will still continue to clean-up redundant folders. + """ + if self._is_building_specific_resource: + self._build_graph.update_definition_md5() + else: + self._build_graph.clean_redundant_definitions_and_update(not self._is_building_specific_resource) + self._cached_build_strategy._clean_redundant_cached() + self._incremental_build_strategy._clean_redundant_dependencies() + + @staticmethod + def _is_incremental_build_supported(runtime: Optional[str]) -> bool: + if not runtime: + return False + + for supported_runtime_prefix in CachedOrIncrementalBuildStrategyWrapper.SUPPORTED_RUNTIME_PREFIXES: + if runtime.startswith(supported_runtime_prefix): + return True + + return False diff --git a/samcli/lib/build/dependency_hash_generator.py b/samcli/lib/build/dependency_hash_generator.py new file mode 100644 index 0000000000..d92b1460a8 --- /dev/null +++ b/samcli/lib/build/dependency_hash_generator.py @@ -0,0 +1,86 @@ +"""Utility Class for Getting Function or Layer Manifest Dependency Hashes""" +import pathlib + +from typing import Any, Optional + +from samcli.lib.build.workflow_config import get_workflow_config +from samcli.lib.utils.hash import file_checksum + +# TODO Expand this class to hash specific sections of the manifest +class DependencyHashGenerator: + _code_uri: str + _base_dir: str + _code_dir: str + _runtime: str + _manifest_path_override: Optional[str] + _hash_generator: Any + _calculated: bool + _hash: Optional[str] + + def __init__( + self, + code_uri: str, + base_dir: str, + runtime: str, + manifest_path_override: Optional[str] = None, + hash_generator: Any = None, + ): + """ + Parameters + ---------- + code_uri : str + Relative path specified in the function/layer resource + base_dir : str + Absolute path which the function/layer dir is located + runtime : str + Runtime of the function/layer + manifest_path_override : Optional[str], optional + Override default manifest path for each runtime, by default None + hash_generator : Any, optional + Hash generation function. Can be hashlib.md5(), hashlib.sha256(), etc, by default None + """ + self._code_uri = code_uri + self._base_dir = base_dir + self._code_dir = str(pathlib.Path(self._base_dir, self._code_uri).resolve()) + self._runtime = runtime + self._manifest_path_override = manifest_path_override + self._hash_generator = hash_generator + self._calculated = False + self._hash = None + + def _calculate_dependency_hash(self) -> Optional[str]: + """Calculate the manifest file hash + + Returns + ------- + Optional[str] + Returns manifest hash. If manifest does not exist or not supported, None will be returned. + """ + if self._manifest_path_override: + manifest_file = self._manifest_path_override + else: + config = get_workflow_config(self._runtime, self._code_dir, self._base_dir) + manifest_file = config.manifest_name + + if not manifest_file: + return None + + manifest_path = pathlib.Path(self._code_dir, manifest_file).resolve() + if not manifest_path.is_file(): + return None + + return file_checksum(str(manifest_path), hash_generator=self._hash_generator) + + @property + def hash(self) -> Optional[str]: + """ + Returns + ------- + Optional[str] + Hash for dependencies in the manifest. + If the manifest does not exist or not supported, this value will be None. + """ + if not self._calculated: + self._hash = self._calculate_dependency_hash() + self._calculated = True + return self._hash diff --git a/samcli/lib/build/exceptions.py b/samcli/lib/build/exceptions.py index 7b5fc265d4..321302c677 100644 --- a/samcli/lib/build/exceptions.py +++ b/samcli/lib/build/exceptions.py @@ -41,6 +41,11 @@ def __init__(self, msg: str) -> None: BuildError.__init__(self, "DockerBuildFailed", msg) +class MissingBuildMethodException(BuildError): + def __init__(self, msg: str) -> None: + BuildError.__init__(self, "MissingBuildMethodException", msg) + + class InvalidBuildGraphException(Exception): def __init__(self, msg: str) -> None: Exception.__init__(self, msg) diff --git a/samcli/lib/deploy/deployer.py b/samcli/lib/deploy/deployer.py index 8c1326698f..8b17cab7e5 100644 --- a/samcli/lib/deploy/deployer.py +++ b/samcli/lib/deploy/deployer.py @@ -21,7 +21,7 @@ import logging import time from datetime import datetime -from typing import Dict, List +from typing import Dict, List, Optional import botocore @@ -52,7 +52,7 @@ } ) -DESCRIBE_STACK_EVENTS_TABLE_HEADER_NAME = "CloudFormation events from changeset" +DESCRIBE_STACK_EVENTS_TABLE_HEADER_NAME = "CloudFormation events from stack operations" DESCRIBE_CHANGESET_FORMAT_STRING = "{Operation:<{0}} {LogicalResourceId:<{1}} {ResourceType:<{2}} {Replacement:<{3}}" DESCRIBE_CHANGESET_DEFAULT_ARGS = OrderedDict( @@ -172,6 +172,17 @@ def create_changeset( "Tags": tags, } + kwargs = self._process_kwargs(kwargs, s3_uploader, capabilities, role_arn, notification_arns) + return self._create_change_set(stack_name=stack_name, changeset_type=changeset_type, **kwargs) + + @staticmethod + def _process_kwargs( + kwargs: dict, + s3_uploader: Optional[S3Uploader], + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + ) -> dict: # If an S3 uploader is available, use TemplateURL to deploy rather than # TemplateBody. This is required for large templates. if s3_uploader: @@ -192,7 +203,7 @@ def create_changeset( kwargs["RoleARN"] = role_arn if notification_arns is not None: kwargs["NotificationARNs"] = notification_arns - return self._create_change_set(stack_name=stack_name, changeset_type=changeset_type, **kwargs) + return kwargs def _create_change_set(self, stack_name, changeset_type, **kwargs): try: @@ -405,17 +416,17 @@ def describe_stack_events(self, stack_name, time_stamp_marker, **kwargs): def _check_stack_complete(status: str) -> bool: return "COMPLETE" in status and "CLEANUP" not in status - def wait_for_execute(self, stack_name, changeset_type): + def wait_for_execute(self, stack_name: str, stack_operation: str) -> None: """ - Wait for changeset to execute and return when execution completes. + Wait for stack operation to execute and return when execution completes. If the stack has "Outputs," they will be printed. Parameters ---------- stack_name : str The name of the stack - changeset_type : str - The type of the changeset, 'CREATE' or 'UPDATE' + stack_operation : str + The type of the stack operation, 'CREATE' or 'UPDATE' """ sys.stdout.write( "\n{} - Waiting for stack create/update " @@ -426,12 +437,12 @@ def wait_for_execute(self, stack_name, changeset_type): self.describe_stack_events(stack_name, self.get_last_event_time(stack_name)) # Pick the right waiter - if changeset_type == "CREATE": + if stack_operation == "CREATE": waiter = self._client.get_waiter("stack_create_complete") - elif changeset_type == "UPDATE": + elif stack_operation == "UPDATE": waiter = self._client.get_waiter("stack_update_complete") else: - raise RuntimeError("Invalid changeset type {0}".format(changeset_type)) + raise RuntimeError("Invalid stack operation type {0}".format(stack_operation)) # Poll every 30 seconds. Polling too frequently risks hitting rate limits # on CloudFormation's DescribeStacks API @@ -440,7 +451,7 @@ def wait_for_execute(self, stack_name, changeset_type): try: waiter.wait(StackName=stack_name, WaiterConfig=waiter_config) except botocore.exceptions.WaiterError as ex: - LOG.debug("Execute changeset waiter exception", exc_info=ex) + LOG.debug("Execute stack waiter exception", exc_info=ex) raise deploy_exceptions.DeployFailedError(stack_name=stack_name, msg=str(ex)) @@ -461,6 +472,99 @@ def create_and_wait_for_changeset( except botocore.exceptions.ClientError as ex: raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + def create_stack(self, **kwargs): + stack_name = kwargs.get("StackName") + try: + resp = self._client.create_stack(**kwargs) + return resp + except botocore.exceptions.ClientError as ex: + if "The bucket you are attempting to access must be addressed using the specified endpoint" in str(ex): + raise DeployBucketInDifferentRegionError(f"Failed to create/update stack {stack_name}") from ex + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + except Exception as ex: + LOG.debug("Unable to create stack", exc_info=ex) + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + def update_stack(self, **kwargs): + stack_name = kwargs.get("StackName") + try: + resp = self._client.update_stack(**kwargs) + return resp + except botocore.exceptions.ClientError as ex: + if "The bucket you are attempting to access must be addressed using the specified endpoint" in str(ex): + raise DeployBucketInDifferentRegionError(f"Failed to create/update stack {stack_name}") from ex + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + except Exception as ex: + LOG.debug("Unable to update stack", exc_info=ex) + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + + def sync( + self, + stack_name: str, + cfn_template: str, + parameter_values: List[Dict], + capabilities: Optional[List[str]], + role_arn: Optional[str], + notification_arns: Optional[List[str]], + s3_uploader: Optional[S3Uploader], + tags: Optional[Dict], + ): + """ + Call the sync command to directly update stack or create stack + + Parameters + ---------- + :param stack_name: The name of the stack + :param cfn_template: CloudFormation template string + :param parameter_values: Template parameters object + :param capabilities: Array of capabilities passed to CloudFormation + :param role_arn: the Arn of the role to create changeset + :param notification_arns: Arns for sending notifications + :param s3_uploader: S3Uploader object to upload files to S3 buckets + :param tags: Array of tags passed to CloudFormation + :return: + """ + exists = self.has_stack(stack_name) + + if not exists: + # When creating a new stack, UsePreviousValue=True is invalid. + # For such parameters, users should either override with new value, + # or set a Default value in template to successfully create a stack. + parameter_values = [x for x in parameter_values if not x.get("UsePreviousValue", False)] + else: + summary = self._client.get_template_summary(StackName=stack_name) + existing_parameters = [parameter["ParameterKey"] for parameter in summary["Parameters"]] + parameter_values = [ + x + for x in parameter_values + if not (x.get("UsePreviousValue", False) and x["ParameterKey"] not in existing_parameters) + ] + + kwargs = { + "StackName": stack_name, + "TemplateBody": cfn_template, + "Parameters": parameter_values, + "Tags": tags, + } + + kwargs = self._process_kwargs(kwargs, s3_uploader, capabilities, role_arn, notification_arns) + + try: + if exists: + result = self.update_stack(**kwargs) + self.wait_for_execute(stack_name, "UPDATE") + LOG.info("\nStack update succeeded. Sync infra completed.\n") + else: + result = self.create_stack(**kwargs) + self.wait_for_execute(stack_name, "CREATE") + LOG.info("\nStack creation succeeded. Sync infra completed.\n") + + return result + except botocore.exceptions.ClientError as ex: + raise DeployFailedError(stack_name=stack_name, msg=str(ex)) from ex + @staticmethod @pprint_column_names( format_string=OUTPUTS_FORMAT_STRING, format_kwargs=OUTPUTS_DEFAULTS_ARGS, table_header=OUTPUTS_TABLE_HEADER_NAME diff --git a/samcli/lib/observability/cw_logs/cw_log_formatters.py b/samcli/lib/observability/cw_logs/cw_log_formatters.py index f0d35a18a6..63b2ffd983 100644 --- a/samcli/lib/observability/cw_logs/cw_log_formatters.py +++ b/samcli/lib/observability/cw_logs/cw_log_formatters.py @@ -4,6 +4,7 @@ import json import logging from json import JSONDecodeError +from typing import Any from samcli.lib.observability.cw_logs.cw_log_event import CWLogEvent from samcli.lib.observability.observability_info_puller import ObservabilityEventMapper @@ -92,3 +93,31 @@ def map(self, event: CWLogEvent) -> CWLogEvent: log_stream_name = self._colored.cyan(event.log_stream_name) event.message = f"{log_stream_name} {timestamp} {event.message}" return event + + +class CWAddNewLineIfItDoesntExist(ObservabilityEventMapper): + """ + Mapper implementation which will add new lines at the end of events if it is not already there + """ + + def map(self, event: Any) -> Any: + # if it is a CWLogEvent, append new line at the end of event.message + if isinstance(event, CWLogEvent) and not event.message.endswith("\n"): + event.message = f"{event.message}\n" + return event + # if event is a string, then append new line at the end of the string + if isinstance(event, str) and not event.endswith("\n"): + return f"{event}\n" + # no-action for unknown events + return event + + +class CWLogEventJSONMapper(ObservabilityEventMapper[CWLogEvent]): + """ + Converts given CWLogEvent into JSON string + """ + + # pylint: disable=no-self-use + def map(self, event: CWLogEvent) -> CWLogEvent: + event.message = json.dumps(event.event) + return event diff --git a/samcli/lib/observability/cw_logs/cw_log_group_provider.py b/samcli/lib/observability/cw_logs/cw_log_group_provider.py index 90893e5238..3d365ad1e7 100644 --- a/samcli/lib/observability/cw_logs/cw_log_group_provider.py +++ b/samcli/lib/observability/cw_logs/cw_log_group_provider.py @@ -1,6 +1,18 @@ """ Discover & provide the log group name """ +import logging +from typing import Optional + +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_STEPFUNCTIONS_STATEMACHINE, +) +from samcli.lib.utils.boto_utils import BotoProviderType + +LOG = logging.getLogger(__name__) class LogGroupProvider: @@ -9,7 +21,21 @@ class LogGroupProvider: """ @staticmethod - def for_lambda_function(function_name): + def for_resource(boto_client_provider: BotoProviderType, resource_type: str, name: str) -> Optional[str]: + log_group = None + if resource_type == AWS_LAMBDA_FUNCTION: + log_group = LogGroupProvider.for_lambda_function(name) + elif resource_type == AWS_APIGATEWAY_RESTAPI: + log_group = LogGroupProvider.for_apigw_rest_api(name) + elif resource_type == AWS_APIGATEWAY_V2_API: + log_group = LogGroupProvider.for_apigwv2_http_api(boto_client_provider, name) + elif resource_type == AWS_STEPFUNCTIONS_STATEMACHINE: + log_group = LogGroupProvider.for_step_functions(boto_client_provider, name) + + return log_group + + @staticmethod + def for_lambda_function(function_name: str) -> str: """ Returns the CloudWatch Log Group Name created by default for the AWS Lambda function with given name @@ -24,3 +50,98 @@ def for_lambda_function(function_name): Default Log Group name used by this function """ return "/aws/lambda/{}".format(function_name) + + @staticmethod + def for_apigw_rest_api(rest_api_id: str, stage: str = "Prod") -> str: + """ + Returns the CloudWatch Log Group Name created by default for the AWS Api gateway rest api with given id + + Parameters + ---------- + rest_api_id : str + Id of the rest api + stage: str + Stage of the rest api (the default value is "Prod") + + Returns + ------- + str + Default Log Group name used by this rest api + """ + + # TODO: A rest api may have multiple stage, here just log out the prod stage and can be extended to log out + # all stages or a specific stage if needed. + return "API-Gateway-Execution-Logs_{}/{}".format(rest_api_id, stage) + + @staticmethod + def for_apigwv2_http_api( + boto_client_provider: BotoProviderType, http_api_id: str, stage: str = "$default" + ) -> Optional[str]: + """ + Returns the CloudWatch Log Group Name created by default for the AWS Api gatewayv2 http api with given id + + Parameters + ---------- + boto_client_provider: BotoProviderType + Boto client provider which contains region and other configurations + http_api_id : str + Id of the http api + stage: str + Stage of the rest api (the default value is "$default") + + Returns + ------- + str + Default Log Group name used by this http api + """ + apigw2_client = boto_client_provider("apigatewayv2") + + # TODO: A http api may have multiple stage, here just log out the default stage and can be extended to log out + # all stages or a specific stage if needed. + stage_info = apigw2_client.get_stage(ApiId=http_api_id, StageName=stage) + log_setting = stage_info.get("AccessLogSettings", None) + if not log_setting: + LOG.warning("Access logging is disabled for http api id (%s)", http_api_id) + return None + log_group_name = str(log_setting.get("DestinationArn").split(":")[-1]) + return log_group_name + + @staticmethod + def for_step_functions( + boto_client_provider: BotoProviderType, + step_function_name: str, + ) -> Optional[str]: + """ + Calls describe_state_machine API to get details of the State Machine, + then extracts logging information to find the configured CW log group. + If nothing is configured it will return None + + Parameters + ---------- + boto_client_provider : BotoProviderType + Boto client provider which contains region and other configurations + step_function_name : str + Name of the step functions resource + + Returns + ------- + CW log group name if logging is configured, None otherwise + """ + sfn_client = boto_client_provider("stepfunctions") + + state_machine_info = sfn_client.describe_state_machine(stateMachineArn=step_function_name) + LOG.debug("State machine info: %s", state_machine_info) + + logging_destinations = state_machine_info.get("loggingConfiguration", {}).get("destinations", []) + LOG.debug("State Machine logging destinations: %s", logging_destinations) + + # users may configure multiple log groups to send state machine logs, find one and return it + for logging_destination in logging_destinations: + log_group_arn = logging_destination.get("cloudWatchLogsLogGroup", {}).get("logGroupArn") + LOG.debug("Log group ARN: %s", log_group_arn) + if log_group_arn: + log_group_arn_parts = log_group_arn.split(":") + log_group_name = log_group_arn_parts[6] + return str(log_group_name) + + return None diff --git a/samcli/lib/observability/cw_logs/cw_log_puller.py b/samcli/lib/observability/cw_logs/cw_log_puller.py index e7d8b7fb10..7eb7bd0d5c 100644 --- a/samcli/lib/observability/cw_logs/cw_log_puller.py +++ b/samcli/lib/observability/cw_logs/cw_log_puller.py @@ -4,7 +4,9 @@ import logging import time from datetime import datetime -from typing import Optional, Any +from typing import Optional, Any, List + +from botocore.exceptions import ClientError from samcli.lib.observability.cw_logs.cw_log_event import CWLogEvent from samcli.lib.observability.observability_info_puller import ObservabilityPuller, ObservabilityEventConsumer @@ -30,7 +32,7 @@ def __init__( """ Parameters ---------- - logs_client: Any + logs_client: CloudWatchLogsClient boto3 logs client instance consumer : ObservabilityEventConsumer Consumer instance that will process pulled events @@ -51,6 +53,7 @@ def __init__( self._poll_interval = poll_interval self.latest_event_time = 0 self.had_data = False + self._invalid_log_group = False def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[str] = None): if start_time: @@ -61,7 +64,26 @@ def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[s LOG.debug("Tailing logs from %s starting at %s", self.cw_log_group, str(self.latest_event_time)) counter -= 1 - self.load_time_period(to_datetime(self.latest_event_time), filter_pattern=filter_pattern) + try: + self.load_time_period(to_datetime(self.latest_event_time), filter_pattern=filter_pattern) + except ClientError as err: + error_code = err.response.get("Error", {}).get("Code") + if error_code == "ThrottlingException": + # if throttled, increase poll interval by 1 second each time + if self._poll_interval == 1: + self._poll_interval += 1 + else: + self._poll_interval **= 2 + LOG.warning( + "Throttled by CloudWatch Logs API, consider pulling logs for certain resources. " + "Increasing the poll interval time for resource %s to %s seconds", + self.cw_log_group, + self._poll_interval, + ) + else: + # if error is other than throttling, re-raise it + LOG.error("Failed while fetching new log events", exc_info=err) + raise err # This poll fetched logs. Reset the retry counter and set the timestamp for next poll if self.had_data: @@ -92,12 +114,23 @@ def load_time_period( while True: LOG.debug("Fetching logs from CloudWatch with parameters %s", kwargs) - result = self.logs_client.filter_log_events(**kwargs) + try: + result = self.logs_client.filter_log_events(**kwargs) + self._invalid_log_group = False + except self.logs_client.exceptions.ResourceNotFoundException: + if not self._invalid_log_group: + LOG.warning( + "The specified log group %s does not exist. " + "Please make sure logging is enable and log group is created", + self.cw_log_group, + ) + self._invalid_log_group = True + break - # Several events will be returned. Yield one at a time + # Several events will be returned. Consume one at a time for event in result.get("events", []): self.had_data = True - cw_event = CWLogEvent(self.cw_log_group, event, self.resource_name) + cw_event = CWLogEvent(self.cw_log_group, dict(event), self.resource_name) if cw_event.timestamp > self.latest_event_time: self.latest_event_time = cw_event.timestamp @@ -109,3 +142,6 @@ def load_time_period( kwargs["nextToken"] = next_token if not next_token: break + + def load_events(self, event_ids: List[Any]): + LOG.debug("Loading specific events are not supported via CloudWatch Log Group") diff --git a/samcli/lib/observability/observability_info_puller.py b/samcli/lib/observability/observability_info_puller.py index b6d6f2b906..48028aaba3 100644 --- a/samcli/lib/observability/observability_info_puller.py +++ b/samcli/lib/observability/observability_info_puller.py @@ -4,7 +4,9 @@ import logging from abc import ABC, abstractmethod from datetime import datetime -from typing import List, Optional, Generic, TypeVar, Any +from typing import List, Optional, Generic, TypeVar, Any, Sequence + +from samcli.lib.utils.async_utils import AsyncContext LOG = logging.getLogger(__name__) @@ -72,6 +74,17 @@ def load_time_period( Optional parameter to filter events with given string """ + @abstractmethod + def load_events(self, event_ids: List[Any]): + """ + This method will load specific events which is given by the event_ids parameter + + Parameters + ---------- + event_ids : List[str] + List of event ids that will be pulled + """ + # pylint: disable=fixme # fixme add ABC parent class back once we bump the pylint to a version 2.8.2 or higher @@ -141,3 +154,59 @@ def consume(self, event: ObservabilityEvent): event = mapper.map(event) LOG.debug("Calling consumer (%s) for event (%s)", self._consumer, event) self._consumer.consume(event) + + +class ObservabilityCombinedPuller(ObservabilityPuller): + """ + A decorator class which will contain multiple ObservabilityPuller instance and pull information from each of them + """ + + def __init__(self, pullers: Sequence[ObservabilityPuller]): + """ + Parameters + ---------- + pullers : List[ObservabilityPuller] + List of pullers which will be managed by this class + """ + self._pullers = pullers + + def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[str] = None): + """ + Implementation of ObservabilityPuller.tail method with AsyncContext. + It will create tasks by calling tail methods of all given pullers, and execute them in async + """ + async_context = AsyncContext() + for puller in self._pullers: + LOG.debug("Adding task 'tail' for puller (%s)", puller) + async_context.add_async_task(puller.tail, start_time, filter_pattern) + LOG.debug("Running all 'tail' tasks in parallel") + async_context.run_async() + + def load_time_period( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + filter_pattern: Optional[str] = None, + ): + """ + Implementation of ObservabilityPuller.load_time_period method with AsyncContext. + It will create tasks by calling load_time_period methods of all given pullers, and execute them in async + """ + async_context = AsyncContext() + for puller in self._pullers: + LOG.debug("Adding task 'load_time_period' for puller (%s)", puller) + async_context.add_async_task(puller.load_time_period, start_time, end_time, filter_pattern) + LOG.debug("Running all 'load_time_period' tasks in parallel") + async_context.run_async() + + def load_events(self, event_ids: List[Any]): + """ + Implementation of ObservabilityPuller.load_events method with AsyncContext. + It will create tasks by calling load_events methods of all given pullers, and execute them in async + """ + async_context = AsyncContext() + for puller in self._pullers: + LOG.debug("Adding task 'load_events' for puller (%s)", puller) + async_context.add_async_task(puller.load_events, event_ids) + LOG.debug("Running all 'load_time_period' tasks in parallel") + async_context.run_async() diff --git a/samcli/lib/observability/xray_traces/__init__.py b/samcli/lib/observability/xray_traces/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/observability/xray_traces/xray_event_mappers.py b/samcli/lib/observability/xray_traces/xray_event_mappers.py new file mode 100644 index 0000000000..4f18ff3a31 --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_event_mappers.py @@ -0,0 +1,165 @@ +""" +Contains mapper implementations of XRay events +""" +import json +from copy import deepcopy +from typing import List + +from samcli.lib.observability.observability_info_puller import ObservabilityEventMapper +from samcli.lib.observability.xray_traces.xray_events import ( + XRayTraceEvent, + XRayTraceSegment, + XRayServiceGraphEvent, + XRayGraphServiceInfo, +) +from samcli.lib.utils.time import to_utc, utc_to_timestamp, timestamp_to_iso + + +class XRayTraceConsoleMapper(ObservabilityEventMapper[XRayTraceEvent]): + """ + Maps given XRayTraceEvent.message field into printable format to use it in the console consumer + """ + + def map(self, event: XRayTraceEvent) -> XRayTraceEvent: + formatted_segments = self.format_segments(event.segments) + iso_formatted_timestamp = timestamp_to_iso(event.timestamp) + mapped_message = ( + f"\nXRay Event at ({iso_formatted_timestamp}) with id ({event.id}) and duration ({event.duration:.3f}s)" + f"{formatted_segments}" + ) + event.message = mapped_message + + return event + + def format_segments(self, segments: List[XRayTraceSegment], level: int = 0) -> str: + """ + Prints given segment information back to console. + + Parameters + ---------- + segments : List[XRayTraceEvent] + List of segments which will be printed into console + level : int + Optional level value which will be used to make the indentation of each segment. Default value is 0 + """ + formatted_str = "" + for segment in segments: + formatted_str += f"\n{' ' * level} - {segment.get_duration():.3f}s - {segment.name}" + if segment.http_status: + formatted_str += f" [HTTP: {segment.http_status}]" + formatted_str += self.format_segments(segment.sub_segments, (level + 1)) + + return formatted_str + + +class XRayTraceJSONMapper(ObservabilityEventMapper[XRayTraceEvent]): + """ + Original response from xray client contains json in an escaped string. This mapper re-constructs Json object again + and converts into JSON string that can be printed into console. + """ + + # pylint: disable=R0201 + def map(self, event: XRayTraceEvent) -> XRayTraceEvent: + mapped_event = deepcopy(event.event) + segments = [segment.document for segment in event.segments] + mapped_event["Segments"] = segments + event.event = mapped_event + event.message = json.dumps(mapped_event) + return event + + +class XRayServiceGraphConsoleMapper(ObservabilityEventMapper[XRayServiceGraphEvent]): + """ + Maps given XRayServiceGraphEvent.message field into printable format to use it in the console consumer + """ + + def map(self, event: XRayServiceGraphEvent) -> XRayServiceGraphEvent: + formatted_services = self.format_services(event.services) + mapped_message = "\nNew XRay Service Graph" + mapped_message += f"\n Start time: {event.start_time}" + mapped_message += f"\n End time: {event.end_time}" + mapped_message += formatted_services + event.message = mapped_message + + return event + + def format_services(self, services: List[XRayGraphServiceInfo]) -> str: + """ + Prints given services information back to console. + + Parameters + ---------- + services : List[XRayGraphServiceInfo] + List of services which will be printed into console + """ + formatted_str = "" + for service in services: + formatted_str += f"\n Reference Id: {service.id}" + formatted_str += f"{ ' - (Root)' if service.is_root else ' -'}" + formatted_str += f" {service.type} - {service.name}" + formatted_str += f" - Edges: {self.format_edges(service)}" + formatted_str += self.format_summary_statistics(service, 1) + + return formatted_str + + @staticmethod + def format_edges(service: XRayGraphServiceInfo) -> str: + edge_ids = service.edge_ids + return str(edge_ids) + + @staticmethod + def format_summary_statistics(service: XRayGraphServiceInfo, level) -> str: + """ + Prints given summary statistics information back to console. + + Parameters + ---------- + service: XRayGraphServiceInfo + summary statistics of the service which will be printed into console + level : int + Optional level value which will be used to make the indentation of each segment. Default value is 0 + """ + formatted_str = f"\n{' ' * level} Summary_statistics:" + formatted_str += f"\n{' ' * (level + 1)} - total requests: {service.total_count}" + formatted_str += f"\n{' ' * (level + 1)} - ok count(2XX): {service.ok_count}" + formatted_str += f"\n{' ' * (level + 1)} - error count(4XX): {service.error_count}" + formatted_str += f"\n{' ' * (level + 1)} - fault count(5XX): {service.fault_count}" + formatted_str += f"\n{' ' * (level + 1)} - total response time: {service.response_time}" + return formatted_str + + +class XRayServiceGraphJSONMapper(ObservabilityEventMapper[XRayServiceGraphEvent]): + """ + Original response from xray client contains datetime object. This mapper convert datetime object to iso string, + and converts final JSON object into string. + """ + + def map(self, event: XRayServiceGraphEvent) -> XRayServiceGraphEvent: + mapped_event = deepcopy(event.event) + + self._convert_start_and_end_time_to_iso(mapped_event) + services = mapped_event.get("Services", []) + for service in services: + self._convert_start_and_end_time_to_iso(service) + edges = service.get("Edges", []) + for edge in edges: + self._convert_start_and_end_time_to_iso(edge) + + event.event = mapped_event + event.message = json.dumps(mapped_event) + return event + + def _convert_start_and_end_time_to_iso(self, event): + self.convert_event_datetime_to_iso(event, "StartTime") + self.convert_event_datetime_to_iso(event, "EndTime") + + def convert_event_datetime_to_iso(self, event, datetime_key): + event_datetime = event.get(datetime_key, None) + if event_datetime: + event[datetime_key] = self.convert_local_datetime_to_iso(event_datetime) + + @staticmethod + def convert_local_datetime_to_iso(local_datetime): + utc_datetime = to_utc(local_datetime) + time_stamp = utc_to_timestamp(utc_datetime) + return timestamp_to_iso(time_stamp) diff --git a/samcli/lib/observability/xray_traces/xray_event_puller.py b/samcli/lib/observability/xray_traces/xray_event_puller.py new file mode 100644 index 0000000000..2d9342817e --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_event_puller.py @@ -0,0 +1,151 @@ +""" +This file contains puller implementations for XRay +""" +import logging +import time +from datetime import datetime +from itertools import zip_longest +from typing import Optional, Any, List, Set, Dict + +from botocore.exceptions import ClientError + +from samcli.lib.observability.observability_info_puller import ObservabilityPuller, ObservabilityEventConsumer +from samcli.lib.observability.xray_traces.xray_events import XRayTraceEvent +from samcli.lib.utils.time import to_timestamp, to_datetime + +LOG = logging.getLogger(__name__) + + +class AbstractXRayPuller(ObservabilityPuller): + def __init__( + self, + max_retries: int = 1000, + poll_interval: int = 1, + ): + """ + Parameters + ---------- + max_retries : int + Optional maximum number of retries which can be used to pull information. Default value is 1000 + poll_interval : int + Optional interval value that will be used to wait between calls in tail operation. Default value is 1 + """ + self._max_retries = max_retries + self._poll_interval = poll_interval + self._had_data = False + self.latest_event_time = 0 + + def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[str] = None): + if start_time: + self.latest_event_time = to_timestamp(start_time) + + counter = self._max_retries + while counter > 0: + LOG.debug("Tailing XRay traces starting at %s", self.latest_event_time) + + counter -= 1 + try: + self.load_time_period(to_datetime(self.latest_event_time), datetime.utcnow()) + except ClientError as err: + error_code = err.response.get("Error", {}).get("Code") + if error_code == "ThrottlingException": + # if throttled, increase poll interval by 1 second each time + if self._poll_interval == 1: + self._poll_interval += 1 + else: + self._poll_interval **= 2 + LOG.warning( + "Throttled by XRay API, increasing the poll interval time to %s seconds", + self._poll_interval, + ) + else: + # if exception is other than throttling re-raise + LOG.error("Failed while fetching new AWS X-Ray events", exc_info=err) + raise err + + if self._had_data: + counter = self._max_retries + self.latest_event_time += 1 + self._had_data = False + + time.sleep(self._poll_interval) + + +class XRayTracePuller(AbstractXRayPuller): + """ + ObservabilityPuller implementation which pulls XRay trace information by summarizing XRay traces first + and then getting them as a batch later. + """ + + def __init__( + self, xray_client: Any, consumer: ObservabilityEventConsumer, max_retries: int = 1000, poll_interval: int = 1 + ): + """ + Parameters + ---------- + xray_client : boto3.client + XRay boto3 client instance + consumer : ObservabilityEventConsumer + Consumer instance which will process pulled events + max_retries : int + Optional maximum number of retries which can be used to pull information. Default value is 1000 + poll_interval : int + Optional interval value that will be used to wait between calls in tail operation. Default value is 1 + """ + super().__init__(max_retries, poll_interval) + self.xray_client = xray_client + self.consumer = consumer + self._previous_trace_ids: Set[str] = set() + + def load_time_period( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + filter_pattern: Optional[str] = None, + ): + kwargs = {"TimeRangeType": "TraceId", "StartTime": start_time, "EndTime": end_time} + + # first, collect all trace ids in given period + trace_ids = [] + LOG.debug("Fetching XRay trace summaries %s", kwargs) + result_paginator = self.xray_client.get_paginator("get_trace_summaries") + result_iterator = result_paginator.paginate(**kwargs) + for result in result_iterator: + trace_summaries = result.get("TraceSummaries", []) + for trace_summary in trace_summaries: + trace_id = trace_summary.get("Id", None) + if trace_id not in self._previous_trace_ids: + trace_ids.append(trace_id) + self._previous_trace_ids.add(trace_id) + + # now load collected events + self.load_events(trace_ids) + + def load_events(self, event_ids: List[str]): + if not event_ids: + LOG.debug("Nothing to fetch, empty event_id list given (%s)", event_ids) + return + + # xray client only accepts 5 items at max, so create batches of 5 element arrays + event_batches = zip_longest(*([iter(event_ids)] * 5)) + + for event_batch in event_batches: + kwargs: Dict[str, Any] = {"TraceIds": list(filter(None, event_batch))} + result_paginator = self.xray_client.get_paginator("batch_get_traces") + result_iterator = result_paginator.paginate(**kwargs) + for result in result_iterator: + traces = result.get("Traces", []) + + if not traces: + LOG.debug("No event found with given trace ids %s", str(event_ids)) + + for trace in traces: + self._had_data = True + xray_trace_event = XRayTraceEvent(trace) + + # update latest fetched event + latest_event_time = xray_trace_event.get_latest_event_time() + if latest_event_time > self.latest_event_time: + self.latest_event_time = latest_event_time + + self.consumer.consume(xray_trace_event) diff --git a/samcli/lib/observability/xray_traces/xray_events.py b/samcli/lib/observability/xray_traces/xray_events.py new file mode 100644 index 0000000000..281c073c95 --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_events.py @@ -0,0 +1,160 @@ +""" +Keeps XRay event definitions +""" +import json +import operator +from typing import List + +from samcli.lib.observability.observability_info_puller import ObservabilityEvent +from samcli.lib.utils.hash import str_checksum + + +start_time_getter = operator.attrgetter("start_time") + + +class XRayTraceEvent(ObservabilityEvent[dict]): + """ + Represents a result of each XRay trace event, which is returned by boto3 client by calling 'batch_get_traces' + See XRayTracePuller + """ + + def __init__(self, event: dict): + super().__init__(event, 0) + self.id = event.get("Id", "") + self.duration = event.get("Duration", 0.0) + self.message = json.dumps(event) + self.segments: List[XRayTraceSegment] = [] + + self._construct_segments(event) + if self.segments: + self.timestamp = self.segments[0].start_time + + def _construct_segments(self, event_dict): + """ + Each event is represented by segment, and it is like a Tree model (each segment also have subsegments). + """ + raw_segments = event_dict.get("Segments", []) + for raw_segment in raw_segments: + segment_document = raw_segment.get("Document", "{}") + self.segments.append(XRayTraceSegment(json.loads(segment_document))) + self.segments.sort(key=start_time_getter) + + def get_latest_event_time(self): + """ + Returns the latest event time for this specific XRayTraceEvent by calling get_latest_event_time for each segment + """ + latest_event_time = 0 + for segment in self.segments: + segment_latest_event_time = segment.get_latest_event_time() + if segment_latest_event_time > latest_event_time: + latest_event_time = segment_latest_event_time + + return latest_event_time + + +class XRayTraceSegment: + """ + Represents each segment information for a XRayTraceEvent + """ + + def __init__(self, document: dict): + self.id = document.get("Id", "") + self.document = document + self.name = document.get("name", "") + self.start_time = document.get("start_time", 0) + self.end_time = document.get("end_time", 0) + self.http_status = document.get("http", {}).get("response", {}).get("status", None) + self.sub_segments: List[XRayTraceSegment] = [] + + sub_segments = document.get("subsegments", []) + for sub_segment in sub_segments: + self.sub_segments.append(XRayTraceSegment(sub_segment)) + self.sub_segments.sort(key=start_time_getter) + + def get_duration(self): + return self.end_time - self.start_time + + def get_latest_event_time(self): + """ + Gets the latest event time by comparing all timestamps (end_time) from current segment and all sub-segments + """ + latest_event_time = self.end_time + for sub_segment in self.sub_segments: + sub_segment_latest_time = sub_segment.get_latest_event_time() + if sub_segment_latest_time > latest_event_time: + latest_event_time = sub_segment_latest_time + + return latest_event_time + + +class XRayServiceGraphEvent(ObservabilityEvent[dict]): + """ + Represents a result of each XRay service graph event, which is returned by boto3 client by calling + 'get_service_graph' See XRayServiceGraphPuller + """ + + def __init__(self, event: dict): + self.services: List[XRayGraphServiceInfo] = [] + self.message = str(event) + self._construct_service(event) + self.start_time = event.get("StartTime", None) + self.end_time = event.get("EndTime", None) + super().__init__(event, 0) + + def _construct_service(self, event_dict): + services = event_dict.get("Services", []) + for service in services: + self.services.append(XRayGraphServiceInfo(service)) + + def get_hash(self): + """ + get the hash of the containing services + """ + services = self.event.get("Services", []) + return str_checksum(str(services)) + + +class XRayGraphServiceInfo: + """ + Represents each services information for a XRayServiceGraphEvent + """ + + def __init__(self, service: dict): + self.id = service.get("ReferenceId", "") + self.document = service + self.name = service.get("Name", "") + self.is_root = service.get("Root", False) + self.type = service.get("Type") + self.edge_ids: List[int] = [] + self.ok_count = 0 + self.error_count = 0 + self.fault_count = 0 + self.total_count = 0 + self.response_time = 0 + self._construct_edge_ids(service.get("Edges", [])) + self._set_summary_statistics(service.get("SummaryStatistics", None)) + + def _construct_edge_ids(self, edges): + """ + covert the edges information to a list of edge reference ids + """ + edge_ids: List[int] = [] + for edge in edges: + edge_ids.append(edge.get("ReferenceId", -1)) + self.edge_ids = edge_ids + + def _set_summary_statistics(self, summary_statistics): + """ + get some useful information from summary statistics + """ + if not summary_statistics: + return + self.ok_count = summary_statistics.get("OkCount", 0) + error_statistics = summary_statistics.get("ErrorStatistics", None) + if error_statistics: + self.error_count = error_statistics.get("TotalCount", 0) + fault_statistics = summary_statistics.get("FaultStatistics", None) + if fault_statistics: + self.fault_count = fault_statistics.get("TotalCount", 0) + self.total_count = summary_statistics.get("TotalCount", 0) + self.response_time = summary_statistics.get("TotalResponseTime", 0) diff --git a/samcli/lib/observability/xray_traces/xray_service_graph_event_puller.py b/samcli/lib/observability/xray_traces/xray_service_graph_event_puller.py new file mode 100644 index 0000000000..4f2b0aa1fc --- /dev/null +++ b/samcli/lib/observability/xray_traces/xray_service_graph_event_puller.py @@ -0,0 +1,72 @@ +""" +This file contains puller implementations for XRay +""" +import logging +from datetime import datetime +from typing import Optional, Any, List, Set + +from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumer +from samcli.lib.observability.xray_traces.xray_event_puller import AbstractXRayPuller +from samcli.lib.observability.xray_traces.xray_events import XRayServiceGraphEvent +from samcli.lib.utils.time import to_utc, utc_to_timestamp + +LOG = logging.getLogger(__name__) + + +class XRayServiceGraphPuller(AbstractXRayPuller): + """ + ObservabilityPuller implementation which pulls XRay Service Graph + """ + + def __init__( + self, xray_client: Any, consumer: ObservabilityEventConsumer, max_retries: int = 1000, poll_interval: int = 1 + ): + """ + Parameters + ---------- + xray_client : boto3.client + XRay boto3 client instance + consumer : ObservabilityEventConsumer + Consumer instance which will process pulled events + max_retries : int + Optional maximum number of retries which can be used to pull information. Default value is 1000 + poll_interval : int + Optional interval value that will be used to wait between calls in tail operation. Default value is 1 + """ + super().__init__(max_retries, poll_interval) + self.xray_client = xray_client + self.consumer = consumer + self._previous_xray_service_graphs: Set[str] = set() + + def load_time_period( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + filter_pattern: Optional[str] = None, + ): + # pull xray traces service graph + kwargs = {"StartTime": start_time, "EndTime": end_time} + result_paginator = self.xray_client.get_paginator("get_service_graph") + result_iterator = result_paginator.paginate(**kwargs) + for result in result_iterator: + services = result.get("Services", []) + + if not services: + LOG.debug("No service graph found%s") + else: + # update latest fetched event + event_end_time = result.get("EndTime", None) + if event_end_time: + utc_end_time = to_utc(event_end_time) + latest_event_time = utc_to_timestamp(utc_end_time) + if latest_event_time > self.latest_event_time: + self.latest_event_time = latest_event_time + 1 + + self._had_data = True + xray_service_graph_event = XRayServiceGraphEvent(result) + if xray_service_graph_event.get_hash() not in self._previous_xray_service_graphs: + self.consumer.consume(xray_service_graph_event) + self._previous_xray_service_graphs.add(xray_service_graph_event.get_hash()) + + def load_events(self, event_ids: List[str]): + LOG.debug("Loading specific service graph events are not supported via XRay Service Graph") diff --git a/samcli/lib/package/artifact_exporter.py b/samcli/lib/package/artifact_exporter.py index bfa4ad9bf8..7b464b69ce 100644 --- a/samcli/lib/package/artifact_exporter.py +++ b/samcli/lib/package/artifact_exporter.py @@ -20,7 +20,7 @@ from botocore.utils import set_value_from_jmespath -from samcli.commands._utils.resources import ( +from samcli.lib.utils.resources import ( AWS_SERVERLESS_FUNCTION, AWS_CLOUDFORMATION_STACK, RESOURCES_WITH_LOCAL_PATHS, diff --git a/samcli/lib/package/packageable_resources.py b/samcli/lib/package/packageable_resources.py index 140aeaedf2..ec7e0b7c0b 100644 --- a/samcli/lib/package/packageable_resources.py +++ b/samcli/lib/package/packageable_resources.py @@ -25,7 +25,7 @@ is_ecr_url, ) -from samcli.commands._utils.resources import ( +from samcli.lib.utils.resources import ( AWS_SERVERLESSREPO_APPLICATION, AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API, diff --git a/samcli/lib/providers/cfn_api_provider.py b/samcli/lib/providers/cfn_api_provider.py index 1224c79054..006461b06c 100644 --- a/samcli/lib/providers/cfn_api_provider.py +++ b/samcli/lib/providers/cfn_api_provider.py @@ -9,29 +9,32 @@ from samcli.lib.providers.cfn_base_api_provider import CfnBaseApiProvider from samcli.lib.providers.api_collector import ApiCollector +from samcli.lib.utils.resources import ( + AWS_APIGATEWAY_METHOD, + AWS_APIGATEWAY_RESOURCE, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_STAGE, + AWS_APIGATEWAY_V2_API, + AWS_APIGATEWAY_V2_INTEGRATION, + AWS_APIGATEWAY_V2_ROUTE, + AWS_APIGATEWAY_V2_STAGE, +) + LOG = logging.getLogger(__name__) class CfnApiProvider(CfnBaseApiProvider): - APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" - APIGATEWAY_STAGE = "AWS::ApiGateway::Stage" - APIGATEWAY_RESOURCE = "AWS::ApiGateway::Resource" - APIGATEWAY_METHOD = "AWS::ApiGateway::Method" - APIGATEWAY_V2_API = "AWS::ApiGatewayV2::Api" - APIGATEWAY_V2_INTEGRATION = "AWS::ApiGatewayV2::Integration" - APIGATEWAY_V2_ROUTE = "AWS::ApiGatewayV2::Route" - APIGATEWAY_V2_STAGE = "AWS::ApiGatewayV2::Stage" METHOD_BINARY_TYPE = "CONVERT_TO_BINARY" HTTP_API_PROTOCOL_TYPE = "HTTP" TYPES = [ - APIGATEWAY_RESTAPI, - APIGATEWAY_STAGE, - APIGATEWAY_RESOURCE, - APIGATEWAY_METHOD, - APIGATEWAY_V2_API, - APIGATEWAY_V2_INTEGRATION, - APIGATEWAY_V2_ROUTE, - APIGATEWAY_V2_STAGE, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_STAGE, + AWS_APIGATEWAY_RESOURCE, + AWS_APIGATEWAY_METHOD, + AWS_APIGATEWAY_V2_API, + AWS_APIGATEWAY_V2_INTEGRATION, + AWS_APIGATEWAY_V2_ROUTE, + AWS_APIGATEWAY_V2_STAGE, ] def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: Optional[str] = None) -> None: @@ -54,22 +57,22 @@ def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: O resources = stack.resources for logical_id, resource in resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: + if resource_type == AWS_APIGATEWAY_RESTAPI: self._extract_cloud_formation_route(stack.stack_path, logical_id, resource, collector, cwd=cwd) - if resource_type == CfnApiProvider.APIGATEWAY_STAGE: + if resource_type == AWS_APIGATEWAY_STAGE: self._extract_cloud_formation_stage(resources, resource, collector) - if resource_type == CfnApiProvider.APIGATEWAY_METHOD: + if resource_type == AWS_APIGATEWAY_METHOD: self._extract_cloud_formation_method(stack.stack_path, resources, logical_id, resource, collector) - if resource_type == CfnApiProvider.APIGATEWAY_V2_API: + if resource_type == AWS_APIGATEWAY_V2_API: self._extract_cfn_gateway_v2_api(stack.stack_path, logical_id, resource, collector, cwd=cwd) - if resource_type == CfnApiProvider.APIGATEWAY_V2_ROUTE: + if resource_type == AWS_APIGATEWAY_V2_ROUTE: self._extract_cfn_gateway_v2_route(stack.stack_path, resources, logical_id, resource, collector) - if resource_type == CfnApiProvider.APIGATEWAY_V2_STAGE: + if resource_type == AWS_APIGATEWAY_V2_STAGE: self._extract_cfn_gateway_v2_stage(resources, resource, collector) @staticmethod @@ -136,7 +139,7 @@ def _extract_cloud_formation_stage( if not logical_id: raise InvalidSamTemplateException("The AWS::ApiGateway::Stage must have a RestApiId property") rest_api_resource_type = resources.get(logical_id, {}).get("Type") - if rest_api_resource_type != CfnApiProvider.APIGATEWAY_RESTAPI: + if rest_api_resource_type != AWS_APIGATEWAY_RESTAPI: raise InvalidSamTemplateException( "The AWS::ApiGateway::Stage must have a valid RestApiId that points to RestApi resource {}".format( logical_id @@ -383,7 +386,7 @@ def _extract_cfn_gateway_v2_stage( if not api_id: raise InvalidSamTemplateException("The AWS::ApiGatewayV2::Stage must have a ApiId property") api_resource_type = resources.get(api_id, {}).get("Type") - if api_resource_type != CfnApiProvider.APIGATEWAY_V2_API: + if api_resource_type != AWS_APIGATEWAY_V2_API: raise InvalidSamTemplateException( "The AWS::ApiGatewayV2::Stag must have a valid ApiId that points to Api resource {}".format(api_id) ) @@ -445,7 +448,7 @@ def _get_route_function_name( integration_resource = resources.get(integration_id, {}) resource_type = integration_resource.get("Type") - if resource_type == CfnApiProvider.APIGATEWAY_V2_INTEGRATION: + if resource_type == AWS_APIGATEWAY_V2_INTEGRATION: properties = integration_resource.get("Properties", {}) integration_uri = properties.get("IntegrationUri") payload_format_version = properties.get("PayloadFormatVersion") diff --git a/samcli/lib/providers/exceptions.py b/samcli/lib/providers/exceptions.py index 60328f781f..81709e4dd6 100644 --- a/samcli/lib/providers/exceptions.py +++ b/samcli/lib/providers/exceptions.py @@ -2,6 +2,12 @@ Exceptions used by providers """ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from samcli.lib.providers.provider import ResourceIdentifier + class InvalidLayerReference(Exception): """ @@ -16,3 +22,36 @@ def __init__(self) -> None: class RemoteStackLocationNotSupported(Exception): pass + + +class MissingCodeUri(Exception): + """Exception when Function or Lambda resources do not have CodeUri specified""" + + +class MissingLocalDefinition(Exception): + """Exception when a resource does not have local path in it's property""" + + _resource_identifier: "ResourceIdentifier" + _property_name: str + + def __init__(self, resource_identifier: "ResourceIdentifier", property_name: str) -> None: + """Exception when a resource does not have local path in it's property + + Parameters + ---------- + resource_identifier : ResourceIdentifier + Resource Identifer + property_name : str + Property name that's missing + """ + self._resource_identifier = resource_identifier + self._property_name = property_name + super().__init__(f"Resource {str(resource_identifier)} does not have {property_name} specified.") + + @property + def resource_identifier(self) -> "ResourceIdentifier": + return self._resource_identifier + + @property + def property_name(self) -> str: + return self._property_name diff --git a/samcli/lib/providers/provider.py b/samcli/lib/providers/provider.py index 5f53158c0d..2bc0e12993 100644 --- a/samcli/lib/providers/provider.py +++ b/samcli/lib/providers/provider.py @@ -7,7 +7,7 @@ import os import posixpath from collections import namedtuple -from typing import Set, NamedTuple, Optional, List, Dict, Union, cast, Iterator, TYPE_CHECKING +from typing import Any, Set, NamedTuple, Optional, List, Dict, Tuple, Union, cast, Iterator, TYPE_CHECKING from samcli.commands.local.cli_common.user_exceptions import InvalidLayerVersionArn, UnsupportedIntrinsic from samcli.lib.providers.sam_base_provider import SamBaseProvider @@ -51,7 +51,7 @@ class Function(NamedTuple): # to get credentials to run the container with. This gives a much higher fidelity simulation of cloud Lambda. rolearn: Optional[str] # List of Layers - layers: List + layers: List["LayerVersion"] # Event events: Optional[List] # Metadata @@ -437,14 +437,176 @@ def get_output_template_path(self, build_root: str) -> str: return os.path.join(build_root, self.stack_path.replace(posixpath.sep, os.path.sep), "template.yaml") +class ResourceIdentifier: + """Resource identifier for representing a resource with nested stack support""" + + _stack_path: str + _logical_id: str + + def __init__(self, resource_identifier_str: str): + """ + Parameters + ---------- + resource_identifier_str : str + Resource identifier in the format of: + Stack1/Stack2/ResourceID + """ + parts = resource_identifier_str.rsplit(posixpath.sep, 1) + if len(parts) == 1: + self._stack_path = "" + self._logical_id = parts[0] + else: + self._stack_path = parts[0] + self._logical_id = parts[1] + + @property + def stack_path(self) -> str: + """ + Returns + ------- + str + Stack path of the resource. + This can be empty string if resource is in the root stack. + """ + return self._stack_path + + @property + def logical_id(self) -> str: + """ + Returns + ------- + str + Logical ID of the resource. + """ + return self._logical_id + + def __str__(self) -> str: + return self.stack_path + posixpath.sep + self.logical_id if self.stack_path else self.logical_id + + def __eq__(self, other: object) -> bool: + return str(self) == str(other) if isinstance(other, ResourceIdentifier) else False + + def __hash__(self) -> int: + return hash(str(self)) + + def get_full_path(stack_path: str, logical_id: str) -> str: """ Return the unique posix path-like identifier while will used for identify a resource from resources in a multi-stack situation """ + if not stack_path: + return logical_id return posixpath.join(stack_path, logical_id) +def get_resource_by_id( + stacks: List[Stack], identifier: ResourceIdentifier, explicit_nested: bool = False +) -> Optional[Dict[str, Any]]: + """Seach resource in stacks based on identifier + + Parameters + ---------- + stacks : List[Stack] + List of stacks to be searched + identifier : ResourceIdentifier + Resource identifier for the resource to be returned + explicit_nested : bool, optional + Set to True to only search in root stack if stack_path does not exist. + Otherwise, all stacks will be searched in order to find matching logical ID. + If stack_path does exist in identifier, this option will be ignored and behave as if it is True + + Returns + ------- + Dict + Resource dict + """ + search_all_stacks = not identifier.stack_path and not explicit_nested + for stack in stacks: + if stack.stack_path == identifier.stack_path or search_all_stacks: + resource = stack.resources.get(identifier.logical_id) + if resource: + return cast(Dict[str, Any], resource) + return None + + +def get_resource_ids_by_type(stacks: List[Stack], resource_type: str) -> List[ResourceIdentifier]: + """Return list of resource IDs + + Parameters + ---------- + stacks : List[Stack] + List of stacks + resource_type : str + Resource type to be used for searching related resources. + + Returns + ------- + List[ResourceIdentifier] + List of ResourceIdentifiers with the type provided + """ + resource_ids: List[ResourceIdentifier] = list() + for stack in stacks: + for resource_id, resource in stack.resources.items(): + if resource.get("Type", "") == resource_type: + resource_ids.append(ResourceIdentifier(get_full_path(stack.stack_path, resource_id))) + return resource_ids + + +def get_all_resource_ids(stacks: List[Stack]) -> List[ResourceIdentifier]: + """Return all resource IDs in stacks + + Parameters + ---------- + stacks : List[Stack] + List of stacks + + Returns + ------- + List[ResourceIdentifier] + List of ResourceIdentifiers + """ + resource_ids: List[ResourceIdentifier] = list() + for stack in stacks: + for resource_id, _ in stack.resources.items(): + resource_ids.append(ResourceIdentifier(get_full_path(stack.stack_path, resource_id))) + return resource_ids + + +def get_unique_resource_ids( + stacks: List[Stack], + resource_ids: Optional[Union[List[str], Tuple[str]]], + resource_types: Optional[Union[List[str], Tuple[str]]], +) -> Set[ResourceIdentifier]: + """Get unique resource IDs for resource_ids and resource_types + + Parameters + ---------- + stacks : List[Stack] + Stacks + resource_ids : Optional[Union[List[str], Tuple[str]]] + Resource ID strings + resource_types : Optional[Union[List[str], Tuple[str]]] + Resource types + + Returns + ------- + Set[ResourceIdentifier] + Set of ResourceIdentifier either in resource_ids or has the type in resource_types + """ + output_resource_ids: Set[ResourceIdentifier] = set() + if resource_ids: + for resources_id in resource_ids: + output_resource_ids.add(ResourceIdentifier(resources_id)) + + if resource_types: + for resource_type in resource_types: + resource_type_ids = get_resource_ids_by_type(stacks, resource_type) + for resource_id in resource_type_ids: + output_resource_ids.add(resource_id) + return output_resource_ids + + def _get_build_dir(resource: Union[Function, LayerVersion], build_root: str) -> str: """ Return the build directory to place build artifact diff --git a/samcli/lib/providers/sam_api_provider.py b/samcli/lib/providers/sam_api_provider.py index 0ad44eb7b8..1fd48dfde5 100644 --- a/samcli/lib/providers/sam_api_provider.py +++ b/samcli/lib/providers/sam_api_provider.py @@ -8,15 +8,13 @@ from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.lib.providers.provider import Stack from samcli.local.apigw.local_apigw_service import Route +from samcli.lib.utils.resources import AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API, AWS_SERVERLESS_HTTPAPI LOG = logging.getLogger(__name__) class SamApiProvider(CfnBaseApiProvider): - SERVERLESS_FUNCTION = "AWS::Serverless::Function" - SERVERLESS_API = "AWS::Serverless::Api" - SERVERLESS_HTTP_API = "AWS::Serverless::HttpApi" - TYPES = [SERVERLESS_FUNCTION, SERVERLESS_API, SERVERLESS_HTTP_API] + TYPES = [AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API, AWS_SERVERLESS_HTTPAPI] _EVENT_TYPE_API = "Api" _EVENT_TYPE_HTTP_API = "HttpApi" _FUNCTION_EVENT = "Events" @@ -46,11 +44,11 @@ def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: O for stack in stacks: for logical_id, resource in stack.resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == SamApiProvider.SERVERLESS_FUNCTION: + if resource_type == AWS_SERVERLESS_FUNCTION: self._extract_routes_from_function(stack.stack_path, logical_id, resource, collector) - if resource_type == SamApiProvider.SERVERLESS_API: + if resource_type == AWS_SERVERLESS_API: self._extract_from_serverless_api(stack.stack_path, logical_id, resource, collector, cwd=cwd) - if resource_type == SamApiProvider.SERVERLESS_HTTP_API: + if resource_type == AWS_SERVERLESS_HTTPAPI: self._extract_from_serverless_http(stack.stack_path, logical_id, resource, collector, cwd=cwd) collector.routes = self.merge_routes(collector) @@ -156,7 +154,7 @@ def _extract_routes_from_function( Path of the stack the resource is located logical_id : str - Logical ID of the resourc + Logical ID of the resource function_resource : dict Contents of the function resource including its properties diff --git a/samcli/lib/providers/sam_base_provider.py b/samcli/lib/providers/sam_base_provider.py index 7a75c70cc8..3e72a6285f 100644 --- a/samcli/lib/providers/sam_base_provider.py +++ b/samcli/lib/providers/sam_base_provider.py @@ -5,7 +5,12 @@ import logging from typing import Any, Dict, Optional, cast, Iterable, Union -from samcli.commands._utils.resources import AWS_SERVERLESS_APPLICATION, AWS_CLOUDFORMATION_STACK +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_SERVERLESS_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_LAYERVERSION, +) from samcli.lib.intrinsic_resolver.intrinsic_property_resolver import IntrinsicResolver from samcli.lib.intrinsic_resolver.intrinsics_symbol_table import IntrinsicsSymbolTable from samcli.lib.samlib.resource_metadata_normalizer import ResourceMetadataNormalizer @@ -21,24 +26,18 @@ class SamBaseProvider: Base class for SAM Template providers """ - SERVERLESS_FUNCTION = "AWS::Serverless::Function" - LAMBDA_FUNCTION = "AWS::Lambda::Function" - SERVERLESS_LAYER = "AWS::Serverless::LayerVersion" - LAMBDA_LAYER = "AWS::Lambda::LayerVersion" - SERVERLESS_APPLICATION = AWS_SERVERLESS_APPLICATION - CLOUDFORMATION_STACK = AWS_CLOUDFORMATION_STACK DEFAULT_CODEURI = "." CODE_PROPERTY_KEYS = { - LAMBDA_FUNCTION: "Code", - SERVERLESS_FUNCTION: "CodeUri", - LAMBDA_LAYER: "Content", - SERVERLESS_LAYER: "ContentUri", + AWS_LAMBDA_FUNCTION: "Code", + AWS_SERVERLESS_FUNCTION: "CodeUri", + AWS_LAMBDA_LAYERVERSION: "Content", + AWS_SERVERLESS_LAYERVERSION: "ContentUri", } IMAGE_PROPERTY_KEYS = { - LAMBDA_FUNCTION: "Code", - SERVERLESS_FUNCTION: "ImageUri", + AWS_LAMBDA_FUNCTION: "Code", + AWS_SERVERLESS_FUNCTION: "ImageUri", } def get(self, name: str) -> Optional[Any]: diff --git a/samcli/lib/providers/sam_function_provider.py b/samcli/lib/providers/sam_function_provider.py index acdc9d91f1..056f943ee6 100644 --- a/samcli/lib/providers/sam_function_provider.py +++ b/samcli/lib/providers/sam_function_provider.py @@ -4,6 +4,12 @@ import logging from typing import Dict, List, Optional, cast, Iterator, Any +from samcli.lib.utils.resources import ( + AWS_LAMBDA_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, +) from samcli.commands.local.cli_common.user_exceptions import InvalidLayerVersionArn from samcli.lib.providers.exceptions import InvalidLayerReference from samcli.lib.utils.colors import Colored @@ -129,7 +135,7 @@ def _extract_functions( if resource_metadata: resource_properties["Metadata"] = resource_metadata - if resource_type in [SamFunctionProvider.SERVERLESS_FUNCTION, SamFunctionProvider.LAMBDA_FUNCTION]: + if resource_type in [AWS_SERVERLESS_FUNCTION, AWS_LAMBDA_FUNCTION]: resource_package_type = resource_properties.get("PackageType", ZIP) code_property_key = SamBaseProvider.CODE_PROPERTY_KEYS[resource_type] @@ -156,7 +162,7 @@ def _extract_functions( SamFunctionProvider._warn_imageuri_extraction(resource_type, name, image_property_key) continue - if resource_type == SamFunctionProvider.SERVERLESS_FUNCTION: + if resource_type == AWS_SERVERLESS_FUNCTION: layers = SamFunctionProvider._parse_layer_info( stack, resource_properties.get("Layers", []), @@ -172,7 +178,7 @@ def _extract_functions( ) result[function.full_path] = function - elif resource_type == SamFunctionProvider.LAMBDA_FUNCTION: + elif resource_type == AWS_LAMBDA_FUNCTION: layers = SamFunctionProvider._parse_layer_info( stack, resource_properties.get("Layers", []), @@ -428,8 +434,8 @@ def _locate_layer_from_ref( layer_logical_id = cast(str, layer.get("Ref")) layer_resource = stack.resources.get(layer_logical_id) if not layer_resource or layer_resource.get("Type", "") not in ( - SamFunctionProvider.SERVERLESS_LAYER, - SamFunctionProvider.LAMBDA_LAYER, + AWS_SERVERLESS_LAYERVERSION, + AWS_LAMBDA_LAYERVERSION, ): raise InvalidLayerReference() @@ -438,7 +444,7 @@ def _locate_layer_from_ref( compatible_runtimes = layer_properties.get("CompatibleRuntimes") codeuri: Optional[str] = None - if resource_type in [SamFunctionProvider.LAMBDA_LAYER, SamFunctionProvider.SERVERLESS_LAYER]: + if resource_type in [AWS_LAMBDA_LAYERVERSION, AWS_SERVERLESS_LAYERVERSION]: code_property_key = SamBaseProvider.CODE_PROPERTY_KEYS[resource_type] if SamBaseProvider._is_s3_location(layer_properties.get(code_property_key)): # Content can be a dictionary of S3 Bucket/Key or a S3 URI, neither of which are supported diff --git a/samcli/lib/providers/sam_layer_provider.py b/samcli/lib/providers/sam_layer_provider.py index 0d08086094..529557cb82 100644 --- a/samcli/lib/providers/sam_layer_provider.py +++ b/samcli/lib/providers/sam_layer_provider.py @@ -5,6 +5,7 @@ import posixpath from typing import List, Dict, Optional +from samcli.lib.utils.resources import AWS_LAMBDA_LAYERVERSION, AWS_SERVERLESS_LAYERVERSION from .provider import LayerVersion, Stack from .sam_base_provider import SamBaseProvider from .sam_stack_provider import SamLocalStackProvider @@ -86,7 +87,7 @@ def _extract_layers(self) -> List[LayerVersion]: resource_type = resource.get("Type") resource_properties = resource.get("Properties", {}) - if resource_type in [SamBaseProvider.LAMBDA_LAYER, SamBaseProvider.SERVERLESS_LAYER]: + if resource_type in [AWS_LAMBDA_LAYERVERSION, AWS_SERVERLESS_LAYERVERSION]: code_property_key = SamBaseProvider.CODE_PROPERTY_KEYS[resource_type] if SamBaseProvider._is_s3_location(resource_properties.get(code_property_key)): # Content can be a dictionary of S3 Bucket/Key or a S3 URI, neither of which are supported diff --git a/samcli/lib/providers/sam_stack_provider.py b/samcli/lib/providers/sam_stack_provider.py index 33c3ad05bd..ba41d83feb 100644 --- a/samcli/lib/providers/sam_stack_provider.py +++ b/samcli/lib/providers/sam_stack_provider.py @@ -10,6 +10,7 @@ from samcli.lib.providers.exceptions import RemoteStackLocationNotSupported from samcli.lib.providers.provider import Stack, get_full_path from samcli.lib.providers.sam_base_provider import SamBaseProvider +from samcli.lib.utils.resources import AWS_CLOUDFORMATION_STACK, AWS_SERVERLESS_APPLICATION LOG = logging.getLogger(__name__) @@ -110,7 +111,7 @@ def _extract_stacks(self) -> None: stack: Optional[Stack] = None try: - if resource_type == SamLocalStackProvider.SERVERLESS_APPLICATION: + if resource_type == AWS_SERVERLESS_APPLICATION: stack = SamLocalStackProvider._convert_sam_application_resource( self._template_file, self._stack_path, @@ -118,7 +119,7 @@ def _extract_stacks(self) -> None: resource_properties, root_template_dir=self._root_template_dir, ) - if resource_type == SamLocalStackProvider.CLOUDFORMATION_STACK: + if resource_type == AWS_CLOUDFORMATION_STACK: stack = SamLocalStackProvider._convert_cfn_stack_resource( self._template_file, self._stack_path, diff --git a/samcli/lib/sync/__init__.py b/samcli/lib/sync/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/sync/continuous_sync_flow_executor.py b/samcli/lib/sync/continuous_sync_flow_executor.py new file mode 100644 index 0000000000..2adf2e1f21 --- /dev/null +++ b/samcli/lib/sync/continuous_sync_flow_executor.py @@ -0,0 +1,130 @@ +"""SyncFlowExecutor that will run continuously until stop is called.""" +import time +import logging + +from typing import Callable, Optional +from concurrent.futures.thread import ThreadPoolExecutor + +from dataclasses import dataclass + +from samcli.lib.sync.exceptions import SyncFlowException +from samcli.lib.sync.sync_flow import SyncFlow +from samcli.lib.sync.sync_flow_executor import SyncFlowExecutor, SyncFlowFuture, SyncFlowTask, default_exception_handler + +LOG = logging.getLogger(__name__) + + +@dataclass(frozen=True, eq=True) +class DelayedSyncFlowTask(SyncFlowTask): + """Data struct for individual SyncFlow execution tasks""" + + # Time in seconds of when the task was initially queued + queue_time: float + + # Number of seconds this task should stay in queue before being executed + wait_time: float + + +class ContinuousSyncFlowExecutor(SyncFlowExecutor): + """SyncFlowExecutor that continuously runs and executes SyncFlows. + Call stop() to stop the executor""" + + # Flag for whether the executor should be stopped at the next available time + _stop_flag: bool + + def __init__(self) -> None: + super().__init__() + self._stop_flag = False + + def stop(self, should_stop=True) -> None: + """Stop executor after all current SyncFlows are finished.""" + with self._flow_queue_lock: + self._stop_flag = should_stop + if should_stop: + self._flow_queue.queue.clear() + + def should_stop(self) -> bool: + """ + Returns + ------- + bool + Should executor stop execution on the next available time. + """ + return self._stop_flag + + def _can_exit(self): + return self.should_stop() and super()._can_exit() + + def _submit_sync_flow_task( + self, executor: ThreadPoolExecutor, sync_flow_task: SyncFlowTask + ) -> Optional[SyncFlowFuture]: + """Submit SyncFlowTask to be executed by ThreadPoolExecutor + and return its future + Adds additional time checks for DelayedSyncFlowTask + + Parameters + ---------- + executor : ThreadPoolExecutor + THreadPoolExecutor to be used for execution + sync_flow_task : SyncFlowTask + SyncFlowTask to be executed. + + Returns + ------- + Optional[SyncFlowFuture] + Returns SyncFlowFuture generated by the SyncFlowTask. + Can be None if the task cannot be executed yet. + """ + if ( + isinstance(sync_flow_task, DelayedSyncFlowTask) + and sync_flow_task.wait_time + sync_flow_task.queue_time > time.time() + ): + return None + + return super()._submit_sync_flow_task(executor, sync_flow_task) + + def _add_sync_flow_task(self, task: SyncFlowTask) -> None: + """Add SyncFlowTask to the queue + Skips if the executor is in the state of being shut down. + + Parameters + ---------- + task : SyncFlowTask + SyncFlowTask to be added. + """ + if self.should_stop(): + LOG.debug( + "%s is skipped from queueing as executor is in the process of stopping.", task.sync_flow.log_prefix + ) + return + + super()._add_sync_flow_task(task) + + def add_delayed_sync_flow(self, sync_flow: SyncFlow, dedup: bool = True, wait_time: float = 0) -> None: + """Add a SyncFlow to queue to be executed + Locks will be set with LockDistributor + + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow to be executed + dedup : bool + SyncFlow will not be added if this flag is True and has a duplicate in the queue + wait_time : float + Minimum number of seconds before SyncFlow executes + """ + self._add_sync_flow_task(DelayedSyncFlowTask(sync_flow, dedup, time.time(), wait_time)) + + def execute( + self, exception_handler: Optional[Callable[[SyncFlowException], None]] = default_exception_handler + ) -> None: + """Blocking continuous execution of the SyncFlows + + Parameters + ---------- + exception_handler : Optional[Callable[[Exception], None]], optional + Function to be called if an exception is raised during the execution of a SyncFlow, + by default default_exception_handler.__func__ + """ + super().execute(exception_handler=exception_handler) + self.stop(should_stop=False) diff --git a/samcli/lib/sync/exceptions.py b/samcli/lib/sync/exceptions.py new file mode 100644 index 0000000000..f012f71086 --- /dev/null +++ b/samcli/lib/sync/exceptions.py @@ -0,0 +1,105 @@ +"""Exceptions related to sync functionalities""" +from typing import Dict, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from samcli.lib.sync.sync_flow import SyncFlow + + +class SyncFlowException(Exception): + """Exception wrapper for exceptions raised in SyncFlows""" + + _sync_flow: "SyncFlow" + _exception: Exception + + def __init__(self, sync_flow: "SyncFlow", exception: Exception): + """ + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow that raised the exception + exception : Exception + exception raised + """ + super().__init__(f"SyncFlow Exception for {sync_flow.log_name}") + self._sync_flow = sync_flow + self._exception = exception + + @property + def sync_flow(self) -> "SyncFlow": + return self._sync_flow + + @property + def exception(self) -> Exception: + return self._exception + + +class MissingPhysicalResourceError(Exception): + """Exception used for not having a remote/physical counterpart for a local stack resource""" + + _resource_identifier: Optional[str] + _physical_resource_mapping: Optional[Dict[str, str]] + + def __init__( + self, resource_identifier: Optional[str] = None, physical_resource_mapping: Optional[Dict[str, str]] = None + ): + """ + Parameters + ---------- + resource_identifier : str + Logical resource identifier + physical_resource_mapping: Dict[str, str] + Current mapping between logical and physical IDs + """ + super().__init__(f"{resource_identifier} is not found in remote.") + self._resource_identifier = resource_identifier + self._physical_resource_mapping = physical_resource_mapping + + @property + def resource_identifier(self) -> Optional[str]: + """ + Returns + ------- + str + Resource identifier of the resource that does not have a remote/physical counterpart + """ + return self._resource_identifier + + @property + def physical_resource_mapping(self) -> Optional[Dict[str, str]]: + """ + Returns + ------- + Optional[Dict[str, str]] + Physical ID mapping for resources when the excecption was raised + """ + return self._physical_resource_mapping + + +class NoLayerVersionsFoundError(Exception): + """This is used when we try to list all versions for layer, but we found none""" + + _layer_name_arn: str + + def __init__(self, layer_name_arn: str): + """ + Parameters + ---------- + layer_name_arn : str + Layer ARN without version info at the end of it + """ + super().__init__(f"{layer_name_arn} doesn't have any versions in remote.") + self._layer_name_arn = layer_name_arn + + @property + def layer_name_arn(self) -> str: + """ + Returns + ------- + str + Layer ARN without version info at the end of it + """ + return self._layer_name_arn + + +class MissingLockException(Exception): + """Exception for not having an associated lock to be used.""" diff --git a/samcli/lib/sync/flows/__init__.py b/samcli/lib/sync/flows/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/sync/flows/alias_version_sync_flow.py b/samcli/lib/sync/flows/alias_version_sync_flow.py new file mode 100644 index 0000000000..1777a51f86 --- /dev/null +++ b/samcli/lib/sync/flows/alias_version_sync_flow.py @@ -0,0 +1,89 @@ +"""SyncFlow for Lambda Function Alias and Version""" +import logging +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.providers.provider import Stack +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class AliasVersionSyncFlow(SyncFlow): + """This SyncFlow is used for updating Lambda Function version and its associating Alias. + Currently, this is created after a FunctionSyncFlow is finished. + """ + + _function_identifier: str + _alias_name: str + _lambda_client: Any + + def __init__( + self, + function_identifier: str, + alias_name: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: Optional[List[Stack]] = None, + ): + """ + Parameters + ---------- + function_identifier : str + Function resource identifier that need to have associated Alias and Version updated. + alias_name : str + Alias name for the function + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name=f"Alias {alias_name} and Version of {function_identifier}", + stacks=stacks, + ) + self._function_identifier = function_identifier + self._alias_name = alias_name + self._lambda_client = None + + def set_up(self) -> None: + super().set_up() + self._lambda_client = cast(Session, self._session).client("lambda") + + def gather_resources(self) -> None: + pass + + def compare_remote(self) -> bool: + return False + + def sync(self) -> None: + function_physical_id = self.get_physical_id(self._function_identifier) + version = self._lambda_client.publish_version(FunctionName=function_physical_id).get("Version") + LOG.debug("%sCreated new function version: %s", self.log_prefix, version) + if version: + self._lambda_client.update_alias( + FunctionName=function_physical_id, Name=self._alias_name, FunctionVersion=version + ) + + def gather_dependencies(self) -> List[SyncFlow]: + return [] + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] + + def _equality_keys(self) -> Any: + """Combination of function identifier and alias name can used to identify each unique SyncFlow""" + return self._function_identifier, self._alias_name diff --git a/samcli/lib/sync/flows/function_sync_flow.py b/samcli/lib/sync/flows/function_sync_flow.py new file mode 100644 index 0000000000..df9bb1ebc5 --- /dev/null +++ b/samcli/lib/sync/flows/function_sync_flow.py @@ -0,0 +1,104 @@ +"""Base SyncFlow for Lambda Function""" +import logging +from typing import Any, Dict, List, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.sync.flows.alias_version_sync_flow import AliasVersionSyncFlow +from samcli.lib.providers.provider import Function, Stack +from samcli.local.lambdafn.exceptions import FunctionNotFound + +from samcli.lib.sync.sync_flow import SyncFlow + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class FunctionSyncFlow(SyncFlow): + _function_identifier: str + _function_provider: SamFunctionProvider + _function: Function + _lambda_client: Any + _lambda_waiter: Any + _lambda_waiter_config: Dict[str, Any] + + def __init__( + self, + function_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + function_identifier : str + Function resource identifier that need to be synced. + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name="Lambda Function " + function_identifier, + stacks=stacks, + ) + self._function_identifier = function_identifier + self._function_provider = self._build_context.function_provider + self._function = cast(Function, self._function_provider.functions.get(self._function_identifier)) + self._lambda_client = None + self._lambda_waiter = None + self._lambda_waiter_config = {"Delay": 1, "MaxAttempts": 60} + + def set_up(self) -> None: + super().set_up() + self._lambda_client = cast(Session, self._session).client("lambda") + self._lambda_waiter = self._lambda_client.get_waiter("function_updated") + + def gather_dependencies(self) -> List[SyncFlow]: + """Gathers alias and versions related to a function. + Currently only handles serverless function AutoPublishAlias field + since a manually created function version resource behaves statically in a stack. + Redeploying a version resource through CFN will not create a new version. + """ + LOG.debug("%sWaiting on Remote Function Update", self.log_prefix) + self._lambda_waiter.wait( + FunctionName=self.get_physical_id(self._function_identifier), WaiterConfig=self._lambda_waiter_config + ) + LOG.debug("%sRemote Function Updated", self.log_prefix) + sync_flows: List[SyncFlow] = list() + + function_resource = self._get_resource(self._function_identifier) + if not function_resource: + raise FunctionNotFound(f"Unable to find function {self._function_identifier}") + + auto_publish_alias_name = function_resource.get("Properties", dict()).get("AutoPublishAlias", None) + if auto_publish_alias_name: + sync_flows.append( + AliasVersionSyncFlow( + self._function_identifier, + auto_publish_alias_name, + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + ) + LOG.debug("%sCreated Alias and Version SyncFlow", self.log_prefix) + + return sync_flows + + def _equality_keys(self): + return self._function_identifier diff --git a/samcli/lib/sync/flows/generic_api_sync_flow.py b/samcli/lib/sync/flows/generic_api_sync_flow.py new file mode 100644 index 0000000000..af9c8ad675 --- /dev/null +++ b/samcli/lib/sync/flows/generic_api_sync_flow.py @@ -0,0 +1,89 @@ +"""SyncFlow interface for HttpApi and RestApi""" +import logging +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.providers.provider import Stack, get_resource_by_id, ResourceIdentifier + +# BuildContext and DeployContext will only be imported for type checking to improve performance +# since no istances of contexts will be instantiated in this class +if TYPE_CHECKING: + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class GenericApiSyncFlow(SyncFlow): + """SyncFlow interface for HttpApi and RestApi""" + + _api_client: Any + _api_identifier: str + _definition_uri: Optional[str] + _stacks: List[Stack] + _swagger_body: Optional[bytes] + + def __init__( + self, + api_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + log_name: str, + stacks: List[Stack], + ): + """ + Parameters + ---------- + api_identifier : str + HttpApi resource identifier that needs to have associated Api updated. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + log_name: str + Log name passed from subclasses, HttpApi or RestApi + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name=log_name, + stacks=stacks, + ) + self._api_identifier = api_identifier + + def gather_resources(self) -> None: + self._definition_uri = self._get_definition_file(self._api_identifier) + self._swagger_body = self._process_definition_file() + + def _process_definition_file(self) -> Optional[bytes]: + if self._definition_uri is None: + return None + with open(self._definition_uri, "rb") as swagger_file: + swagger_body = swagger_file.read() + return swagger_body + + def _get_definition_file(self, api_identifier: str) -> Optional[str]: + api_resource = get_resource_by_id(self._stacks, ResourceIdentifier(api_identifier)) + if api_resource is None: + return None + properties = api_resource.get("Properties", {}) + definition_file = properties.get("DefinitionUri") + return cast(Optional[str], definition_file) + + def compare_remote(self) -> bool: + return False + + def gather_dependencies(self) -> List[SyncFlow]: + return [] + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] + + def _equality_keys(self) -> Any: + return self._api_identifier diff --git a/samcli/lib/sync/flows/http_api_sync_flow.py b/samcli/lib/sync/flows/http_api_sync_flow.py new file mode 100644 index 0000000000..d8c8ba5703 --- /dev/null +++ b/samcli/lib/sync/flows/http_api_sync_flow.py @@ -0,0 +1,69 @@ +"""SyncFlow for HttpApi""" +import logging +from typing import Dict, List, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.sync.flows.generic_api_sync_flow import GenericApiSyncFlow +from samcli.lib.providers.provider import ResourceIdentifier, Stack +from samcli.lib.providers.exceptions import MissingLocalDefinition + +# BuildContext and DeployContext will only be imported for type checking to improve performance +# since no instances of contexts will be instantiated in this class +if TYPE_CHECKING: + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class HttpApiSyncFlow(GenericApiSyncFlow): + """SyncFlow for HttpApi's""" + + def __init__( + self, + api_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + api_identifier : str + HttpApi resource identifier that needs to have associated HttpApi updated. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + api_identifier, + build_context, + deploy_context, + physical_id_mapping, + log_name="HttpApi " + api_identifier, + stacks=stacks, + ) + + def set_up(self) -> None: + super().set_up() + self._api_client = cast(Session, self._session).client("apigatewayv2") + + def sync(self) -> None: + api_physical_id = self.get_physical_id(self._api_identifier) + if self._definition_uri is None: + LOG.error( + "%sImport HttpApi fails since no DefinitionUri defined in the template, \ +if you are using DefinitionBody please run sam sync --infra", + self.log_prefix, + ) + raise MissingLocalDefinition(ResourceIdentifier(self._api_identifier), "DefinitionUri") + LOG.debug("%sTrying to import HttpAPI through client", self.log_prefix) + response = self._api_client.reimport_api(ApiId=api_physical_id, Body=self._swagger_body) + LOG.debug("%sImport HttpApi Result: %s", self.log_prefix, response) diff --git a/samcli/lib/sync/flows/image_function_sync_flow.py b/samcli/lib/sync/flows/image_function_sync_flow.py new file mode 100644 index 0000000000..3fb2457f5b --- /dev/null +++ b/samcli/lib/sync/flows/image_function_sync_flow.py @@ -0,0 +1,110 @@ +"""SyncFlow for Image based Lambda Functions""" +import logging +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +import docker +from boto3.session import Session +from docker.client import DockerClient + +from samcli.lib.providers.provider import Stack +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.package.ecr_uploader import ECRUploader + +from samcli.lib.build.app_builder import ApplicationBuilder +from samcli.lib.sync.sync_flow import ResourceAPICall + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class ImageFunctionSyncFlow(FunctionSyncFlow): + _ecr_client: Any + _docker_client: Optional[DockerClient] + _image_name: Optional[str] + + def __init__( + self, + function_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + docker_client: Optional[DockerClient] = None, + ): + """ + Parameters + ---------- + function_identifier : str + Image function resource identifier that need to be synced. + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + docker_client : Optional[DockerClient] + Docker client to be used for building and uploading images. + Defaults to docker.from_env() if None is provided. + """ + super().__init__(function_identifier, build_context, deploy_context, physical_id_mapping, stacks) + self._ecr_client = None + self._image_name = None + self._docker_client = docker_client + + def set_up(self) -> None: + super().set_up() + self._ecr_client = cast(Session, self._session).client("ecr") + if not self._docker_client: + self._docker_client = docker.from_env() + + def gather_resources(self) -> None: + """Build function image and save it in self._image_name""" + builder = ApplicationBuilder( + self._build_context.collect_build_resources(self._function_identifier), + self._build_context.build_dir, + self._build_context.base_dir, + self._build_context.cache_dir, + cached=False, + is_building_specific_resource=True, + manifest_path_override=self._build_context.manifest_path_override, + container_manager=self._build_context.container_manager, + mode=self._build_context.mode, + ) + self._image_name = builder.build().get(self._function_identifier) + + def compare_remote(self) -> bool: + return False + + def sync(self) -> None: + if not self._image_name: + LOG.debug("%sSkipping sync. Image name is None.", self.log_prefix) + return + function_physical_id = self.get_physical_id(self._function_identifier) + # Load ECR Repo from --image-repository + ecr_repo = self._deploy_context.image_repository + + # Load ECR Repo from --image-repositories + if ( + not ecr_repo + and self._deploy_context.image_repositories + and isinstance(self._deploy_context.image_repositories, dict) + ): + ecr_repo = self._deploy_context.image_repositories.get(self._function_identifier) + + # Load ECR Repo directly from remote function + if not ecr_repo: + LOG.debug("%sGetting ECR Repo from Remote Function", self.log_prefix) + function_result = self._lambda_client.get_function(FunctionName=function_physical_id) + ecr_repo = function_result.get("Code", dict()).get("ImageUri", "").split(":")[0] + ecr_uploader = ECRUploader(self._docker_client, self._ecr_client, ecr_repo, None) + image_uri = ecr_uploader.upload(self._image_name, self._function_identifier) + + self._lambda_client.update_function_code(FunctionName=function_physical_id, ImageUri=image_uri) + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] diff --git a/samcli/lib/sync/flows/layer_sync_flow.py b/samcli/lib/sync/flows/layer_sync_flow.py new file mode 100644 index 0000000000..e02754eca1 --- /dev/null +++ b/samcli/lib/sync/flows/layer_sync_flow.py @@ -0,0 +1,303 @@ +"""SyncFlow for Layers""" +import base64 +import hashlib +import logging +import os +import re +import tempfile +import uuid +from typing import Any, TYPE_CHECKING, cast, Dict, List, Optional + +from boto3.session import Session +from samcli.lib.build.app_builder import ApplicationBuilder +from samcli.lib.package.utils import make_zip +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_resource_by_id +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.sync.exceptions import MissingPhysicalResourceError, NoLayerVersionsFoundError +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.sync.sync_flow_executor import HELP_TEXT_FOR_SYNC_INFRA +from samcli.lib.utils.hash import file_checksum + +if TYPE_CHECKING: + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class LayerSyncFlow(SyncFlow): + """SyncFlow for Lambda Layers""" + + _layer_identifier: str + _layer_physical_name: Optional[str] + _old_layer_version: Optional[int] + _new_layer_version: Optional[int] + _artifact_folder: Optional[str] + _zip_file: Optional[str] + _local_sha: Optional[str] + _s3_client: Any + _lambda_client: Any + _stacks: List[Stack] + + def __init__( + self, + layer_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + super().__init__(build_context, deploy_context, physical_id_mapping, f"Layer {layer_identifier}", stacks) + self._layer_identifier = layer_identifier + self._layer_physical_name = None + self._old_layer_version = None + self._new_layer_version = None + + def set_up(self) -> None: + super().set_up() + self._s3_client = cast(Session, self._session).client("s3") + self._lambda_client = cast(Session, self._session).client("lambda") + + # if layer is a serverless layer, its physical id contains hashes, try to find layer resource + if self._layer_identifier not in self._physical_id_mapping: + expression = re.compile(f"^{self._layer_identifier}[0-9a-z]{{10}}$") + for logical_id, _ in self._physical_id_mapping.items(): + # Skip over resources that do exist in the template as generated LayerVersion should not be in there + if get_resource_by_id(self._stacks, ResourceIdentifier(logical_id), True): + continue + # Check if logical ID starts with serverless layer and has 10 characters behind + if not expression.match(logical_id): + continue + + self._layer_physical_name = self.get_physical_id(logical_id).rsplit(":", 1)[0] + LOG.debug("%sLayer physical name has been set to %s", self.log_prefix, self._layer_identifier) + break + else: + raise MissingPhysicalResourceError( + self._layer_identifier, + self._physical_id_mapping, + ) + else: + self._layer_physical_name = self.get_physical_id(self._layer_identifier).rsplit(":", 1)[0] + LOG.debug("%sLayer physical name has been set to %s", self.log_prefix, self._layer_identifier) + + def gather_resources(self) -> None: + """Build layer and ZIP it into a temp file in self._zip_file""" + with self._get_lock_chain(): + builder = ApplicationBuilder( + self._build_context.collect_build_resources(self._layer_identifier), + self._build_context.build_dir, + self._build_context.base_dir, + self._build_context.cache_dir, + cached=True, + is_building_specific_resource=True, + manifest_path_override=self._build_context.manifest_path_override, + container_manager=self._build_context.container_manager, + mode=self._build_context.mode, + ) + LOG.debug("%sBuilding Layer", self.log_prefix) + self._artifact_folder = builder.build().get(self._layer_identifier) + + zip_file_path = os.path.join(tempfile.gettempdir(), f"data-{uuid.uuid4().hex}") + self._zip_file = make_zip(zip_file_path, self._artifact_folder) + LOG.debug("%sCreated artifact ZIP file: %s", self.log_prefix, self._zip_file) + self._local_sha = file_checksum(cast(str, self._zip_file), hashlib.sha256()) + + def compare_remote(self) -> bool: + """ + Compare Sha256 of the deployed layer code vs the one just built, True if they are same, False otherwise + """ + self._old_layer_version = self._get_latest_layer_version() + old_layer_info = self._lambda_client.get_layer_version( + LayerName=self._layer_physical_name, + VersionNumber=self._old_layer_version, + ) + remote_sha = base64.b64decode(old_layer_info.get("Content", {}).get("CodeSha256", "")).hex() + LOG.debug("%sLocal SHA: %s Remote SHA: %s", self.log_prefix, self._local_sha, remote_sha) + + return self._local_sha == remote_sha + + def sync(self) -> None: + """ + Publish new layer version, and delete the existing (old) one + """ + LOG.debug("%sPublishing new Layer Version", self.log_prefix) + self._new_layer_version = self._publish_new_layer_version() + self._delete_old_layer_version() + + def _publish_new_layer_version(self) -> int: + """ + Publish new layer version and keep new layer version arn so that we can update related functions + """ + layer_resource = cast(Dict[str, Any], self._get_resource(self._layer_identifier)) + compatible_runtimes = layer_resource.get("Properties", {}).get("CompatibleRuntimes", []) + with open(cast(str, self._zip_file), "rb") as zip_file: + data = zip_file.read() + layer_publish_result = self._lambda_client.publish_layer_version( + LayerName=self._layer_physical_name, Content={"ZipFile": data}, CompatibleRuntimes=compatible_runtimes + ) + LOG.debug("%sPublish Layer Version Result %s", self.log_prefix, layer_publish_result) + return int(layer_publish_result.get("Version")) + + def _delete_old_layer_version(self) -> None: + """ + Delete old layer version for not hitting the layer version limit + """ + LOG.debug( + "%sDeleting old Layer Version %s:%s", self.log_prefix, self._old_layer_version, self._old_layer_version + ) + delete_layer_version_result = self._lambda_client.delete_layer_version( + LayerName=self._layer_physical_name, + VersionNumber=self._old_layer_version, + ) + LOG.debug("%sDelete Layer Version Result %s", self.log_prefix, delete_layer_version_result) + + def gather_dependencies(self) -> List[SyncFlow]: + if self._zip_file and os.path.exists(self._zip_file): + os.remove(self._zip_file) + + dependencies: List[SyncFlow] = list() + if self._stacks: + function_provider = SamFunctionProvider(self._stacks) + for function in function_provider.get_all(): + if self._layer_identifier in [layer.full_path for layer in function.layers]: + LOG.debug( + "%sAdding function %s for updating its Layers with this new version", + self.log_prefix, + function.name, + ) + dependencies.append( + FunctionLayerReferenceSync( + function.full_path, + cast(str, self._layer_physical_name), + cast(int, self._new_layer_version), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + ) + return dependencies + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [ResourceAPICall(self._layer_identifier, ["Build"])] + + def _get_latest_layer_version(self): + """Fetches all layer versions from remote and returns the latest one""" + layer_versions = self._lambda_client.list_layer_versions(LayerName=self._layer_physical_name).get( + "LayerVersions", [] + ) + if not layer_versions: + raise NoLayerVersionsFoundError(self._layer_physical_name) + return layer_versions[0].get("Version") + + def _equality_keys(self) -> Any: + return self._layer_identifier + + +class FunctionLayerReferenceSync(SyncFlow): + """ + Used for updating new Layer version for the related functions + """ + + UPDATE_FUNCTION_CONFIGURATION = "UpdateFunctionConfiguration" + + _lambda_client: Any + + _function_identifier: str + _layer_physical_name: str + _old_layer_version: int + _new_layer_version: int + + def __init__( + self, + function_identifier: str, + layer_physical_name: str, + new_layer_version: int, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name="Function Layer Reference Sync " + function_identifier, + stacks=stacks, + ) + self._function_identifier = function_identifier + self._layer_physical_name = layer_physical_name + self._new_layer_version = new_layer_version + + def set_up(self) -> None: + super().set_up() + self._lambda_client = cast(Session, self._session).client("lambda") + + def sync(self) -> None: + """ + First read the current Layers property and update the old layer version arn with new one + then call the update function configuration to update the function with new layer version arn + """ + if not self._locks: + LOG.warning("%sLocks is None", self.log_prefix) + return + lock_key = SyncFlow._get_lock_key( + self._function_identifier, FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION + ) + lock = self._locks.get(lock_key) + if not lock: + LOG.warning("%s%s lock is None", self.log_prefix, lock_key) + return + + with lock: + new_layer_arn = f"{self._layer_physical_name}:{self._new_layer_version}" + + function_physical_id = self.get_physical_id(self._function_identifier) + get_function_result = self._lambda_client.get_function(FunctionName=function_physical_id) + + # get the current layer version arns + layer_arns = [layer.get("Arn") for layer in get_function_result.get("Configuration", {}).get("Layers", [])] + + # Check whether layer version is up to date + if new_layer_arn in layer_arns: + LOG.warning( + "%sLambda Function (%s) is already up to date with new Layer version (%d).", + self.log_prefix, + self._function_identifier, + self._new_layer_version, + ) + return + + # Check function uses layer + old_layer_arn = [layer_arn for layer_arn in layer_arns if layer_arn.startswith(self._layer_physical_name)] + old_layer_arn = old_layer_arn[0] if len(old_layer_arn) == 1 else None + if not old_layer_arn: + LOG.warning( + "%sLambda Function (%s) does not have layer (%s).%s", + self.log_prefix, + self._function_identifier, + self._layer_physical_name, + HELP_TEXT_FOR_SYNC_INFRA, + ) + return + + # remove the old layer version arn and add the new one + layer_arns.remove(old_layer_arn) + layer_arns.append(new_layer_arn) + self._lambda_client.update_function_configuration(FunctionName=function_physical_id, Layers=layer_arns) + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [ResourceAPICall(self._function_identifier, [FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION])] + + def compare_remote(self) -> bool: + return False + + def gather_resources(self) -> None: + pass + + def gather_dependencies(self) -> List["SyncFlow"]: + return [] + + def _equality_keys(self) -> Any: + return self._function_identifier, self._layer_physical_name, self._new_layer_version diff --git a/samcli/lib/sync/flows/rest_api_sync_flow.py b/samcli/lib/sync/flows/rest_api_sync_flow.py new file mode 100644 index 0000000000..881d2ab1a9 --- /dev/null +++ b/samcli/lib/sync/flows/rest_api_sync_flow.py @@ -0,0 +1,69 @@ +"""SyncFlow for RestApi""" +import logging +from typing import Dict, List, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.sync.flows.generic_api_sync_flow import GenericApiSyncFlow +from samcli.lib.providers.provider import ResourceIdentifier, Stack +from samcli.lib.providers.exceptions import MissingLocalDefinition + +# BuildContext and DeployContext will only be imported for type checking to improve performance +# since no instances of contexts will be instantiated in this class +if TYPE_CHECKING: + from samcli.commands.build.build_context import BuildContext + from samcli.commands.deploy.deploy_context import DeployContext + +LOG = logging.getLogger(__name__) + + +class RestApiSyncFlow(GenericApiSyncFlow): + """SyncFlow for RestApi's""" + + def __init__( + self, + api_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + api_identifier : str + RestApi resource identifier that needs to have associated RestApi updated. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + api_identifier, + build_context, + deploy_context, + physical_id_mapping, + log_name="RestApi " + api_identifier, + stacks=stacks, + ) + + def set_up(self) -> None: + super().set_up() + self._api_client = cast(Session, self._session).client("apigateway") + + def sync(self) -> None: + api_physical_id = self.get_physical_id(self._api_identifier) + if self._definition_uri is None: + LOG.error( + "%sImport HttpApi fails since no DefinitionUri defined in the template, \ +if you are using DefinitionBody please run sam sync --infra", + self.log_prefix, + ) + raise MissingLocalDefinition(ResourceIdentifier(self._api_identifier), "DefinitionUri") + LOG.debug("%sTrying to put RestAPI through client", self.log_prefix) + response = self._api_client.put_rest_api(restApiId=api_physical_id, mode="overwrite", body=self._swagger_body) + LOG.debug("%sPut RestApi Result: %s", self.log_prefix, response) diff --git a/samcli/lib/sync/flows/stepfunctions_sync_flow.py b/samcli/lib/sync/flows/stepfunctions_sync_flow.py new file mode 100644 index 0000000000..1728dff09b --- /dev/null +++ b/samcli/lib/sync/flows/stepfunctions_sync_flow.py @@ -0,0 +1,109 @@ +"""Base SyncFlow for StepFunctions""" +import logging +from typing import Any, Dict, List, TYPE_CHECKING, cast, Optional + + +from boto3.session import Session + +from samcli.lib.providers.provider import Stack, get_resource_by_id, ResourceIdentifier +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.providers.exceptions import MissingLocalDefinition + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class StepFunctionsSyncFlow(SyncFlow): + _state_machine_identifier: str + _stepfunctions_client: Any + _definition_uri: Optional[str] + _stacks: List[Stack] + _states_definition: Optional[str] + + def __init__( + self, + state_machine_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + """ + Parameters + ---------- + state_machine_identifier : str + State Machine resource identifier that need to be synced. + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + super().__init__( + build_context, + deploy_context, + physical_id_mapping, + log_name="StepFunctions " + state_machine_identifier, + stacks=stacks, + ) + self._state_machine_identifier = state_machine_identifier + self._stepfunctions_client = None + + def set_up(self) -> None: + super().set_up() + self._stepfunctions_client = cast(Session, self._session).client("stepfunctions") + + def gather_resources(self) -> None: + self._definition_uri = self._get_definition_file(self._state_machine_identifier) + self._states_definition = self._process_definition_file() + + def _process_definition_file(self) -> Optional[str]: + if self._definition_uri is None: + return None + with open(self._definition_uri, "r", encoding="utf-8") as states_file: + states_data = states_file.read() + return states_data + + def _get_definition_file(self, state_machine_identifier: str) -> Optional[str]: + state_machine_resource = get_resource_by_id(self._stacks, ResourceIdentifier(state_machine_identifier)) + if state_machine_resource is None: + return None + properties = state_machine_resource.get("Properties", {}) + definition_file = properties.get("DefinitionUri") + return cast(Optional[str], definition_file) + + def compare_remote(self) -> bool: + # Not comparing with remote right now, instead only making update api calls + # Note: describe state machine has a better rate limit then update state machine + # So if we face any throttling issues, comparing should be desired + return False + + def gather_dependencies(self) -> List[SyncFlow]: + return [] + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + return [] + + def _equality_keys(self): + return self._state_machine_identifier + + def sync(self) -> None: + state_machine_arn = self.get_physical_id(self._state_machine_identifier) + if self._definition_uri is None: + LOG.error( + "%sUpdate State Machine fails since no DefinitionUri defined in the template, \ +if you are using inline Definition please run sam sync --infra", + self.log_prefix, + ) + raise MissingLocalDefinition(ResourceIdentifier(self._state_machine_identifier), "DefinitionUri") + LOG.debug("%sTrying to update State Machine definition", self.log_prefix) + response = self._stepfunctions_client.update_state_machine( + stateMachineArn=state_machine_arn, definition=self._states_definition + ) + LOG.debug("%sUpdate State Machine: %s", self.log_prefix, response) diff --git a/samcli/lib/sync/flows/zip_function_sync_flow.py b/samcli/lib/sync/flows/zip_function_sync_flow.py new file mode 100644 index 0000000000..0c6656ba49 --- /dev/null +++ b/samcli/lib/sync/flows/zip_function_sync_flow.py @@ -0,0 +1,145 @@ +"""SyncFlow for ZIP based Lambda Functions""" +import hashlib +import logging +import os +import base64 +import tempfile +import uuid + +from contextlib import ExitStack +from typing import Any, Dict, List, Optional, TYPE_CHECKING, cast + +from boto3.session import Session + +from samcli.lib.providers.provider import Stack + +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.package.s3_uploader import S3Uploader +from samcli.lib.utils.hash import file_checksum +from samcli.lib.package.utils import make_zip + +from samcli.lib.build.app_builder import ApplicationBuilder +from samcli.lib.sync.sync_flow import ResourceAPICall + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) +MAXIMUM_FUNCTION_ZIP_SIZE = 50 * 1024 * 1024 # 50MB limit for Lambda direct ZIP upload + + +class ZipFunctionSyncFlow(FunctionSyncFlow): + """SyncFlow for ZIP based functions""" + + _s3_client: Any + _artifact_folder: Optional[str] + _zip_file: Optional[str] + _local_sha: Optional[str] + + def __init__( + self, + function_identifier: str, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + stacks: List[Stack], + ): + + """ + Parameters + ---------- + function_identifier : str + ZIP function resource identifier that need to be synced. + build_context : BuildContext + BuildContext + deploy_context : DeployContext + DeployContext + physical_id_mapping : Dict[str, str] + Physical ID Mapping + stacks : Optional[List[Stack]] + Stacks + """ + super().__init__(function_identifier, build_context, deploy_context, physical_id_mapping, stacks) + self._s3_client = None + self._artifact_folder = None + self._zip_file = None + self._local_sha = None + + def set_up(self) -> None: + super().set_up() + self._s3_client = cast(Session, self._session).client("s3") + + def gather_resources(self) -> None: + """Build function and ZIP it into a temp file in self._zip_file""" + with ExitStack() as exit_stack: + if self._function.layers: + exit_stack.enter_context(self._get_lock_chain()) + + builder = ApplicationBuilder( + self._build_context.collect_build_resources(self._function_identifier), + self._build_context.build_dir, + self._build_context.base_dir, + self._build_context.cache_dir, + cached=True, + is_building_specific_resource=True, + manifest_path_override=self._build_context.manifest_path_override, + container_manager=self._build_context.container_manager, + mode=self._build_context.mode, + ) + LOG.debug("%sBuilding Function", self.log_prefix) + self._artifact_folder = builder.build().get(self._function_identifier) + + zip_file_path = os.path.join(tempfile.gettempdir(), "data-" + uuid.uuid4().hex) + self._zip_file = make_zip(zip_file_path, self._artifact_folder) + LOG.debug("%sCreated artifact ZIP file: %s", self.log_prefix, self._zip_file) + self._local_sha = file_checksum(cast(str, self._zip_file), hashlib.sha256()) + + def compare_remote(self) -> bool: + remote_info = self._lambda_client.get_function(FunctionName=self.get_physical_id(self._function_identifier)) + remote_sha = base64.b64decode(remote_info["Configuration"]["CodeSha256"]).hex() + LOG.debug("%sLocal SHA: %s Remote SHA: %s", self.log_prefix, self._local_sha, remote_sha) + + return self._local_sha == remote_sha + + def sync(self) -> None: + if not self._zip_file: + LOG.debug("%sSkipping Sync. ZIP file is None.", self.log_prefix) + return + + zip_file_size = os.path.getsize(self._zip_file) + if zip_file_size < MAXIMUM_FUNCTION_ZIP_SIZE: + # Direct upload through Lambda API + LOG.debug("%sUploading Function Directly", self.log_prefix) + with open(self._zip_file, "rb") as zip_file: + data = zip_file.read() + self._lambda_client.update_function_code( + FunctionName=self.get_physical_id(self._function_identifier), ZipFile=data + ) + else: + # Upload to S3 first for oversized ZIPs + LOG.debug("%sUploading Function Through S3", self.log_prefix) + uploader = S3Uploader( + s3_client=self._s3_client, + bucket_name=self._deploy_context.s3_bucket, + prefix=self._deploy_context.s3_prefix, + kms_key_id=self._deploy_context.kms_key_id, + force_upload=True, + no_progressbar=True, + ) + s3_url = uploader.upload_with_dedup(self._zip_file) + s3_key = s3_url[5:].split("/", 1)[1] + self._lambda_client.update_function_code( + FunctionName=self.get_physical_id(self._function_identifier), + S3Bucket=self._deploy_context.s3_bucket, + S3Key=s3_key, + ) + + if os.path.exists(self._zip_file): + os.remove(self._zip_file) + + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + resource_calls = list() + for layer in self._function.layers: + resource_calls.append(ResourceAPICall(layer.full_path, ["Build"])) + return resource_calls diff --git a/samcli/lib/sync/sync_flow.py b/samcli/lib/sync/sync_flow.py new file mode 100644 index 0000000000..3d661f72da --- /dev/null +++ b/samcli/lib/sync/sync_flow.py @@ -0,0 +1,295 @@ +"""SyncFlow base class """ +import logging + +from abc import ABC, abstractmethod +from threading import Lock +from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING, cast +from boto3.session import Session + +from samcli.lib.providers.provider import get_resource_by_id + +from samcli.lib.providers.provider import ResourceIdentifier, Stack +from samcli.lib.utils.lock_distributor import LockDistributor, LockChain +from samcli.lib.sync.exceptions import MissingLockException, MissingPhysicalResourceError + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +# Logging with multiple processes is not safe. Use a log queue in the future. +# https://docs.python.org/3/howto/logging-cookbook.html#:~:text=Although%20logging%20is%20thread%2Dsafe,across%20multiple%20processes%20in%20Python. +LOG = logging.getLogger(__name__) + + +class ResourceAPICall(NamedTuple): + """Named tuple for a resource and its potential API calls""" + + resource_identifier: str + api_calls: List[str] + + +class SyncFlow(ABC): + """Base class for a SyncFlow""" + + _log_name: str + _build_context: "BuildContext" + _deploy_context: "DeployContext" + _stacks: Optional[List[Stack]] + _session: Optional[Session] + _physical_id_mapping: Dict[str, str] + _locks: Optional[Dict[str, Lock]] + + def __init__( + self, + build_context: "BuildContext", + deploy_context: "DeployContext", + physical_id_mapping: Dict[str, str], + log_name: str, + stacks: Optional[List[Stack]] = None, + ): + """ + Parameters + ---------- + build_context : BuildContext + BuildContext used for build related parameters + deploy_context : BuildContext + DeployContext used for this deploy related parameters + physical_id_mapping : Dict[str, str] + Mapping between resource logical identifier and physical identifier + log_name : str + Name to be used for logging purposes + stacks : List[Stack], optional + List of stacks containing a root stack and optional nested stacks + """ + self._build_context = build_context + self._deploy_context = deploy_context + self._log_name = log_name + self._stacks = stacks + self._session = None + self._physical_id_mapping = physical_id_mapping + self._locks = None + + def set_up(self) -> None: + """Clients and other expensives setups should be handled here instead of constructor""" + self._session = Session(profile_name=self._deploy_context.profile, region_name=self._deploy_context.region) + + @abstractmethod + def gather_resources(self) -> None: + """Local operations that need to be done before comparison and syncing with remote + Ex: Building lambda functions + """ + raise NotImplementedError("gather_resources") + + @abstractmethod + def compare_remote(self) -> bool: + """Comparison between local and remote resources. + This can be used for optimization if comparison is a lot faster than sync. + If the resources are identical, sync and gather dependencies will be skipped. + Simply return False if there is no comparison needed. + Ex: Comparing local Lambda function artifact with remote SHA256 + + Returns + ------- + bool + Return True if local and remote are in sync. Skipping rest of the execution. + Return False otherwise. + """ + raise NotImplementedError("compare_remote") + + @abstractmethod + def sync(self) -> None: + """Step that syncs local resources with remote. + Ex: Call UpdateFunctionCode for Lambda Functions + """ + raise NotImplementedError("sync") + + @abstractmethod + def gather_dependencies(self) -> List["SyncFlow"]: + """Gather a list of SyncFlows that should be executed after the current change. + This can be sync flows for other resources that depends on the current one. + Ex: Update Lambda functions if a layer sync flow creates a new version. + + Returns + ------ + List[SyncFlow] + List of sync flows that need to be executed after the current one finishes. + """ + raise NotImplementedError("update_dependencies") + + @abstractmethod + def _get_resource_api_calls(self) -> List[ResourceAPICall]: + """Get resources and their associating API calls. This is used for locking purposes. + Returns + ------- + Dict[str, List[str]] + Key as resource logical ID + Value as list of api calls that the resource can make + """ + raise NotImplementedError("_get_resource_api_calls") + + def get_lock_keys(self) -> List[str]: + """Get a list of function + API calls that can be used as keys for LockDistributor + + Returns + ------- + List[str] + List of keys for all resources and their API calls + """ + lock_keys = list() + for resource_api_calls in self._get_resource_api_calls(): + for api_call in resource_api_calls.api_calls: + lock_keys.append(SyncFlow._get_lock_key(resource_api_calls.resource_identifier, api_call)) + return lock_keys + + def set_locks_with_distributor(self, distributor: LockDistributor): + """Set locks to be used with a LockDistributor. Keys should be generated using get_lock_keys(). + + Parameters + ---------- + distributor : LockDistributor + Lock distributor + """ + self.set_locks_with_dict(distributor.get_locks(self.get_lock_keys())) + + def set_locks_with_dict(self, locks: Dict[str, Lock]): + """Set locks to be used. Keys should be generated using get_lock_keys(). + + Parameters + ---------- + locks : Dict[str, Lock] + Dict of locks with keys from get_lock_keys() + """ + self._locks = locks + + @staticmethod + def _get_lock_key(logical_id: str, api_call: str) -> str: + """Get a single lock key for a pair of resource and API call. + + Parameters + ---------- + logical_id : str + Logical ID of a resource. + api_call : str + API call the resource will use. + + Returns + ------- + str + String key created with logical ID and API call name. + """ + return logical_id + "_" + api_call + + def _get_lock_chain(self) -> LockChain: + """Return a LockChain object for all the locks + + Returns + ------- + Optional[LockChain] + A LockChain object containing all locks. None if there are no locks. + """ + if self._locks: + return LockChain(self._locks) + raise MissingLockException("Missing Locks for LockChain") + + def _get_resource(self, resource_identifier: str) -> Optional[Dict[str, Any]]: + """Get a resource dict with resource identifier + + Parameters + ---------- + resource_identifier : str + Resource identifier + + Returns + ------- + Optional[Dict[str, Any]] + Resource dict containing its template fields. + """ + return get_resource_by_id(self._stacks, ResourceIdentifier(resource_identifier)) if self._stacks else None + + def get_physical_id(self, resource_identifier: str) -> str: + """Get the physical ID of a resource using physical_id_mapping. This does not directly check with remote. + + Parameters + ---------- + resource_identifier : str + Resource identifier + + Returns + ------- + str + Resource physical ID + + Raises + ------ + MissingPhysicalResourceError + Resource does not exist in the physical ID mapping. + This could mean remote and local templates are not in sync. + """ + physical_id = self._physical_id_mapping.get(resource_identifier) + if not physical_id: + raise MissingPhysicalResourceError(resource_identifier) + + return physical_id + + @abstractmethod + def _equality_keys(self) -> Any: + """This method needs to be overridden to distinguish between multiple instances of SyncFlows + If the return values of two instances are the same, then those two instances will be assumed to be equal. + + Returns + ------- + Any + Anything that can be hashed and compared with "==" + """ + raise NotImplementedError("_equality_keys is not implemented.") + + def __hash__(self) -> int: + return hash((type(self), self._equality_keys())) + + def __eq__(self, o: object) -> bool: + if type(o) is not type(self): + return False + return cast(bool, self._equality_keys() == cast(SyncFlow, o)._equality_keys()) + + @property + def log_name(self) -> str: + """ + Returns + ------- + str + Human readable name/identifier for logging purposes + """ + return self._log_name + + @property + def log_prefix(self) -> str: + """ + Returns + ------- + str + Log prefix to be used for logging. + """ + return f"SyncFlow [{self.log_name}]: " + + def execute(self) -> List["SyncFlow"]: + """Execute the sync flow and returns a list of dependent sync flows. + Skips sync() and gather_dependencies() if compare() is True + + Returns + ------- + List[SyncFlow] + A list of dependent sync flows + """ + dependencies: List["SyncFlow"] = list() + LOG.debug("%sSetting Up", self.log_prefix) + self.set_up() + LOG.debug("%sGathering Resources", self.log_prefix) + self.gather_resources() + LOG.debug("%sComparing with Remote", self.log_prefix) + if not self.compare_remote(): + LOG.debug("%sSyncing", self.log_prefix) + self.sync() + LOG.debug("%sGathering Dependencies", self.log_prefix) + dependencies = self.gather_dependencies() + LOG.debug("%sFinished", self.log_prefix) + return dependencies diff --git a/samcli/lib/sync/sync_flow_executor.py b/samcli/lib/sync/sync_flow_executor.py new file mode 100644 index 0000000000..dac68ddf2f --- /dev/null +++ b/samcli/lib/sync/sync_flow_executor.py @@ -0,0 +1,318 @@ +"""Executor for SyncFlows""" +import logging +import time + +from queue import Queue +from typing import Callable, List, Optional, Set +from dataclasses import dataclass + +from threading import RLock +from concurrent.futures import ThreadPoolExecutor, Future + +from botocore.exceptions import ClientError + +from samcli.lib.utils.colors import Colored +from samcli.lib.sync.exceptions import ( + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, +) + +from samcli.lib.utils.lock_distributor import LockDistributor, LockDistributorType +from samcli.lib.sync.sync_flow import SyncFlow + +LOG = logging.getLogger(__name__) + +HELP_TEXT_FOR_SYNC_INFRA = " Try sam sync --infra or sam deploy." + + +@dataclass(frozen=True, eq=True) +class SyncFlowTask: + """Data struct for individual SyncFlow execution tasks""" + + # SyncFlow to be executed + sync_flow: SyncFlow + + # Should this task be ignored if there is a sync flow in the queue that's the same + dedup: bool + + +@dataclass(frozen=True, eq=True) +class SyncFlowResult: + """Data struct for SyncFlow results""" + + sync_flow: SyncFlow + dependent_sync_flows: List[SyncFlow] + + +@dataclass(frozen=True, eq=True) +class SyncFlowFuture: + """Data struct for SyncFlow futures""" + + sync_flow: SyncFlow + future: Future + + +def default_exception_handler(sync_flow_exception: SyncFlowException) -> None: + """Default exception handler for SyncFlowExecutor + This will try log and parse common SyncFlow exceptions. + + Parameters + ---------- + sync_flow_exception : SyncFlowException + SyncFlowException containing exception to be handled and SyncFlow that raised it + + Raises + ------ + exception + Unhandled exception + """ + exception = sync_flow_exception.exception + if isinstance(exception, MissingPhysicalResourceError): + LOG.error("Cannot find resource %s in remote.%s", exception.resource_identifier, HELP_TEXT_FOR_SYNC_INFRA) + elif ( + isinstance(exception, ClientError) + and exception.response.get("Error", dict()).get("Code", "") == "ResourceNotFoundException" + ): + LOG.error("Cannot find resource in remote.%s", HELP_TEXT_FOR_SYNC_INFRA) + LOG.error(exception.response.get("Error", dict()).get("Message", "")) + elif isinstance(exception, NoLayerVersionsFoundError): + LOG.error("Cannot find any versions for layer %s.%s", exception.layer_name_arn, HELP_TEXT_FOR_SYNC_INFRA) + else: + raise exception + + +class SyncFlowExecutor: + """Executor for SyncFlows + Can be used with ThreadPoolExecutor or ProcessPoolExecutor with/without manager + """ + + _flow_queue: Queue + _flow_queue_lock: RLock + _lock_distributor: LockDistributor + _running_flag: bool + _color: Colored + _running_futures: Set[SyncFlowFuture] + + def __init__( + self, + ) -> None: + self._flow_queue = Queue() + self._lock_distributor = LockDistributor(LockDistributorType.THREAD) + self._running_flag = False + self._flow_queue_lock = RLock() + self._color = Colored() + self._running_futures = set() + + def _add_sync_flow_task(self, task: SyncFlowTask) -> None: + """Add SyncFlowTask to the queue + + Parameters + ---------- + task : SyncFlowTask + SyncFlowTask to be added. + """ + # Lock flow_queue as check dedup and add is not atomic + with self._flow_queue_lock: + if task.dedup and task.sync_flow in [task.sync_flow for task in self._flow_queue.queue]: + LOG.debug("Found the same SyncFlow in queue. Skip adding.") + return + + task.sync_flow.set_locks_with_distributor(self._lock_distributor) + self._flow_queue.put(task) + + def add_sync_flow(self, sync_flow: SyncFlow, dedup: bool = True) -> None: + """Add a SyncFlow to queue to be executed + Locks will be set with LockDistributor + + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow to be executed + dedup : bool + SyncFlow will not be added if this flag is True and has a duplicate in the queue + """ + self._add_sync_flow_task(SyncFlowTask(sync_flow, dedup)) + + def is_running(self) -> bool: + """ + Returns + ------- + bool + Is executor running + """ + return self._running_flag + + def _can_exit(self) -> bool: + """ + Returns + ------- + bool + Can executor be safely exited + """ + return not self._running_futures and self._flow_queue.empty() + + def execute( + self, exception_handler: Optional[Callable[[SyncFlowException], None]] = default_exception_handler + ) -> None: + """Blocking execution of the SyncFlows + + Parameters + ---------- + exception_handler : Optional[Callable[[Exception], None]], optional + Function to be called if an exception is raised during the execution of a SyncFlow, + by default default_exception_handler.__func__ + """ + self._running_flag = True + with ThreadPoolExecutor() as executor: + self._running_futures.clear() + while True: + + self._execute_step(executor, exception_handler) + + # Exit execution if there are no running and pending sync flows + if self._can_exit(): + LOG.debug("No more SyncFlows in executor. Stopping.") + break + + # Sleep for a bit to cut down CPU utilization of this busy wait loop + time.sleep(0.1) + self._running_flag = False + + def _execute_step( + self, + executor: ThreadPoolExecutor, + exception_handler: Optional[Callable[[SyncFlowException], None]], + ) -> None: + """A single step in the execution flow + + Parameters + ---------- + executor : ThreadPoolExecutor + THreadPoolExecutor to be used for execution + exception_handler : Optional[Callable[[SyncFlowException], None]] + Exception handler + """ + # Execute all pending sync flows + with self._flow_queue_lock: + # Putting nonsubmitted tasks into this deferred tasks list + # to avoid modifying the queue while emptying it + deferred_tasks = list() + + # Go through all queued tasks and try to execute them + while not self._flow_queue.empty(): + sync_flow_task: SyncFlowTask = self._flow_queue.get() + + sync_flow_future = self._submit_sync_flow_task(executor, sync_flow_task) + + # sync_flow_future can be None if the task cannot be submitted currently + # Put it into deferred_tasks and add all of them at the end to avoid endless loop + if sync_flow_future: + self._running_futures.add(sync_flow_future) + LOG.info(self._color.cyan(f"Syncing {sync_flow_future.sync_flow.log_name}...")) + else: + deferred_tasks.append(sync_flow_task) + + # Add tasks that cannot be executed yet + for task in deferred_tasks: + self._add_sync_flow_task(task) + + # Check for finished sync flows + for sync_flow_future in set(self._running_futures): + if self._handle_result(sync_flow_future, exception_handler): + self._running_futures.remove(sync_flow_future) + + def _submit_sync_flow_task( + self, executor: ThreadPoolExecutor, sync_flow_task: SyncFlowTask + ) -> Optional[SyncFlowFuture]: + """Submit SyncFlowTask to be executed by ThreadPoolExecutor + and return its future + + Parameters + ---------- + executor : ThreadPoolExecutor + THreadPoolExecutor to be used for execution + sync_flow_task : SyncFlowTask + SyncFlowTask to be executed. + + Returns + ------- + Optional[SyncFlowFuture] + Returns SyncFlowFuture generated by the SyncFlowTask. + Can be None if the task cannot be executed yet. + """ + sync_flow = sync_flow_task.sync_flow + + # Check whether the same sync flow is already running or not + if sync_flow in [future.sync_flow for future in self._running_futures]: + return None + + sync_flow_future = SyncFlowFuture( + sync_flow=sync_flow, future=executor.submit(SyncFlowExecutor._sync_flow_execute_wrapper, sync_flow) + ) + + return sync_flow_future + + def _handle_result( + self, sync_flow_future: SyncFlowFuture, exception_handler: Optional[Callable[[SyncFlowException], None]] + ) -> bool: + """Checks and handles the result of a SyncFlowFuture + + Parameters + ---------- + sync_flow_future : SyncFlowFuture + The SyncFlowFuture that needs to be handled + exception_handler : Optional[Callable[[SyncFlowException], None]] + Exception handler that will be called if an exception is raised within the SyncFlow + + Returns + ------- + bool + Returns True if the SyncFlowFuture was finished and successfully handled, False otherwise. + """ + future = sync_flow_future.future + + if not future.done(): + return False + + exception = future.exception() + + if exception and isinstance(exception, SyncFlowException) and exception_handler: + # Exception handling + exception_handler(exception) + else: + # Add dependency sync flows to queue + sync_flow_result: SyncFlowResult = future.result() + for dependent_sync_flow in sync_flow_result.dependent_sync_flows: + self.add_sync_flow(dependent_sync_flow) + LOG.info(self._color.green(f"Finished syncing {sync_flow_result.sync_flow.log_name}.")) + return True + + @staticmethod + def _sync_flow_execute_wrapper(sync_flow: SyncFlow) -> SyncFlowResult: + """Simple wrapper method for executing SyncFlow and converting all Exceptions into SyncFlowException + + Parameters + ---------- + sync_flow : SyncFlow + SyncFlow to be executed + + Returns + ------- + SyncFlowResult + SyncFlowResult for the SyncFlow executed + + Raises + ------ + SyncFlowException + """ + dependent_sync_flows = [] + try: + dependent_sync_flows = sync_flow.execute() + except ClientError as e: + if e.response.get("Error", dict()).get("Code", "") == "ResourceNotFoundException": + raise SyncFlowException(sync_flow, MissingPhysicalResourceError()) from e + raise SyncFlowException(sync_flow, e) from e + except Exception as e: + raise SyncFlowException(sync_flow, e) from e + return SyncFlowResult(sync_flow=sync_flow, dependent_sync_flows=dependent_sync_flows) diff --git a/samcli/lib/sync/sync_flow_factory.py b/samcli/lib/sync/sync_flow_factory.py new file mode 100644 index 0000000000..e7ee96c8f8 --- /dev/null +++ b/samcli/lib/sync/sync_flow_factory.py @@ -0,0 +1,166 @@ +"""SyncFlow Factory for creating SyncFlows based on resource types""" +import logging +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, cast + +from samcli.lib.providers.provider import Stack, get_resource_by_id, ResourceIdentifier +from samcli.lib.sync.flows.layer_sync_flow import LayerSyncFlow +from samcli.lib.utils.packagetype import ZIP, IMAGE +from samcli.lib.utils.resource_type_based_factory import ResourceTypeBasedFactory + +from samcli.lib.sync.sync_flow import SyncFlow +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.sync.flows.zip_function_sync_flow import ZipFunctionSyncFlow +from samcli.lib.sync.flows.image_function_sync_flow import ImageFunctionSyncFlow +from samcli.lib.sync.flows.rest_api_sync_flow import RestApiSyncFlow +from samcli.lib.sync.flows.http_api_sync_flow import HttpApiSyncFlow +from samcli.lib.sync.flows.stepfunctions_sync_flow import StepFunctionsSyncFlow +from samcli.lib.utils.boto_utils import get_boto_resource_provider_with_config +from samcli.lib.utils.cloudformation import get_physical_id_mapping +from samcli.lib.utils.resources import ( + AWS_SERVERLESS_FUNCTION, + AWS_LAMBDA_FUNCTION, + AWS_SERVERLESS_LAYERVERSION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_API, + AWS_APIGATEWAY_RESTAPI, + AWS_SERVERLESS_HTTPAPI, + AWS_APIGATEWAY_V2_API, + AWS_SERVERLESS_STATEMACHINE, + AWS_STEPFUNCTIONS_STATEMACHINE, +) + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.build.build_context import BuildContext + +LOG = logging.getLogger(__name__) + + +class SyncFlowFactory(ResourceTypeBasedFactory[SyncFlow]): # pylint: disable=E1136 + """Factory class for SyncFlow + Creates appropriate SyncFlow types based on stack resource types + """ + + _deploy_context: "DeployContext" + _build_context: "BuildContext" + _physical_id_mapping: Dict[str, str] + + def __init__(self, build_context: "BuildContext", deploy_context: "DeployContext", stacks: List[Stack]) -> None: + """ + Parameters + ---------- + build_context : BuildContext + BuildContext to be passed into each individual SyncFlow + deploy_context : DeployContext + DeployContext to be passed into each individual SyncFlow + stacks : List[Stack] + List of stacks containing a root stack and optional nested ones + """ + super().__init__(stacks) + self._deploy_context = deploy_context + self._build_context = build_context + self._physical_id_mapping = dict() + + def load_physical_id_mapping(self) -> None: + """Load physical IDs of the stack resources from remote""" + LOG.debug("Loading physical ID mapping") + self._physical_id_mapping = get_physical_id_mapping( + get_boto_resource_provider_with_config( + region_name=self._deploy_context.region if self._deploy_context.region else None + ), + self._deploy_context.stack_name, + ) + + def _create_lambda_flow( + self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any] + ) -> Optional[FunctionSyncFlow]: + package_type = resource.get("Properties", dict()).get("PackageType", ZIP) + if package_type == ZIP: + return ZipFunctionSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + if package_type == IMAGE: + return ImageFunctionSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + return None + + def _create_layer_flow(self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any]) -> SyncFlow: + return LayerSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + def _create_rest_api_flow(self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any]) -> SyncFlow: + return RestApiSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + def _create_api_flow(self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any]) -> SyncFlow: + return HttpApiSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + def _create_stepfunctions_flow( + self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any] + ) -> Optional[SyncFlow]: + definition_substitutions = resource.get("Properties", dict()).get("DefinitionSubstitutions", None) + if definition_substitutions: + LOG.warning( + "DefinitionSubstitutions property is specified in resource %s. Skipping this resource. " + "Code sync for StepFunctions does not go through CFN, please run sam sync --infra to update.", + resource_identifier, + ) + return None + return StepFunctionsSyncFlow( + str(resource_identifier), + self._build_context, + self._deploy_context, + self._physical_id_mapping, + self._stacks, + ) + + GeneratorFunction = Callable[["SyncFlowFactory", ResourceIdentifier, Dict[str, Any]], Optional[SyncFlow]] + GENERATOR_MAPPING: Dict[str, GeneratorFunction] = { + AWS_LAMBDA_FUNCTION: _create_lambda_flow, + AWS_SERVERLESS_FUNCTION: _create_lambda_flow, + AWS_SERVERLESS_LAYERVERSION: _create_layer_flow, + AWS_LAMBDA_LAYERVERSION: _create_layer_flow, + AWS_SERVERLESS_API: _create_rest_api_flow, + AWS_APIGATEWAY_RESTAPI: _create_rest_api_flow, + AWS_SERVERLESS_HTTPAPI: _create_api_flow, + AWS_APIGATEWAY_V2_API: _create_api_flow, + AWS_SERVERLESS_STATEMACHINE: _create_stepfunctions_flow, + AWS_STEPFUNCTIONS_STATEMACHINE: _create_stepfunctions_flow, + } + + # SyncFlow mapping between resource type and creation function + # Ignoring no-self-use as PyLint has a bug with Generic Abstract Classes + def _get_generator_mapping(self) -> Dict[str, GeneratorFunction]: # pylint: disable=no-self-use + return SyncFlowFactory.GENERATOR_MAPPING + + def create_sync_flow(self, resource_identifier: ResourceIdentifier) -> Optional[SyncFlow]: + resource = get_resource_by_id(self._stacks, resource_identifier) + generator = self._get_generator_function(resource_identifier) + if not generator or not resource: + return None + return cast(SyncFlowFactory.GeneratorFunction, generator)(self, resource_identifier, resource) diff --git a/samcli/lib/sync/watch_manager.py b/samcli/lib/sync/watch_manager.py new file mode 100644 index 0000000000..bfb2772fa2 --- /dev/null +++ b/samcli/lib/sync/watch_manager.py @@ -0,0 +1,228 @@ +""" +WatchManager for Sync Watch Logic +""" +import logging +import time +import threading + +from typing import List, Optional, TYPE_CHECKING + +from samcli.lib.utils.colors import Colored +from samcli.lib.providers.exceptions import MissingCodeUri, MissingLocalDefinition + +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_all_resource_ids +from samcli.lib.utils.code_trigger_factory import CodeTriggerFactory +from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider +from samcli.lib.utils.path_observer import HandlerObserver + +from samcli.lib.sync.sync_flow_factory import SyncFlowFactory +from samcli.lib.sync.exceptions import MissingPhysicalResourceError, SyncFlowException +from samcli.lib.utils.resource_trigger import OnChangeCallback, TemplateTrigger +from samcli.lib.sync.continuous_sync_flow_executor import ContinuousSyncFlowExecutor + +if TYPE_CHECKING: + from samcli.commands.deploy.deploy_context import DeployContext + from samcli.commands.package.package_context import PackageContext + from samcli.commands.build.build_context import BuildContext + +DEFAULT_WAIT_TIME = 1 +LOG = logging.getLogger(__name__) + + +class WatchManager: + _stacks: Optional[List[Stack]] + _template: str + _build_context: "BuildContext" + _package_context: "PackageContext" + _deploy_context: "DeployContext" + _sync_flow_factory: Optional[SyncFlowFactory] + _sync_flow_executor: ContinuousSyncFlowExecutor + _executor_thread: Optional[threading.Thread] + _observer: HandlerObserver + _trigger_factory: Optional[CodeTriggerFactory] + _waiting_infra_sync: bool + _color: Colored + + def __init__( + self, + template: str, + build_context: "BuildContext", + package_context: "PackageContext", + deploy_context: "DeployContext", + ): + """Manager for sync watch execution logic. + This manager will observe template and its code resources. + Automatically execute infra/code syncs when changes are detected. + + Parameters + ---------- + template : str + Template file path + build_context : BuildContext + BuildContext + package_context : PackageContext + PackageContext + deploy_context : DeployContext + DeployContext + """ + self._stacks = None + self._template = template + self._build_context = build_context + self._package_context = package_context + self._deploy_context = deploy_context + + self._sync_flow_factory = None + self._sync_flow_executor = ContinuousSyncFlowExecutor() + self._executor_thread = None + + self._observer = HandlerObserver() + self._trigger_factory = None + + self._waiting_infra_sync = False + self._color = Colored() + + def queue_infra_sync(self) -> None: + """Queue up an infra structure sync. + A simple bool flag is suffice + """ + self._waiting_infra_sync = True + + def _update_stacks(self) -> None: + """ + Reloads template and its stacks. + Update all other member that also depends on the stacks. + This should be called whenever there is a change to the template. + """ + self._stacks = SamLocalStackProvider.get_stacks(self._template)[0] + self._sync_flow_factory = SyncFlowFactory(self._build_context, self._deploy_context, self._stacks) + self._sync_flow_factory.load_physical_id_mapping() + self._trigger_factory = CodeTriggerFactory(self._stacks) + + def _add_code_triggers(self) -> None: + """Create CodeResourceTrigger for all resources and add their handlers to observer""" + if not self._stacks or not self._trigger_factory: + return + resource_ids = get_all_resource_ids(self._stacks) + for resource_id in resource_ids: + try: + trigger = self._trigger_factory.create_trigger(resource_id, self._on_code_change_wrapper(resource_id)) + except (MissingCodeUri, MissingLocalDefinition): + LOG.debug("CodeTrigger not created as CodeUri or DefinitionUri is missing for %s.", str(resource_id)) + continue + + if not trigger: + continue + self._observer.schedule_handlers(trigger.get_path_handlers()) + + def _add_template_trigger(self) -> None: + """Create TemplateTrigger and add its handlers to observer""" + template_trigger = TemplateTrigger(self._template, lambda _=None: self.queue_infra_sync()) + self._observer.schedule_handlers(template_trigger.get_path_handlers()) + + def _execute_infra_context(self) -> None: + """Execute infrastructure sync""" + self._build_context.set_up() + self._build_context.run() + self._package_context.run() + self._deploy_context.run() + + def _start_code_sync(self) -> None: + """Start SyncFlowExecutor in a separate thread.""" + if not self._executor_thread or not self._executor_thread.is_alive(): + self._executor_thread = threading.Thread( + target=lambda: self._sync_flow_executor.execute( + exception_handler=self._watch_sync_flow_exception_handler + ) + ) + self._executor_thread.start() + + def _stop_code_sync(self) -> None: + """Blocking call that stops SyncFlowExecutor and waits for it to finish.""" + if self._executor_thread and self._executor_thread.is_alive(): + self._sync_flow_executor.stop() + self._executor_thread.join() + + def start(self) -> None: + """Start WatchManager and watch for changes to the template and its code resources.""" + + # The actual execution is done in _start() + # This is a wrapper for gracefully handling Ctrl+C or other termination cases. + try: + self.queue_infra_sync() + self._start() + except KeyboardInterrupt: + LOG.info(self._color.cyan("Shutting down sync watch...")) + self._observer.stop() + self._stop_code_sync() + LOG.info(self._color.green("Sync watch stopped.")) + + def _start(self) -> None: + """Start WatchManager and watch for changes to the template and its code resources.""" + self._observer.start() + while True: + if self._waiting_infra_sync: + self._execute_infra_sync() + time.sleep(1) + + def _execute_infra_sync(self) -> None: + LOG.info(self._color.cyan("Queued infra sync. Wating for in progress code syncs to complete...")) + self._waiting_infra_sync = False + self._stop_code_sync() + try: + LOG.info(self._color.cyan("Starting infra sync.")) + self._execute_infra_context() + except Exception as e: + LOG.error( + self._color.red("Failed to sync infra. Code sync is paused until template/stack is fixed."), + exc_info=e, + ) + # Unschedule all triggers and only add back the template one as infra sync is incorrect. + self._observer.unschedule_all() + self._add_template_trigger() + else: + # Update stacks and repopulate triggers + # Trigger are not removed until infra sync is finished as there + # can be code changes during infra sync. + self._observer.unschedule_all() + self._update_stacks() + self._add_template_trigger() + self._add_code_triggers() + self._start_code_sync() + LOG.info(self._color.green("Infra sync completed.")) + + def _on_code_change_wrapper(self, resource_id: ResourceIdentifier) -> OnChangeCallback: + """Wrapper method that generates a callback for code changes. + + Parameters + ---------- + resource_id : ResourceIdentifier + Resource that associates to the callback + + Returns + ------- + OnChangeCallback + Callback function + """ + + def on_code_change(_=None): + sync_flow = self._sync_flow_factory.create_sync_flow(resource_id) + if sync_flow and not self._waiting_infra_sync: + self._sync_flow_executor.add_delayed_sync_flow(sync_flow, dedup=True, wait_time=DEFAULT_WAIT_TIME) + + return on_code_change + + def _watch_sync_flow_exception_handler(self, sync_flow_exception: SyncFlowException) -> None: + """Exception handler for watch. + Simply logs unhandled exceptions instead of failing the entire process. + + Parameters + ---------- + sync_flow_exception : SyncFlowException + SyncFlowException + """ + exception = sync_flow_exception.exception + if isinstance(exception, MissingPhysicalResourceError): + LOG.warning(self._color.yellow("Missing physical resource. Infra sync will be started.")) + self.queue_infra_sync() + else: + LOG.error(self._color.red("Code sync encountered an error."), exc_info=exception) diff --git a/samcli/lib/utils/boto_utils.py b/samcli/lib/utils/boto_utils.py new file mode 100644 index 0000000000..11dd2ffa48 --- /dev/null +++ b/samcli/lib/utils/boto_utils.py @@ -0,0 +1,84 @@ +""" +This module contains utility functions for boto3 library +""" +from typing import Any +from typing_extensions import Protocol + +import boto3 +from botocore.config import Config + +from samcli import __version__ +from samcli.cli.global_config import GlobalConfig + + +def get_boto_config_with_user_agent(**kwargs) -> Config: + """ + Automatically add user agent string to boto configs. + + Parameters + ---------- + kwargs : + key=value params which will be added to the Config object + + Returns + ------- + Config + Returns config instance which contains given parameters in it + """ + gc = GlobalConfig() + return Config( + user_agent_extra=f"aws-sam-cli/{__version__}/{gc.installation_id}" + if gc.telemetry_enabled + else f"aws-sam-cli/{__version__}", + **kwargs, + ) + + +# Type definition of following boto providers, which is equal to Callable[[str], Any] +class BotoProviderType(Protocol): + def __call__(self, service_name: str) -> Any: + ... + + +def get_boto_client_provider_with_config(**kwargs) -> BotoProviderType: + """ + Returns a wrapper function for boto client with given configuration. It can be used like; + + client_provider = get_boto_client_wrapper_with_config(region_name=region) + lambda_client = client_provider("lambda") + + Parameters + ---------- + kwargs : + Key-value params that will be passed to get_boto_config_with_user_agent + + Returns + ------- + A callable function which will return a boto client + """ + # ignore typing because mypy tries to assert client_name with a valid service name + return lambda client_name: boto3.session.Session().client( + client_name, config=get_boto_config_with_user_agent(**kwargs) + ) + + +def get_boto_resource_provider_with_config(**kwargs) -> BotoProviderType: + """ + Returns a wrapper function for boto resource with given configuration. It can be used like; + + resource_provider = get_boto_resource_wrapper_with_config(region_name=region) + cloudformation_resource = resource_provider("cloudformation") + + Parameters + ---------- + kwargs : + Key-value params that will be passed to get_boto_config_with_user_agent + + Returns + ------- + A callable function which will return a boto resource + """ + # ignore typing because mypy tries to assert client_name with a valid service name + return lambda resource_name: boto3.session.Session().resource( + resource_name, config=get_boto_config_with_user_agent(**kwargs) + ) diff --git a/samcli/lib/utils/botoconfig.py b/samcli/lib/utils/botoconfig.py deleted file mode 100644 index 7a7bd6d792..0000000000 --- a/samcli/lib/utils/botoconfig.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Automatically add user agent string to boto configs. -""" -from botocore.config import Config - -from samcli import __version__ -from samcli.cli.global_config import GlobalConfig - - -def get_boto_config_with_user_agent(**kwargs): - gc = GlobalConfig() - return Config( - user_agent_extra=f"aws-sam-cli/{__version__}/{gc.installation_id}" - if gc.telemetry_enabled - else f"aws-sam-cli/{__version__}", - **kwargs, - ) diff --git a/samcli/lib/utils/cloudformation.py b/samcli/lib/utils/cloudformation.py new file mode 100644 index 0000000000..b6590550d4 --- /dev/null +++ b/samcli/lib/utils/cloudformation.py @@ -0,0 +1,127 @@ +""" +This utility file contains methods to read information from certain CFN stack +""" +import logging +from typing import List, Dict, NamedTuple, Set, Optional + +from botocore.exceptions import ClientError + +from samcli.lib.utils.boto_utils import BotoProviderType + +LOG = logging.getLogger(__name__) + + +class CloudFormationResourceSummary(NamedTuple): + """ + Keeps information about CFN resource + """ + + resource_type: str + logical_resource_id: str + physical_resource_id: str + + +def get_physical_id_mapping( + boto_resource_provider: BotoProviderType, stack_name: str, resource_types: Optional[Set[str]] = None +) -> Dict[str, str]: + """ + Uses get_resource_summaries method to gather resource summaries and creates a dictionary which contains + logical_id to physical_id mapping + + Parameters + ---------- + boto_resource_provider : BotoProviderType + A callable which will return boto3 resource + stack_name : str + Name of the stack which is deployed to CFN + resource_types : Optional[Set[str]] + List of resource types, which will filter the results + + Returns + ------- + Dictionary of string, string which will contain logical_id to physical_id mapping + + """ + resource_summaries = get_resource_summaries(boto_resource_provider, stack_name, resource_types) + + resource_physical_id_map: Dict[str, str] = {} + for resource_summary in resource_summaries: + resource_physical_id_map[resource_summary.logical_resource_id] = resource_summary.physical_resource_id + + return resource_physical_id_map + + +def get_resource_summaries( + boto_resource_provider: BotoProviderType, stack_name: str, resource_types: Optional[Set[str]] = None +) -> List[CloudFormationResourceSummary]: + """ + Collects information about CFN resources and return their summary as list + + Parameters + ---------- + boto_resource_provider : BotoProviderType + A callable which will return boto3 resource + stack_name : str + Name of the stack which is deployed to CFN + resource_types : Optional[Set[str]] + List of resource types, which will filter the results + + Returns + ------- + List of CloudFormationResourceSummary which contains information about resources in the given stack + + """ + LOG.debug("Fetching stack (%s) resources", stack_name) + cfn_resource_summaries = boto_resource_provider("cloudformation").Stack(stack_name).resource_summaries.all() + resource_summaries: List[CloudFormationResourceSummary] = [] + + for cfn_resource_summary in cfn_resource_summaries: + resource_summary = CloudFormationResourceSummary( + cfn_resource_summary.resource_type, + cfn_resource_summary.logical_resource_id, + cfn_resource_summary.physical_resource_id, + ) + if resource_types and resource_summary.resource_type not in resource_types: + LOG.debug( + "Skipping resource %s since its type %s is not supported. Supported types %s", + resource_summary.logical_resource_id, + resource_summary.resource_type, + resource_types, + ) + continue + + resource_summaries.append(resource_summary) + + return resource_summaries + + +def get_resource_summary(boto_resource_provider: BotoProviderType, stack_name: str, resource_logical_id: str): + """ + Returns resource summary of given single resource with its logical id + + Parameters + ---------- + boto_resource_provider : BotoProviderType + A callable which will return boto3 resource + stack_name : str + Name of the stack which is deployed to CFN + resource_logical_id : str + Logical ID of the resource that will be returned as resource summary + + Returns + ------- + CloudFormationResourceSummary of the resource which is identified by given logical id + """ + try: + cfn_resource_summary = boto_resource_provider("cloudformation").StackResource(stack_name, resource_logical_id) + + return CloudFormationResourceSummary( + cfn_resource_summary.resource_type, + cfn_resource_summary.logical_resource_id, + cfn_resource_summary.physical_resource_id, + ) + except ClientError as e: + LOG.error( + "Failed to pull resource (%s) information from stack (%s)", resource_logical_id, stack_name, exc_info=e + ) + return None diff --git a/samcli/lib/utils/code_trigger_factory.py b/samcli/lib/utils/code_trigger_factory.py new file mode 100644 index 0000000000..f577ace803 --- /dev/null +++ b/samcli/lib/utils/code_trigger_factory.py @@ -0,0 +1,112 @@ +""" +Factory for creating CodeResourceTriggers +""" +import logging +from typing import Any, Callable, Dict, List, Optional, cast + +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_resource_by_id +from samcli.lib.utils.packagetype import IMAGE, ZIP +from samcli.lib.utils.resource_trigger import ( + CodeResourceTrigger, + DefinitionCodeTrigger, + LambdaImageCodeTrigger, + LambdaLayerCodeTrigger, + LambdaZipCodeTrigger, +) +from samcli.lib.utils.resource_type_based_factory import ResourceTypeBasedFactory +from samcli.lib.utils.resources import ( + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_V2_API, + AWS_LAMBDA_FUNCTION, + AWS_LAMBDA_LAYERVERSION, + AWS_SERVERLESS_API, + AWS_SERVERLESS_FUNCTION, + AWS_SERVERLESS_HTTPAPI, + AWS_SERVERLESS_LAYERVERSION, + AWS_SERVERLESS_STATEMACHINE, + AWS_STEPFUNCTIONS_STATEMACHINE, +) + +LOG = logging.getLogger(__name__) + + +class CodeTriggerFactory(ResourceTypeBasedFactory[CodeResourceTrigger]): # pylint: disable=E1136 + _stacks: List[Stack] + + def _create_lambda_trigger( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + resource: Dict[str, Any], + on_code_change: Callable, + ): + package_type = resource.get("Properties", dict()).get("PackageType", ZIP) + if package_type == ZIP: + return LambdaZipCodeTrigger(resource_identifier, self._stacks, on_code_change) + if package_type == IMAGE: + return LambdaImageCodeTrigger(resource_identifier, self._stacks, on_code_change) + return None + + def _create_layer_trigger( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + resource: Dict[str, Any], + on_code_change: Callable, + ): + return LambdaLayerCodeTrigger(resource_identifier, self._stacks, on_code_change) + + def _create_definition_code_trigger( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + resource: Dict[str, Any], + on_code_change: Callable, + ): + return DefinitionCodeTrigger(resource_identifier, resource_type, self._stacks, on_code_change) + + GeneratorFunction = Callable[ + ["CodeTriggerFactory", ResourceIdentifier, str, Dict[str, Any], Callable], Optional[CodeResourceTrigger] + ] + GENERATOR_MAPPING: Dict[str, GeneratorFunction] = { + AWS_LAMBDA_FUNCTION: _create_lambda_trigger, + AWS_SERVERLESS_FUNCTION: _create_lambda_trigger, + AWS_SERVERLESS_LAYERVERSION: _create_layer_trigger, + AWS_LAMBDA_LAYERVERSION: _create_layer_trigger, + AWS_SERVERLESS_API: _create_definition_code_trigger, + AWS_APIGATEWAY_RESTAPI: _create_definition_code_trigger, + AWS_SERVERLESS_HTTPAPI: _create_definition_code_trigger, + AWS_APIGATEWAY_V2_API: _create_definition_code_trigger, + AWS_SERVERLESS_STATEMACHINE: _create_definition_code_trigger, + AWS_STEPFUNCTIONS_STATEMACHINE: _create_definition_code_trigger, + } + + # Ignoring no-self-use as PyLint has a bug with Generic Abstract Classes + def _get_generator_mapping(self) -> Dict[str, GeneratorFunction]: # pylint: disable=no-self-use + return CodeTriggerFactory.GENERATOR_MAPPING + + def create_trigger( + self, resource_identifier: ResourceIdentifier, on_code_change: Callable + ) -> Optional[CodeResourceTrigger]: + """Create Trigger for the resource type + + Parameters + ---------- + resource_identifier : ResourceIdentifier + Resource associated with the trigger + on_code_change : Callable + Callback for code change + + Returns + ------- + Optional[CodeResourceTrigger] + CodeResourceTrigger for the resource + """ + resource = get_resource_by_id(self._stacks, resource_identifier) + generator = self._get_generator_function(resource_identifier) + resource_type = self._get_resource_type(resource_identifier) + if not generator or not resource or not resource_type: + return None + return cast(CodeTriggerFactory.GeneratorFunction, generator)( + self, resource_identifier, resource_type, resource, on_code_change + ) diff --git a/samcli/lib/utils/definition_validator.py b/samcli/lib/utils/definition_validator.py new file mode 100644 index 0000000000..54d06f4101 --- /dev/null +++ b/samcli/lib/utils/definition_validator.py @@ -0,0 +1,60 @@ +"""DefinitionValidator for Validating YAML and JSON Files""" +import logging +from pathlib import Path +from typing import Any, Dict, Optional + +import yaml +from samcli.yamlhelper import parse_yaml_file + +LOG = logging.getLogger(__name__) + + +class DefinitionValidator: + _path: Path + _detect_change: bool + _data: Optional[Dict[str, Any]] + + def __init__(self, path: Path, detect_change: bool = True, initialize_data: bool = True) -> None: + """ + Validator for JSON and YAML files. + Calling validate() will return True if the definition is valid and + has changes. + + Parameters + ---------- + path : Path + Path to the definition file + detect_change : bool, optional + validation will only be successful if there are changes between current and previous data, + by default True + initialize_data : bool, optional + Should initialize existing definition data before the first validate, by default True + Used along with detect_change + """ + super().__init__() + self._path = path + self._detect_change = detect_change + self._data = None + if initialize_data: + self.validate() + + def validate(self) -> bool: + """Validate json or yaml file. + + Returns + ------- + bool + True if it is valid, False otherwise. + If detect_change is set, False will also be returned if there is + no change compared to the previous validation. + """ + if not self._path.exists(): + return False + + old_data = self._data + try: + self._data = parse_yaml_file(str(self._path)) + return old_data != self._data if self._detect_change else True + except (ValueError, yaml.YAMLError) as e: + LOG.debug("DefinitionValidator failed to validate.", exc_info=e) + return False diff --git a/samcli/lib/utils/hash.py b/samcli/lib/utils/hash.py index a9cbae1885..57dffa25c1 100644 --- a/samcli/lib/utils/hash.py +++ b/samcli/lib/utils/hash.py @@ -3,39 +3,42 @@ """ import os import hashlib -from typing import List, Optional +from typing import Any, cast, List, Optional BLOCK_SIZE = 4096 -def file_checksum(file_name: str) -> str: +def file_checksum(file_name: str, hash_generator: Any = None) -> str: """ Parameters ---------- file_name: file name of the file for which md5 checksum is required. + hash_generator: hashlib _Hash object for generating hashes. Defaults to hashlib.md5. + Returns ------- - md5 checksum of the given file. + checksum of the given file. """ + # Default value is set here because default values are static mutable in Python + if not hash_generator: + hash_generator = hashlib.md5() with open(file_name, "rb") as file_handle: - md5 = hashlib.md5() - # Save current cursor position and reset cursor to start of file curpos = file_handle.tell() file_handle.seek(0) buf = file_handle.read(BLOCK_SIZE) while buf: - md5.update(buf) + hash_generator.update(buf) buf = file_handle.read(BLOCK_SIZE) # Restore file cursor's position file_handle.seek(curpos) - return md5.hexdigest() + return cast(str, hash_generator.hexdigest()) def dir_checksum(directory: str, followlinks: bool = True, ignore_list: Optional[List[str]] = None) -> str: diff --git a/samcli/lib/utils/lock_distributor.py b/samcli/lib/utils/lock_distributor.py new file mode 100644 index 0000000000..80d53edad0 --- /dev/null +++ b/samcli/lib/utils/lock_distributor.py @@ -0,0 +1,142 @@ +"""LockDistributor for creating and managing a set of locks""" +import threading +import multiprocessing +import multiprocessing.managers +from typing import Dict, List, Optional, cast +from enum import Enum, auto + + +class LockChain: + """Wrapper class for acquiring multiple locks in the same order to prevent dead locks + Can be used with `with` statement""" + + def __init__(self, lock_mapping: Dict[str, threading.Lock]): + """ + Parameters + ---------- + lock_mapping : Dict[str, threading.Lock] + Dictionary of locks with keys being used as generating reproduciable order for aquiring and releasing locks. + """ + self._locks = [value for _, value in sorted(lock_mapping.items())] + + def acquire(self) -> None: + """Aquire all locks in the LockChain""" + for lock in self._locks: + lock.acquire() + + def release(self) -> None: + """Release all locks in the LockChain""" + for lock in self._locks: + lock.release() + + def __enter__(self) -> "LockChain": + self.acquire() + return self + + def __exit__(self, exception_type, exception_value, traceback) -> None: + self.release() + + +class LockDistributorType(Enum): + """Types of LockDistributor""" + + THREAD = auto() + PROCESS = auto() + + +class LockDistributor: + """Dynamic lock distributor that supports threads and processes. + In the case of processes, both manager(server process) or shared memory can be used. + """ + + _lock_type: LockDistributorType + _manager: Optional[multiprocessing.managers.SyncManager] + _dict_lock: threading.Lock + _locks: Dict[str, threading.Lock] + + def __init__( + self, + lock_type: LockDistributorType = LockDistributorType.THREAD, + manager: Optional[multiprocessing.managers.SyncManager] = None, + ): + """[summary] + + Parameters + ---------- + lock_type : LockDistributorType, optional + Whether locking with threads or processes, by default LockDistributorType.THREAD + manager : Optional[multiprocessing.managers.SyncManager], optional + Optional process sync mananger for creating proxy locks, by default None + """ + self._lock_type = lock_type + self._manager = manager + self._dict_lock = self._create_new_lock() + self._locks = ( + self._manager.dict() + if self._lock_type == LockDistributorType.PROCESS and self._manager is not None + else dict() + ) + + def _create_new_lock(self) -> threading.Lock: + """Create a new lock based on lock type + + Returns + ------- + threading.Lock + Newly created lock + """ + if self._lock_type == LockDistributorType.THREAD: + return threading.Lock() + + return self._manager.Lock() if self._manager is not None else cast(threading.Lock, multiprocessing.Lock()) + + def get_lock(self, key: str) -> threading.Lock: + """Retrieve a lock associating with the key + If the lock does not exist, a new lock will be created. + + Parameters + ---------- + key : Key for retrieving the lock + + Returns + ------- + threading.Lock + Lock associated with the key + """ + with self._dict_lock: + if key not in self._locks: + self._locks[key] = self._create_new_lock() + return self._locks[key] + + def get_locks(self, keys: List[str]) -> Dict[str, threading.Lock]: + """Retrieve a list of locks associating with keys + + Parameters + ---------- + keys : List[str] + List of keys for retrieving the locks + + Returns + ------- + Dict[str, threading.Lock] + Dictionary mapping keys to locks + """ + lock_mapping = dict() + for key in keys: + lock_mapping[key] = self.get_lock(key) + return lock_mapping + + def get_lock_chain(self, keys: List[str]) -> LockChain: + """Similar to get_locks, but retrieves a LockChain object instead of a dictionary + + Parameters + ---------- + keys : List[str] + List of keys for retrieving the locks + + Returns + ------- + LockChain + LockChain object containing all the locks associated with keys + """ + return LockChain(self.get_locks(keys)) diff --git a/samcli/lib/utils/path_observer.py b/samcli/lib/utils/path_observer.py new file mode 100644 index 0000000000..615c8963fa --- /dev/null +++ b/samcli/lib/utils/path_observer.py @@ -0,0 +1,161 @@ +""" +HandlerObserver and its helper classes. +""" +import re + +from pathlib import Path +from typing import Callable, List, Optional +from dataclasses import dataclass + +from watchdog.observers import Observer +from watchdog.events import ( + FileSystemEvent, + FileSystemEventHandler, + RegexMatchingEventHandler, +) +from watchdog.observers.api import DEFAULT_OBSERVER_TIMEOUT, ObservedWatch + + +@dataclass +class PathHandler: + """PathHandler is an object that can be passed into + Bundle Observer directly for watching a specific path with + corresponding EventHandler + + Fields: + event_handler : FileSystemEventHandler + Handler for the event + path : Path + Path to the folder to be watched + recursive : bool, optional + True to watch child folders, by default False + static_folder : bool, optional + Should the observed folder name be static, by default False + See StaticFolderWrapper on the use case. + self_create : Optional[Callable[[], None]], optional + Callback when the folder to be observed itself is created, by default None + This will not be called if static_folder is False + self_delete : Optional[Callable[[], None]], optional + Callback when the folder to be observed itself is deleted, by default None + This will not be called if static_folder is False + """ + + event_handler: FileSystemEventHandler + path: Path + recursive: bool = False + static_folder: bool = False + self_create: Optional[Callable[[], None]] = None + self_delete: Optional[Callable[[], None]] = None + + +class StaticFolderWrapper: + """This class is used to alter the behavior of watchdog folder watches. + https://github.com/gorakhargosh/watchdog/issues/415 + By default, if a folder is renamed, the handler will still get triggered for the new folder + Ex: + 1. Create FolderA + 2. Watch FolderA + 3. Rename FolderA to FolderB + 4. Add file to FolderB + 5. Handler will get event for adding the file to FolderB but with event path still as FolderA + This class watches the parent folder and if the folder to be watched gets renamed or deleted, + the watch will be stopped and changes in the renamed folder will not be triggered. + """ + + def __init__(self, observer: "HandlerObserver", initial_watch: ObservedWatch, path_handler: PathHandler): + """[summary] + + Parameters + ---------- + observer : HandlerObserver + HandlerObserver + initial_watch : ObservedWatch + Initial watch for the folder to be watched that gets returned by HandlerObserver + path_handler : PathHandler + PathHandler of the folder to be watched. + """ + self._observer = observer + self._path_handler = path_handler + self._watch = initial_watch + + def _on_parent_change(self, _: FileSystemEvent) -> None: + """Callback for changes detected in the parent folder""" + + # When folder is being watched but the folder does not exist + if self._watch and not self._path_handler.path.exists(): + if self._path_handler.self_delete: + self._path_handler.self_delete() + self._observer.unschedule(self._watch) + self._watch = None + # When folder is not being watched but the folder does exist + elif not self._watch and self._path_handler.path.exists(): + if self._path_handler.self_create: + self._path_handler.self_create() + self._watch = self._observer.schedule_handler(self._path_handler) + + def get_dir_parent_path_handler(self) -> PathHandler: + """Get PathHandler that watches the folder changes from the parent folder. + + Returns + ------- + PathHandler + PathHandler for the parent folder. This should be added back into the HandlerObserver. + """ + dir_path = self._path_handler.path.resolve() + parent_dir_path = dir_path.parent + parent_folder_handler = RegexMatchingEventHandler( + regexes=[f"^{re.escape(str(dir_path))}$"], + ignore_regexes=[], + ignore_directories=False, + case_sensitive=True, + ) + parent_folder_handler.on_any_event = self._on_parent_change + return PathHandler(path=parent_dir_path, event_handler=parent_folder_handler) + + +class HandlerObserver(Observer): # pylint: disable=too-many-ancestors + """ + Extended WatchDog Observer that takes in a single PathHandler object. + """ + + def __init__(self, timeout=DEFAULT_OBSERVER_TIMEOUT): + super().__init__(timeout=timeout) + + def schedule_handlers(self, path_handlers: List[PathHandler]) -> List[ObservedWatch]: + """Schedule a list of PathHandlers + + Parameters + ---------- + path_handlers : List[PathHandler] + List of PathHandlers to be scheduled + + Returns + ------- + List[ObservedWatch] + List of ObservedWatch corresponding to path_handlers in the same order. + """ + watches = list() + for path_handler in path_handlers: + watches.append(self.schedule_handler(path_handler)) + return watches + + def schedule_handler(self, path_handler: PathHandler) -> ObservedWatch: + """Schedule a PathHandler + + Parameters + ---------- + path_handler : PathHandler + PathHandler to be scheduled + + Returns + ------- + ObservedWatch + ObservedWatch corresponding to the PathHandler. + If static_folder is True, the parent folder watch will be returned instead. + """ + watch = self.schedule(path_handler.event_handler, str(path_handler.path), path_handler.recursive) + if path_handler.static_folder: + static_wrapper = StaticFolderWrapper(self, watch, path_handler) + parent_path_handler = static_wrapper.get_dir_parent_path_handler() + watch = self.schedule_handler(parent_path_handler) + return watch diff --git a/samcli/lib/utils/resource_trigger.py b/samcli/lib/utils/resource_trigger.py new file mode 100644 index 0000000000..eb9f7f31a7 --- /dev/null +++ b/samcli/lib/utils/resource_trigger.py @@ -0,0 +1,339 @@ +"""ResourceTrigger Classes for Creating PathHandlers According to a Resource""" +import re +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, List, Optional, cast + +from typing_extensions import Protocol +from watchdog.events import FileSystemEvent, PatternMatchingEventHandler, RegexMatchingEventHandler + +from samcli.lib.providers.exceptions import MissingCodeUri, MissingLocalDefinition +from samcli.lib.providers.provider import Function, LayerVersion, ResourceIdentifier, Stack, get_resource_by_id +from samcli.lib.providers.sam_function_provider import SamFunctionProvider +from samcli.lib.providers.sam_layer_provider import SamLayerProvider +from samcli.lib.utils.definition_validator import DefinitionValidator +from samcli.lib.utils.path_observer import PathHandler +from samcli.local.lambdafn.exceptions import FunctionNotFound, ResourceNotFound +from samcli.lib.utils.resources import RESOURCES_WITH_LOCAL_PATHS + + +class OnChangeCallback(Protocol): + """Callback Type""" + + def __call__(self, event: Optional[FileSystemEvent] = None) -> None: + pass + + +class ResourceTrigger(ABC): + """Abstract class for creating PathHandlers for a resource. + PathHandlers returned by get_path_handlers() can then be used with an observer for + detecting file changes associated with the resource.""" + + def __init__(self) -> None: + pass + + @abstractmethod + def get_path_handlers(self) -> List[PathHandler]: + """List of PathHandlers that corresponds to a resource + Returns + ------- + List[PathHandler] + List of PathHandlers that corresponds to a resource + """ + raise NotImplementedError("get_path_handleres is not implemented.") + + @staticmethod + def get_single_file_path_handler(file_path_str: str) -> PathHandler: + """Get PathHandler for watching a single file + + Parameters + ---------- + file_path_str : str + File path in string + + Returns + ------- + PathHandler + The PathHandler for the file specified + """ + file_path = Path(file_path_str).resolve() + folder_path = file_path.parent + file_handler = RegexMatchingEventHandler( + regexes=[f"^{re.escape(str(file_path))}$"], ignore_regexes=[], ignore_directories=True, case_sensitive=True + ) + return PathHandler(path=folder_path, event_handler=file_handler, recursive=False) + + @staticmethod + def get_dir_path_handler(dir_path_str: str) -> PathHandler: + """Get PathHandler for watching a single directory + + Parameters + ---------- + dir_path_str : str + Folder path in string + + Returns + ------- + PathHandler + The PathHandler for the folder specified + """ + dir_path = Path(dir_path_str).resolve() + file_handler = PatternMatchingEventHandler( + patterns=["*"], ignore_patterns=[], ignore_directories=False, case_sensitive=True + ) + return PathHandler(path=dir_path, event_handler=file_handler, recursive=True, static_folder=True) + + +class TemplateTrigger(ResourceTrigger): + _template_file: str + _on_template_change: OnChangeCallback + _validator: DefinitionValidator + + def __init__(self, template_file: str, on_template_change: OnChangeCallback) -> None: + """ + Parameters + ---------- + template_file : str + Template file to be watched + on_template_change : OnChangeCallback + Callback when template changes + """ + super().__init__() + self._template_file = template_file + self._on_template_change = on_template_change + self._validator = DefinitionValidator(Path(self._template_file)) + + def _validator_wrapper(self, event: Optional[FileSystemEvent] = None) -> None: + """Wrapper for callback that only executes if the template is valid and non-trivial changes are detected. + + Parameters + ---------- + event : Optional[FileSystemEvent], optional + """ + if self._validator.validate(): + self._on_template_change(event) + + def get_path_handlers(self) -> List[PathHandler]: + file_path_handler = ResourceTrigger.get_single_file_path_handler(self._template_file) + file_path_handler.event_handler.on_any_event = self._validator_wrapper + return [file_path_handler] + + +class CodeResourceTrigger(ResourceTrigger): + """Parent class for ResourceTriggers that are for a single template resource.""" + + _resource_identifier: ResourceIdentifier + _resource: Dict[str, Any] + _on_code_change: OnChangeCallback + + def __init__(self, resource_identifier: ResourceIdentifier, stacks: List[Stack], on_code_change: OnChangeCallback): + """ + Parameters + ---------- + resource_identifier : ResourceIdentifier + ResourceIdentifier + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when the resource files are changed. + + Raises + ------ + ResourceNotFound + Raised when the resource cannot be found in the stacks. + """ + super().__init__() + self._resource_identifier = resource_identifier + resource = get_resource_by_id(stacks, resource_identifier) + if not resource: + raise ResourceNotFound() + self._resource = resource + self._on_code_change = on_code_change + + +class LambdaFunctionCodeTrigger(CodeResourceTrigger): + _function: Function + _code_uri: str + + def __init__(self, function_identifier: ResourceIdentifier, stacks: List[Stack], on_code_change: OnChangeCallback): + """ + Parameters + ---------- + function_identifier : ResourceIdentifier + ResourceIdentifier for the function + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when function code files are changed. + + Raises + ------ + FunctionNotFound + raised when the function cannot be found in stacks + MissingCodeUri + raised when there is no CodeUri property in the function definition. + """ + super().__init__(function_identifier, stacks, on_code_change) + function = SamFunctionProvider(stacks).get(str(function_identifier)) + if not function: + raise FunctionNotFound() + self._function = function + + code_uri = self._get_code_uri() + if not code_uri: + raise MissingCodeUri() + self._code_uri = code_uri + + @abstractmethod + def _get_code_uri(self) -> Optional[str]: + """ + Returns + ------- + Optional[str] + Path for the folder to be watched. + """ + raise NotImplementedError() + + def get_path_handlers(self) -> List[PathHandler]: + """ + Returns + ------- + List[PathHandler] + PathHandlers for the code folder associated with the function + """ + dir_path_handler = ResourceTrigger.get_dir_path_handler(self._code_uri) + dir_path_handler.self_create = self._on_code_change + dir_path_handler.self_delete = self._on_code_change + dir_path_handler.event_handler.on_any_event = self._on_code_change + return [dir_path_handler] + + +class LambdaZipCodeTrigger(LambdaFunctionCodeTrigger): + def _get_code_uri(self) -> Optional[str]: + return self._function.codeuri + + +class LambdaImageCodeTrigger(LambdaFunctionCodeTrigger): + def _get_code_uri(self) -> Optional[str]: + if not self._function.metadata: + return None + return cast(Optional[str], self._function.metadata.get("DockerContext", None)) + + +class LambdaLayerCodeTrigger(CodeResourceTrigger): + _layer: LayerVersion + _code_uri: str + + def __init__( + self, + layer_identifier: ResourceIdentifier, + stacks: List[Stack], + on_code_change: OnChangeCallback, + ): + """ + Parameters + ---------- + layer_identifier : ResourceIdentifier + ResourceIdentifier for the layer + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when layer code files are changed. + + Raises + ------ + ResourceNotFound + raised when the layer cannot be found in stacks + MissingCodeUri + raised when there is no CodeUri property in the function definition. + """ + super().__init__(layer_identifier, stacks, on_code_change) + layer = SamLayerProvider(stacks).get(str(layer_identifier)) + if not layer: + raise ResourceNotFound() + self._layer = layer + code_uri = self._layer.codeuri + if not code_uri: + raise MissingCodeUri() + self._code_uri = code_uri + + def get_path_handlers(self) -> List[PathHandler]: + """ + Returns + ------- + List[PathHandler] + PathHandlers for the code folder associated with the layer + """ + dir_path_handler = ResourceTrigger.get_dir_path_handler(self._code_uri) + dir_path_handler.self_create = self._on_code_change + dir_path_handler.self_delete = self._on_code_change + dir_path_handler.event_handler.on_any_event = self._on_code_change + return [dir_path_handler] + + +class DefinitionCodeTrigger(CodeResourceTrigger): + _validator: DefinitionValidator + _definition_file: str + + def __init__( + self, + resource_identifier: ResourceIdentifier, + resource_type: str, + stacks: List[Stack], + on_code_change: OnChangeCallback, + ): + """ + Parameters + ---------- + resource_identifier : ResourceIdentifier + ResourceIdentifier for the Resource + resource_type : str + Resource type + stacks : List[Stack] + List of stacks + on_code_change : OnChangeCallback + Callback when definition file is changed. + """ + super().__init__(resource_identifier, stacks, on_code_change) + self._resource_type = resource_type + self._definition_file = self._get_definition_file() + self._validator = DefinitionValidator(Path(self._definition_file)) + + def _get_definition_file(self) -> str: + """ + Returns + ------- + str + JSON/YAML definition file path + + Raises + ------ + MissingLocalDefinition + raised when resource property related to definition path is not specified. + """ + property_name = RESOURCES_WITH_LOCAL_PATHS[self._resource_type][0] + definition_file = self._resource.get("Properties", {}).get(property_name) + if not definition_file or not isinstance(definition_file, str): + raise MissingLocalDefinition(self._resource_identifier, property_name) + return definition_file + + def _validator_wrapper(self, event: Optional[FileSystemEvent] = None): + """Wrapper for callback that only executes if the definition is valid and non-trivial changes are detected. + + Parameters + ---------- + event : Optional[FileSystemEvent], optional + """ + if self._validator.validate(): + self._on_code_change(event) + + def get_path_handlers(self) -> List[PathHandler]: + """ + Returns + ------- + List[PathHandler] + A single PathHandler for watching the definition file. + """ + file_path_handler = ResourceTrigger.get_single_file_path_handler(self._definition_file) + file_path_handler.event_handler.on_any_event = self._validator_wrapper + return [file_path_handler] diff --git a/samcli/lib/utils/resource_type_based_factory.py b/samcli/lib/utils/resource_type_based_factory.py new file mode 100644 index 0000000000..67a46f08af --- /dev/null +++ b/samcli/lib/utils/resource_type_based_factory.py @@ -0,0 +1,69 @@ +"""Base Factory Abstract Class for Creating Objects Specific to a Resource Type""" +import logging +from abc import ABC, abstractmethod +from typing import Callable, Dict, Generic, List, Optional, TypeVar + +from samcli.lib.providers.provider import ResourceIdentifier, Stack, get_resource_by_id + +LOG = logging.getLogger(__name__) + +T = TypeVar("T") # pylint: disable=invalid-name + + +class ResourceTypeBasedFactory(ABC, Generic[T]): + def __init__(self, stacks: List[Stack]) -> None: + self._stacks = stacks + + @abstractmethod + def _get_generator_mapping(self) -> Dict[str, Callable]: + """ + Returns + ------- + Dict[str, GeneratorFunction] + Mapping between resource type and generator function + """ + raise NotImplementedError() + + def _get_resource_type(self, resource_identifier: ResourceIdentifier) -> Optional[str]: + """Get resource type of the resource + + Parameters + ---------- + resource_identifier : ResourceIdentifier + + Returns + ------- + Optional[str] + Resource type of the resource + """ + resource = get_resource_by_id(self._stacks, resource_identifier) + if not resource: + LOG.debug("Resource %s does not exist.", str(resource_identifier)) + return None + + resource_type = resource.get("Type", None) + if not isinstance(resource_type, str): + LOG.debug("Resource %s has none string property Type.", str(resource_identifier)) + return None + return resource_type + + def _get_generator_function(self, resource_identifier: ResourceIdentifier) -> Optional[Callable]: + """Create an appropriate T object based on stack resource type + + Parameters + ---------- + resource_identifier : ResourceIdentifier + Resource identifier of the resource + + Returns + ------- + Optional[T] + Object T for the resource. Returns None if resource cannot be + found or have no associating T generator function. + """ + resource_type = self._get_resource_type(resource_identifier) + if not resource_type: + LOG.debug("Resource %s has invalid property Type.", str(resource_identifier)) + return None + generator = self._get_generator_mapping().get(resource_type, None) + return generator diff --git a/samcli/commands/_utils/resources.py b/samcli/lib/utils/resources.py similarity index 84% rename from samcli/commands/_utils/resources.py rename to samcli/lib/utils/resources.py index ce448d6968..7e3cfb8943 100644 --- a/samcli/commands/_utils/resources.py +++ b/samcli/lib/utils/resources.py @@ -1,26 +1,49 @@ """ -Enums for Resources and thier Location Properties, along with utility functions +Enums for Resources and their Location Properties, along with utility functions """ from collections import defaultdict -AWS_SERVERLESSREPO_APPLICATION = "AWS::ServerlessRepo::Application" +# Lambda AWS_SERVERLESS_FUNCTION = "AWS::Serverless::Function" +AWS_SERVERLESS_LAYERVERSION = "AWS::Serverless::LayerVersion" + +AWS_LAMBDA_FUNCTION = "AWS::Lambda::Function" +AWS_LAMBDA_LAYERVERSION = "AWS::Lambda::LayerVersion" + +# APIGW AWS_SERVERLESS_API = "AWS::Serverless::Api" AWS_SERVERLESS_HTTPAPI = "AWS::Serverless::HttpApi" + +AWS_APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" +AWS_APIGATEWAY_STAGE = "AWS::ApiGateway::Stage" +AWS_APIGATEWAY_RESOURCE = "AWS::ApiGateway::Resource" +AWS_APIGATEWAY_METHOD = "AWS::ApiGateway::Method" + +AWS_APIGATEWAY_V2_API = "AWS::ApiGatewayV2::Api" +AWS_APIGATEWAY_V2_INTEGRATION = "AWS::ApiGatewayV2::Integration" +AWS_APIGATEWAY_V2_ROUTE = "AWS::ApiGatewayV2::Route" +AWS_APIGATEWAY_V2_STAGE = "AWS::ApiGatewayV2::Stage" + +# SFN +AWS_SERVERLESS_STATEMACHINE = "AWS::Serverless::StateMachine" + +AWS_STEPFUNCTIONS_STATEMACHINE = "AWS::StepFunctions::StateMachine" + +# Others +AWS_SERVERLESS_APPLICATION = "AWS::Serverless::Application" + +AWS_SERVERLESSREPO_APPLICATION = "AWS::ServerlessRepo::Application" AWS_APPSYNC_GRAPHQLSCHEMA = "AWS::AppSync::GraphQLSchema" AWS_APPSYNC_RESOLVER = "AWS::AppSync::Resolver" AWS_APPSYNC_FUNCTIONCONFIGURATION = "AWS::AppSync::FunctionConfiguration" -AWS_LAMBDA_FUNCTION = "AWS::Lambda::Function" -AWS_APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" AWS_ELASTICBEANSTALK_APPLICATIONVERSION = "AWS::ElasticBeanstalk::ApplicationVersion" AWS_CLOUDFORMATION_MODULEVERSION = "AWS::CloudFormation::ModuleVersion" AWS_CLOUDFORMATION_RESOURCEVERSION = "AWS::CloudFormation::ResourceVersion" AWS_CLOUDFORMATION_STACK = "AWS::CloudFormation::Stack" -AWS_SERVERLESS_APPLICATION = "AWS::Serverless::Application" -AWS_LAMBDA_LAYERVERSION = "AWS::Lambda::LayerVersion" -AWS_SERVERLESS_LAYERVERSION = "AWS::Serverless::LayerVersion" AWS_GLUE_JOB = "AWS::Glue::Job" +AWS_SQS_QUEUE = "AWS::SQS::Queue" +AWS_KINESIS_STREAM = "AWS::Kinesis::Stream" AWS_SERVERLESS_STATEMACHINE = "AWS::Serverless::StateMachine" AWS_STEPFUNCTIONS_STATEMACHINE = "AWS::StepFunctions::StateMachine" AWS_ECR_REPOSITORY = "AWS::ECR::Repository" diff --git a/tests/integration/buildcmd/test_build_cmd.py b/tests/integration/buildcmd/test_build_cmd.py index 09860e0212..3c426c02e2 100644 --- a/tests/integration/buildcmd/test_build_cmd.py +++ b/tests/integration/buildcmd/test_build_cmd.py @@ -1349,6 +1349,36 @@ def test_cache_build(self, use_container, code_uri, function1_handler, function2 expected_messages, command_result, self._make_parameter_override_arg(overrides) ) + @skipIf(SKIP_DOCKER_TESTS, SKIP_DOCKER_MESSAGE) + def test_cached_build_with_env_vars(self): + """ + Build 2 times to verify that second time hits the cached build + """ + overrides = { + "FunctionCodeUri": "Python", + "Function1Handler": "main.first_function_handler", + "Function2Handler": "main.second_function_handler", + "FunctionRuntime": "python3.8", + } + cmdlist = self.get_command_list( + use_container=True, parameter_overrides=overrides, cached=True, container_env_var="FOO=BAR" + ) + + LOG.info("Running Command (cache should be invalid): %s", cmdlist) + command_result = run_command(cmdlist, cwd=self.working_dir) + self.assertTrue( + "Cache is invalid, running build and copying resources to function build definition" + in command_result.stderr.decode("utf-8") + ) + + LOG.info("Re-Running Command (valid cache should exist): %s", cmdlist) + command_result_with_cache = run_command(cmdlist, cwd=self.working_dir) + + self.assertTrue( + "Valid cache found, copying previously built resources from function build definition" + in command_result_with_cache.stderr.decode("utf-8") + ) + @skipIf( ((IS_WINDOWS and RUNNING_ON_CI) and not CI_OVERRIDE), diff --git a/tests/unit/commands/_utils/test_options.py b/tests/unit/commands/_utils/test_options.py index ea82e5cdbf..02240c403d 100644 --- a/tests/unit/commands/_utils/test_options.py +++ b/tests/unit/commands/_utils/test_options.py @@ -17,6 +17,7 @@ _TEMPLATE_OPTION_DEFAULT_VALUE, guided_deploy_stack_name, artifact_callback, + parameterized_option, resolve_s3_callback, image_repositories_callback, _space_separated_list_func_type, @@ -463,3 +464,33 @@ class TestSpaceSeparatedListInvalidDataTypes: def test_raise_value_error(self, test_input): with pytest.raises(ValueError): _space_separated_list_func_type(test_input) + + +class TestParameterizedOption(TestCase): + @parameterized_option + def option_dec_with_value(f, value=2): + def wrapper(): + return f(value) + + return wrapper + + @parameterized_option + def option_dec_without_value(f, value=2): + def wrapper(): + return f(value) + + return wrapper + + @option_dec_with_value(5) + def some_function_with_value(value): + return value + 2 + + @option_dec_without_value + def some_function_without_value(value): + return value + 2 + + def test_option_dec_with_value(self): + self.assertEqual(TestParameterizedOption.some_function_with_value(), 7) + + def test_option_dec_without_value(self): + self.assertEqual(TestParameterizedOption.some_function_without_value(), 4) diff --git a/tests/unit/commands/_utils/test_template.py b/tests/unit/commands/_utils/test_template.py index 1de707ec38..c75db92c67 100644 --- a/tests/unit/commands/_utils/test_template.py +++ b/tests/unit/commands/_utils/test_template.py @@ -7,7 +7,7 @@ from botocore.utils import set_value_from_jmespath from parameterized import parameterized, param -from samcli.commands._utils.resources import AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API +from samcli.lib.utils.resources import AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_API from samcli.commands._utils.template import ( get_template_data, METADATA_WITH_LOCAL_PATHS, diff --git a/tests/unit/commands/buildcmd/test_build_context.py b/tests/unit/commands/buildcmd/test_build_context.py index 3ab805a7ee..9576c95c0b 100644 --- a/tests/unit/commands/buildcmd/test_build_context.py +++ b/tests/unit/commands/buildcmd/test_build_context.py @@ -1,12 +1,26 @@ import os +from samcli.lib.build.app_builder import ApplicationBuilder from unittest import TestCase -from unittest.mock import patch, Mock, ANY +from unittest.mock import patch, Mock, ANY, call from parameterized import parameterized from samcli.local.lambdafn.exceptions import ResourceNotFound from samcli.commands.build.build_context import BuildContext from samcli.commands.build.exceptions import InvalidBuildDirException, MissingBuildMethodException +from samcli.commands.exceptions import UserException +from samcli.lib.build.app_builder import ( + BuildError, + UnsupportedBuilderLibraryVersionError, + BuildInsideContainerError, + ContainerBuildNotSupported, +) +from samcli.lib.build.workflow_config import UnsupportedRuntimeException +from samcli.local.lambdafn.exceptions import FunctionNotFound + + +class DeepWrap(Exception): + pass class TestBuildContext__enter__(TestCase): @@ -56,6 +70,7 @@ def test_must_setup_context( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, aws_region="any_aws_region", ) setup_build_dir_mock = Mock() @@ -134,6 +149,7 @@ def test_must_fail_with_illegal_identifier( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -187,6 +203,7 @@ def test_must_return_only_layer_when_layer_is_build( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -242,6 +259,7 @@ def test_must_return_buildable_dependent_layer_when_function_is_build( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -297,6 +315,7 @@ def test_must_fail_when_layer_is_build_without_buildmethod( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -365,6 +384,7 @@ def test_must_return_many_functions_to_build( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) setup_build_dir_mock = Mock() build_dir_result = setup_build_dir_mock.return_value = "my/new/build/dir" @@ -430,6 +450,7 @@ def test_must_print_remote_url_warning( mode="buildmode", cached=False, cache_dir="cache_dir", + parallel=True, ) context._setup_build_dir = Mock() @@ -566,6 +587,288 @@ def test_when_build_dir_is_cwd_raises_exception(self, pathlib_patch, os_patch, s pathlib_patch.Path.cwd.assert_called_once() +class TestBuildContext_run(TestCase): + @patch("samcli.commands.build.build_context.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.build.build_context.SamFunctionProvider") + @patch("samcli.commands.build.build_context.SamLayerProvider") + @patch("samcli.commands.build.build_context.pathlib") + @patch("samcli.commands.build.build_context.ContainerManager") + @patch("samcli.commands.build.build_context.BuildContext._setup_build_dir") + @patch("samcli.commands.build.build_context.ApplicationBuilder") + @patch("samcli.commands.build.build_context.BuildContext.get_resources_to_build") + @patch("samcli.commands.build.build_context.move_template") + @patch("samcli.commands.build.build_context.os") + def test_run_build_context( + self, + os_mock, + move_template_mock, + resources_mock, + ApplicationBuilderMock, + build_dir_mock, + ContainerManagerMock, + pathlib_mock, + SamLayerProviderMock, + SamFunctionProviderMock, + get_buildable_stacks_mock, + ): + + root_stack = Mock() + root_stack.is_root_stack = True + root_stack.get_output_template_path = Mock(return_value="./build_dir/template.yaml") + child_stack = Mock() + child_stack.get_output_template_path = Mock(return_value="./build_dir/abcd/template.yaml") + stack_output_template_path_by_stack_path = { + root_stack.stack_path: "./build_dir/template.yaml", + child_stack.stack_path: "./build_dir/abcd/template.yaml", + } + resources_mock.return_value = Mock() + + builder_mock = ApplicationBuilderMock.return_value = Mock() + artifacts = builder_mock.build.return_value = "artifacts" + modified_template_root = "modified template 1" + modified_template_child = "modified template 2" + builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] + + get_buildable_stacks_mock.return_value = ([root_stack, child_stack], []) + layer1 = DummyLayer("layer1", "python3.8") + layer_provider_mock = Mock() + layer_provider_mock.get.return_value = layer1 + layerprovider = SamLayerProviderMock.return_value = layer_provider_mock + func1 = DummyFunction("func1", [layer1]) + func_provider_mock = Mock() + func_provider_mock.get.return_value = func1 + funcprovider = SamFunctionProviderMock.return_value = func_provider_mock + base_dir = pathlib_mock.Path.return_value.resolve.return_value.parent = "basedir" + container_mgr_mock = ContainerManagerMock.return_value = Mock() + build_dir_mock.return_value = "build_dir" + + with BuildContext( + resource_identifier="function_identifier", + template_file="template_file", + base_dir="base_dir", + build_dir="build_dir", + cache_dir="cache_dir", + cached=False, + clean="clean", + use_container=False, + parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", + container_env_var={}, + container_env_var_file=None, + build_images={}, + ) as build_context: + build_context.run() + + ApplicationBuilderMock.assert_called_once_with( + ANY, + build_context.build_dir, + build_context.base_dir, + build_context.cache_dir, + build_context.cached, + build_context.is_building_specific_resource, + manifest_path_override=build_context.manifest_path_override, + container_manager=build_context.container_manager, + mode=build_context.mode, + parallel=build_context._parallel, + container_env_var=build_context._container_env_var, + container_env_var_file=build_context._container_env_var_file, + build_images=build_context._build_images, + ) + builder_mock.build.assert_called_once() + builder_mock.update_template.assert_has_calls( + [ + call( + root_stack, + artifacts, + stack_output_template_path_by_stack_path, + ) + ], + [ + call( + child_stack, + artifacts, + stack_output_template_path_by_stack_path, + ) + ], + ) + move_template_mock.assert_has_calls( + [ + call( + root_stack.location, + stack_output_template_path_by_stack_path[root_stack.stack_path], + modified_template_root, + ), + call( + child_stack.location, + stack_output_template_path_by_stack_path[child_stack.stack_path], + modified_template_child, + ), + ] + ) + + @parameterized.expand( + [ + (UnsupportedRuntimeException(), "UnsupportedRuntimeException"), + (BuildInsideContainerError(), "BuildInsideContainerError"), + (BuildError(wrapped_from=DeepWrap().__class__.__name__, msg="Test"), "DeepWrap"), + (ContainerBuildNotSupported(), "ContainerBuildNotSupported"), + ( + UnsupportedBuilderLibraryVersionError(container_name="name", error_msg="msg"), + "UnsupportedBuilderLibraryVersionError", + ), + ] + ) + @patch("samcli.commands.build.build_context.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.build.build_context.SamFunctionProvider") + @patch("samcli.commands.build.build_context.SamLayerProvider") + @patch("samcli.commands.build.build_context.pathlib") + @patch("samcli.commands.build.build_context.ContainerManager") + @patch("samcli.commands.build.build_context.BuildContext._setup_build_dir") + @patch("samcli.commands.build.build_context.ApplicationBuilder") + @patch("samcli.commands.build.build_context.BuildContext.get_resources_to_build") + @patch("samcli.commands.build.build_context.move_template") + @patch("samcli.commands.build.build_context.os") + def test_must_catch_known_exceptions( + self, + exception, + wrapped_exception, + os_mock, + move_template_mock, + resources_mock, + ApplicationBuilderMock, + build_dir_mock, + ContainerManagerMock, + pathlib_mock, + SamLayerProviderMock, + SamFunctionProviderMock, + get_buildable_stacks_mock, + ): + + stack = Mock() + resources_mock.return_value = Mock() + + builder_mock = ApplicationBuilderMock.return_value = Mock() + artifacts = builder_mock.build.return_value = "artifacts" + modified_template_root = "modified template 1" + modified_template_child = "modified template 2" + builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] + + get_buildable_stacks_mock.return_value = ([stack], []) + layer1 = DummyLayer("layer1", "python3.8") + layer_provider_mock = Mock() + layer_provider_mock.get.return_value = layer1 + layerprovider = SamLayerProviderMock.return_value = layer_provider_mock + func1 = DummyFunction("func1", [layer1]) + func_provider_mock = Mock() + func_provider_mock.get.return_value = func1 + funcprovider = SamFunctionProviderMock.return_value = func_provider_mock + base_dir = pathlib_mock.Path.return_value.resolve.return_value.parent = "basedir" + container_mgr_mock = ContainerManagerMock.return_value = Mock() + build_dir_mock.return_value = "build_dir" + + builder_mock.build.side_effect = exception + + with self.assertRaises(UserException) as ctx: + with BuildContext( + resource_identifier="function_identifier", + template_file="template_file", + base_dir="base_dir", + build_dir="build_dir", + cache_dir="cache_dir", + cached=False, + clean="clean", + use_container=False, + parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", + container_env_var={}, + container_env_var_file=None, + build_images={}, + ) as build_context: + build_context.run() + + self.assertEqual(str(ctx.exception), str(exception)) + self.assertEqual(wrapped_exception, ctx.exception.wrapped_from) + + @patch("samcli.commands.build.build_context.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.build.build_context.SamFunctionProvider") + @patch("samcli.commands.build.build_context.SamLayerProvider") + @patch("samcli.commands.build.build_context.pathlib") + @patch("samcli.commands.build.build_context.ContainerManager") + @patch("samcli.commands.build.build_context.BuildContext._setup_build_dir") + @patch("samcli.commands.build.build_context.ApplicationBuilder") + @patch("samcli.commands.build.build_context.BuildContext.get_resources_to_build") + @patch("samcli.commands.build.build_context.move_template") + @patch("samcli.commands.build.build_context.os") + def test_must_catch_function_not_found_exception( + self, + os_mock, + move_template_mock, + resources_mock, + ApplicationBuilderMock, + build_dir_mock, + ContainerManagerMock, + pathlib_mock, + SamLayerProviderMock, + SamFunctionProviderMock, + get_buildable_stacks_mock, + ): + stack = Mock() + resources_mock.return_value = Mock() + + builder_mock = ApplicationBuilderMock.return_value = Mock() + artifacts = builder_mock.build.return_value = "artifacts" + modified_template_root = "modified template 1" + modified_template_child = "modified template 2" + builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] + + get_buildable_stacks_mock.return_value = ([stack], []) + layer1 = DummyLayer("layer1", "python3.8") + layer_provider_mock = Mock() + layer_provider_mock.get.return_value = layer1 + layerprovider = SamLayerProviderMock.return_value = layer_provider_mock + func1 = DummyFunction("func1", [layer1]) + func_provider_mock = Mock() + func_provider_mock.get.return_value = func1 + funcprovider = SamFunctionProviderMock.return_value = func_provider_mock + base_dir = pathlib_mock.Path.return_value.resolve.return_value.parent = "basedir" + container_mgr_mock = ContainerManagerMock.return_value = Mock() + build_dir_mock.return_value = "build_dir" + + ApplicationBuilderMock.side_effect = FunctionNotFound("Function Not Found") + + with self.assertRaises(UserException) as ctx: + with BuildContext( + resource_identifier="function_identifier", + template_file="template_file", + base_dir="base_dir", + build_dir="build_dir", + cache_dir="cache_dir", + cached=False, + clean="clean", + use_container=False, + parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", + container_env_var={}, + container_env_var_file=None, + build_images={}, + ) as build_context: + build_context.run() + + self.assertEqual(str(ctx.exception), "Function Not Found") + + class DummyLayer: def __init__(self, name, build_method, codeuri="layer_src"): self.name = name diff --git a/tests/unit/commands/buildcmd/test_command.py b/tests/unit/commands/buildcmd/test_command.py index 3d6f296d0a..3cc894d03c 100644 --- a/tests/unit/commands/buildcmd/test_command.py +++ b/tests/unit/commands/buildcmd/test_command.py @@ -2,53 +2,20 @@ import click from unittest import TestCase -from unittest.mock import Mock, patch, call +from unittest.mock import Mock, patch from parameterized import parameterized from samcli.commands.build.command import do_cli, _get_mode_value_from_envvar, _process_env_var, _process_image_options -from samcli.commands.exceptions import UserException -from samcli.lib.build.app_builder import ( - BuildError, - UnsupportedBuilderLibraryVersionError, - BuildInsideContainerError, - ContainerBuildNotSupported, -) -from samcli.lib.build.workflow_config import UnsupportedRuntimeException -from samcli.local.lambdafn.exceptions import FunctionNotFound - - -class DeepWrap(Exception): - pass class TestDoCli(TestCase): + @patch("samcli.commands.build.command.click") @patch("samcli.commands.build.build_context.BuildContext") - @patch("samcli.lib.build.app_builder.ApplicationBuilder") - @patch("samcli.commands._utils.template.move_template") @patch("samcli.commands.build.command.os") - def test_must_succeed_build(self, os_mock, move_template_mock, ApplicationBuilderMock, BuildContextMock): + def test_must_succeed_build(self, os_mock, BuildContextMock, mock_build_click): ctx_mock = Mock() - - # create stack mocks - root_stack = Mock() - root_stack.is_root_stack = True - root_stack.get_output_template_path = Mock(return_value="./build_dir/template.yaml") - child_stack = Mock() - child_stack.get_output_template_path = Mock(return_value="./build_dir/abcd/template.yaml") - ctx_mock.stacks = [root_stack, child_stack] - stack_output_template_path_by_stack_path = { - root_stack.stack_path: "./build_dir/template.yaml", - child_stack.stack_path: "./build_dir/abcd/template.yaml", - } - - BuildContextMock.return_value.__enter__ = Mock() BuildContextMock.return_value.__enter__.return_value = ctx_mock - builder_mock = ApplicationBuilderMock.return_value = Mock() - artifacts = builder_mock.build.return_value = "artifacts" - modified_template_root = "modified template 1" - modified_template_child = "modified template 2" - builder_mock.update_template.side_effect = [modified_template_root, modified_template_child] do_cli( ctx_mock, @@ -63,7 +30,7 @@ def test_must_succeed_build(self, os_mock, move_template_mock, ApplicationBuilde "parallel", "manifest_path", "docker_network", - "skip_pull", + "skip_pull_image", "parameter_overrides", "mode", (""), @@ -71,132 +38,28 @@ def test_must_succeed_build(self, os_mock, move_template_mock, ApplicationBuilde (), ) - ApplicationBuilderMock.assert_called_once_with( - ctx_mock.resources_to_build, - ctx_mock.build_dir, - ctx_mock.base_dir, - ctx_mock.cache_dir, - ctx_mock.cached, - ctx_mock.is_building_specific_resource, - manifest_path_override=ctx_mock.manifest_path_override, - container_manager=ctx_mock.container_manager, - mode=ctx_mock.mode, + BuildContextMock.assert_called_with( + "function_identifier", + "template", + "base_dir", + "build_dir", + "cache_dir", + "cached", + clean="clean", + use_container="use_container", parallel="parallel", + parameter_overrides="parameter_overrides", + manifest_path="manifest_path", + docker_network="docker_network", + skip_pull_image="skip_pull_image", + mode="mode", container_env_var={}, container_env_var_file="container_env_var_file", build_images={}, + aws_region=ctx_mock.region, ) - builder_mock.build.assert_called_once() - builder_mock.update_template.assert_has_calls( - [ - call( - root_stack, - artifacts, - stack_output_template_path_by_stack_path, - ) - ], - [ - call( - child_stack, - artifacts, - stack_output_template_path_by_stack_path, - ) - ], - ) - move_template_mock.assert_has_calls( - [ - call( - root_stack.location, - stack_output_template_path_by_stack_path[root_stack.stack_path], - modified_template_root, - ), - call( - child_stack.location, - stack_output_template_path_by_stack_path[child_stack.stack_path], - modified_template_child, - ), - ] - ) - - @parameterized.expand( - [ - (UnsupportedRuntimeException(), "UnsupportedRuntimeException"), - (BuildInsideContainerError(), "BuildInsideContainerError"), - (BuildError(wrapped_from=DeepWrap().__class__.__name__, msg="Test"), "DeepWrap"), - (ContainerBuildNotSupported(), "ContainerBuildNotSupported"), - ( - UnsupportedBuilderLibraryVersionError(container_name="name", error_msg="msg"), - "UnsupportedBuilderLibraryVersionError", - ), - ] - ) - @patch("samcli.commands.build.build_context.BuildContext") - @patch("samcli.lib.build.app_builder.ApplicationBuilder") - def test_must_catch_known_exceptions(self, exception, wrapped_exception, ApplicationBuilderMock, BuildContextMock): - - ctx_mock = Mock() - BuildContextMock.return_value.__enter__ = Mock() - BuildContextMock.return_value.__enter__.return_value = ctx_mock - builder_mock = ApplicationBuilderMock.return_value = Mock() - - builder_mock.build.side_effect = exception - - with self.assertRaises(UserException) as ctx: - do_cli( - ctx_mock, - "function_identifier", - "template", - "base_dir", - "build_dir", - "cache_dir", - "clean", - "use_container", - "cached", - "parallel", - "manifest_path", - "docker_network", - "skip_pull", - "parameteroverrides", - "mode", - (""), - "container_env_var_file", - (), - ) - - self.assertEqual(str(ctx.exception), str(exception)) - self.assertEqual(wrapped_exception, ctx.exception.wrapped_from) - - @patch("samcli.commands.build.build_context.BuildContext") - @patch("samcli.lib.build.app_builder.ApplicationBuilder") - def test_must_catch_function_not_found_exception(self, ApplicationBuilderMock, BuildContextMock): - ctx_mock = Mock() - BuildContextMock.return_value.__enter__ = Mock() - BuildContextMock.return_value.__enter__.return_value = ctx_mock - ApplicationBuilderMock.side_effect = FunctionNotFound("Function Not Found") - - with self.assertRaises(UserException) as ctx: - do_cli( - ctx_mock, - "function_identifier", - "template", - "base_dir", - "build_dir", - "cache_dir", - "clean", - "use_container", - "cached", - "parallel", - "manifest_path", - "docker_network", - "skip_pull", - "parameteroverrides", - "mode", - (""), - "container_env_var_file", - (), - ) - - self.assertEqual(str(ctx.exception), "Function Not Found") + ctx_mock.run.assert_called_with() + self.assertEqual(ctx_mock.run.call_count, 1) class TestGetModeValueFromEnvvar(TestCase): diff --git a/tests/unit/commands/deploy/test_command.py b/tests/unit/commands/deploy/test_command.py index 46ac917e06..289d89b9f0 100644 --- a/tests/unit/commands/deploy/test_command.py +++ b/tests/unit/commands/deploy/test_command.py @@ -47,6 +47,7 @@ def setUp(self): self.config_env = "mock-default-env" self.config_file = "mock-default-filename" self.signing_profiles = None + self.use_changeset = True self.resolve_image_repos = False MOCK_SAM_CONFIG.reset_mock() @@ -121,6 +122,7 @@ def test_all_args(self, mock_deploy_context, mock_deploy_click, mock_package_con profile=self.profile, confirm_changeset=self.confirm_changeset, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, ) context_mock.run.assert_called_with() @@ -323,6 +325,7 @@ def test_all_args_guided( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, ) context_mock.run.assert_called_with() @@ -466,6 +469,7 @@ def test_all_args_guided_no_save_echo_param_to_config( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, ) context_mock.run.assert_called_with() @@ -613,6 +617,7 @@ def test_all_args_guided_no_params_save_config( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, ) context_mock.run.assert_called_with() @@ -745,6 +750,7 @@ def test_all_args_guided_no_params_no_save_config( profile=self.profile, confirm_changeset=True, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, ) context_mock.run.assert_called_with() @@ -815,6 +821,7 @@ def test_all_args_resolve_s3( profile=self.profile, confirm_changeset=self.confirm_changeset, signing_profiles=self.signing_profiles, + use_changeset=self.use_changeset, ) context_mock.run.assert_called_with() @@ -922,6 +929,7 @@ def test_all_args_resolve_image_repos( profile=self.profile, confirm_changeset=self.confirm_changeset, signing_profiles=self.signing_profiles, + use_changeset=True, ) context_mock.run.assert_called_with() diff --git a/tests/unit/commands/deploy/test_deploy_context.py b/tests/unit/commands/deploy/test_deploy_context.py index fdc2b49d7f..2bddf0f14a 100644 --- a/tests/unit/commands/deploy/test_deploy_context.py +++ b/tests/unit/commands/deploy/test_deploy_context.py @@ -31,6 +31,7 @@ def setUp(self): profile=None, confirm_changeset=False, signing_profiles=None, + use_changeset=True, ) def test_template_improper(self): @@ -152,3 +153,62 @@ def test_template_valid_execute_changeset_with_parameters( patched_get_buildable_stacks.assert_called_once_with( ANY, parameter_overrides={"a": "b"}, global_parameter_overrides={"AWS::Region": "any-aws-region"} ) + + @patch("boto3.Session") + @patch("samcli.commands.deploy.deploy_context.auth_per_resource") + @patch("samcli.commands.deploy.deploy_context.SamLocalStackProvider.get_stacks") + @patch.object(Deployer, "sync", MagicMock()) + def test_sync(self, patched_get_buildable_stacks, patched_auth_required, patched_boto): + sync_context = DeployContext( + template_file="template-file", + stack_name="stack-name", + s3_bucket="s3-bucket", + image_repository="image-repo", + image_repositories=None, + force_upload=True, + no_progressbar=False, + s3_prefix="s3-prefix", + kms_key_id="kms-key-id", + parameter_overrides={"a": "b"}, + capabilities="CAPABILITY_IAM", + no_execute_changeset=False, + role_arn="role-arn", + notification_arns=[], + fail_on_empty_changeset=False, + tags={"a": "b"}, + region=None, + profile=None, + confirm_changeset=False, + signing_profiles=None, + use_changeset=False, + ) + patched_get_buildable_stacks.return_value = (Mock(), []) + patched_auth_required.return_value = [("HelloWorldFunction", False)] + with tempfile.NamedTemporaryFile(delete=False) as template_file: + template_file.write(b'{"Parameters": {"a":"b","c":"d"}}') + template_file.flush() + sync_context.template_file = template_file.name + sync_context.run() + + self.assertEqual(sync_context.deployer.sync.call_count, 1) + print(sync_context.deployer.sync.call_args[1]) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["stack_name"], + "stack-name", + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["capabilities"], + "CAPABILITY_IAM", + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["cfn_template"], + '{"Parameters": {"a":"b","c":"d"}}', + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["notification_arns"], + [], + ) + self.assertEqual( + sync_context.deployer.sync.call_args[1]["role_arn"], + "role-arn", + ) diff --git a/tests/unit/commands/local/lib/test_provider.py b/tests/unit/commands/local/lib/test_provider.py index e84f9bc176..ee316e708c 100644 --- a/tests/unit/commands/local/lib/test_provider.py +++ b/tests/unit/commands/local/lib/test_provider.py @@ -1,11 +1,20 @@ import os from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch from parameterized import parameterized -from samcli.lib.providers.provider import LayerVersion, Stack, _get_build_dir +from samcli.lib.providers.provider import ( + LayerVersion, + ResourceIdentifier, + Stack, + _get_build_dir, + get_all_resource_ids, + get_resource_by_id, + get_resource_ids_by_type, + get_unique_resource_ids, +) from samcli.commands.local.cli_common.user_exceptions import InvalidLayerVersionArn, UnsupportedIntrinsic @@ -93,3 +102,256 @@ def test_layer_version_raises_unsupported_intrinsic(self): with self.assertRaises(UnsupportedIntrinsic): LayerVersion(intrinsic_arn, ".") + + +class TestResourceIdentifier(TestCase): + @parameterized.expand( + [ + ("Function1", "", "Function1"), + ("NestedStack1/Function1", "NestedStack1", "Function1"), + ("NestedStack1/NestedNestedStack2/Function1", "NestedStack1/NestedNestedStack2", "Function1"), + ("", "", ""), + ] + ) + def test_parser(self, resource_identifier_string, stack_path, logical_id): + resource_identifier = ResourceIdentifier(resource_identifier_string) + self.assertEqual(resource_identifier.stack_path, stack_path) + self.assertEqual(resource_identifier.logical_id, logical_id) + + @parameterized.expand( + [ + ("Function1", "Function1", True), + ("NestedStack1/Function1", "NestedStack1/Function1", True), + ("NestedStack1/NestedNestedStack2/Function1", "NestedStack1/NestedNestedStack2/Function2", False), + ("NestedStack1/NestedNestedStack3/Function1", "NestedStack1/NestedNestedStack2/Function1", False), + ("", "", True), + ] + ) + def test_equal(self, resource_identifier_string_1, resource_identifier_string_2, equal): + resource_identifier_1 = ResourceIdentifier(resource_identifier_string_1) + resource_identifier_2 = ResourceIdentifier(resource_identifier_string_2) + self.assertEqual(resource_identifier_1 == resource_identifier_2, equal) + + @parameterized.expand( + [ + ("Function1"), + ("NestedStack1/Function1"), + ("NestedStack1/NestedNestedStack2/Function1"), + ] + ) + def test_hash(self, resource_identifier_string): + resource_identifier_1 = ResourceIdentifier(resource_identifier_string) + resource_identifier_2 = ResourceIdentifier(resource_identifier_string) + self.assertEqual(hash(resource_identifier_1), hash(resource_identifier_2)) + + @parameterized.expand( + [ + ("Function1"), + ("NestedStack1/Function1"), + ("NestedStack1/NestedNestedStack2/Function1"), + (""), + ] + ) + def test_str(self, resource_identifier_string): + resource_identifier = ResourceIdentifier(resource_identifier_string) + self.assertEqual(str(resource_identifier), resource_identifier_string) + + +class TestGetResourceByID(TestCase): + def setUp(self) -> None: + super().setUp() + self.root_stack = MagicMock() + self.root_stack.stack_path = "" + self.root_stack.resources = {"Function1": "Body1"} + + self.nested_stack = MagicMock() + self.nested_stack.stack_path = "NestedStack1" + self.nested_stack.resources = {"Function1": "Body2"} + + self.nested_nested_stack = MagicMock() + self.nested_nested_stack.stack_path = "NestedStack1/NestedNestedStack1" + self.nested_nested_stack.resources = {"Function2": "Body3"} + + def test_get_resource_by_id_explicit_root( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, True + ) + self.assertEqual(result, self.root_stack.resources["Function1"]) + + def test_get_resource_by_id_explicit_nested( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "NestedStack1" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, True + ) + self.assertEqual(result, self.nested_stack.resources["Function1"]) + + def test_get_resource_by_id_explicit_nested_nested( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "NestedStack1/NestedNestedStack1" + resource_identifier.logical_id = "Function2" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, True + ) + self.assertEqual(result, self.nested_nested_stack.resources["Function2"]) + + def test_get_resource_by_id_implicit_root( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, self.root_stack.resources["Function1"]) + + def test_get_resource_by_id_implicit_nested( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "" + resource_identifier.logical_id = "Function2" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, self.nested_nested_stack.resources["Function2"]) + + def test_get_resource_by_id_implicit_with_stack_path( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.stack_path = "NestedStack1" + resource_identifier.logical_id = "Function1" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, self.nested_stack.resources["Function1"]) + + def test_get_resource_by_id_not_found( + self, + ): + + resource_identifier = MagicMock() + resource_identifier.logical_id = "Function3" + + result = get_resource_by_id( + [self.root_stack, self.nested_stack, self.nested_nested_stack], resource_identifier, False + ) + self.assertEqual(result, None) + + +class TestGetResourceIDsByType(TestCase): + def setUp(self) -> None: + super().setUp() + self.root_stack = MagicMock() + self.root_stack.stack_path = "" + self.root_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_stack = MagicMock() + self.nested_stack.stack_path = "NestedStack1" + self.nested_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_nested_stack = MagicMock() + self.nested_nested_stack.stack_path = "NestedStack1/NestedNestedStack1" + self.nested_nested_stack.resources = {"Function2": {"Type": "TypeB"}} + + def test_get_resource_ids_by_type_single_nested( + self, + ): + result = get_resource_ids_by_type([self.root_stack, self.nested_stack, self.nested_nested_stack], "TypeB") + self.assertEqual(result, [ResourceIdentifier("NestedStack1/NestedNestedStack1/Function2")]) + + def test_get_resource_ids_by_type_multiple_nested( + self, + ): + result = get_resource_ids_by_type([self.root_stack, self.nested_stack, self.nested_nested_stack], "TypeA") + self.assertEqual(result, [ResourceIdentifier("Function1"), ResourceIdentifier("NestedStack1/Function1")]) + + +class TestGetAllResourceIDs(TestCase): + def setUp(self) -> None: + super().setUp() + self.root_stack = MagicMock() + self.root_stack.stack_path = "" + self.root_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_stack = MagicMock() + self.nested_stack.stack_path = "NestedStack1" + self.nested_stack.resources = {"Function1": {"Type": "TypeA"}} + + self.nested_nested_stack = MagicMock() + self.nested_nested_stack.stack_path = "NestedStack1/NestedNestedStack1" + self.nested_nested_stack.resources = {"Function2": {"Type": "TypeB"}} + + def test_get_all_resource_ids( + self, + ): + result = get_all_resource_ids([self.root_stack, self.nested_stack, self.nested_nested_stack]) + self.assertEqual( + result, + [ + ResourceIdentifier("Function1"), + ResourceIdentifier("NestedStack1/Function1"), + ResourceIdentifier("NestedStack1/NestedNestedStack1/Function2"), + ], + ) + + +class TestGetUniqueResourceIDs(TestCase): + def setUp(self) -> None: + super().setUp() + self.stacks = MagicMock() + + @patch("samcli.lib.providers.provider.get_resource_ids_by_type") + def test_only_resource_ids(self, get_resource_ids_by_type_mock): + resource_ids = ["Function1", "Function2"] + resource_types = [] + get_resource_ids_by_type_mock.return_value = {} + result = get_unique_resource_ids(self.stacks, resource_ids, resource_types) + get_resource_ids_by_type_mock.assert_not_called() + self.assertEqual(result, {ResourceIdentifier("Function1"), ResourceIdentifier("Function2")}) + + @patch("samcli.lib.providers.provider.get_resource_ids_by_type") + def test_only_resource_types(self, get_resource_ids_by_type_mock): + resource_ids = [] + resource_types = ["Type1", "Type2"] + get_resource_ids_by_type_mock.return_value = {ResourceIdentifier("Function1"), ResourceIdentifier("Function2")} + result = get_unique_resource_ids(self.stacks, resource_ids, resource_types) + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type1") + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type2") + self.assertEqual(result, {ResourceIdentifier("Function1"), ResourceIdentifier("Function2")}) + + @patch("samcli.lib.providers.provider.get_resource_ids_by_type") + def test_duplicates(self, get_resource_ids_by_type_mock): + resource_ids = ["Function1", "Function2"] + resource_types = ["Type1", "Type2"] + get_resource_ids_by_type_mock.return_value = {ResourceIdentifier("Function2"), ResourceIdentifier("Function3")} + result = get_unique_resource_ids(self.stacks, resource_ids, resource_types) + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type1") + get_resource_ids_by_type_mock.assert_any_call(self.stacks, "Type2") + self.assertEqual( + result, {ResourceIdentifier("Function1"), ResourceIdentifier("Function2"), ResourceIdentifier("Function3")} + ) diff --git a/tests/unit/commands/local/lib/test_stack_provider.py b/tests/unit/commands/local/lib/test_stack_provider.py index 64315df7c0..faf778eb4c 100644 --- a/tests/unit/commands/local/lib/test_stack_provider.py +++ b/tests/unit/commands/local/lib/test_stack_provider.py @@ -6,7 +6,7 @@ from parameterized import parameterized -from samcli.commands._utils.resources import AWS_SERVERLESS_APPLICATION, AWS_CLOUDFORMATION_STACK +from samcli.lib.utils.resources import AWS_SERVERLESS_APPLICATION, AWS_CLOUDFORMATION_STACK from samcli.lib.providers.provider import Stack from samcli.lib.providers.sam_stack_provider import SamLocalStackProvider diff --git a/tests/unit/commands/logs/test_command.py b/tests/unit/commands/logs/test_command.py index 3a48600ae0..79b132e9a9 100644 --- a/tests/unit/commands/logs/test_command.py +++ b/tests/unit/commands/logs/test_command.py @@ -1,5 +1,7 @@ from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, call, ANY + +from parameterized import parameterized from samcli.commands.logs.command import do_cli @@ -12,47 +14,109 @@ def setUp(self): self.filter_pattern = "filter" self.start_time = "start" self.end_time = "end" + self.output_dir = "output_dir" + self.region = "region" + + @parameterized.expand( + [ + ( + True, + False, + [], + ), + ( + False, + False, + [], + ), + ( + True, + False, + ["cw_log_group"], + ), + ( + False, + False, + ["cw_log_group", "cw_log_group2"], + ), + ] + ) + @patch("samcli.commands.logs.puller_factory.generate_puller") + @patch("samcli.commands.logs.logs_context.ResourcePhysicalIdResolver") + @patch("samcli.commands.logs.logs_context.parse_time") + @patch("samcli.lib.utils.boto_utils.get_boto_client_provider_with_config") + @patch("samcli.lib.utils.boto_utils.get_boto_resource_provider_with_config") + def test_logs_command( + self, + tailing, + include_tracing, + cw_log_group, + patched_boto_resource_provider, + patched_boto_client_provider, + patched_parse_time, + patched_resource_physical_id_resolver, + patched_generate_puller, + ): + mocked_start_time = Mock() + mocked_end_time = Mock() + patched_parse_time.side_effect = [mocked_start_time, mocked_end_time] + + mocked_resource_physical_id_resolver = Mock() + mocked_resource_information = Mock() + mocked_resource_physical_id_resolver.get_resource_information.return_value = mocked_resource_information + patched_resource_physical_id_resolver.return_value = mocked_resource_physical_id_resolver - @patch("samcli.commands.logs.logs_context.LogsCommandContext") - def test_without_tail(self, logs_command_context_mock): - tailing = False + mocked_puller = Mock() + patched_generate_puller.return_value = mocked_puller - context_mock = Mock() - logs_command_context_mock.return_value.__enter__.return_value = context_mock + mocked_client_provider = Mock() + patched_boto_client_provider.return_value = mocked_client_provider - do_cli(self.function_name, self.stack_name, self.filter_pattern, tailing, self.start_time, self.end_time) + mocked_resource_provider = Mock() + patched_boto_resource_provider.return_value = mocked_resource_provider - logs_command_context_mock.assert_called_with( + do_cli( self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, + self.stack_name, + self.filter_pattern, + tailing, + include_tracing, + self.start_time, + self.end_time, + cw_log_group, + self.output_dir, + self.region, ) - context_mock.fetcher.load_time_period.assert_called_with( - filter_pattern=context_mock.filter_pattern, - start_time=context_mock.start_time, - end_time=context_mock.end_time, + patched_parse_time.assert_has_calls( + [ + call(self.start_time, "start-time"), + call(self.end_time, "end-time"), + ] ) - @patch("samcli.commands.logs.logs_context.LogsCommandContext") - def test_with_tailing(self, logs_command_context_mock): - tailing = True + patched_boto_client_provider.assert_called_with(region_name=self.region) + patched_boto_resource_provider.assert_called_with(region_name=self.region) - context_mock = Mock() - logs_command_context_mock.return_value.__enter__.return_value = context_mock + patched_resource_physical_id_resolver.assert_called_with( + mocked_resource_provider, self.stack_name, self.function_name + ) - do_cli(self.function_name, self.stack_name, self.filter_pattern, tailing, self.start_time, self.end_time) + fetch_param = not bool(len(cw_log_group)) + mocked_resource_physical_id_resolver.assert_has_calls([call.get_resource_information(fetch_param)]) - logs_command_context_mock.assert_called_with( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, + patched_generate_puller.assert_called_with( + mocked_client_provider, + mocked_resource_information, + self.filter_pattern, + cw_log_group, + self.output_dir, + False, ) - context_mock.fetcher.tail.assert_called_with( - filter_pattern=context_mock.filter_pattern, start_time=context_mock.start_time - ) + if tailing: + mocked_puller.assert_has_calls([call.tail(mocked_start_time, self.filter_pattern)]) + else: + mocked_puller.assert_has_calls( + [call.load_time_period(mocked_start_time, mocked_end_time, self.filter_pattern)] + ) diff --git a/tests/unit/commands/logs/test_console_consumers.py b/tests/unit/commands/logs/test_console_consumers.py index ab824ca769..bfb4a6ba13 100644 --- a/tests/unit/commands/logs/test_console_consumers.py +++ b/tests/unit/commands/logs/test_console_consumers.py @@ -1,15 +1,31 @@ from unittest import TestCase from unittest.mock import patch, Mock +from parameterized import parameterized + from samcli.commands.logs.console_consumers import CWConsoleEventConsumer class TestCWConsoleEventConsumer(TestCase): - def setUp(self): - self.consumer = CWConsoleEventConsumer() + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + @patch("samcli.commands.logs.console_consumers.click") + def test_consumer_with_event(self, add_newline, patched_click): + consumer = CWConsoleEventConsumer(add_newline) + event = Mock() + consumer.consume(event) + + expected_new_line_param = add_newline if add_newline is not None else True + patched_click.echo.assert_called_with(event.message, nl=expected_new_line_param) @patch("samcli.commands.logs.console_consumers.click") - def test_consume_with_event(self, patched_click): + def test_default_consumer_with_event(self, patched_click): + consumer = CWConsoleEventConsumer() event = Mock() - self.consumer.consume(event) + consumer.consume(event) + patched_click.echo.assert_called_with(event.message, nl=False) diff --git a/tests/unit/commands/logs/test_logs_context.py b/tests/unit/commands/logs/test_logs_context.py index abcd792b27..050ae9ef91 100644 --- a/tests/unit/commands/logs/test_logs_context.py +++ b/tests/unit/commands/logs/test_logs_context.py @@ -1,11 +1,14 @@ -from unittest import TestCase -from unittest.mock import Mock, patch, ANY - -import botocore.session -from botocore.stub import Stubber +from unittest import TestCase, mock +from unittest.mock import Mock, patch from samcli.commands.exceptions import UserException -from samcli.commands.logs.logs_context import LogsCommandContext +from samcli.commands.logs.logs_context import parse_time, ResourcePhysicalIdResolver +from samcli.lib.utils.cloudformation import CloudFormationResourceSummary + +AWS_SOME_RESOURCE = "AWS::Some::Resource" +AWS_LAMBDA_FUNCTION = "AWS::Lambda::Function" +AWS_APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" +AWS_APIGATEWAY_HTTPAPI = "AWS::ApiGatewayV2::Api" class TestLogsCommandContext(TestCase): @@ -17,214 +20,110 @@ def setUp(self): self.end_time = "end" self.output_file = "somefile" - self.context = LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=self.output_file, - ) - - def test_basic_properties(self): - self.assertEqual(self.context.filter_pattern, self.filter_pattern) - self.assertIsNone(self.context.output_file_handle) # before setting context handle will be null - - @patch("samcli.commands.logs.logs_context.Colored") - def test_colored_property(self, ColoredMock): - ColoredMock.return_value = Mock() - - self.assertEqual(self.context.colored, ColoredMock.return_value) - ColoredMock.assert_called_with(colorize=False) - - @patch("samcli.commands.logs.logs_context.Colored") - def test_colored_property_without_output_file(self, ColoredMock): - ColoredMock.return_value = Mock() - - # No output file. It means we are printing to Terminal. Hence set the color - ctx = LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=None, - ) - - self.assertEqual(ctx.colored, ColoredMock.return_value) - ColoredMock.assert_called_with(colorize=True) # Must enable colors - - @patch("samcli.commands.logs.logs_context.LogGroupProvider") - @patch.object(LogsCommandContext, "_get_resource_id_from_stack") - def test_log_group_name_property_with_stack_name(self, get_resource_id_mock, LogGroupProviderMock): - logical_id = "someid" - group = "groupname" - - LogGroupProviderMock.for_lambda_function.return_value = group - get_resource_id_mock.return_value = logical_id - - self.assertEqual(self.context.log_group_name, group) - - LogGroupProviderMock.for_lambda_function.assert_called_with(logical_id) - get_resource_id_mock.assert_called_with(ANY, self.stack_name, self.function_name) - - @patch("samcli.commands.logs.logs_context.LogGroupProvider") - @patch.object(LogsCommandContext, "_get_resource_id_from_stack") - def test_log_group_name_property_without_stack_name(self, get_resource_id_mock, LogGroupProviderMock): - group = "groupname" - - LogGroupProviderMock.for_lambda_function.return_value = group - - ctx = LogsCommandContext( - self.function_name, - stack_name=None, # No Stack Name - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=self.output_file, - ) - - self.assertEqual(ctx.log_group_name, group) - - LogGroupProviderMock.for_lambda_function.assert_called_with(self.function_name) - get_resource_id_mock.assert_not_called() - - def test_start_time_property(self): - self.context._parse_time = Mock() - self.context._parse_time.return_value = "foo" - - self.assertEqual(self.context.start_time, "foo") - - def test_end_time_property(self): - self.context._parse_time = Mock() - self.context._parse_time.return_value = "foo" - - self.assertEqual(self.context.end_time, "foo") - @patch("samcli.commands.logs.logs_context.parse_date") @patch("samcli.commands.logs.logs_context.to_utc") def test_parse_time(self, to_utc_mock, parse_date_mock): - input = "some time" + given_input = "some time" parsed_result = "parsed" expected = "bar" parse_date_mock.return_value = parsed_result to_utc_mock.return_value = expected - actual = LogsCommandContext._parse_time(input, "some prop") + actual = parse_time(given_input, "some prop") self.assertEqual(actual, expected) - parse_date_mock.assert_called_with(input) + parse_date_mock.assert_called_with(given_input) to_utc_mock.assert_called_with(parsed_result) @patch("samcli.commands.logs.logs_context.parse_date") def test_parse_time_raises_exception(self, parse_date_mock): - input = "some time" + given_input = "some time" parsed_result = None parse_date_mock.return_value = parsed_result with self.assertRaises(UserException) as ctx: - LogsCommandContext._parse_time(input, "some prop") + parse_time(given_input, "some prop") self.assertEqual(str(ctx.exception), "Unable to parse the time provided by 'some prop'") def test_parse_time_empty_time(self): - result = LogsCommandContext._parse_time(None, "some prop") + result = parse_time(None, "some prop") self.assertIsNone(result) - @patch("samcli.commands.logs.logs_context.open") - def test_setup_output_file(self, open_mock): - - open_mock.return_value = "handle" - result = LogsCommandContext._setup_output_file(self.output_file) - - self.assertEqual(result, "handle") - open_mock.assert_called_with(self.output_file, "wb") - - def test_setup_output_file_without_file(self): - self.assertIsNone(LogsCommandContext._setup_output_file(None)) - - @patch.object(LogsCommandContext, "_setup_output_file") - def test_context_manager_with_output_file(self, setup_output_file_mock): - handle = Mock() - setup_output_file_mock.return_value = handle - - with LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=self.output_file, - ) as context: - self.assertEqual(context._output_file_handle, handle) - - # Context should be reset - self.assertIsNone(self.context._output_file_handle) - handle.close.assert_called_with() - setup_output_file_mock.assert_called_with(self.output_file) - - @patch.object(LogsCommandContext, "_setup_output_file") - def test_context_manager_no_output_file(self, setup_output_file_mock): - setup_output_file_mock.return_value = None - - with LogsCommandContext( - self.function_name, - stack_name=self.stack_name, - filter_pattern=self.filter_pattern, - start_time=self.start_time, - end_time=self.end_time, - output_file=None, - ) as context: - self.assertEqual(context._output_file_handle, None) - - # Context should be reset - setup_output_file_mock.assert_called_with(None) - - -class TestLogsCommandContext_get_resource_id_from_stack(TestCase): - def setUp(self): - - self.real_client = botocore.session.get_session().create_client("cloudformation", region_name="us-east-1") - self.cfn_client_stubber = Stubber(self.real_client) - - self.logical_id = "name" - self.stack_name = "stackname" - self.physical_id = "myid" - - def test_must_get_from_cfn(self): - - expected_params = {"StackName": self.stack_name, "LogicalResourceId": self.logical_id} - - mock_response = { - "StackResourceDetail": { - "PhysicalResourceId": self.physical_id, - "LogicalResourceId": self.logical_id, - "ResourceType": "AWS::Lambda::Function", - "ResourceStatus": "UPDATE_COMPLETE", - "LastUpdatedTimestamp": "2017-07-28T23:34:13.435Z", - } - } - - self.cfn_client_stubber.add_response("describe_stack_resource", mock_response, expected_params) - - with self.cfn_client_stubber: - result = LogsCommandContext._get_resource_id_from_stack(self.real_client, self.stack_name, self.logical_id) - - self.assertEqual(result, self.physical_id) - - def test_must_handle_resource_not_found(self): - errmsg = "Something went wrong" - errcode = "SomeException" - - self.cfn_client_stubber.add_client_error( - "describe_stack_resource", service_error_code=errcode, service_message=errmsg - ) - expected_error_msg = "An error occurred ({}) when calling the DescribeStackResource operation: {}".format( - errcode, errmsg - ) - - with self.cfn_client_stubber: - with self.assertRaises(UserException) as context: - LogsCommandContext._get_resource_id_from_stack(self.real_client, self.stack_name, self.logical_id) - self.assertEqual(expected_error_msg, str(context.exception)) +class TestResourcePhysicalIdResolver(TestCase): + def test_get_resource_information_with_resources(self): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", ["resource_name"]) + with mock.patch( + "samcli.commands.logs.logs_context.ResourcePhysicalIdResolver._fetch_resources_from_stack" + ) as mocked_fetch: + expected_return = Mock() + mocked_fetch.return_value = expected_return + + actual_return = resource_physical_id_resolver.get_resource_information(False) + + mocked_fetch.assert_called_once() + self.assertEqual(actual_return, expected_return) + + def test_get_resource_information_of_all_stack(self): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", []) + with mock.patch( + "samcli.commands.logs.logs_context.ResourcePhysicalIdResolver._fetch_resources_from_stack" + ) as mocked_fetch: + expected_return = Mock() + mocked_fetch.return_value = expected_return + + actual_return = resource_physical_id_resolver.get_resource_information(True) + + mocked_fetch.assert_called_once() + self.assertEqual(actual_return, expected_return) + + def test_get_no_resource_information(self): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", None) + actual_return = resource_physical_id_resolver.get_resource_information(False) + self.assertEqual(actual_return, []) + + @patch("samcli.commands.logs.logs_context.get_resource_summaries") + def test_fetch_all_resources(self, patched_get_resources): + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", []) + mocked_return_value = [ + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_1", "physical_id_1"), + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_2", "physical_id_2"), + CloudFormationResourceSummary(AWS_APIGATEWAY_RESTAPI, "logical_id_3", "physical_id_3"), + CloudFormationResourceSummary(AWS_APIGATEWAY_HTTPAPI, "logical_id_4", "physical_id_4"), + ] + patched_get_resources.return_value = mocked_return_value + + actual_result = resource_physical_id_resolver._fetch_resources_from_stack() + self.assertEqual(len(actual_result), 4) + + expected_results = [ + item + for item in mocked_return_value + if item.resource_type in ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES + ] + self.assertEqual(expected_results, actual_result) + + @patch("samcli.commands.logs.logs_context.get_resource_summaries") + def test_fetch_given_resources(self, patched_get_resources): + given_resources = ["logical_id_1", "logical_id_2", "logical_id_3", "logical_id_5", "logical_id_6"] + resource_physical_id_resolver = ResourcePhysicalIdResolver(Mock(), "stack_name", given_resources) + mocked_return_value = [ + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_1", "physical_id_1"), + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_2", "physical_id_2"), + CloudFormationResourceSummary(AWS_LAMBDA_FUNCTION, "logical_id_3", "physical_id_3"), + CloudFormationResourceSummary(AWS_APIGATEWAY_RESTAPI, "logical_id_4", "physical_id_4"), + CloudFormationResourceSummary(AWS_APIGATEWAY_HTTPAPI, "logical_id_5", "physical_id_5"), + ] + patched_get_resources.return_value = mocked_return_value + + actual_result = resource_physical_id_resolver._fetch_resources_from_stack(set(given_resources)) + self.assertEqual(len(actual_result), 4) + + expected_results = [ + item + for item in mocked_return_value + if item.resource_type in ResourcePhysicalIdResolver.DEFAULT_SUPPORTED_RESOURCES + and item.logical_resource_id in given_resources + ] + self.assertEqual(expected_results, actual_result) diff --git a/tests/unit/commands/logs/test_puller_factory.py b/tests/unit/commands/logs/test_puller_factory.py new file mode 100644 index 0000000000..bf4f6dd143 --- /dev/null +++ b/tests/unit/commands/logs/test_puller_factory.py @@ -0,0 +1,258 @@ +from unittest import TestCase +from unittest.mock import Mock, patch, call, ANY + +from parameterized import parameterized + +from samcli.lib.utils.resources import AWS_LAMBDA_FUNCTION +from samcli.commands.logs.puller_factory import ( + generate_puller, + generate_unformatted_consumer, + generate_console_consumer, + NoPullerGeneratedException, + generate_consumer, +) + + +class TestPullerFactory(TestCase): + @parameterized.expand( + [ + (None, None, False), + ("filter_pattern", None, False), + ("filter_pattern", ["cw_log_groups"], False), + ("filter_pattern", ["cw_log_groups"], True), + (None, ["cw_log_groups"], True), + (None, None, True), + ] + ) + @patch("samcli.commands.logs.puller_factory.generate_console_consumer") + @patch("samcli.commands.logs.puller_factory.generate_unformatted_consumer") + @patch("samcli.commands.logs.puller_factory.CWLogPuller") + @patch("samcli.commands.logs.puller_factory.generate_trace_puller") + @patch("samcli.commands.logs.puller_factory.ObservabilityCombinedPuller") + def test_generate_puller( + self, + param_filter_pattern, + param_cw_log_groups, + param_unformatted, + patched_combined_puller, + patched_xray_puller, + patched_cw_log_puller, + patched_unformatted_consumer, + patched_console_consumer, + ): + mock_logs_client = Mock() + mock_xray_client = Mock() + + mock_client_provider = lambda client_name: mock_logs_client if client_name == "logs" else mock_xray_client + + mock_resource_info_list = [ + Mock(resource_type=AWS_LAMBDA_FUNCTION), + Mock(resource_type=AWS_LAMBDA_FUNCTION), + Mock(resource_type=AWS_LAMBDA_FUNCTION), + ] + + mocked_resource_consumers = [Mock() for _ in mock_resource_info_list] + mocked_cw_specific_consumers = [Mock() for _ in (param_cw_log_groups or [])] + mocked_consumers = mocked_resource_consumers + mocked_cw_specific_consumers + + # depending on the output_dir param patch file consumer or console consumer + if param_unformatted: + patched_unformatted_consumer.side_effect = mocked_consumers + else: + patched_console_consumer.side_effect = mocked_consumers + + mocked_xray_puller = Mock() + patched_xray_puller.return_value = mocked_xray_puller + mocked_pullers = [Mock() for _ in mocked_consumers] + mocked_pullers.append(mocked_xray_puller) # add a mock puller for xray puller + patched_cw_log_puller.side_effect = mocked_pullers + + mocked_combined_puller = Mock() + + patched_combined_puller.return_value = mocked_combined_puller + + puller = generate_puller( + mock_client_provider, + mock_resource_info_list, + param_filter_pattern, + param_cw_log_groups, + param_unformatted, + True, + ) + + self.assertEqual(puller, mocked_combined_puller) + + patched_xray_puller.assert_called_once_with(mock_xray_client, param_unformatted) + + patched_cw_log_puller.assert_has_calls( + [call(mock_logs_client, consumer, ANY, ANY) for consumer in mocked_resource_consumers] + ) + + patched_cw_log_puller.assert_has_calls( + [call(mock_logs_client, consumer, ANY) for consumer in mocked_cw_specific_consumers] + ) + + patched_combined_puller.assert_called_with(mocked_pullers) + + # depending on the output_dir param assert calls for file consumer or console consumer + if param_unformatted: + patched_unformatted_consumer.assert_has_calls([call() for _ in mocked_consumers]) + else: + patched_console_consumer.assert_has_calls([call(param_filter_pattern) for _ in mocked_consumers]) + + def test_puller_with_invalid_resource_type(self): + mock_logs_client = Mock() + mock_resource_information = Mock() + mock_resource_information.get_log_group_name.return_value = None + + with self.assertRaises(NoPullerGeneratedException): + generate_puller(mock_logs_client, [mock_resource_information]) + + @patch("samcli.commands.logs.puller_factory.generate_console_consumer") + @patch("samcli.commands.logs.puller_factory.CWLogPuller") + @patch("samcli.commands.logs.puller_factory.ObservabilityCombinedPuller") + def test_generate_puller_with_console_with_additional_cw_logs_groups( + self, patched_combined_puller, patched_cw_log_puller, patched_console_consumer + ): + mock_logs_client = Mock() + mock_logs_client_generator = lambda client: mock_logs_client + mock_cw_log_groups = [Mock(), Mock(), Mock()] + + mocked_consumers = [Mock() for _ in mock_cw_log_groups] + patched_console_consumer.side_effect = mocked_consumers + + mocked_pullers = [Mock() for _ in mock_cw_log_groups] + patched_cw_log_puller.side_effect = mocked_pullers + + mocked_combined_puller = Mock() + patched_combined_puller.return_value = mocked_combined_puller + + puller = generate_puller(mock_logs_client_generator, [], additional_cw_log_groups=mock_cw_log_groups) + + self.assertEqual(puller, mocked_combined_puller) + + patched_cw_log_puller.assert_has_calls([call(mock_logs_client, consumer, ANY) for consumer in mocked_consumers]) + + patched_combined_puller.assert_called_with(mocked_pullers) + + patched_console_consumer.assert_has_calls([call(None) for _ in mock_cw_log_groups]) + + @parameterized.expand( + [ + (False,), + (True,), + ] + ) + @patch("samcli.commands.logs.puller_factory.generate_unformatted_consumer") + @patch("samcli.commands.logs.puller_factory.generate_console_consumer") + def test_generate_consumer(self, param_unformatted, patched_console_consumer, patched_unformatted_consumer): + given_filter_pattern = Mock() + given_resource_name = Mock() + + given_console_consumer = Mock() + patched_console_consumer.return_value = given_console_consumer + given_file_consumer = Mock() + patched_unformatted_consumer.return_value = given_file_consumer + + actual_consumer = generate_consumer(given_filter_pattern, param_unformatted, given_resource_name) + + if param_unformatted: + patched_unformatted_consumer.assert_called_with() + self.assertEqual(actual_consumer, given_file_consumer) + else: + patched_console_consumer.assert_called_with(given_filter_pattern) + self.assertEqual(actual_consumer, given_console_consumer) + + @patch("samcli.commands.logs.puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.logs.puller_factory.CWLogEventJSONMapper") + @patch("samcli.commands.logs.puller_factory.CWConsoleEventConsumer") + def test_generate_unformatted_consumer( + self, + patched_event_consumer, + patched_json_formatter, + patched_decorated_consumer, + ): + expected_consumer = Mock() + patched_decorated_consumer.return_value = expected_consumer + + expected_event_consumer = Mock() + patched_event_consumer.return_value = expected_event_consumer + + expected_json_formatter = Mock() + patched_json_formatter.return_value = expected_json_formatter + + consumer = generate_unformatted_consumer() + + self.assertEqual(expected_consumer, consumer) + + patched_decorated_consumer.assert_called_with([expected_json_formatter], expected_event_consumer) + patched_event_consumer.assert_called_with(True) + patched_json_formatter.assert_called_once() + + @patch("samcli.commands.logs.puller_factory.Colored") + @patch("samcli.commands.logs.puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.logs.puller_factory.CWColorizeErrorsFormatter") + @patch("samcli.commands.logs.puller_factory.CWJsonFormatter") + @patch("samcli.commands.logs.puller_factory.CWKeywordHighlighterFormatter") + @patch("samcli.commands.logs.puller_factory.CWPrettyPrintFormatter") + @patch("samcli.commands.logs.puller_factory.CWAddNewLineIfItDoesntExist") + @patch("samcli.commands.logs.puller_factory.CWConsoleEventConsumer") + def test_generate_console_consumer( + self, + patched_event_consumer, + patched_new_line_mapper, + patched_pretty_formatter, + patched_highlighter, + patched_json_formatter, + patched_errors_formatter, + patched_decorated_consumer, + patched_colored, + ): + mock_filter_pattern = Mock() + + expected_colored = Mock() + patched_colored.return_value = expected_colored + + expected_errors_formatter = Mock() + patched_errors_formatter.return_value = expected_errors_formatter + + expected_json_formatter = Mock() + patched_json_formatter.return_value = expected_json_formatter + + expected_highlighter = Mock() + patched_highlighter.return_value = expected_highlighter + + expected_pretty_formatter = Mock() + patched_pretty_formatter.return_value = expected_pretty_formatter + + expected_new_line_mapper = Mock() + patched_new_line_mapper.return_value = expected_new_line_mapper + + expected_event_consumer = Mock() + patched_event_consumer.return_value = expected_event_consumer + + expected_consumer = Mock() + patched_decorated_consumer.return_value = expected_consumer + + consumer = generate_console_consumer(mock_filter_pattern) + + self.assertEqual(expected_consumer, consumer) + + patched_colored.assert_called_once() + patched_event_consumer.assert_called_once() + patched_new_line_mapper.assert_called_once() + patched_pretty_formatter.assert_called_with(expected_colored) + patched_highlighter.assert_called_with(expected_colored, mock_filter_pattern) + patched_json_formatter.assert_called_once() + patched_errors_formatter.assert_called_with(expected_colored) + + patched_decorated_consumer.assert_called_with( + [ + expected_errors_formatter, + expected_json_formatter, + expected_highlighter, + expected_pretty_formatter, + expected_new_line_mapper, + ], + expected_event_consumer, + ) diff --git a/tests/unit/commands/samconfig/test_samconfig.py b/tests/unit/commands/samconfig/test_samconfig.py index 1d943c169c..79d6ba246c 100644 --- a/tests/unit/commands/samconfig/test_samconfig.py +++ b/tests/unit/commands/samconfig/test_samconfig.py @@ -536,10 +536,14 @@ def test_package_with_image_repository_and_image_repositories( self.assertIsNotNone(result.exception) @patch("samcli.lib.cli_validation.image_repository_validation.get_template_artifacts_format") + @patch("samcli.commands._utils.template.get_template_artifacts_format") + @patch("samcli.commands._utils.options.get_template_artifacts_format") @patch("samcli.commands.deploy.command.do_cli") - def test_deploy(self, do_cli_mock, get_template_artifacts_format_mock): + def test_deploy(self, do_cli_mock, template_artifacts_mock1, template_artifacts_mock2, template_artifacts_mock3): - get_template_artifacts_format_mock.return_value = [ZIP] + template_artifacts_mock1.return_value = [ZIP] + template_artifacts_mock2.return_value = [ZIP] + template_artifacts_mock3.return_value = [ZIP] config_values = { "template_file": "mytemplate.yaml", "stack_name": "mystack", @@ -644,10 +648,16 @@ def test_deploy_image_repositories_and_image_repository(self, do_cli_mock): self.assertIsNotNone(result.exception) @patch("samcli.lib.cli_validation.image_repository_validation.get_template_artifacts_format") + @patch("samcli.commands._utils.options.get_template_artifacts_format") + @patch("samcli.commands._utils.template.get_template_artifacts_format") @patch("samcli.commands.deploy.command.do_cli") - def test_deploy_different_parameter_override_format(self, do_cli_mock, get_template_artifacts_format_mock): + def test_deploy_different_parameter_override_format( + self, do_cli_mock, template_artifacts_mock1, template_artifacts_mock2, template_artifacts_mock3 + ): - get_template_artifacts_format_mock.return_value = [ZIP] + template_artifacts_mock1.return_value = [ZIP] + template_artifacts_mock2.return_value = [ZIP] + template_artifacts_mock3.return_value = [ZIP] config_values = { "template_file": "mytemplate.yaml", @@ -719,12 +729,15 @@ def test_deploy_different_parameter_override_format(self, do_cli_mock, get_templ @patch("samcli.commands.logs.command.do_cli") def test_logs(self, do_cli_mock): config_values = { - "name": "myfunction", + "name": ["myfunction"], "stack_name": "mystack", "filter": "myfilter", "tail": True, + "include_traces": True, "start_time": "starttime", "end_time": "endtime", + "cw_log_group": ["cw_log_group"], + "region": "myregion", } with samconfig_parameters(["logs"], self.scratch_dir, **config_values) as config_path: @@ -740,7 +753,18 @@ def test_logs(self, do_cli_mock): LOG.exception("Command failed", exc_info=result.exc_info) self.assertIsNone(result.exception) - do_cli_mock.assert_called_with("myfunction", "mystack", "myfilter", True, "starttime", "endtime") + do_cli_mock.assert_called_with( + ("myfunction",), + "mystack", + "myfilter", + True, + True, + "starttime", + "endtime", + ("cw_log_group",), + False, + "myregion", + ) @patch("samcli.commands.publish.command.do_cli") def test_publish(self, do_cli_mock): diff --git a/tests/unit/commands/sync/__init__.py b/tests/unit/commands/sync/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/commands/sync/test_command.py b/tests/unit/commands/sync/test_command.py new file mode 100644 index 0000000000..f67931e2d6 --- /dev/null +++ b/tests/unit/commands/sync/test_command.py @@ -0,0 +1,556 @@ +from unittest import TestCase +from unittest.mock import ANY, MagicMock, Mock, patch +from parameterized import parameterized + +from samcli.commands.sync.command import do_cli, execute_code_sync, execute_watch +from samcli.lib.providers.provider import ResourceIdentifier +from samcli.commands._utils.options import DEFAULT_BUILD_DIR, DEFAULT_CACHE_DIR + + +def get_mock_sam_config(): + mock_sam_config = MagicMock() + mock_sam_config.exists = MagicMock(return_value=True) + return mock_sam_config + + +MOCK_SAM_CONFIG = get_mock_sam_config() + + +class TestDoCli(TestCase): + def setUp(self): + + self.template_file = "input-template-file" + self.stack_name = "stack-name" + self.resource_id = [] + self.resource = [] + self.image_repository = "123456789012.dkr.ecr.us-east-1.amazonaws.com/test1" + self.image_repositories = None + self.mode = "mode" + self.s3_prefix = "s3-prefix" + self.kms_key_id = "kms-key-id" + self.notification_arns = [] + self.parameter_overrides = {"a": "b"} + self.capabilities = ("CAPABILITY_IAM",) + self.tags = {"c": "d"} + self.role_arn = "role_arn" + self.metadata = {} + self.region = None + self.profile = None + self.base_dir = None + self.clean = True + self.config_env = "mock-default-env" + self.config_file = "mock-default-filename" + MOCK_SAM_CONFIG.reset_mock() + + @parameterized.expand([(True, False, False), (False, False, False)]) + @patch("samcli.commands.sync.command.execute_code_sync") + @patch("samcli.commands.build.command.click") + @patch("samcli.commands.build.build_context.BuildContext") + @patch("samcli.commands.package.command.click") + @patch("samcli.commands.package.package_context.PackageContext") + @patch("samcli.commands.deploy.command.click") + @patch("samcli.commands.deploy.deploy_context.DeployContext") + @patch("samcli.commands.build.command.os") + @patch("samcli.commands.sync.command.manage_stack") + def test_infra_must_succeed_sync( + self, + infra, + code, + watch, + manage_stack_mock, + os_mock, + DeployContextMock, + mock_deploy_click, + PackageContextMock, + mock_package_click, + BuildContextMock, + mock_build_click, + execute_code_sync_mock, + ): + + build_context_mock = Mock() + BuildContextMock.return_value.__enter__.return_value = build_context_mock + package_context_mock = Mock() + PackageContextMock.return_value.__enter__.return_value = package_context_mock + deploy_context_mock = Mock() + DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + + do_cli( + self.template_file, + infra, + code, + watch, + self.resource_id, + self.resource, + self.stack_name, + self.region, + self.profile, + self.base_dir, + self.parameter_overrides, + self.mode, + self.image_repository, + self.image_repositories, + self.s3_prefix, + self.kms_key_id, + self.capabilities, + self.role_arn, + self.notification_arns, + self.tags, + self.metadata, + self.config_file, + self.config_env, + ) + + BuildContextMock.assert_called_with( + resource_identifier=None, + template_file=self.template_file, + base_dir=self.base_dir, + build_dir=DEFAULT_BUILD_DIR, + cache_dir=DEFAULT_CACHE_DIR, + clean=True, + use_container=False, + parallel=True, + parameter_overrides=self.parameter_overrides, + mode=self.mode, + cached=True, + ) + + PackageContextMock.assert_called_with( + template_file=ANY, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + output_template_file=ANY, + no_progressbar=True, + metadata=self.metadata, + region=self.region, + profile=self.profile, + use_json=False, + force_upload=True, + ) + + DeployContextMock.assert_called_with( + template_file=ANY, + stack_name=self.stack_name, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + no_progressbar=True, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + parameter_overrides=self.parameter_overrides, + capabilities=self.capabilities, + role_arn=self.role_arn, + notification_arns=self.notification_arns, + tags=self.tags, + region=self.region, + profile=self.profile, + no_execute_changeset=True, + fail_on_empty_changeset=True, + confirm_changeset=False, + use_changeset=False, + force_upload=True, + signing_profiles=None, + ) + package_context_mock.run.assert_called_once_with() + deploy_context_mock.run.assert_called_once_with() + execute_code_sync_mock.assert_not_called() + + @parameterized.expand([(False, False, True)]) + @patch("samcli.commands.sync.command.execute_watch") + @patch("samcli.commands.build.command.click") + @patch("samcli.commands.build.build_context.BuildContext") + @patch("samcli.commands.package.command.click") + @patch("samcli.commands.package.package_context.PackageContext") + @patch("samcli.commands.deploy.command.click") + @patch("samcli.commands.deploy.deploy_context.DeployContext") + @patch("samcli.commands.build.command.os") + @patch("samcli.commands.sync.command.manage_stack") + def test_watch_must_succeed_sync( + self, + infra, + code, + watch, + manage_stack_mock, + os_mock, + DeployContextMock, + mock_deploy_click, + PackageContextMock, + mock_package_click, + BuildContextMock, + mock_build_click, + execute_watch_mock, + ): + + build_context_mock = Mock() + BuildContextMock.return_value.__enter__.return_value = build_context_mock + package_context_mock = Mock() + PackageContextMock.return_value.__enter__.return_value = package_context_mock + deploy_context_mock = Mock() + DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + + do_cli( + self.template_file, + infra, + code, + watch, + self.resource_id, + self.resource, + self.stack_name, + self.region, + self.profile, + self.base_dir, + self.parameter_overrides, + self.mode, + self.image_repository, + self.image_repositories, + self.s3_prefix, + self.kms_key_id, + self.capabilities, + self.role_arn, + self.notification_arns, + self.tags, + self.metadata, + self.config_file, + self.config_env, + ) + + BuildContextMock.assert_called_with( + resource_identifier=None, + template_file=self.template_file, + base_dir=self.base_dir, + build_dir=DEFAULT_BUILD_DIR, + cache_dir=DEFAULT_CACHE_DIR, + clean=True, + use_container=False, + parallel=True, + parameter_overrides=self.parameter_overrides, + mode=self.mode, + cached=True, + ) + + PackageContextMock.assert_called_with( + template_file=ANY, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + output_template_file=ANY, + no_progressbar=True, + metadata=self.metadata, + region=self.region, + profile=self.profile, + use_json=False, + force_upload=True, + ) + + DeployContextMock.assert_called_with( + template_file=ANY, + stack_name=self.stack_name, + s3_bucket=ANY, + image_repository=self.image_repository, + image_repositories=self.image_repositories, + no_progressbar=True, + s3_prefix=self.s3_prefix, + kms_key_id=self.kms_key_id, + parameter_overrides=self.parameter_overrides, + capabilities=self.capabilities, + role_arn=self.role_arn, + notification_arns=self.notification_arns, + tags=self.tags, + region=self.region, + profile=self.profile, + no_execute_changeset=True, + fail_on_empty_changeset=True, + confirm_changeset=False, + use_changeset=False, + force_upload=True, + signing_profiles=None, + ) + execute_watch_mock.assert_called_once_with( + self.template_file, build_context_mock, package_context_mock, deploy_context_mock + ) + + @parameterized.expand([(False, True, False)]) + @patch("samcli.commands.sync.command.execute_code_sync") + @patch("samcli.commands.build.command.click") + @patch("samcli.commands.build.build_context.BuildContext") + @patch("samcli.commands.package.command.click") + @patch("samcli.commands.package.package_context.PackageContext") + @patch("samcli.commands.deploy.command.click") + @patch("samcli.commands.deploy.deploy_context.DeployContext") + @patch("samcli.commands.build.command.os") + @patch("samcli.commands.sync.command.manage_stack") + def test_code_must_succeed_sync( + self, + infra, + code, + watch, + manage_stack_mock, + os_mock, + DeployContextMock, + mock_deploy_click, + PackageContextMock, + mock_package_click, + BuildContextMock, + mock_build_click, + execute_code_sync_mock, + ): + + build_context_mock = Mock() + BuildContextMock.return_value.__enter__.return_value = build_context_mock + package_context_mock = Mock() + PackageContextMock.return_value.__enter__.return_value = package_context_mock + deploy_context_mock = Mock() + DeployContextMock.return_value.__enter__.return_value = deploy_context_mock + + do_cli( + self.template_file, + infra, + code, + watch, + self.resource_id, + self.resource, + self.stack_name, + self.region, + self.profile, + self.base_dir, + self.parameter_overrides, + self.mode, + self.image_repository, + self.image_repositories, + self.s3_prefix, + self.kms_key_id, + self.capabilities, + self.role_arn, + self.notification_arns, + self.tags, + self.metadata, + self.config_file, + self.config_env, + ) + execute_code_sync_mock.assert_called_once_with( + self.template_file, build_context_mock, deploy_context_mock, self.resource_id, self.resource + ) + + +class TestSyncCode(TestCase): + def setUp(self) -> None: + self.template_file = "template.yaml" + self.build_context = MagicMock() + self.deploy_context = MagicMock() + + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_single_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + ): + + resource_identifier_strings = ["Function1"] + resource_types = [] + sync_flows = [MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + } + + execute_code_sync( + self.template_file, self.build_context, self.deploy_context, resource_identifier_strings, resource_types + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_called_once_with(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_called_once_with(sync_flows[0]) + + get_unique_resource_ids_mock.assert_called_once_with( + get_stacks_mock.return_value[0], resource_identifier_strings, [] + ) + + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_multiple_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + ): + + resource_identifier_strings = ["Function1", "Function2"] + resource_types = [] + sync_flows = [MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + } + + execute_code_sync( + self.template_file, self.build_context, self.deploy_context, resource_identifier_strings, resource_types + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 2) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 2) + + get_unique_resource_ids_mock.assert_called_once_with( + get_stacks_mock.return_value[0], resource_identifier_strings, [] + ) + + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_single_type_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + ): + + resource_identifier_strings = ["Function1", "Function2"] + resource_types = ["Type1"] + sync_flows = [MagicMock(), MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + ResourceIdentifier("Function3"), + } + execute_code_sync( + self.template_file, self.build_context, self.deploy_context, resource_identifier_strings, resource_types + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function3")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[2]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 3) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 3) + + get_unique_resource_ids_mock.assert_called_once_with( + get_stacks_mock.return_value[0], resource_identifier_strings, ["Type1"] + ) + + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_unique_resource_ids") + def test_execute_code_sync_multiple_type_resource( + self, + get_unique_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + ): + resource_identifier_strings = ["Function1", "Function2"] + resource_types = ["Type1", "Type2"] + sync_flows = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_unique_resource_ids_mock.return_value = { + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + ResourceIdentifier("Function3"), + ResourceIdentifier("Function4"), + } + execute_code_sync( + self.template_file, self.build_context, self.deploy_context, resource_identifier_strings, resource_types + ) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function3")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[2]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function4")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[3]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 4) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 4) + + get_unique_resource_ids_mock.assert_any_call( + get_stacks_mock.return_value[0], resource_identifier_strings, ["Type1", "Type2"] + ) + + @patch("samcli.commands.sync.command.SamLocalStackProvider.get_stacks") + @patch("samcli.commands.sync.command.SyncFlowFactory") + @patch("samcli.commands.sync.command.SyncFlowExecutor") + @patch("samcli.commands.sync.command.get_all_resource_ids") + def test_execute_code_sync_default_all_resources( + self, + get_all_resource_ids_mock, + sync_flow_executor_mock, + sync_flow_factory_mock, + get_stacks_mock, + ): + sync_flows = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + sync_flow_factory_mock.return_value.create_sync_flow.side_effect = sync_flows + get_all_resource_ids_mock.return_value = [ + ResourceIdentifier("Function1"), + ResourceIdentifier("Function2"), + ResourceIdentifier("Function3"), + ResourceIdentifier("Function4"), + ] + execute_code_sync(self.template_file, self.build_context, self.deploy_context, "", []) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function1")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[0]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function2")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[1]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function3")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[2]) + + sync_flow_factory_mock.return_value.create_sync_flow.assert_any_call(ResourceIdentifier("Function4")) + sync_flow_executor_mock.return_value.add_sync_flow.assert_any_call(sync_flows[3]) + + self.assertEqual(sync_flow_factory_mock.return_value.create_sync_flow.call_count, 4) + self.assertEqual(sync_flow_executor_mock.return_value.add_sync_flow.call_count, 4) + + get_all_resource_ids_mock.assert_called_once_with(get_stacks_mock.return_value[0]) + + +class TestWatch(TestCase): + def setUp(self) -> None: + self.template_file = "template.yaml" + self.build_context = MagicMock() + self.package_context = MagicMock() + self.deploy_context = MagicMock() + + @patch("samcli.commands.sync.command.WatchManager") + def test_execute_watch( + self, + watch_manager_mock, + ): + execute_watch(self.template_file, self.build_context, self.package_context, self.deploy_context) + + watch_manager_mock.assert_called_once_with( + self.template_file, self.build_context, self.package_context, self.deploy_context + ) + watch_manager_mock.return_value.start.assert_called_once_with() diff --git a/tests/unit/commands/traces/test_command.py b/tests/unit/commands/traces/test_command.py new file mode 100644 index 0000000000..69d457eaa3 --- /dev/null +++ b/tests/unit/commands/traces/test_command.py @@ -0,0 +1,69 @@ +from unittest import TestCase +from unittest.mock import patch, call, Mock + +from parameterized import parameterized + +from samcli.commands.traces.command import do_cli + + +class TestTracesCommand(TestCase): + def setUp(self): + self.region = "region" + + @parameterized.expand( + [ + (None, None, None, False, None), + (["trace_id1", "trace_id2"], None, None, False, None), + (None, "start_time", None, False, None), + (None, "start_time", "end_time", False, None), + (None, None, None, True, None), + (None, None, None, True, "output_dir"), + ] + ) + @patch("samcli.commands.logs.logs_context.parse_time") + @patch("samcli.lib.utils.boto_utils.get_boto_config_with_user_agent") + @patch("boto3.client") + @patch("samcli.commands.traces.traces_puller_factory.generate_trace_puller") + def test_traces_command( + self, + trace_ids, + start_time, + end_time, + tail, + output_dir, + patched_generate_puller, + patched_boto3, + patched_get_boto_config_with_user_agent, + patched_parse_time, + ): + given_start_time = Mock() + given_end_time = Mock() + patched_parse_time.side_effect = [given_start_time, given_end_time] + + given_boto_config = Mock() + patched_get_boto_config_with_user_agent.return_value = given_boto_config + + given_xray_client = Mock() + patched_boto3.return_value = given_xray_client + + given_puller = Mock() + patched_generate_puller.return_value = given_puller + + do_cli(trace_ids, start_time, end_time, tail, output_dir, self.region) + + patched_parse_time.assert_has_calls( + [ + call(start_time, "start-time"), + call(end_time, "end-time"), + ] + ) + patched_get_boto_config_with_user_agent.assert_called_with(region_name=self.region) + patched_boto3.assert_called_with("xray", config=given_boto_config) + patched_generate_puller.assert_called_with(given_xray_client, output_dir) + + if trace_ids: + given_puller.load_events.assert_called_with(trace_ids) + elif tail: + given_puller.tail.assert_called_with(given_start_time) + else: + given_puller.load_time_period.assert_called_with(given_start_time, given_end_time) diff --git a/tests/unit/commands/traces/test_trace_console_consumers.py b/tests/unit/commands/traces/test_trace_console_consumers.py new file mode 100644 index 0000000000..cb98885239 --- /dev/null +++ b/tests/unit/commands/traces/test_trace_console_consumers.py @@ -0,0 +1,14 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +from samcli.commands.traces.trace_console_consumers import XRayTraceConsoleConsumer + + +class TestTraceConsoleConsumers(TestCase): + @patch("samcli.commands.traces.trace_console_consumers.click") + def test_console_consumer(self, patched_click): + event = Mock() + consumer = XRayTraceConsoleConsumer() + consumer.consume(event) + + patched_click.echo.assert_called_with(event.message) diff --git a/tests/unit/commands/traces/test_traces_puller_factory.py b/tests/unit/commands/traces/test_traces_puller_factory.py new file mode 100644 index 0000000000..10c0ad6c26 --- /dev/null +++ b/tests/unit/commands/traces/test_traces_puller_factory.py @@ -0,0 +1,87 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +from parameterized import parameterized + +from samcli.commands.traces.traces_puller_factory import ( + generate_trace_puller, + generate_unformatted_xray_event_consumer, + generate_xray_event_console_consumer, +) + + +class TestGenerateTracePuller(TestCase): + @parameterized.expand( + [ + (False,), + (True,), + ] + ) + @patch("samcli.commands.traces.traces_puller_factory.generate_xray_event_console_consumer") + @patch("samcli.commands.traces.traces_puller_factory.generate_unformatted_xray_event_consumer") + @patch("samcli.commands.traces.traces_puller_factory.XRayTracePuller") + @patch("samcli.commands.traces.traces_puller_factory.XRayServiceGraphPuller") + @patch("samcli.commands.traces.traces_puller_factory.ObservabilityCombinedPuller") + def test_generate_trace_puller( + self, + unformatted, + patched_combine_puller, + patched_xray_service_graph_puller, + patched_xray_trace_puller, + patched_generate_unformatted_consumer, + patched_generate_console_consumer, + ): + given_xray_client = Mock() + given_xray_trace_puller = Mock() + given_xray_service_graph_puller = Mock() + given_combine_puller = Mock() + patched_xray_trace_puller.return_value = given_xray_trace_puller + patched_xray_service_graph_puller.return_value = given_xray_service_graph_puller + patched_combine_puller.return_value = given_combine_puller + + given_console_consumer = Mock() + patched_generate_console_consumer.return_value = given_console_consumer + + given_file_consumer = Mock() + patched_generate_unformatted_consumer.return_value = given_file_consumer + + actual_puller = generate_trace_puller(given_xray_client, unformatted) + self.assertEqual(given_combine_puller, actual_puller) + + if unformatted: + patched_generate_unformatted_consumer.assert_called_with() + patched_xray_trace_puller.assert_called_with(given_xray_client, given_file_consumer) + else: + patched_generate_console_consumer.assert_called_once() + patched_xray_trace_puller.assert_called_with(given_xray_client, given_console_consumer) + + @patch("samcli.commands.traces.traces_puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceJSONMapper") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceConsoleConsumer") + def test_generate_file_consumer(self, patched_consumer, patched_trace_json_mapper, patched_consumer_decorator): + given_consumer = Mock() + patched_consumer_decorator.return_value = given_consumer + + actual_consumer = generate_unformatted_xray_event_consumer() + self.assertEqual(given_consumer, actual_consumer) + + patched_trace_json_mapper.assert_called_once() + patched_consumer.assert_called_with() + + @patch("samcli.commands.traces.traces_puller_factory.ObservabilityEventConsumerDecorator") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceConsoleMapper") + @patch("samcli.commands.traces.traces_puller_factory.XRayTraceConsoleConsumer") + def test_generate_console_consumer( + self, + patched_console_consumer, + patched_console_mapper, + patched_consumer_decorator, + ): + given_consumer = Mock() + patched_consumer_decorator.return_value = given_consumer + + actual_consumer = generate_xray_event_console_consumer() + self.assertEqual(given_consumer, actual_consumer) + + patched_console_mapper.assert_called_once() + patched_console_consumer.assert_called_once() diff --git a/tests/unit/lib/build_module/test_app_builder.py b/tests/unit/lib/build_module/test_app_builder.py index 8645a7c6c5..299e3f3ad5 100644 --- a/tests/unit/lib/build_module/test_app_builder.py +++ b/tests/unit/lib/build_module/test_app_builder.py @@ -75,7 +75,14 @@ def test_must_iterate_on_functions_and_layers(self, persist_mock): build_layer_mock = Mock() def build_layer_return( - layer_name, layer_codeuri, layer_build_method, layer_compatible_runtimes, artifact_dir, layer_env_vars + layer_name, + layer_codeuri, + layer_build_method, + layer_compatible_runtimes, + artifact_dir, + layer_env_vars, + dependencies_dir, + download_dependencies, ): return f"{layer_name}_location" @@ -116,6 +123,8 @@ def build_layer_return( ANY, self.func1.metadata, ANY, + ANY, + True, ), call( self.func2.name, @@ -126,6 +135,8 @@ def build_layer_return( ANY, self.func2.metadata, ANY, + ANY, + True, ), call( self.imageFunc1.name, @@ -136,6 +147,8 @@ def build_layer_return( ANY, self.imageFunc1.metadata, ANY, + ANY, + True, ), ], any_order=False, @@ -150,6 +163,8 @@ def build_layer_return( self.layer1.compatible_runtimes, ANY, ANY, + ANY, + True, ), call( self.layer2.name, @@ -158,6 +173,8 @@ def build_layer_return( self.layer2.compatible_runtimes, ANY, ANY, + ANY, + True, ), ] ) @@ -165,10 +182,10 @@ def build_layer_return( @patch("samcli.lib.build.build_graph.BuildGraph._write") def test_should_use_function_or_layer_get_build_dir_to_determine_artifact_dir(self, persist_mock): def get_func_call_with_artifact_dir(artifact_dir): - return call(ANY, ANY, ANY, ANY, ANY, artifact_dir, ANY, ANY) + return call(ANY, ANY, ANY, ANY, ANY, artifact_dir, ANY, ANY, ANY, True) def get_layer_call_with_artifact_dir(artifact_dir): - return call(ANY, ANY, ANY, ANY, artifact_dir, ANY) + return call(ANY, ANY, ANY, ANY, artifact_dir, ANY, ANY, True) build_function_mock = Mock() build_layer_mock = Mock() @@ -262,6 +279,8 @@ def test_should_run_build_for_only_unique_builds(self, persist_mock, read_mock, ANY, function1_1.metadata, ANY, + ANY, + True, ), call( function2.name, @@ -272,6 +291,8 @@ def test_should_run_build_for_only_unique_builds(self, persist_mock, read_mock, ANY, function2.metadata, ANY, + ANY, + True, ), ], any_order=True, @@ -293,7 +314,7 @@ def test_default_run_should_pick_default_strategy(self, mock_default_build_strat mock_default_build_strategy.build.assert_called_once() self.assertEqual(result, mock_default_build_strategy.build()) - @patch("samcli.lib.build.app_builder.CachedBuildStrategy") + @patch("samcli.lib.build.app_builder.CachedOrIncrementalBuildStrategyWrapper") def test_cached_run_should_pick_cached_strategy(self, mock_cached_build_strategy_class): mock_cached_build_strategy = Mock() mock_cached_build_strategy_class.return_value = mock_cached_build_strategy @@ -326,7 +347,7 @@ def test_parallel_run_should_pick_parallel_strategy(self, mock_parallel_build_st self.assertEqual(result, mock_parallel_build_strategy.build()) @patch("samcli.lib.build.app_builder.ParallelBuildStrategy") - @patch("samcli.lib.build.app_builder.CachedBuildStrategy") + @patch("samcli.lib.build.app_builder.CachedOrIncrementalBuildStrategyWrapper") def test_parallel_and_cached_run_should_pick_parallel_with_cached_strategy( self, mock_cached_build_strategy_class, mock_parallel_build_strategy_class ): @@ -391,6 +412,8 @@ def test_must_build_layer_in_process(self, get_layer_subfolder_mock, osutils_moc PathValidator("manifest_name"), "python3.8", None, + None, + True, ) @patch("samcli.lib.build.app_builder.get_workflow_config") @@ -880,7 +903,7 @@ def test_must_build_in_process(self, osutils_mock, get_workflow_config_mock): self.builder._build_function(function_name, codeuri, ZIP, runtime, handler, artifacts_dir) self.builder._build_function_in_process.assert_called_with( - config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, None + config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, None, None, True ) @patch("samcli.lib.build.app_builder.get_workflow_config") @@ -913,7 +936,7 @@ def test_must_build_in_process_with_metadata(self, osutils_mock, get_workflow_co ) self.builder._build_function_in_process.assert_called_with( - config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, None + config_mock, code_dir, artifacts_dir, scratch_dir, manifest_path, runtime, None, None, True ) @patch("samcli.lib.build.app_builder.get_workflow_config") @@ -1056,7 +1079,7 @@ def test_must_use_lambda_builder(self, lambda_builder_mock): builder_instance_mock = lambda_builder_mock.return_value = Mock() result = self.builder._build_function_in_process( - config_mock, "source_dir", "artifacts_dir", "scratch_dir", "manifest_path", "runtime", None + config_mock, "source_dir", "artifacts_dir", "scratch_dir", "manifest_path", "runtime", None, None, True ) self.assertEqual(result, "artifacts_dir") @@ -1075,6 +1098,9 @@ def test_must_use_lambda_builder(self, lambda_builder_mock): executable_search_paths=config_mock.executable_search_paths, mode="mode", options=None, + # todo: put the two checks back after app builder release + # dependencies_dir=None, + # download_dependencies=True, ) @patch("samcli.lib.build.app_builder.LambdaBuilder") @@ -1086,7 +1112,7 @@ def test_must_raise_on_error(self, lambda_builder_mock): with self.assertRaises(BuildError): self.builder._build_function_in_process( - config_mock, "source_dir", "artifacts_dir", "scratch_dir", "manifest_path", "runtime", None + config_mock, "source_dir", "artifacts_dir", "scratch_dir", "manifest_path", "runtime", None, None, True ) diff --git a/tests/unit/lib/build_module/test_build_graph.py b/tests/unit/lib/build_module/test_build_graph.py index 7b326beea1..67064648c4 100644 --- a/tests/unit/lib/build_module/test_build_graph.py +++ b/tests/unit/lib/build_module/test_build_graph.py @@ -1,9 +1,11 @@ from unittest import TestCase +from unittest.mock import patch from uuid import uuid4 from pathlib import Path import tomlkit from parameterized import parameterized +from typing import Dict, cast from samcli.lib.build.build_graph import ( FunctionBuildDefinition, @@ -25,6 +27,8 @@ BuildGraph, InvalidBuildGraphException, LayerBuildDefinition, + MANIFEST_MD5_FIELD, + BuildHashingInformation, ) from samcli.lib.providers.provider import Function, LayerVersion from samcli.lib.utils import osutils @@ -94,7 +98,7 @@ def generate_layer( class TestConversionFunctions(TestCase): def test_function_build_definition_to_toml_table(self): build_definition = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", env_vars={"env_vars": "value1"} + "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", "manifest_md5", env_vars={"env_vars": "value1"} ) build_definition.add_function(generate_function()) @@ -106,10 +110,13 @@ def test_function_build_definition_to_toml_table(self): self.assertEqual(toml_table[METADATA_FIELD], build_definition.metadata) self.assertEqual(toml_table[FUNCTIONS_FIELD], [f.name for f in build_definition.functions]) self.assertEqual(toml_table[SOURCE_MD5_FIELD], build_definition.source_md5) + self.assertEqual(toml_table[MANIFEST_MD5_FIELD], build_definition.manifest_md5) self.assertEqual(toml_table[ENV_VARS_FIELD], build_definition.env_vars) def test_layer_build_definition_to_toml_table(self): - build_definition = LayerBuildDefinition("name", "codeuri", "method", "runtime", env_vars={"env_vars": "value"}) + build_definition = LayerBuildDefinition( + "name", "codeuri", "method", ["runtime"], "source_md5", "manifest_md5", env_vars={"env_vars": "value"} + ) build_definition.layer = generate_function() toml_table = _layer_build_definition_to_toml_table(build_definition) @@ -120,6 +127,7 @@ def test_layer_build_definition_to_toml_table(self): self.assertEqual(toml_table[COMPATIBLE_RUNTIMES_FIELD], build_definition.compatible_runtimes) self.assertEqual(toml_table[LAYER_FIELD], build_definition.layer.name) self.assertEqual(toml_table[SOURCE_MD5_FIELD], build_definition.source_md5) + self.assertEqual(toml_table[MANIFEST_MD5_FIELD], build_definition.manifest_md5) self.assertEqual(toml_table[ENV_VARS_FIELD], build_definition.env_vars) def test_toml_table_to_function_build_definition(self): @@ -130,6 +138,7 @@ def test_toml_table_to_function_build_definition(self): toml_table[METADATA_FIELD] = {"key": "value"} toml_table[FUNCTIONS_FIELD] = ["function1"] toml_table[SOURCE_MD5_FIELD] = "source_md5" + toml_table[MANIFEST_MD5_FIELD] = "manifest_md5" toml_table[ENV_VARS_FIELD] = {"env_vars": "value"} uuid = str(uuid4()) @@ -142,6 +151,7 @@ def test_toml_table_to_function_build_definition(self): self.assertEqual(build_definition.uuid, uuid) self.assertEqual(build_definition.functions, []) self.assertEqual(build_definition.source_md5, toml_table[SOURCE_MD5_FIELD]) + self.assertEqual(build_definition.manifest_md5, toml_table[MANIFEST_MD5_FIELD]) self.assertEqual(build_definition.env_vars, toml_table[ENV_VARS_FIELD]) def test_toml_table_to_layer_build_definition(self): @@ -152,6 +162,7 @@ def test_toml_table_to_layer_build_definition(self): toml_table[COMPATIBLE_RUNTIMES_FIELD] = "runtime" toml_table[COMPATIBLE_RUNTIMES_FIELD] = "layer1" toml_table[SOURCE_MD5_FIELD] = "source_md5" + toml_table[MANIFEST_MD5_FIELD] = "manifest_md5" toml_table[ENV_VARS_FIELD] = {"env_vars": "value"} uuid = str(uuid4()) @@ -164,6 +175,7 @@ def test_toml_table_to_layer_build_definition(self): self.assertEqual(build_definition.compatible_runtimes, toml_table[COMPATIBLE_RUNTIMES_FIELD]) self.assertEqual(build_definition.layer, None) self.assertEqual(build_definition.source_md5, toml_table[SOURCE_MD5_FIELD]) + self.assertEqual(build_definition.manifest_md5, toml_table[MANIFEST_MD5_FIELD]) self.assertEqual(build_definition.env_vars, toml_table[ENV_VARS_FIELD]) @@ -178,6 +190,7 @@ class TestBuildGraph(TestCase): UUID = "3c1c254e-cd4b-4d94-8c74-7ab870b36063" LAYER_UUID = "7dnc257e-cd4b-4d94-8c74-7ab870b3abc3" SOURCE_MD5 = "cae49aa393d669e850bd49869905099d" + MANIFEST_MD5 = "rty87gh393d669e850bd49869905099e" ENV_VARS = {"env_vars": "value"} BUILD_GRAPH_CONTENTS = f""" @@ -186,6 +199,7 @@ class TestBuildGraph(TestCase): codeuri = "{CODEURI}" runtime = "{RUNTIME}" source_md5 = "{SOURCE_MD5}" + manifest_md5 = "{MANIFEST_MD5}" packagetype = "{ZIP}" functions = ["HelloWorldPython", "HelloWorldPython2"] [function_build_definitions.{UUID}.metadata] @@ -201,6 +215,7 @@ class TestBuildGraph(TestCase): build_method = "{LAYER_RUNTIME}" compatible_runtimes = ["{LAYER_RUNTIME}"] source_md5 = "{SOURCE_MD5}" + manifest_md5 = "{MANIFEST_MD5}" layer = "SumLayer" [layer_build_definitions.{LAYER_UUID}.env_vars] env_vars = "{ENV_VARS['env_vars']}" @@ -233,6 +248,7 @@ def test_should_instantiate_first_time_and_update(self): TestBuildGraph.ZIP, TestBuildGraph.METADATA, TestBuildGraph.SOURCE_MD5, + TestBuildGraph.MANIFEST_MD5, TestBuildGraph.ENV_VARS, ) function1 = generate_function( @@ -245,6 +261,7 @@ def test_should_instantiate_first_time_and_update(self): TestBuildGraph.LAYER_RUNTIME, [TestBuildGraph.LAYER_RUNTIME], TestBuildGraph.SOURCE_MD5, + TestBuildGraph.MANIFEST_MD5, TestBuildGraph.ENV_VARS, ) layer1 = generate_layer( @@ -288,12 +305,15 @@ def test_should_read_existing_build_graph(self): self.assertEqual(function_build_definition.packagetype, TestBuildGraph.ZIP) self.assertEqual(function_build_definition.metadata, TestBuildGraph.METADATA) self.assertEqual(function_build_definition.source_md5, TestBuildGraph.SOURCE_MD5) + self.assertEqual(function_build_definition.manifest_md5, TestBuildGraph.MANIFEST_MD5) self.assertEqual(function_build_definition.env_vars, TestBuildGraph.ENV_VARS) for layer_build_definition in build_graph.get_layer_build_definitions(): self.assertEqual(layer_build_definition.name, TestBuildGraph.LAYER_NAME) self.assertEqual(layer_build_definition.codeuri, TestBuildGraph.LAYER_CODEURI) self.assertEqual(layer_build_definition.build_method, TestBuildGraph.LAYER_RUNTIME) + self.assertEqual(layer_build_definition.source_md5, TestBuildGraph.SOURCE_MD5) + self.assertEqual(layer_build_definition.manifest_md5, TestBuildGraph.MANIFEST_MD5) self.assertEqual(layer_build_definition.compatible_runtimes, [TestBuildGraph.LAYER_RUNTIME]) self.assertEqual(layer_build_definition.env_vars, TestBuildGraph.ENV_VARS) @@ -313,6 +333,7 @@ def test_functions_should_be_added_existing_build_graph(self): TestBuildGraph.ZIP, TestBuildGraph.METADATA, TestBuildGraph.SOURCE_MD5, + TestBuildGraph.MANIFEST_MD5, TestBuildGraph.ENV_VARS, ) function1 = generate_function( @@ -334,6 +355,7 @@ def test_functions_should_be_added_existing_build_graph(self): TestBuildGraph.ZIP, None, "another_source_md5", + "another_manifest_md5", {"env_vars": "value2"}, ) function2 = generate_function(name="another_function") @@ -360,6 +382,7 @@ def test_layers_should_be_added_existing_build_graph(self): TestBuildGraph.LAYER_RUNTIME, [TestBuildGraph.LAYER_RUNTIME], TestBuildGraph.SOURCE_MD5, + TestBuildGraph.MANIFEST_MD5, TestBuildGraph.ENV_VARS, ) layer1 = generate_layer( @@ -380,6 +403,7 @@ def test_layers_should_be_added_existing_build_graph(self): "another_runtime", ["another_runtime"], "another_source_md5", + "another_manifest_md5", {"env_vars": "value2"}, ) layer2 = generate_layer(arn="arn:aws:lambda:region:account-id:layer:another-layer-name:1") @@ -389,11 +413,115 @@ def test_layers_should_be_added_existing_build_graph(self): self.assertEqual(len(build_definitions), 2) self.assertEqual(build_definitions[1].layer, layer2) + @patch("samcli.lib.build.build_graph.BuildGraph._write_source_md5") + @patch("samcli.lib.build.build_graph.BuildGraph._compare_md5_changes") + def test_update_definition_md5_should_succeed(self, compare_md5_mock, write_md5_mock): + compare_md5_mock.return_value = {"mock": "md5"} + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + + build_graph_path = Path(build_dir.parent, "build.toml") + build_graph_path.write_text(TestBuildGraph.BUILD_GRAPH_CONTENTS) + + build_graph = BuildGraph(str(build_dir)) + build_graph.update_definition_md5() + write_md5_mock.assert_called_with({"mock": "md5"}, {"mock": "md5"}) + + def test_compare_md5_changes_should_succeed(self): + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + + build_graph_path = Path(build_dir.parent, "build.toml") + build_graph_path.write_text(TestBuildGraph.BUILD_GRAPH_CONTENTS) + + build_graph = BuildGraph(str(build_dir)) + + build_definition = FunctionBuildDefinition( + TestBuildGraph.RUNTIME, + TestBuildGraph.CODEURI, + TestBuildGraph.ZIP, + TestBuildGraph.METADATA, + TestBuildGraph.SOURCE_MD5, + TestBuildGraph.MANIFEST_MD5, + TestBuildGraph.ENV_VARS, + ) + updated_definition = FunctionBuildDefinition( + TestBuildGraph.RUNTIME, + TestBuildGraph.CODEURI, + TestBuildGraph.ZIP, + TestBuildGraph.METADATA, + "new_value", + "new_manifest_value", + TestBuildGraph.ENV_VARS, + ) + updated_definition.uuid = build_definition.uuid + + layer_definition = LayerBuildDefinition( + TestBuildGraph.LAYER_NAME, + TestBuildGraph.LAYER_CODEURI, + TestBuildGraph.LAYER_RUNTIME, + [TestBuildGraph.LAYER_RUNTIME], + TestBuildGraph.SOURCE_MD5, + TestBuildGraph.MANIFEST_MD5, + TestBuildGraph.ENV_VARS, + ) + updated_layer = LayerBuildDefinition( + TestBuildGraph.LAYER_NAME, + TestBuildGraph.LAYER_CODEURI, + TestBuildGraph.LAYER_RUNTIME, + [TestBuildGraph.LAYER_RUNTIME], + "new_value", + "new_manifest_value", + TestBuildGraph.ENV_VARS, + ) + updated_layer.uuid = layer_definition.uuid + + build_graph._function_build_definitions = [build_definition] + build_graph._layer_build_definitions = [layer_definition] + + function_content = BuildGraph._compare_md5_changes( + [updated_definition], build_graph._function_build_definitions + ) + layer_content = BuildGraph._compare_md5_changes([updated_layer], build_graph._layer_build_definitions) + self.assertEqual(function_content, {build_definition.uuid: ("new_value", "new_manifest_value")}) + self.assertEqual(layer_content, {layer_definition.uuid: ("new_value", "new_manifest_value")}) + + def test_write_source_md5_should_succeed(self): + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + + build_graph_path = Path(build_dir.parent, "build.toml") + build_graph_path.write_text(TestBuildGraph.BUILD_GRAPH_CONTENTS) + + build_graph = BuildGraph(str(build_dir)) + + build_graph._write_source_md5( + {TestBuildGraph.UUID: BuildHashingInformation("new_value", "new_manifest_value")}, + {TestBuildGraph.LAYER_UUID: BuildHashingInformation("new_value", "new_manifest_value")}, + ) + + txt = build_graph_path.read_text() + document = cast(Dict, tomlkit.loads(txt)) + + self.assertEqual(document["function_build_definitions"][TestBuildGraph.UUID][SOURCE_MD5_FIELD], "new_value") + self.assertEqual( + document["function_build_definitions"][TestBuildGraph.UUID][MANIFEST_MD5_FIELD], "new_manifest_value" + ) + self.assertEqual( + document["layer_build_definitions"][TestBuildGraph.LAYER_UUID][SOURCE_MD5_FIELD], "new_value" + ) + self.assertEqual( + document["layer_build_definitions"][TestBuildGraph.LAYER_UUID][MANIFEST_MD5_FIELD], "new_manifest_value" + ) + class TestBuildDefinition(TestCase): def test_single_function_should_return_function_and_handler_name(self): build_definition = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, "metadata", "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, "metadata", "source_md5", "manifest_md5", {"env_vars": "value"} ) build_definition.add_function(generate_function()) @@ -402,24 +530,28 @@ def test_single_function_should_return_function_and_handler_name(self): def test_no_function_should_raise_exception(self): build_definition = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, "metadata", "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, "metadata", "source_md5", "manifest_md5", {"env_vars": "value"} ) self.assertRaises(InvalidBuildGraphException, build_definition.get_handler_name) self.assertRaises(InvalidBuildGraphException, build_definition.get_function_name) def test_same_runtime_codeuri_metadata_should_reflect_as_same_object(self): - build_definition1 = FunctionBuildDefinition("runtime", "codeuri", ZIP, {"key": "value"}, "source_md5") - build_definition2 = FunctionBuildDefinition("runtime", "codeuri", ZIP, {"key": "value"}, "source_md5") + build_definition1 = FunctionBuildDefinition( + "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", "manifest_md5" + ) + build_definition2 = FunctionBuildDefinition( + "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", "manifest_md5" + ) self.assertEqual(build_definition1, build_definition2) def test_same_env_vars_reflect_as_same_object(self): build_definition1 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", "manifest_md5", {"env_vars": "value"} ) build_definition2 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", {"env_vars": "value"} + "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", "manifest_md5", {"env_vars": "value"} ) self.assertEqual(build_definition1, build_definition2) @@ -479,20 +611,20 @@ def test_different_runtime_codeuri_metadata_should_not_reflect_as_same_object( def test_different_env_vars_should_not_reflect_as_same_object(self): build_definition1 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", {"env_vars": "value1"} + "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", "manifest_md5", {"env_vars": "value1"} ) build_definition2 = FunctionBuildDefinition( - "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", {"env_vars": "value2"} + "runtime", "codeuri", ZIP, {"key": "value"}, "source_md5", "manifest_md5", {"env_vars": "value2"} ) self.assertNotEqual(build_definition1, build_definition2) def test_euqality_with_another_object(self): - build_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, None, "source_md5") + build_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, None, "source_md5", "manifest_md5") self.assertNotEqual(build_definition, {}) def test_str_representation(self): - build_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, None, "source_md5") + build_definition = FunctionBuildDefinition("runtime", "codeuri", ZIP, None, "source_md5", "manifest_md5") self.assertEqual( str(build_definition), f"BuildDefinition(runtime, codeuri, Zip, source_md5, {build_definition.uuid}, {{}}, {{}}, [])", diff --git a/tests/unit/lib/build_module/test_build_strategy.py b/tests/unit/lib/build_module/test_build_strategy.py index f0e7ab3e7d..2a277a2bc3 100644 --- a/tests/unit/lib/build_module/test_build_strategy.py +++ b/tests/unit/lib/build_module/test_build_strategy.py @@ -1,13 +1,18 @@ +from copy import deepcopy from unittest import TestCase from unittest.mock import Mock, patch, MagicMock, call, ANY -from samcli.commands.build.exceptions import MissingBuildMethodException +from parameterized import parameterized + +from samcli.lib.build.exceptions import MissingBuildMethodException from samcli.lib.build.build_graph import BuildGraph, FunctionBuildDefinition, LayerBuildDefinition from samcli.lib.build.build_strategy import ( ParallelBuildStrategy, BuildStrategy, DefaultBuildStrategy, CachedBuildStrategy, + CachedOrIncrementalBuildStrategyWrapper, + IncrementalBuildStrategy, ) from samcli.lib.utils import osutils from pathlib import Path @@ -155,6 +160,8 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.function_build_definition1.get_build_dir(given_build_dir), self.function_build_definition1.metadata, self.function_build_definition1.env_vars, + self.function_build_definition1.dependencies_dir, + True, ), call( self.function_build_definition2.get_function_name(), @@ -165,6 +172,8 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.function_build_definition2.get_build_dir(given_build_dir), self.function_build_definition2.metadata, self.function_build_definition2.env_vars, + self.function_build_definition2.dependencies_dir, + True, ), ] ) @@ -178,7 +187,9 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.layer1.build_method, self.layer1.compatible_runtimes, self.layer1.get_build_dir(given_build_dir), - self.function_build_definition1.env_vars, + self.layer_build_definition1.env_vars, + self.layer_build_definition1.dependencies_dir, + True, ), call( self.layer2.name, @@ -186,7 +197,9 @@ def test_build_layers_and_functions(self, mock_copy_tree): self.layer2.build_method, self.layer2.compatible_runtimes, self.layer2.get_build_dir(given_build_dir), - self.function_build_definition2.env_vars, + self.layer_build_definition2.env_vars, + self.layer_build_definition2.dependencies_dir, + True, ), ] ) @@ -222,7 +235,11 @@ def test_build_single_function_definition_image_functions_with_same_metadata(sel # since they have the same metadata, they are put into the same build_definition. build_definition.functions = [function1, function2] - result = default_build_strategy.build_single_function_definition(build_definition) + with patch("samcli.lib.build.build_strategy.deepcopy", wraps=deepcopy) as patched_deepcopy: + result = default_build_strategy.build_single_function_definition(build_definition) + + patched_deepcopy.assert_called_with(build_definition.env_vars) + # both of the function name should show up in results self.assertEqual(result, {"Function": built_image, "Function2": built_image}) @@ -267,7 +284,7 @@ def test_build_call(self, mock_layer_build, mock_function_build, mock_rmtree, mo self.build_graph, given_build_dir, given_build_function, given_build_layer ) cache_build_strategy = CachedBuildStrategy( - self.build_graph, default_build_strategy, "base_dir", given_build_dir, "cache_dir", True + self.build_graph, default_build_strategy, "base_dir", given_build_dir, "cache_dir" ) cache_build_strategy.build() mock_function_build.assert_called() @@ -277,7 +294,6 @@ def test_build_call(self, mock_layer_build, mock_function_build, mock_rmtree, mo @patch("samcli.lib.build.build_strategy.pathlib.Path.exists") @patch("samcli.lib.build.build_strategy.dir_checksum") def test_if_cached_valid_when_build_single_function_definition(self, dir_checksum_mock, exists_mock, copytree_mock): - pass with osutils.mkdir_temp() as temp_base_dir: build_dir = Path(temp_base_dir, ".aws-sam", "build") build_dir.mkdir(parents=True) @@ -291,7 +307,7 @@ def test_if_cached_valid_when_build_single_function_definition(self, dir_checksu build_graph_path.write_text(CachedBuildStrategyTest.BUILD_GRAPH_CONTENTS) build_graph = BuildGraph(str(build_dir)) cached_build_strategy = CachedBuildStrategy( - build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir, True + build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir ) func1 = Mock() func1.name = "func1_name" @@ -330,7 +346,7 @@ def test_if_cached_invalid_with_no_cached_folder(self, build_layer_mock, build_f build_graph_path.write_text(CachedBuildStrategyTest.BUILD_GRAPH_CONTENTS) build_graph = BuildGraph(str(build_dir)) cached_build_strategy = CachedBuildStrategy( - build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir, True + build_graph, DefaultBuildStrategy, temp_base_dir, build_dir, cache_dir ) cached_build_strategy.build_single_function_definition(build_graph.get_function_build_definitions()[0]) cached_build_strategy.build_single_layer_definition(build_graph.get_layer_build_definitions()[0]) @@ -348,7 +364,7 @@ def test_redundant_cached_should_be_clean(self): redundant_cache_folder = Path(cache_dir, "redundant") redundant_cache_folder.mkdir(parents=True) - cached_build_strategy = CachedBuildStrategy(build_graph, Mock(), temp_base_dir, build_dir, cache_dir, True) + cached_build_strategy = CachedBuildStrategy(build_graph, Mock(), temp_base_dir, build_dir, cache_dir) cached_build_strategy._clean_redundant_cached() self.assertTrue(not redundant_cache_folder.exists()) @@ -429,3 +445,134 @@ def test_given_delegate_strategy_it_should_call_delegated_build_methods(self): call(self.layer_build_definition2), ] ) + + +@patch("samcli.lib.build.build_strategy.DependencyHashGenerator") +class TestIncrementalBuildStrategy(TestCase): + def setUp(self): + self.build_function = Mock() + self.build_layer = Mock() + self.build_graph = Mock() + self.delegate_build_strategy = DefaultBuildStrategy( + self.build_graph, Mock(), self.build_function, self.build_layer + ) + self.build_strategy = IncrementalBuildStrategy( + self.build_graph, + self.delegate_build_strategy, + Mock(), + Mock(), + ) + + def test_assert_incremental_build_function(self, patched_manifest_hash): + same_hash = "same_hash" + patched_manifest_hash_instance = Mock(hash=same_hash) + patched_manifest_hash.return_value = patched_manifest_hash_instance + + given_function_build_def = Mock(manifest_md5=same_hash, functions=[Mock()]) + self.build_graph.get_function_build_definitions.return_value = [given_function_build_def] + self.build_graph.get_layer_build_definitions.return_value = [] + + self.build_strategy.build() + self.build_function.assert_called_with(ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY, False) + + def test_assert_incremental_build_layer(self, patched_manifest_hash): + same_hash = "same_hash" + patched_manifest_hash_instance = Mock(hash=same_hash) + patched_manifest_hash.return_value = patched_manifest_hash_instance + + given_layer_build_def = Mock(manifest_md5=same_hash, functions=[Mock()]) + self.build_graph.get_function_build_definitions.return_value = [] + self.build_graph.get_layer_build_definitions.return_value = [given_layer_build_def] + + self.build_strategy.build() + self.build_layer.assert_called_with(ANY, ANY, ANY, ANY, ANY, ANY, ANY, False) + + +@patch("samcli.lib.build.build_graph.BuildGraph._write") +@patch("samcli.lib.build.build_graph.BuildGraph._read") +class TestCachedOrIncrementalBuildStrategyWrapper(TestCase): + def setUp(self) -> None: + self.build_graph = BuildGraph("build/graph/location") + + self.build_strategy = CachedOrIncrementalBuildStrategyWrapper( + self.build_graph, + Mock(), + "base_dir", + "build_dir", + "cache_dir", + "manifest_path_override", + False, + ) + + @parameterized.expand( + [ + "python3.7", + "nodejs12.x", + "ruby2.7", + ] + ) + def test_will_call_incremental_build_strategy(self, mocked_read, mocked_write, runtime): + build_definition = FunctionBuildDefinition(runtime, "codeuri", "packate_type", {}) + self.build_graph.put_function_build_definition(build_definition, Mock()) + with patch.object( + self.build_strategy, "_incremental_build_strategy" + ) as patched_incremental_build_strategy, patch.object( + self.build_strategy, "_cached_build_strategy" + ) as patched_cached_build_strategy: + self.build_strategy.build() + + patched_incremental_build_strategy.build_single_function_definition.assert_called_with(build_definition) + patched_cached_build_strategy.assert_not_called() + + @parameterized.expand( + [ + "dotnetcore2.1", + "go1.x", + "java11", + ] + ) + def test_will_call_cached_build_strategy(self, mocked_read, mocked_write, runtime): + build_definition = FunctionBuildDefinition(runtime, "codeuri", "packate_type", {}) + self.build_graph.put_function_build_definition(build_definition, Mock()) + with patch.object( + self.build_strategy, "_incremental_build_strategy" + ) as patched_incremental_build_strategy, patch.object( + self.build_strategy, "_cached_build_strategy" + ) as patched_cached_build_strategy: + self.build_strategy.build() + + patched_cached_build_strategy.build_single_function_definition.assert_called_with(build_definition) + patched_incremental_build_strategy.assert_not_called() + + @parameterized.expand([(True,), (False,)]) + @patch("samcli.lib.build.build_strategy.CachedBuildStrategy._clean_redundant_cached") + @patch("samcli.lib.build.build_strategy.IncrementalBuildStrategy._clean_redundant_dependencies") + def test_exit_build_strategy_for_specific_resource( + self, is_building_specific_resource, clean_cache_mock, clean_dep_mock, mocked_read, mocked_write + ): + with osutils.mkdir_temp() as temp_base_dir: + build_dir = Path(temp_base_dir, ".aws-sam", "build") + build_dir.mkdir(parents=True) + cache_dir = Path(temp_base_dir, ".aws-sam", "cache") + cache_dir.mkdir(parents=True) + + mocked_build_graph = Mock() + mocked_build_graph.get_layer_build_definitions.return_value = [] + mocked_build_graph.get_function_build_definitions.return_value = [] + + cached_build_strategy = CachedOrIncrementalBuildStrategyWrapper( + mocked_build_graph, Mock(), temp_base_dir, build_dir, cache_dir, None, is_building_specific_resource + ) + + cached_build_strategy.build() + + if is_building_specific_resource: + mocked_build_graph.update_definition_md5.assert_called_once() + mocked_build_graph.clean_redundant_definitions_and_update.assert_not_called() + clean_cache_mock.assert_not_called() + clean_dep_mock.assert_not_called() + else: + mocked_build_graph.update_definition_md5.assert_not_called() + mocked_build_graph.clean_redundant_definitions_and_update.assert_called_once() + clean_cache_mock.assert_called_once() + clean_dep_mock.assert_called_once() diff --git a/tests/unit/lib/build_module/test_dependency_hash_generator.py b/tests/unit/lib/build_module/test_dependency_hash_generator.py new file mode 100644 index 0000000000..b3cf3d1c41 --- /dev/null +++ b/tests/unit/lib/build_module/test_dependency_hash_generator.py @@ -0,0 +1,86 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from samcli.lib.build.dependency_hash_generator import DependencyHashGenerator + + +class TestDependencyHashGenerator(TestCase): + def setUp(self): + self.get_workflow_config_patch = patch("samcli.lib.build.dependency_hash_generator.get_workflow_config") + self.get_workflow_config_mock = self.get_workflow_config_patch.start() + self.get_workflow_config_mock.return_value.manifest_name = "manifest_file" + + self.file_checksum_patch = patch("samcli.lib.build.dependency_hash_generator.file_checksum") + self.file_checksum_mock = self.file_checksum_patch.start() + self.file_checksum_mock.return_value = "checksum" + + def tearDown(self): + self.get_workflow_config_patch.stop() + self.file_checksum_patch.stop() + + @patch("samcli.lib.build.dependency_hash_generator.DependencyHashGenerator._calculate_dependency_hash") + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_init_and_properties(self, path_mock, calculate_hash_mock): + path_mock.return_value.resolve.return_value.__str__.return_value = "code_dir" + calculate_hash_mock.return_value = "dependency_hash" + self.generator = DependencyHashGenerator("code_uri", "base_dir", "runtime") + self.assertEqual(self.generator._code_uri, "code_uri") + self.assertEqual(self.generator._base_dir, "base_dir") + self.assertEqual(self.generator._code_dir, "code_dir") + self.assertEqual(self.generator._runtime, "runtime") + self.assertEqual(self.generator.hash, "dependency_hash") + + path_mock.assert_called_once_with("base_dir", "code_uri") + + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_calculate_manifest_hash(self, path_mock): + code_dir_mock = MagicMock() + code_dir_mock.resolve.return_value.__str__.return_value = "code_dir" + manifest_path_mock = MagicMock() + manifest_path_mock.resolve.return_value.__str__.return_value = "manifest_path" + manifest_path_mock.resolve.return_value.is_file.return_value = True + path_mock.side_effect = [code_dir_mock, manifest_path_mock] + + self.generator = DependencyHashGenerator("code_uri", "base_dir", "runtime") + hash = self.generator.hash + self.file_checksum_mock.assert_called_once_with("manifest_path", hash_generator=None) + self.assertEqual(hash, "checksum") + + path_mock.assert_any_call("base_dir", "code_uri") + path_mock.assert_any_call("code_dir", "manifest_file") + + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_calculate_manifest_hash_missing_file(self, path_mock): + code_dir_mock = MagicMock() + code_dir_mock.resolve.return_value.__str__.return_value = "code_dir" + manifest_path_mock = MagicMock() + manifest_path_mock.resolve.return_value.__str__.return_value = "manifest_path" + manifest_path_mock.resolve.return_value.is_file.return_value = False + path_mock.side_effect = [code_dir_mock, manifest_path_mock] + + self.generator = DependencyHashGenerator("code_uri", "base_dir", "runtime") + self.file_checksum_mock.assert_not_called() + self.assertEqual(self.generator.hash, None) + + path_mock.assert_any_call("base_dir", "code_uri") + path_mock.assert_any_call("code_dir", "manifest_file") + + @patch("samcli.lib.build.dependency_hash_generator.pathlib.Path") + def test_calculate_manifest_hash_manifest_override(self, path_mock): + code_dir_mock = MagicMock() + code_dir_mock.resolve.return_value.__str__.return_value = "code_dir" + manifest_path_mock = MagicMock() + manifest_path_mock.resolve.return_value.__str__.return_value = "manifest_path" + manifest_path_mock.resolve.return_value.is_file.return_value = True + path_mock.side_effect = [code_dir_mock, manifest_path_mock] + + self.generator = DependencyHashGenerator( + "code_uri", "base_dir", "runtime", manifest_path_override="manifest_override" + ) + hash = self.generator.hash + self.get_workflow_config_mock.assert_not_called() + self.file_checksum_mock.assert_called_once_with("manifest_path", hash_generator=None) + self.assertEqual(hash, "checksum") + + path_mock.assert_any_call("base_dir", "code_uri") + path_mock.assert_any_call("code_dir", "manifest_override") diff --git a/tests/unit/lib/deploy/test_deployer.py b/tests/unit/lib/deploy/test_deployer.py index a0b29efe5d..62982cf6b0 100644 --- a/tests/unit/lib/deploy/test_deployer.py +++ b/tests/unit/lib/deploy/test_deployer.py @@ -1,3 +1,4 @@ +from logging import captureWarnings import uuid import time from datetime import datetime, timedelta @@ -750,3 +751,107 @@ def test_wait_for_execute_with_outputs(self, patched_time): self.deployer.get_stack_outputs = MagicMock(return_value=outputs["Stacks"][0]["Outputs"]) self.deployer.wait_for_execute("test", "CREATE") self.assertEqual(self.deployer._display_stack_outputs.call_count, 1) + + def test_sync_update_stack(self): + self.deployer.has_stack = MagicMock(return_value=True) + self.deployer.wait_for_execute = MagicMock() + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + self.assertEqual(self.deployer._client.update_stack.call_count, 1) + self.deployer._client.update_stack.assert_called_with( + Capabilities=["CAPABILITY_IAM"], + NotificationARNs=[], + Parameters=[{"ParameterKey": "a", "ParameterValue": "b"}], + RoleARN="role-arn", + StackName="test", + Tags={"unit": "true"}, + TemplateURL=ANY, + ) + + def test_sync_update_stack_exception(self): + self.deployer.has_stack = MagicMock(return_value=True) + self.deployer.wait_for_execute = MagicMock() + self.deployer._client.update_stack = MagicMock(side_effect=Exception) + with self.assertRaises(DeployFailedError): + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + def test_sync_create_stack(self): + self.deployer.has_stack = MagicMock(return_value=False) + self.deployer.wait_for_execute = MagicMock() + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + self.assertEqual(self.deployer._client.create_stack.call_count, 1) + self.deployer._client.create_stack.assert_called_with( + Capabilities=["CAPABILITY_IAM"], + NotificationARNs=[], + Parameters=[{"ParameterKey": "a", "ParameterValue": "b"}], + RoleARN="role-arn", + StackName="test", + Tags={"unit": "true"}, + TemplateURL=ANY, + ) + + def test_sync_create_stack_exception(self): + self.deployer.has_stack = MagicMock(return_value=False) + self.deployer.wait_for_execute = MagicMock() + self.deployer._client.create_stack = MagicMock(side_effect=Exception) + with self.assertRaises(DeployFailedError): + self.deployer.sync( + stack_name="test", + cfn_template=" ", + parameter_values=[ + {"ParameterKey": "a", "ParameterValue": "b"}, + ], + capabilities=["CAPABILITY_IAM"], + role_arn="role-arn", + notification_arns=[], + s3_uploader=S3Uploader(s3_client=self.s3_client, bucket_name="test_bucket"), + tags={"unit": "true"}, + ) + + def test_process_kwargs(self): + kwargs = {"Capabilities": []} + capabilities = ["CAPABILITY_IAM"] + role_arn = "role-arn" + notification_arns = ["arn"] + + expected = { + "Capabilities": ["CAPABILITY_IAM"], + "RoleARN": "role-arn", + "NotificationARNs": ["arn"], + } + result = self.deployer._process_kwargs(kwargs, None, capabilities, role_arn, notification_arns) + self.assertEqual(expected, result) diff --git a/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py b/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py index f864ff1fe7..652615a81e 100644 --- a/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py +++ b/tests/unit/lib/observability/cw_logs/test_cw_log_formatters.py @@ -10,6 +10,8 @@ CWColorizeErrorsFormatter, CWKeywordHighlighterFormatter, CWJsonFormatter, + CWAddNewLineIfItDoesntExist, + CWLogEventJSONMapper, ) @@ -118,3 +120,47 @@ def test_ignore_non_json(self, input_msg): result = self.formatter.map(event) self.assertEqual(result.message, input_msg) + + +class TestCWAddNewLineIfItDoesntExist(TestCase): + def setUp(self) -> None: + self.formatter = CWAddNewLineIfItDoesntExist() + + @parameterized.expand( + [ + (CWLogEvent("log_group", {"message": "input"}),), + (CWLogEvent("log_group", {"message": "input\n"}),), + ] + ) + def test_cw_log_event(self, log_event): + mapped_event = self.formatter.map(log_event) + self.assertEqual(mapped_event.message, "input\n") + + @parameterized.expand( + [ + ("input",), + ("input\n",), + ] + ) + def test_str_event(self, str_event): + mapped_event = self.formatter.map(str_event) + self.assertEqual(mapped_event, "input\n") + + @parameterized.expand( + [ + ({"some": "dict"},), + (5,), + ] + ) + def test_other_events(self, event): + mapped_event = self.formatter.map(event) + self.assertEqual(mapped_event, event) + + +class TestCWLogEventJSONMapper(TestCase): + def test_mapper(self): + given_event = CWLogEvent("log_group", {"message": "input"}) + mapper = CWLogEventJSONMapper() + + mapped_event = mapper.map(given_event) + self.assertEqual(mapped_event.message, json.dumps(given_event.event)) diff --git a/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py b/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py index 295ad6d898..4e890bfa12 100644 --- a/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py +++ b/tests/unit/lib/observability/cw_logs/test_cw_log_group_provider.py @@ -1,11 +1,67 @@ from unittest import TestCase +from unittest.mock import Mock, ANY from samcli.lib.observability.cw_logs.cw_log_group_provider import LogGroupProvider class TestLogGroupProvider_for_lambda_function(TestCase): def test_must_return_log_group_name(self): - expected = "/aws/lambda/myfunctionname" - result = LogGroupProvider.for_lambda_function("myfunctionname") + expected = "/aws/lambda/my_function_name" + result = LogGroupProvider.for_lambda_function("my_function_name") self.assertEqual(expected, result) + + def test_rest_api_log_group_name(self): + expected = "API-Gateway-Execution-Logs_my_function_name/Prod" + result = LogGroupProvider.for_resource(Mock(), "AWS::ApiGateway::RestApi", "my_function_name") + + self.assertEqual(expected, result) + + def test_http_api_log_group_name(self): + given_client_provider = Mock() + given_client_provider(ANY).get_stage.return_value = { + "AccessLogSettings": {"DestinationArn": "test:my_log_group"} + } + expected = "my_log_group" + result = LogGroupProvider.for_resource(given_client_provider, "AWS::ApiGatewayV2::Api", "my_function_name") + + self.assertEqual(expected, result) + + def test_http_api_log_group_name_not_exist(self): + given_client_provider = Mock() + given_client_provider(ANY).get_stage.return_value = {} + result = LogGroupProvider.for_resource(given_client_provider, "AWS::ApiGatewayV2::Api", "my_function_name") + + self.assertIsNone(result) + + def test_step_functions(self): + given_client_provider = Mock() + given_cw_log_group_name = "sam-app-logs-command-test-MyStateMachineLogGroup-ucwMaQpNBJTD" + given_client_provider(ANY).describe_state_machine.return_value = { + "loggingConfiguration": { + "destinations": [ + { + "cloudWatchLogsLogGroup": { + "logGroupArn": f"arn:aws:logs:us-west-2:694866504768:log-group:{given_cw_log_group_name}:*" + } + } + ] + } + } + + result = LogGroupProvider.for_resource( + given_client_provider, "AWS::StepFunctions::StateMachine", "my_state_machine" + ) + + self.assertIsNotNone(result) + self.assertEqual(result, given_cw_log_group_name) + + def test_invalid_step_functions(self): + given_client_provider = Mock() + given_client_provider(ANY).describe_state_machine.return_value = {"loggingConfiguration": {"destinations": []}} + + result = LogGroupProvider.for_resource( + given_client_provider, "AWS::StepFunctions::StateMachine", "my_state_machine" + ) + + self.assertIsNone(result) diff --git a/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py b/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py index 98f4e6d3de..7e609e1b01 100644 --- a/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py +++ b/tests/unit/lib/observability/cw_logs/test_cw_log_puller.py @@ -100,6 +100,39 @@ def test_must_fetch_logs_with_all_params(self): for event in self.expected_events: self.assertIn(event, call_args) + @patch("samcli.lib.observability.cw_logs.cw_log_puller.LOG") + def test_must_print_resource_not_found_only_once(self, patched_log): + pattern = "foobar" + start = datetime.utcnow() + end = datetime.utcnow() + + expected_params = { + "logGroupName": self.log_group_name, + "interleaved": True, + "startTime": to_timestamp(start), + "endTime": to_timestamp(end), + "filterPattern": pattern, + } + + self.client_stubber.add_client_error( + "filter_log_events", expected_params=expected_params, service_error_code="ResourceNotFoundException" + ) + self.client_stubber.add_client_error( + "filter_log_events", expected_params=expected_params, service_error_code="ResourceNotFoundException" + ) + self.client_stubber.add_response("filter_log_events", self.mock_api_response, expected_params) + + with self.client_stubber: + self.assertFalse(self.fetcher._invalid_log_group) + self.fetcher.load_time_period(start_time=start, end_time=end, filter_pattern=pattern) + self.assertTrue(self.fetcher._invalid_log_group) + self.fetcher.load_time_period(start_time=start, end_time=end, filter_pattern=pattern) + self.assertTrue(self.fetcher._invalid_log_group) + self.fetcher.load_time_period(start_time=start, end_time=end, filter_pattern=pattern) + self.assertFalse(self.fetcher._invalid_log_group) + + patched_log.warning.assert_called_once() + def test_must_paginate_using_next_token(self): """Make three API calls, first two returns a nextToken and last does not.""" token = "token" @@ -320,3 +353,31 @@ def test_without_start_time(self, time_mock): self.assertEqual([], expected_consumer_call_args) self.assertEqual(expected_load_time_period_calls, patched_load_time_period.call_args_list) self.assertEqual(expected_sleep_calls, time_mock.sleep.call_args_list) + + @patch("samcli.lib.observability.cw_logs.cw_log_puller.time") + def test_with_throttling(self, time_mock): + expected_params = { + "logGroupName": self.log_group_name, + "interleaved": True, + "startTime": 0, + "filterPattern": self.filter_pattern, + } + + for _ in range(self.max_retries): + self.client_stubber.add_client_error( + "filter_log_events", expected_params=expected_params, service_error_code="ThrottlingException" + ) + + expected_load_time_period_calls = [call(to_datetime(0), filter_pattern=ANY) for _ in range(self.max_retries)] + + expected_time_calls = [call(2), call(4), call(16)] + + with patch.object( + self.fetcher, "load_time_period", wraps=self.fetcher.load_time_period + ) as patched_load_time_period: + with self.client_stubber: + self.fetcher.tail(filter_pattern=self.filter_pattern) + + self.consumer.consume.assert_not_called() + self.assertEqual(expected_load_time_period_calls, patched_load_time_period.call_args_list) + time_mock.sleep.assert_has_calls(expected_time_calls, any_order=True) diff --git a/tests/unit/lib/observability/test_observability_info_puller.py b/tests/unit/lib/observability/test_observability_info_puller.py index 3fbbb9fe34..59f5b08689 100644 --- a/tests/unit/lib/observability/test_observability_info_puller.py +++ b/tests/unit/lib/observability/test_observability_info_puller.py @@ -1,9 +1,12 @@ from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import Mock, patch, call from parameterized import parameterized, param -from samcli.lib.observability.observability_info_puller import ObservabilityEventConsumerDecorator +from samcli.lib.observability.observability_info_puller import ( + ObservabilityEventConsumerDecorator, + ObservabilityCombinedPuller, +) class TestObservabilityEventConsumerDecorator(TestCase): @@ -48,3 +51,56 @@ def test_decorator_with_mappers(self, mappers): actual_consumer.consume.assert_called_with(event) for mapper in mappers: mapper.map.assert_called_with(event) + + +class TestObservabilityCombinedPuller(TestCase): + @patch("samcli.lib.observability.observability_info_puller.AsyncContext") + def test_tail(self, patched_async_context): + mocked_async_context = Mock() + patched_async_context.return_value = mocked_async_context + + mock_puller_1 = Mock() + mock_puller_2 = Mock() + + combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2]) + + given_start_time = Mock() + given_filter_pattern = Mock() + combined_puller.tail(given_start_time, given_filter_pattern) + + patched_async_context.assert_called_once() + mocked_async_context.assert_has_calls( + [ + call.add_async_task(mock_puller_1.tail, given_start_time, given_filter_pattern), + call.add_async_task(mock_puller_2.tail, given_start_time, given_filter_pattern), + call.run_async(), + ] + ) + + @patch("samcli.lib.observability.observability_info_puller.AsyncContext") + def test_load_time_period(self, patched_async_context): + mocked_async_context = Mock() + patched_async_context.return_value = mocked_async_context + + mock_puller_1 = Mock() + mock_puller_2 = Mock() + + combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2]) + + given_start_time = Mock() + given_end_time = Mock() + given_filter_pattern = Mock() + combined_puller.load_time_period(given_start_time, given_end_time, given_filter_pattern) + + patched_async_context.assert_called_once() + mocked_async_context.assert_has_calls( + [ + call.add_async_task( + mock_puller_1.load_time_period, given_start_time, given_end_time, given_filter_pattern + ), + call.add_async_task( + mock_puller_2.load_time_period, given_start_time, given_end_time, given_filter_pattern + ), + call.run_async(), + ] + ) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_event_mappers.py b/tests/unit/lib/observability/xray_traces/test_xray_event_mappers.py new file mode 100644 index 0000000000..c1b0eccf47 --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_event_mappers.py @@ -0,0 +1,196 @@ +import json +import time +import uuid +from datetime import datetime +from unittest import TestCase + +from samcli.lib.observability.xray_traces.xray_event_mappers import ( + XRayTraceConsoleMapper, + XRayTraceJSONMapper, + XRayServiceGraphConsoleMapper, + XRayServiceGraphJSONMapper, +) +from samcli.lib.observability.xray_traces.xray_events import XRayTraceEvent, XRayServiceGraphEvent +from samcli.lib.utils.time import to_utc, utc_to_timestamp, timestamp_to_iso + + +class AbstraceXRayTraceMapperTest(TestCase): + def setUp(self): + self.trace_event = XRayTraceEvent( + { + "Id": str(uuid.uuid4()), + "name": str(uuid.uuid4()), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + "subsegments": [ + { + "Id": str(uuid.uuid4()), + "Document": json.dumps( + { + "name": str(uuid.uuid4()), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + } + ), + }, + { + "Id": str(uuid.uuid4()), + "Document": json.dumps( + { + "name": str(uuid.uuid4()), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + "subsegments": [ + { + "Id": str(uuid.uuid4()), + "name": str(uuid.uuid4()), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + } + ], + } + ), + }, + ], + } + ) + + +class TestXRayTraceConsoleMapper(AbstraceXRayTraceMapperTest): + def test_console_mapper(self): + console_mapper = XRayTraceConsoleMapper() + mapped_event = console_mapper.map(self.trace_event) + + self.assertTrue(isinstance(mapped_event, XRayTraceEvent)) + + event_timestamp = timestamp_to_iso(self.trace_event.timestamp) + self.assertTrue( + f"XRay Event at ({event_timestamp}) with id ({self.trace_event.id}) and duration ({self.trace_event.duration:.3f}s)" + in mapped_event.message + ) + + self.validate_segments(self.trace_event.segments, mapped_event.message) + + def validate_segments(self, segments, message): + for segment in segments: + + if segment.http_status: + self.assertTrue( + f" - {segment.get_duration():.3f}s - {segment.name} [HTTP: {segment.http_status}]" in message + ) + else: + self.assertTrue(f" - {segment.get_duration():.3f}s - {segment.name}" in message) + self.validate_segments(segments.sub_segments, message) + + +class TestXRayTraceJSONMapper(AbstraceXRayTraceMapperTest): + def test_escaped_json_will_be_dict(self): + json_mapper = XRayTraceJSONMapper() + mapped_event = json_mapper.map(self.trace_event) + + segments = mapped_event.event.get("Segments") + self.assertTrue(isinstance(segments, list)) + for segment in segments: + self.assertTrue(isinstance(segment, dict)) + self.assertEqual(mapped_event.event, json.loads(mapped_event.message)) + + +class AbstractXRayServiceGraphMapperTest(TestCase): + def setUp(self): + self.service_graph_event = XRayServiceGraphEvent( + { + "StartTime": datetime(2015, 1, 1), + "EndTime": datetime(2015, 1, 1), + "Services": [ + { + "ReferenceId": 123, + "Name": "string", + "Root": True | False, + "Type": "string", + "StartTime": datetime(2015, 1, 1), + "EndTime": datetime(2015, 1, 1), + "Edges": [ + { + "ReferenceId": 123, + "StartTime": datetime(2015, 1, 1), + "EndTime": datetime(2015, 1, 1), + }, + ], + "SummaryStatistics": { + "OkCount": 123, + "ErrorStatistics": {"TotalCount": 123}, + "FaultStatistics": {"TotalCount": 123}, + "TotalCount": 123, + "TotalResponseTime": 123.0, + }, + }, + ], + } + ) + + +class TestXRayServiceGraphConsoleMapper(AbstractXRayServiceGraphMapperTest): + def test_console_mapper(self): + console_mapper = XRayServiceGraphConsoleMapper() + mapped_event = console_mapper.map(self.service_graph_event) + + self.assertTrue(isinstance(mapped_event, XRayServiceGraphEvent)) + + self.assertTrue(f"\nNew XRay Service Graph" in mapped_event.message) + self.assertTrue(f"\n Start time: {self.service_graph_event.start_time}" in mapped_event.message) + self.assertTrue(f"\n End time: {self.service_graph_event.end_time}" in mapped_event.message) + + self.validate_services(self.service_graph_event.services, mapped_event.message) + + def validate_services(self, services, message): + for service in services: + self.assertTrue(f"Reference Id: {service.id}" in message) + if service.is_root: + self.assertTrue("(Root)" in message) + else: + self.assertFalse("(Root)" in message) + self.assertTrue(f" {service.type} - {service.name}" in message) + edg_id_str = str(service.edge_ids) + self.assertTrue(f"Edges: {edg_id_str}" in message) + self.validate_summary_statistics(service, message) + + def validate_summary_statistics(self, service, message): + self.assertTrue("Summary_statistics:" in message) + self.assertTrue(f"total requests: {service.total_count}" in message) + self.assertTrue(f"ok count(2XX): {service.ok_count}" in message) + self.assertTrue(f"error count(4XX): {service.error_count}" in message) + self.assertTrue(f"fault count(5XX): {service.fault_count}" in message) + self.assertTrue(f"total response time: {service.response_time}" in message) + + +class TestXRayServiceGraphFileMapper(AbstractXRayServiceGraphMapperTest): + def test_datetime_object_convert_to_iso_string(self): + actual_datetime = datetime(2015, 1, 1) + json_mapper = XRayServiceGraphJSONMapper() + mapped_event = json_mapper.map(self.service_graph_event) + mapped_dict = mapped_event.event + + self.validate_start_and_end_time(actual_datetime, mapped_dict) + services = mapped_dict.get("Services", []) + for service in services: + self.validate_start_and_end_time(actual_datetime, service) + edges = service.get("Edges", []) + for edge in edges: + self.validate_start_and_end_time(actual_datetime, edge) + self.assertEqual(mapped_event.event, json.loads(mapped_event.message)) + + def validate_start_and_end_time(self, datetime_obj, event_dict): + self.validate_datetime_object_to_iso_string("StartTime", datetime_obj, event_dict) + self.validate_datetime_object_to_iso_string("EndTime", datetime_obj, event_dict) + + def validate_datetime_object_to_iso_string(self, datetime_key, datetime_obj, event_dict): + datetime_str = event_dict.get(datetime_key) + self.assertTrue(isinstance(datetime_str, str)) + expected_utc_datetime = to_utc(datetime_obj) + expected_timestamp = utc_to_timestamp(expected_utc_datetime) + expected_iso_str = timestamp_to_iso(expected_timestamp) + self.assertEqual(datetime_str, expected_iso_str) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_event_puller.py b/tests/unit/lib/observability/xray_traces/test_xray_event_puller.py new file mode 100644 index 0000000000..380bfa3e5e --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_event_puller.py @@ -0,0 +1,151 @@ +import time +import uuid +from itertools import zip_longest +from unittest import TestCase +from unittest.mock import patch, mock_open, call, Mock, ANY + +from botocore.exceptions import ClientError +from parameterized import parameterized + +from samcli.lib.observability.xray_traces.xray_event_puller import XRayTracePuller + + +class TestXrayTracePuller(TestCase): + def setUp(self): + self.xray_client = Mock() + self.consumer = Mock() + + self.max_retries = 4 + self.xray_trace_puller = XRayTracePuller(self.xray_client, self.consumer, self.max_retries) + + @parameterized.expand([(i,) for i in range(1, 15)]) + @patch("samcli.lib.observability.xray_traces.xray_event_puller.XRayTraceEvent") + def test_load_events(self, size, patched_xray_trace_event): + ids = [str(uuid.uuid4()) for _ in range(size)] + batch_ids = list(zip_longest(*([iter(ids)] * 5))) + + given_paginators = [Mock() for _ in batch_ids] + self.xray_client.get_paginator.side_effect = given_paginators + + given_results = [] + for i in range(len(batch_ids)): + given_result = [{"Traces": [Mock() for _ in batch]} for batch in batch_ids] + given_paginators[i].paginate.return_value = given_result + given_results.append(given_result) + + collected_events = [] + + def dynamic_mock(trace): + mocked_trace_event = Mock(trace=trace) + mocked_trace_event.get_latest_event_time.return_value = time.time() + collected_events.append(mocked_trace_event) + return mocked_trace_event + + patched_xray_trace_event.side_effect = dynamic_mock + + self.xray_trace_puller.load_events(ids) + + for i in range(len(batch_ids)): + self.xray_client.get_paginator.assert_called_with("batch_get_traces") + given_paginators[i].assert_has_calls([call.paginate(TraceIds=list(filter(None, batch_ids[i])))]) + self.consumer.assert_has_calls([call.consume(event) for event in collected_events]) + for event in collected_events: + event.get_latest_event_time.assert_called_once() + + def test_load_events_with_no_event_ids(self): + self.xray_trace_puller.load_events([]) + self.consumer.assert_not_called() + + def test_load_events_with_no_event_returned(self): + event_ids = [str(uuid.uuid4())] + + given_paginator = Mock() + given_paginator.paginate.return_value = [{"Traces": []}] + self.xray_client.get_paginator.return_value = given_paginator + + self.xray_trace_puller.load_events(event_ids) + given_paginator.paginate.assert_called_with(TraceIds=event_ids) + self.consumer.assert_not_called() + + def test_load_time_period(self): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_trace_summaries = [{"TraceSummaries": [{"Id": str(uuid.uuid4())} for _ in range(10)]}] + given_paginator.paginate.return_value = given_trace_summaries + + start_time = "start_time" + end_time = "end_time" + with patch.object(self.xray_trace_puller, "load_events") as patched_load_events: + self.xray_trace_puller.load_time_period(start_time, end_time) + given_paginator.paginate.assert_called_with(TimeRangeType="TraceId", StartTime=start_time, EndTime=end_time) + + collected_trace_ids = [item.get("Id") for item in given_trace_summaries[0].get("TraceSummaries", [])] + patched_load_events.assert_called_with(collected_trace_ids) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_no_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + + with patch.object(self.xray_trace_puller, "load_time_period") as patched_load_time_period: + self.xray_trace_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [call(self.xray_trace_puller.latest_event_time) for _ in range(self.max_retries)] + ) + + patched_time.sleep.assert_has_calls( + [call(self.xray_trace_puller._poll_interval) for _ in range(self.max_retries)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_with_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + given_start_time = 5 + patched_to_timestamp.return_value = 5 + with patch.object(self.xray_trace_puller, "_had_data") as patched_had_data: + patched_had_data.side_effect = [True, False] + + with patch.object(self.xray_trace_puller, "load_time_period") as patched_load_time_period: + self.xray_trace_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [ + call(given_start_time), + ], + any_order=True, + ) + patched_to_datetime.assert_has_calls([call(given_start_time + 1) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls( + [call(self.xray_trace_puller._poll_interval) for _ in range(self.max_retries + 1)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries + 1)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + def test_with_throttling(self, patched_time): + with patch.object( + self.xray_trace_puller, "load_time_period", wraps=self.xray_trace_puller.load_time_period + ) as patched_load_time_period: + patched_load_time_period.side_effect = [ + ClientError({"Error": {"Code": "ThrottlingException"}}, "operation") for _ in range(self.max_retries) + ] + + self.xray_trace_puller.tail() + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls([call(2), call(4), call(16), call(256)]) + + self.assertEqual(self.xray_trace_puller._poll_interval, 256) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_events.py b/tests/unit/lib/observability/xray_traces/test_xray_events.py new file mode 100644 index 0000000000..c1db606dd1 --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_events.py @@ -0,0 +1,198 @@ +import json +import time +import uuid +from unittest import TestCase + +from samcli.lib.observability.xray_traces.xray_events import XRayTraceSegment, XRayTraceEvent, XRayServiceGraphEvent +from samcli.lib.utils.hash import str_checksum + +LATEST_EVENT_TIME = 9621490723 + + +class AbstractXRayEventTextTest(TestCase): + def validate_segment(self, segment, event_dict): + self.assertEqual(segment.id, event_dict.get("Id")) + self.assertEqual(segment.name, event_dict.get("name")) + self.assertEqual(segment.start_time, event_dict.get("start_time")) + self.assertEqual(segment.end_time, event_dict.get("end_time")) + self.assertEqual(segment.http_status, event_dict.get("http", {}).get("response", {}).get("status", None)) + event_subsegments = event_dict.get("subsegments", []) + self.assertEqual(len(segment.sub_segments), len(event_subsegments)) + + for event_subsegment in event_subsegments: + subsegment = next(x for x in segment.sub_segments if x.id == event_subsegment.get("Id")) + self.validate_segment(subsegment, event_subsegment) + + +class TestXRayTraceEvent(AbstractXRayEventTextTest): + def setUp(self): + self.first_segment_date = time.time() - 1000 + self.segment_1 = { + "Id": str(uuid.uuid4()), + "name": f"Second {str(uuid.uuid4())}", + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + } + self.segment_2 = { + "Id": str(uuid.uuid4()), + "name": f"First {str(uuid.uuid4())}", + "start_time": self.first_segment_date, + "end_time": LATEST_EVENT_TIME, + "http": {"response": {"status": 200}}, + } + self.event_dict = { + "Id": str(uuid.uuid4()), + "Duration": 400, + "Segments": [ + {"Id": self.segment_1.get("Id"), "Document": json.dumps(self.segment_1)}, + {"Id": self.segment_2.get("Id"), "Document": json.dumps(self.segment_2)}, + ], + } + + def test_xray_trace_event(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + self.assertEqual(xray_trace_event.id, self.event_dict.get("Id")) + self.assertEqual(xray_trace_event.duration, self.event_dict.get("Duration")) + segments = self.event_dict.get("Segments", []) + self.assertEqual(len(xray_trace_event.segments), len(segments)) + + for segment in segments: + subsegment = next(x for x in xray_trace_event.segments if x.id == segment.get("Id")) + self.validate_segment(subsegment, json.loads(segment.get("Document"))) + + def test_latest_event_time(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + self.assertEqual(xray_trace_event.get_latest_event_time(), LATEST_EVENT_TIME) + + def test_first_event_time(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + self.assertEqual(xray_trace_event.timestamp, self.first_segment_date) + + def test_segment_order(self): + xray_trace_event = XRayTraceEvent(self.event_dict) + + self.assertEqual(len(xray_trace_event.segments), 2) + self.assertIn("First", xray_trace_event.segments[0].name) + self.assertIn("Second", xray_trace_event.segments[1].name) + + +class TestXRayTraceSegment(AbstractXRayEventTextTest): + def setUp(self): + self.event_dict = { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + "subsegments": [ + { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + }, + { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": time.time(), + "http": {"response": {"status": 200}}, + "subsegments": [ + { + "Id": uuid.uuid4(), + "name": uuid.uuid4(), + "start_time": time.time(), + "end_time": LATEST_EVENT_TIME, + "http": {"response": {"status": 200}}, + } + ], + }, + ], + } + + def test_xray_trace_segment_duration(self): + xray_trace_segment = XRayTraceSegment(self.event_dict) + self.assertEqual( + xray_trace_segment.get_duration(), self.event_dict.get("end_time") - self.event_dict.get("start_time") + ) + + def test_xray_latest_event_time(self): + xray_trace_segment = XRayTraceSegment(self.event_dict) + self.assertEqual(xray_trace_segment.get_latest_event_time(), LATEST_EVENT_TIME) + + def test_xray_trace_segment(self): + xray_trace_segment = XRayTraceSegment(self.event_dict) + self.validate_segment(xray_trace_segment, self.event_dict) + + +class AbstractXRayServiceTest(TestCase): + def validate_service(self, service, service_dict): + self.assertEqual(service.id, service_dict.get("ReferenceId")) + self.assertEqual(service.name, service_dict.get("Name")) + self.assertEqual(service.is_root, service_dict.get("Root")) + self.assertEqual(service.type, service_dict.get("Type")) + self.assertEqual(service.name, service_dict.get("Name")) + edges = service_dict.get("Edges") + self.assertEqual(len(service.edge_ids), len(edges)) + summary_statistics = service_dict.get("SummaryStatistics") + self.assertEqual(service.ok_count, summary_statistics.get("OkCount")) + self.assertEqual(service.error_count, summary_statistics.get("ErrorStatistics").get("TotalCount")) + self.assertEqual(service.fault_count, summary_statistics.get("FaultStatistics").get("TotalCount")) + self.assertEqual(service.total_count, summary_statistics.get("TotalCount")) + self.assertEqual(service.response_time, summary_statistics.get("TotalResponseTime")) + + +class TestXRayServiceGraphEvent(AbstractXRayServiceTest): + def setUp(self): + self.service_1 = { + "ReferenceId": 0, + "Name": "test1", + "Root": True, + "Type": "Lambda", + "Edges": [ + { + "ReferenceId": 1, + }, + ], + "SummaryStatistics": { + "OkCount": 1, + "ErrorStatistics": {"TotalCount": 2}, + "FaultStatistics": {"TotalCount": 3}, + "TotalCount": 6, + "TotalResponseTime": 123.0, + }, + } + + self.service_2 = { + "ReferenceId": 1, + "Name": "test2", + "Root": False, + "Type": "Api", + "Edges": [], + "SummaryStatistics": { + "OkCount": 2, + "ErrorStatistics": {"TotalCount": 3}, + "FaultStatistics": {"TotalCount": 3}, + "TotalCount": 8, + "TotalResponseTime": 200.0, + }, + } + self.event_dict = { + "Services": [self.service_1, self.service_2], + } + + def test_xray_service_graph_event(self): + xray_service_graph_event = XRayServiceGraphEvent(self.event_dict) + services_array = self.event_dict.get("Services", []) + services = xray_service_graph_event.services + self.assertEqual(len(services), len(services_array)) + + for service, service_dict in zip(services, services_array): + self.validate_service(service, service_dict) + + def test__xray_service_graph_event_get_hash(self): + xray_service_graph_event = XRayServiceGraphEvent(self.event_dict) + expected_hash = str_checksum(str(self.event_dict["Services"])) + self.assertEqual(expected_hash, xray_service_graph_event.get_hash()) diff --git a/tests/unit/lib/observability/xray_traces/test_xray_service_grpah_event_puller.py b/tests/unit/lib/observability/xray_traces/test_xray_service_grpah_event_puller.py new file mode 100644 index 0000000000..8bd604a6ee --- /dev/null +++ b/tests/unit/lib/observability/xray_traces/test_xray_service_grpah_event_puller.py @@ -0,0 +1,146 @@ +import time +import uuid +from itertools import zip_longest +from unittest import TestCase +from unittest.mock import patch, mock_open, call, Mock, ANY + +from botocore.exceptions import ClientError +from parameterized import parameterized + +from samcli.lib.observability.xray_traces.xray_event_puller import XRayTracePuller +from samcli.lib.observability.xray_traces.xray_service_graph_event_puller import XRayServiceGraphPuller + + +class TestXRayServiceGraphPuller(TestCase): + def setUp(self): + self.xray_client = Mock() + self.consumer = Mock() + + self.max_retries = 4 + self.xray_service_graph_puller = XRayServiceGraphPuller(self.xray_client, self.consumer, self.max_retries) + + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.XRayServiceGraphEvent") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.to_utc") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.utc_to_timestamp") + def test_load_time_period(self, patched_utc_to_timestamp, patched_to_utc, patched_xray_service_graph_event): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_services = [{"EndTime": "endtime", "Services": [{"id": 1}]}] + given_paginator.paginate.return_value = given_services + + start_time = "start_time" + end_time = "end_time" + patched_utc_to_timestamp.return_value = 1 + self.xray_service_graph_puller.load_time_period(start_time, end_time) + patched_utc_to_timestamp.assert_called() + patched_to_utc.assert_called() + given_paginator.paginate.assert_called_with(StartTime=start_time, EndTime=end_time) + patched_xray_service_graph_event.assrt_called_with({"EndTime": "endtime", "Services": [{"id": 1}]}) + self.consumer.consume.assert_called() + + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.XRayServiceGraphEvent") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.to_utc") + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.utc_to_timestamp") + def test_load_time_period_with_same_event_twice( + self, patched_utc_to_timestamp, patched_to_utc, patched_xray_service_graph_event + ): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_services = [{"EndTime": "endtime", "Services": [{"id": 1}]}] + given_paginator.paginate.return_value = given_services + + start_time = "start_time" + end_time = "end_time" + patched_utc_to_timestamp.return_value = 1 + self.xray_service_graph_puller.load_time_period(start_time, end_time) + # called with the same event twice + self.xray_service_graph_puller.load_time_period(start_time, end_time) + patched_utc_to_timestamp.assert_called() + patched_to_utc.assert_called() + given_paginator.paginate.assert_called_with(StartTime=start_time, EndTime=end_time) + patched_xray_service_graph_event.assrt_called_with({"EndTime": "endtime", "Services": [{"id": 1}]}) + # consumer should only get called once + self.consumer.consume.assert_called_once() + + @patch("samcli.lib.observability.xray_traces.xray_service_graph_event_puller.XRayServiceGraphEvent") + def test_load_time_period_with_no_service(self, patched_xray_service_graph_event): + given_paginator = Mock() + self.xray_client.get_paginator.return_value = given_paginator + + given_services = [{"EndTime": "endtime", "Services": []}] + given_paginator.paginate.return_value = given_services + + start_time = "start_time" + end_time = "end_time" + self.xray_service_graph_puller.load_time_period(start_time, end_time) + patched_xray_service_graph_event.assert_not_called() + self.consumer.consume.assert_not_called() + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_no_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + + with patch.object(self.xray_service_graph_puller, "load_time_period") as patched_load_time_period: + self.xray_service_graph_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [call(self.xray_service_graph_puller.latest_event_time) for _ in range(self.max_retries)] + ) + + patched_time.sleep.assert_has_calls( + [call(self.xray_service_graph_puller._poll_interval) for _ in range(self.max_retries)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_timestamp") + @patch("samcli.lib.observability.xray_traces.xray_event_puller.to_datetime") + def test_tail_with_with_data(self, patched_to_datetime, patched_to_timestamp, patched_time): + start_time = Mock() + given_start_time = 5 + patched_to_timestamp.return_value = 5 + with patch.object(self.xray_service_graph_puller, "_had_data") as patched_had_data: + patched_had_data.side_effect = [True, False] + + with patch.object(self.xray_service_graph_puller, "load_time_period") as patched_load_time_period: + self.xray_service_graph_puller.tail(start_time) + + patched_to_timestamp.assert_called_with(start_time) + + patched_to_datetime.assert_has_calls( + [ + call(given_start_time), + ], + any_order=True, + ) + patched_to_datetime.assert_has_calls([call(given_start_time + 1) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls( + [call(self.xray_service_graph_puller._poll_interval) for _ in range(self.max_retries + 1)] + ) + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries + 1)]) + + @patch("samcli.lib.observability.xray_traces.xray_event_puller.time") + def test_with_throttling(self, patched_time): + with patch.object( + self.xray_service_graph_puller, "load_time_period", wraps=self.xray_service_graph_puller.load_time_period + ) as patched_load_time_period: + patched_load_time_period.side_effect = [ + ClientError({"Error": {"Code": "ThrottlingException"}}, "operation") for _ in range(self.max_retries) + ] + + self.xray_service_graph_puller.tail() + + patched_load_time_period.assert_has_calls([call(ANY, ANY) for _ in range(self.max_retries)]) + + patched_time.sleep.assert_has_calls([call(2), call(4), call(16), call(256)]) + + self.assertEqual(self.xray_service_graph_puller._poll_interval, 256) diff --git a/tests/unit/lib/sync/__init__.py b/tests/unit/lib/sync/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/lib/sync/flows/__init__.py b/tests/unit/lib/sync/flows/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/lib/sync/flows/test_alias_version_sync_flow.py b/tests/unit/lib/sync/flows/test_alias_version_sync_flow.py new file mode 100644 index 0000000000..e3db396e8d --- /dev/null +++ b/tests/unit/lib/sync/flows/test_alias_version_sync_flow.py @@ -0,0 +1,50 @@ +import os +import hashlib + +from samcli.lib.sync.sync_flow import SyncFlow +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, mock_open, patch + +from samcli.lib.sync.flows.alias_version_sync_flow import AliasVersionSyncFlow + + +class TestAliasVersionSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = AliasVersionSyncFlow( + "Function1", + "Alias1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("lambda") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow._lambda_client.publish_version.return_value = {"Version": "2"} + + sync_flow.sync() + + sync_flow._lambda_client.publish_version.assert_called_once_with(FunctionName="PhysicalFunction1") + sync_flow._lambda_client.update_alias.assert_called_once_with( + FunctionName="PhysicalFunction1", Name="Alias1", FunctionVersion="2" + ) + + def test_equality_keys(self): + sync_flow = self.create_sync_flow() + self.assertEqual(sync_flow._equality_keys(), ("Function1", "Alias1")) diff --git a/tests/unit/lib/sync/flows/test_function_sync_flow.py b/tests/unit/lib/sync/flows/test_function_sync_flow.py new file mode 100644 index 0000000000..c2d3e12f6c --- /dev/null +++ b/tests/unit/lib/sync/flows/test_function_sync_flow.py @@ -0,0 +1,50 @@ +from samcli.lib.providers.provider import ResourceIdentifier +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.sync.flows.function_sync_flow import FunctionSyncFlow +from samcli.lib.utils.lock_distributor import LockChain + + +class TestFunctionSyncFlow(TestCase): + def create_function_sync_flow(self): + sync_flow = FunctionSyncFlow( + "Function1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow.gather_resources = MagicMock() + sync_flow.compare_remote = MagicMock() + sync_flow.sync = MagicMock() + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(FunctionSyncFlow, __abstractmethods__=set()) + def test_sets_up_clients(self, session_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_called_once_with("lambda") + sync_flow._lambda_client.get_waiter.assert_called_once_with("function_updated") + + @patch("samcli.lib.sync.flows.function_sync_flow.AliasVersionSyncFlow") + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(FunctionSyncFlow, __abstractmethods__=set()) + def test_gather_dependencies(self, session_mock, alias_version_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.get_physical_id = lambda x: "PhysicalFunction1" + sync_flow._get_resource = lambda x: MagicMock() + + sync_flow.set_up() + result = sync_flow.gather_dependencies() + + sync_flow._lambda_waiter.wait.assert_called_once_with(FunctionName="PhysicalFunction1", WaiterConfig=ANY) + self.assertEqual(result, [alias_version_mock.return_value]) + + @patch.multiple(FunctionSyncFlow, __abstractmethods__=set()) + def test_equality_keys(self): + sync_flow = self.create_function_sync_flow() + self.assertEqual(sync_flow._equality_keys(), "Function1") diff --git a/tests/unit/lib/sync/flows/test_http_api_sync_flow.py b/tests/unit/lib/sync/flows/test_http_api_sync_flow.py new file mode 100644 index 0000000000..cc3d0a3f66 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_http_api_sync_flow.py @@ -0,0 +1,84 @@ +from unittest import TestCase +from unittest.mock import ANY, MagicMock, mock_open, patch + +from samcli.lib.sync.flows.http_api_sync_flow import HttpApiSyncFlow +from samcli.lib.providers.exceptions import MissingLocalDefinition + + +class TestHttpApiSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = HttpApiSyncFlow( + "Api1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("apigatewayv2") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + sync_flow.gather_resources() + + sync_flow._api_client.reimport_api.return_value = {"Response": "success"} + + sync_flow.sync() + + sync_flow._api_client.reimport_api.assert_called_once_with( + ApiId="PhysicalApi1", Body='{"key": "value"}'.encode("utf-8") + ) + + @patch("samcli.lib.sync.flows.generic_api_sync_flow.get_resource_by_id") + def test_get_definition_file(self, get_resource_mock): + sync_flow = self.create_sync_flow() + + get_resource_mock.return_value = {"Properties": {"DefinitionUri": "test_uri"}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, "test_uri") + + get_resource_mock.return_value = {"Properties": {}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, None) + + def test_process_definition_file(self): + sync_flow = self.create_sync_flow() + sync_flow._definition_uri = "path" + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + data = sync_flow._process_definition_file() + self.assertEqual(data, '{"key": "value"}'.encode("utf-8")) + + @patch("samcli.lib.sync.sync_flow.Session") + def test_failed_gather_resources(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + sync_flow._definition_uri = None + + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + with self.assertRaises(MissingLocalDefinition): + sync_flow.sync() diff --git a/tests/unit/lib/sync/flows/test_image_function_sync_flow.py b/tests/unit/lib/sync/flows/test_image_function_sync_flow.py new file mode 100644 index 0000000000..31c8951050 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_image_function_sync_flow.py @@ -0,0 +1,109 @@ +from samcli.lib.sync.sync_flow import SyncFlow +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.flows.image_function_sync_flow import ImageFunctionSyncFlow + + +class TestImageFunctionSyncFlow(TestCase): + def create_function_sync_flow(self): + sync_flow = ImageFunctionSyncFlow( + "Function1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + docker_client=MagicMock(), + ) + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("lambda") + session_mock.return_value.client.assert_any_call("ecr") + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ApplicationBuilder") + @patch("samcli.lib.sync.sync_flow.Session") + def test_gather_resources(self, session_mock, builder_mock): + get_mock = MagicMock() + get_mock.return_value = "ImageName1" + builder_mock.return_value.build.return_value.get = get_mock + sync_flow = self.create_function_sync_flow() + + sync_flow.set_up() + sync_flow.gather_resources() + + get_mock.assert_called_once_with("Function1") + self.assertEqual(sync_flow._image_name, "ImageName1") + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ECRUploader") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_context_image_repo(self, session_mock, uploader_mock): + sync_flow = self.create_function_sync_flow() + sync_flow._image_name = "ImageName1" + + uploader_mock.return_value.upload.return_value = "image_uri" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + sync_flow._deploy_context.image_repository = "repo_uri" + + sync_flow.set_up() + sync_flow.sync() + + uploader_mock.return_value.upload.assert_called_once_with("ImageName1", "Function1") + uploader_mock.assert_called_once_with(sync_flow._docker_client, sync_flow._ecr_client, "repo_uri", None) + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ImageUri="image_uri" + ) + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ECRUploader") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_context_image_repos(self, session_mock, uploader_mock): + sync_flow = self.create_function_sync_flow() + sync_flow._image_name = "ImageName1" + + uploader_mock.return_value.upload.return_value = "image_uri" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + sync_flow._deploy_context.image_repository = "" + sync_flow._deploy_context.image_repositories = {"Function1": "repo_uri"} + + sync_flow.set_up() + sync_flow.sync() + + uploader_mock.return_value.upload.assert_called_once_with("ImageName1", "Function1") + uploader_mock.assert_called_once_with(sync_flow._docker_client, sync_flow._ecr_client, "repo_uri", None) + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ImageUri="image_uri" + ) + + @patch("samcli.lib.sync.flows.image_function_sync_flow.ECRUploader") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_remote_image_repo(self, session_mock, uploader_mock): + sync_flow = self.create_function_sync_flow() + sync_flow._image_name = "ImageName1" + + uploader_mock.return_value.upload.return_value = "image_uri" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + sync_flow._deploy_context.image_repository = "" + sync_flow._deploy_context.image_repositories = {} + + sync_flow.set_up() + + sync_flow._lambda_client.get_function = MagicMock() + sync_flow._lambda_client.get_function.return_value = {"Code": {"ImageUri": "repo_uri:tag"}} + + sync_flow.sync() + + uploader_mock.return_value.upload.assert_called_once_with("ImageName1", "Function1") + uploader_mock.assert_called_once_with(sync_flow._docker_client, sync_flow._ecr_client, "repo_uri", None) + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ImageUri="image_uri" + ) diff --git a/tests/unit/lib/sync/flows/test_layer_sync_flow.py b/tests/unit/lib/sync/flows/test_layer_sync_flow.py new file mode 100644 index 0000000000..9df81f3fe0 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_layer_sync_flow.py @@ -0,0 +1,426 @@ +import base64 +import hashlib +from unittest import TestCase +from unittest.mock import MagicMock, Mock, patch, call, ANY, mock_open, PropertyMock + +from parameterized import parameterized + +from samcli.lib.sync.exceptions import MissingPhysicalResourceError, NoLayerVersionsFoundError +from samcli.lib.sync.flows.layer_sync_flow import LayerSyncFlow, FunctionLayerReferenceSync +from samcli.lib.sync.sync_flow import SyncFlow + + +class TestLayerSyncFlow(TestCase): + def setUp(self): + self.layer_identifier = "LayerA" + self.build_context_mock = Mock() + self.deploy_context_mock = Mock() + + self.layer_sync_flow = LayerSyncFlow( + self.layer_identifier, + self.build_context_mock, + self.deploy_context_mock, + {self.layer_identifier: "layer_version_arn"}, + [], + ) + + def test_setup(self): + with patch.object(self.layer_sync_flow, "_session") as patched_session: + with patch.object(SyncFlow, "set_up") as patched_super_setup: + self.layer_sync_flow.set_up() + + patched_super_setup.assert_called_once() + patched_session.assert_has_calls( + [ + call.client("s3"), + call.client("lambda"), + ] + ) + + @patch("samcli.lib.sync.flows.layer_sync_flow.get_resource_by_id") + def test_setup_with_serverless_layer(self, get_resource_by_id_mock): + given_layer_name_with_hashes = f"{self.layer_identifier}abcdefghij" + self.layer_sync_flow._physical_id_mapping = {given_layer_name_with_hashes: "layer_version_arn"} + get_resource_by_id_mock.return_value = False + with patch.object(self.layer_sync_flow, "_session") as patched_session: + with patch.object(SyncFlow, "set_up") as patched_super_setup: + self.layer_sync_flow.set_up() + + patched_super_setup.assert_called_once() + patched_session.assert_has_calls( + [ + call.client("s3"), + call.client("lambda"), + ] + ) + + self.assertEqual(self.layer_sync_flow._layer_physical_name, "layer_version_arn") + + def test_setup_with_unknown_layer(self): + given_layer_name_with_hashes = f"SomeOtherLayerabcdefghij" + self.layer_sync_flow._physical_id_mapping = {given_layer_name_with_hashes: "layer_version_arn"} + with patch.object(self.layer_sync_flow, "_session") as _: + with patch.object(SyncFlow, "set_up") as _: + with self.assertRaises(MissingPhysicalResourceError): + self.layer_sync_flow.set_up() + + @patch("samcli.lib.sync.flows.layer_sync_flow.ApplicationBuilder") + @patch("samcli.lib.sync.flows.layer_sync_flow.tempfile") + @patch("samcli.lib.sync.flows.layer_sync_flow.make_zip") + @patch("samcli.lib.sync.flows.layer_sync_flow.file_checksum") + @patch("samcli.lib.sync.flows.layer_sync_flow.os") + def test_setup_gather_resources( + self, patched_os, patched_file_checksum, patched_make_zip, patched_tempfile, patched_app_builder + ): + given_collect_build_resources = Mock() + self.build_context_mock.collect_build_resources.return_value = given_collect_build_resources + + given_app_builder = Mock() + given_artifact_folder = Mock() + given_app_builder.build().get.return_value = given_artifact_folder + patched_app_builder.return_value = given_app_builder + + given_zip_location = Mock() + patched_make_zip.return_value = given_zip_location + + given_file_checksum = Mock() + patched_file_checksum.return_value = given_file_checksum + + self.layer_sync_flow._get_lock_chain = MagicMock() + + self.layer_sync_flow.gather_resources() + + self.build_context_mock.collect_build_resources.assert_called_with(self.layer_identifier) + + patched_app_builder.assert_called_with( + given_collect_build_resources, + self.build_context_mock.build_dir, + self.build_context_mock.base_dir, + self.build_context_mock.cache_dir, + cached=True, + is_building_specific_resource=True, + manifest_path_override=self.build_context_mock.manifest_path_override, + container_manager=self.build_context_mock.container_manager, + mode=self.build_context_mock.mode, + ) + + patched_tempfile.gettempdir.assert_called_once() + patched_os.path.join.assert_called_with(ANY, ANY) + patched_make_zip.assert_called_with(ANY, self.layer_sync_flow._artifact_folder) + + patched_file_checksum.assert_called_with(ANY, ANY) + + self.assertEqual(self.layer_sync_flow._artifact_folder, given_artifact_folder) + self.assertEqual(self.layer_sync_flow._zip_file, given_zip_location) + self.assertEqual(self.layer_sync_flow._local_sha, given_file_checksum) + + self.layer_sync_flow._get_lock_chain.assert_called_once() + self.layer_sync_flow._get_lock_chain.return_value.__enter__.assert_called_once() + self.layer_sync_flow._get_lock_chain.return_value.__exit__.assert_called_once() + + def test_compare_remote(self): + given_lambda_client = Mock() + self.layer_sync_flow._lambda_client = given_lambda_client + + given_sha256 = base64.b64encode(b"checksum") + given_layer_info = {"Content": {"CodeSha256": given_sha256}} + given_lambda_client.get_layer_version.return_value = given_layer_info + + self.layer_sync_flow._local_sha = base64.b64decode(given_sha256).hex() + + with patch.object(self.layer_sync_flow, "_get_latest_layer_version") as patched_get_latest_layer_version: + given_layer_name = Mock() + given_latest_layer_version = Mock() + self.layer_sync_flow._layer_physical_name = given_layer_name + patched_get_latest_layer_version.return_value = given_latest_layer_version + + compare_result = self.layer_sync_flow.compare_remote() + + self.assertTrue(compare_result) + + def test_sync(self): + with patch.object(self.layer_sync_flow, "_publish_new_layer_version") as patched_publish_new_layer_version: + with patch.object(self.layer_sync_flow, "_delete_old_layer_version") as patched_delete_old_layer_version: + given_layer_version = Mock() + patched_publish_new_layer_version.return_value = given_layer_version + + self.layer_sync_flow.sync() + self.assertEqual(self.layer_sync_flow._new_layer_version, given_layer_version) + + patched_publish_new_layer_version.assert_called_once() + patched_delete_old_layer_version.assert_called_once() + + def test_publish_new_layer_version(self): + given_layer_name = Mock() + + given_lambda_client = Mock() + self.layer_sync_flow._lambda_client = given_lambda_client + + given_zip_file = Mock() + self.layer_sync_flow._zip_file = given_zip_file + + self.layer_sync_flow._layer_physical_name = given_layer_name + + with patch.object(self.layer_sync_flow, "_get_resource") as patched_get_resource: + with patch("builtins.open", mock_open(read_data="data")) as mock_file: + given_publish_layer_result = {"Version": 24} + given_lambda_client.publish_layer_version.return_value = given_publish_layer_result + + given_layer_resource = Mock() + patched_get_resource.return_value = given_layer_resource + + result_version = self.layer_sync_flow._publish_new_layer_version() + + patched_get_resource.assert_called_with(self.layer_identifier) + given_lambda_client.publish_layer_version.assert_called_with( + LayerName=given_layer_name, + Content={"ZipFile": "data"}, + CompatibleRuntimes=given_layer_resource.get("Properties", {}).get("CompatibleRuntimes", []), + ) + + self.assertEqual(result_version, given_publish_layer_result.get("Version")) + + def test_delete_old_layer_version(self): + given_layer_name = Mock() + given_layer_version = Mock() + + given_lambda_client = Mock() + self.layer_sync_flow._lambda_client = given_lambda_client + + self.layer_sync_flow._layer_physical_name = given_layer_name + self.layer_sync_flow._old_layer_version = given_layer_version + + self.layer_sync_flow._delete_old_layer_version() + + given_lambda_client.delete_layer_version.assert_called_with( + LayerName=given_layer_name, VersionNumber=given_layer_version + ) + + @patch("samcli.lib.sync.flows.layer_sync_flow.os") + @patch("samcli.lib.sync.flows.layer_sync_flow.SamFunctionProvider") + @patch("samcli.lib.sync.flows.layer_sync_flow.FunctionLayerReferenceSync") + def test_gather_dependencies(self, patched_function_ref_sync, patched_function_provider, os_mock): + self.layer_sync_flow._new_layer_version = "given_new_layer_version_arn" + + given_function_provider = Mock() + patched_function_provider.return_value = given_function_provider + + mock_some_random_layer = PropertyMock() + mock_some_random_layer.full_path = "SomeRandomLayer" + + mock_given_layer = PropertyMock() + mock_given_layer.full_path = self.layer_identifier + + mock_some_nested_layer = PropertyMock() + mock_some_nested_layer.full_path = "NestedStack1/" + self.layer_identifier + + mock_function_a = PropertyMock(layers=[mock_some_random_layer]) + mock_function_a.full_path = "FunctionA" + + mock_function_b = PropertyMock(layers=[mock_given_layer]) + mock_function_b.full_path = "FunctionB" + + mock_function_c = PropertyMock(layers=[mock_some_nested_layer]) + mock_function_c.full_path = "NestedStack1/FunctionC" + + given_layers = [ + mock_function_a, + mock_function_b, + mock_function_c, + ] + given_function_provider.get_all.return_value = given_layers + + self.layer_sync_flow._stacks = Mock() + + given_layer_physical_name = Mock() + self.layer_sync_flow._layer_physical_name = given_layer_physical_name + + self.layer_sync_flow._zip_file = Mock() + + dependencies = self.layer_sync_flow.gather_dependencies() + + patched_function_ref_sync.assert_called_once_with( + "FunctionB", + given_layer_physical_name, + self.layer_sync_flow._new_layer_version, + self.layer_sync_flow._build_context, + self.layer_sync_flow._deploy_context, + self.layer_sync_flow._physical_id_mapping, + self.layer_sync_flow._stacks, + ) + + self.assertEqual(len(dependencies), 1) + + @patch("samcli.lib.sync.flows.layer_sync_flow.os") + @patch("samcli.lib.sync.flows.layer_sync_flow.SamFunctionProvider") + @patch("samcli.lib.sync.flows.layer_sync_flow.FunctionLayerReferenceSync") + def test_gather_dependencies_nested_stack(self, patched_function_ref_sync, patched_function_provider, os_mock): + self.layer_identifier = "NestedStack1/Layer1" + self.layer_sync_flow._layer_identifier = "NestedStack1/Layer1" + self.layer_sync_flow._new_layer_version = "given_new_layer_version_arn" + + given_function_provider = Mock() + patched_function_provider.return_value = given_function_provider + + mock_some_random_layer = PropertyMock() + mock_some_random_layer.full_path = "Layer1" + + mock_given_layer = PropertyMock() + mock_given_layer.full_path = self.layer_identifier + + mock_some_nested_layer = PropertyMock() + mock_some_nested_layer.full_path = "NestedStack1/Layer2" + + mock_function_a = PropertyMock(layers=[mock_some_random_layer]) + mock_function_a.full_path = "FunctionA" + + mock_function_b = PropertyMock(layers=[mock_given_layer]) + mock_function_b.full_path = "NestedStack1/FunctionB" + + mock_function_c = PropertyMock(layers=[mock_some_nested_layer]) + mock_function_c.full_path = "NestedStack1/FunctionC" + + given_layers = [ + mock_function_a, + mock_function_b, + mock_function_c, + ] + given_function_provider.get_all.return_value = given_layers + + self.layer_sync_flow._stacks = Mock() + + given_layer_physical_name = Mock() + self.layer_sync_flow._layer_physical_name = given_layer_physical_name + + self.layer_sync_flow._zip_file = Mock() + + dependencies = self.layer_sync_flow.gather_dependencies() + + patched_function_ref_sync.assert_called_once_with( + "NestedStack1/FunctionB", + given_layer_physical_name, + self.layer_sync_flow._new_layer_version, + self.layer_sync_flow._build_context, + self.layer_sync_flow._deploy_context, + self.layer_sync_flow._physical_id_mapping, + self.layer_sync_flow._stacks, + ) + + self.assertEqual(len(dependencies), 1) + + def test_get_latest_layer_version(self): + given_version = Mock() + given_layer_name = Mock() + given_lambda_client = Mock() + given_lambda_client.list_layer_versions.return_value = {"LayerVersions": [{"Version": given_version}]} + self.layer_sync_flow._lambda_client = given_lambda_client + self.layer_sync_flow._layer_physical_name = given_layer_name + + latest_layer_version = self.layer_sync_flow._get_latest_layer_version() + + given_lambda_client.list_layer_versions.assert_called_with(LayerName=given_layer_name) + self.assertEqual(latest_layer_version, given_version) + + def test_get_latest_layer_version_error(self): + given_layer_name = Mock() + given_lambda_client = Mock() + given_lambda_client.list_layer_versions.return_value = {"LayerVersions": []} + self.layer_sync_flow._lambda_client = given_lambda_client + self.layer_sync_flow._layer_physical_name = given_layer_name + + with self.assertRaises(NoLayerVersionsFoundError): + self.layer_sync_flow._get_latest_layer_version() + + def test_equality_keys(self): + self.assertEqual(self.layer_sync_flow._equality_keys(), self.layer_identifier) + + @patch("samcli.lib.sync.flows.layer_sync_flow.ResourceAPICall") + def test_get_resource_api_calls(self, resource_api_call_mock): + result = self.layer_sync_flow._get_resource_api_calls() + self.assertEqual(len(result), 1) + resource_api_call_mock.assert_called_once_with(self.layer_identifier, ["Build"]) + + +class TestFunctionLayerReferenceSync(TestCase): + def setUp(self): + self.function_identifier = "function" + self.layer_name = "Layer1" + self.old_layer_version = 1 + self.new_layer_version = 2 + + self.function_layer_sync = FunctionLayerReferenceSync( + self.function_identifier, self.layer_name, self.new_layer_version, Mock(), Mock(), {}, [] + ) + + def test_setup(self): + with patch.object(self.function_layer_sync, "_session") as patched_session: + with patch.object(SyncFlow, "set_up") as patched_super_setup: + self.function_layer_sync.set_up() + + patched_super_setup.assert_called_once() + patched_session.assert_has_calls( + [ + call.client("lambda"), + ] + ) + + def test_sync(self): + given_lambda_client = Mock() + self.function_layer_sync._lambda_client = given_lambda_client + + other_layer_version_arn = "SomeOtherLayerVersionArn" + given_function_result = {"Configuration": {"Layers": [{"Arn": "Layer1:1"}, {"Arn": other_layer_version_arn}]}} + given_lambda_client.get_function.return_value = given_function_result + + with patch.object(self.function_layer_sync, "get_physical_id") as patched_get_physical_id: + with patch.object(self.function_layer_sync, "_locks") as patched_locks: + given_physical_id = Mock() + patched_get_physical_id.return_value = given_physical_id + + self.function_layer_sync.sync() + + patched_get_physical_id.assert_called_with(self.function_identifier) + + patched_locks.get.assert_called_with( + SyncFlow._get_lock_key( + self.function_identifier, FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION + ) + ) + + given_lambda_client.get_function.assert_called_with(FunctionName=given_physical_id) + + given_lambda_client.update_function_configuration.assert_called_with( + FunctionName=given_physical_id, Layers=[other_layer_version_arn, "Layer1:2"] + ) + + def test_sync_with_existing_new_layer_version_arn(self): + given_lambda_client = Mock() + self.function_layer_sync._lambda_client = given_lambda_client + + given_function_result = {"Configuration": {"Layers": [{"Arn": "Layer1:2"}]}} + given_lambda_client.get_function.return_value = given_function_result + + with patch.object(self.function_layer_sync, "get_physical_id") as patched_get_physical_id: + with patch.object(self.function_layer_sync, "_locks") as patched_locks: + given_physical_id = Mock() + patched_get_physical_id.return_value = given_physical_id + + self.function_layer_sync.sync() + + patched_locks.get.assert_called_with( + SyncFlow._get_lock_key( + self.function_identifier, FunctionLayerReferenceSync.UPDATE_FUNCTION_CONFIGURATION + ) + ) + + patched_get_physical_id.assert_called_with(self.function_identifier) + + given_lambda_client.get_function.assert_called_with(FunctionName=given_physical_id) + + given_lambda_client.update_function_configuration.assert_not_called() + + def test_equality_keys(self): + self.assertEqual( + self.function_layer_sync._equality_keys(), + (self.function_identifier, self.layer_name, self.new_layer_version), + ) diff --git a/tests/unit/lib/sync/flows/test_rest_api_sync_flow.py b/tests/unit/lib/sync/flows/test_rest_api_sync_flow.py new file mode 100644 index 0000000000..ec46083907 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_rest_api_sync_flow.py @@ -0,0 +1,84 @@ +from unittest import TestCase +from unittest.mock import ANY, MagicMock, mock_open, patch + +from samcli.lib.sync.flows.rest_api_sync_flow import RestApiSyncFlow +from samcli.lib.providers.exceptions import MissingLocalDefinition + + +class TestRestApiSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = RestApiSyncFlow( + "Api1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("apigateway") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + sync_flow.gather_resources() + + sync_flow._api_client.put_rest_api.return_value = {"Response": "success"} + + sync_flow.sync() + + sync_flow._api_client.put_rest_api.assert_called_once_with( + restApiId="PhysicalApi1", mode="overwrite", body='{"key": "value"}'.encode("utf-8") + ) + + @patch("samcli.lib.sync.flows.generic_api_sync_flow.get_resource_by_id") + def test_get_definition_file(self, get_resource_mock): + sync_flow = self.create_sync_flow() + + get_resource_mock.return_value = {"Properties": {"DefinitionUri": "test_uri"}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, "test_uri") + + get_resource_mock.return_value = {"Properties": {}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, None) + + def test_process_definition_file(self): + sync_flow = self.create_sync_flow() + sync_flow._definition_uri = "path" + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + data = sync_flow._process_definition_file() + self.assertEqual(data, '{"key": "value"}'.encode("utf-8")) + + @patch("samcli.lib.sync.sync_flow.Session") + def test_failed_gather_resources(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + sync_flow._definition_uri = None + + with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file: + with self.assertRaises(MissingLocalDefinition): + sync_flow.sync() diff --git a/tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py b/tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py new file mode 100644 index 0000000000..d5652a2dc7 --- /dev/null +++ b/tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py @@ -0,0 +1,84 @@ +from samcli.lib.providers.exceptions import MissingLocalDefinition +from unittest import TestCase +from unittest.mock import ANY, MagicMock, mock_open, patch + +from samcli.lib.sync.flows.stepfunctions_sync_flow import StepFunctionsSyncFlow + + +class TestStepFunctionsSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = StepFunctionsSyncFlow( + "StateMachine1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("stepfunctions") + + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalId1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file: + sync_flow.gather_resources() + + sync_flow._stepfunctions_client.update_state_machine.return_value = {"Response": "success"} + + sync_flow.sync() + + sync_flow._stepfunctions_client.update_state_machine.assert_called_once_with( + stateMachineArn="PhysicalId1", definition='{"key": "value"}' + ) + + @patch("samcli.lib.sync.flows.stepfunctions_sync_flow.get_resource_by_id") + def test_get_definition_file(self, get_resource_mock): + sync_flow = self.create_sync_flow() + + get_resource_mock.return_value = {"Properties": {"DefinitionUri": "test_uri"}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, "test_uri") + + get_resource_mock.return_value = {"Properties": {}} + result_uri = sync_flow._get_definition_file("test") + + self.assertEqual(result_uri, None) + + def test_process_definition_file(self): + sync_flow = self.create_sync_flow() + sync_flow._definition_uri = "path" + with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file: + data = sync_flow._process_definition_file() + self.assertEqual(data, '{"key": "value"}') + + @patch("samcli.lib.sync.sync_flow.Session") + def test_failed_gather_resources(self, session_mock): + sync_flow = self.create_sync_flow() + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalApi1" + + sync_flow._get_definition_file = MagicMock() + sync_flow._get_definition_file.return_value = "file.yaml" + + sync_flow.set_up() + sync_flow._definition_uri = None + + with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file: + with self.assertRaises(MissingLocalDefinition): + sync_flow.sync() diff --git a/tests/unit/lib/sync/flows/test_zip_function_sync_flow.py b/tests/unit/lib/sync/flows/test_zip_function_sync_flow.py new file mode 100644 index 0000000000..cf5364780c --- /dev/null +++ b/tests/unit/lib/sync/flows/test_zip_function_sync_flow.py @@ -0,0 +1,176 @@ +import os +import hashlib + +from samcli.lib.sync.sync_flow import SyncFlow +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, mock_open, patch + +from samcli.lib.sync.flows.zip_function_sync_flow import ZipFunctionSyncFlow + + +class TestZipFunctionSyncFlow(TestCase): + def create_function_sync_flow(self): + sync_flow = ZipFunctionSyncFlow( + "Function1", + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + def test_set_up(self, session_mock): + sync_flow = self.create_function_sync_flow() + sync_flow.set_up() + session_mock.return_value.client.assert_any_call("lambda") + session_mock.return_value.client.assert_any_call("s3") + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.hashlib.sha256") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.uuid.uuid4") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.file_checksum") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.make_zip") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.tempfile.gettempdir") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.ApplicationBuilder") + @patch("samcli.lib.sync.sync_flow.Session") + def test_gather_resources( + self, session_mock, builder_mock, gettempdir_mock, make_zip_mock, file_checksum_mock, uuid4_mock, sha256_mock + ): + get_mock = MagicMock() + get_mock.return_value = "ArtifactFolder1" + builder_mock.return_value.build.return_value.get = get_mock + uuid4_mock.return_value.hex = "uuid_value" + gettempdir_mock.return_value = "temp_folder" + make_zip_mock.return_value = "zip_file" + file_checksum_mock.return_value = "sha256_value" + sync_flow = self.create_function_sync_flow() + + sync_flow._get_lock_chain = MagicMock() + + sync_flow.set_up() + sync_flow.gather_resources() + + get_mock.assert_called_once_with("Function1") + self.assertEqual(sync_flow._artifact_folder, "ArtifactFolder1") + make_zip_mock.assert_called_once_with("temp_folder" + os.sep + "data-uuid_value", "ArtifactFolder1") + file_checksum_mock.assert_called_once_with("zip_file", sha256_mock.return_value) + self.assertEqual("sha256_value", sync_flow._local_sha) + sync_flow._get_lock_chain.assert_called_once() + sync_flow._get_lock_chain.return_value.__enter__.assert_called_once() + sync_flow._get_lock_chain.return_value.__exit__.assert_called_once() + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.base64.b64decode") + @patch("samcli.lib.sync.sync_flow.Session") + def test_compare_remote_true(self, session_mock, b64decode_mock): + b64decode_mock.return_value.hex.return_value = "sha256_value" + sync_flow = self.create_function_sync_flow() + sync_flow._local_sha = "sha256_value" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow._lambda_client.get_function.return_value = {"Configuration": {"CodeSha256": "sha256_value_b64"}} + + result = sync_flow.compare_remote() + + sync_flow._lambda_client.get_function.assert_called_once_with(FunctionName="PhysicalFunction1") + b64decode_mock.assert_called_once_with("sha256_value_b64") + self.assertTrue(result) + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.base64.b64decode") + @patch("samcli.lib.sync.sync_flow.Session") + def test_compare_remote_false(self, session_mock, b64decode_mock): + b64decode_mock.return_value.hex.return_value = "sha256_value_2" + sync_flow = self.create_function_sync_flow() + sync_flow._local_sha = "sha256_value" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow._lambda_client.get_function.return_value = {"Configuration": {"CodeSha256": "sha256_value_b64"}} + + result = sync_flow.compare_remote() + + sync_flow._lambda_client.get_function.assert_called_once_with(FunctionName="PhysicalFunction1") + b64decode_mock.assert_called_once_with("sha256_value_b64") + self.assertFalse(result) + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.open", mock_open(read_data=b"zip_content"), create=True) + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.remove") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.exists") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.S3Uploader") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.getsize") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_direct(self, session_mock, getsize_mock, uploader_mock, exists_mock, remove_mock): + getsize_mock.return_value = 49 * 1024 * 1024 + exists_mock.return_value = True + sync_flow = self.create_function_sync_flow() + sync_flow._zip_file = "zip_file" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow.sync() + + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", ZipFile=b"zip_content" + ) + remove_mock.assert_called_once_with("zip_file") + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.open", mock_open(read_data=b"zip_content"), create=True) + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.remove") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.exists") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.S3Uploader") + @patch("samcli.lib.sync.flows.zip_function_sync_flow.os.path.getsize") + @patch("samcli.lib.sync.sync_flow.Session") + def test_sync_s3(self, session_mock, getsize_mock, uploader_mock, exists_mock, remove_mock): + getsize_mock.return_value = 51 * 1024 * 1024 + exists_mock.return_value = True + uploader_mock.return_value.upload_with_dedup.return_value = "s3://bucket_name/bucket/key" + sync_flow = self.create_function_sync_flow() + sync_flow._zip_file = "zip_file" + sync_flow._deploy_context.s3_bucket = "bucket_name" + + sync_flow.get_physical_id = MagicMock() + sync_flow.get_physical_id.return_value = "PhysicalFunction1" + + sync_flow.set_up() + + sync_flow.sync() + + uploader_mock.return_value.upload_with_dedup.assert_called_once_with("zip_file") + + sync_flow._lambda_client.update_function_code.assert_called_once_with( + FunctionName="PhysicalFunction1", S3Bucket="bucket_name", S3Key="bucket/key" + ) + remove_mock.assert_called_once_with("zip_file") + + @patch("samcli.lib.sync.flows.zip_function_sync_flow.ResourceAPICall") + def test_get_resource_api_calls(self, resource_api_call_mock): + build_context = MagicMock() + layer1 = MagicMock() + layer2 = MagicMock() + layer1.full_path = "Layer1" + layer2.full_path = "Layer2" + function_mock = MagicMock() + function_mock.layers = [layer1, layer2] + build_context.function_provider.functions.get.return_value = function_mock + sync_flow = ZipFunctionSyncFlow( + "Function1", + build_context=build_context, + deploy_context=MagicMock(), + physical_id_mapping={}, + stacks=[MagicMock()], + ) + + result = sync_flow._get_resource_api_calls() + self.assertEqual(len(result), 2) + resource_api_call_mock.assert_any_call("Layer1", ["Build"]) + resource_api_call_mock.assert_any_call("Layer2", ["Build"]) diff --git a/tests/unit/lib/sync/test_continuous_sync_flow_executor.py b/tests/unit/lib/sync/test_continuous_sync_flow_executor.py new file mode 100644 index 0000000000..d9c526abfe --- /dev/null +++ b/tests/unit/lib/sync/test_continuous_sync_flow_executor.py @@ -0,0 +1,144 @@ +from multiprocessing.managers import ValueProxy +from queue import Queue +from samcli.lib.sync.continuous_sync_flow_executor import ContinuousSyncFlowExecutor, DelayedSyncFlowTask +from samcli.lib.sync.sync_flow import SyncFlow + +from botocore.exceptions import ClientError +from samcli.lib.sync.exceptions import ( + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, +) +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.sync_flow_executor import ( + SyncFlowExecutor, + SyncFlowResult, + SyncFlowTask, + default_exception_handler, + HELP_TEXT_FOR_SYNC_INFRA, +) + + +class TestContinuousSyncFlowExecutor(TestCase): + def setUp(self): + self.thread_pool_executor_patch = patch("samcli.lib.sync.sync_flow_executor.ThreadPoolExecutor") + self.thread_pool_executor_mock = self.thread_pool_executor_patch.start() + self.thread_pool_executor = self.thread_pool_executor_mock.return_value + self.thread_pool_executor.__enter__.return_value = self.thread_pool_executor + self.lock_distributor_patch = patch("samcli.lib.sync.sync_flow_executor.LockDistributor") + self.lock_distributor_mock = self.lock_distributor_patch.start() + self.lock_distributor = self.lock_distributor_mock.return_value + self.executor = ContinuousSyncFlowExecutor() + + def tearDown(self) -> None: + self.thread_pool_executor_patch.stop() + self.lock_distributor_patch.stop() + + @patch("samcli.lib.sync.continuous_sync_flow_executor.time.time") + @patch("samcli.lib.sync.continuous_sync_flow_executor.DelayedSyncFlowTask") + def test_add_delayed_sync_flow(self, task_mock, time_mock): + add_sync_flow_task_mock = MagicMock() + task = MagicMock() + task_mock.return_value = task + time_mock.return_value = 1000 + self.executor._add_sync_flow_task = add_sync_flow_task_mock + sync_flow = MagicMock() + + self.executor.add_delayed_sync_flow(sync_flow, False, 15) + + task_mock.assert_called_once_with(sync_flow, False, 1000, 15) + add_sync_flow_task_mock.assert_called_once_with(task) + + def test_add_sync_flow_task(self): + sync_flow = MagicMock() + task = DelayedSyncFlowTask(sync_flow, False, 1000, 15) + + self.executor._add_sync_flow_task(task) + + sync_flow.set_locks_with_distributor.assert_called_once_with(self.executor._lock_distributor) + + queue_task = self.executor._flow_queue.get() + self.assertEqual(sync_flow, queue_task.sync_flow) + + def test_stop_without_manager(self): + self.executor.stop() + self.assertTrue(self.executor._stop_flag) + + def test_should_stop_without_manager(self): + self.executor._stop_flag = True + self.assertTrue(self.executor.should_stop()) + + @patch("samcli.lib.sync.continuous_sync_flow_executor.time.time") + @patch("samcli.lib.sync.sync_flow_executor.time.sleep") + def test_execute_high_level_logic(self, sleep_mock, time_mock): + exception_handler_mock = MagicMock() + time_mock.return_value = 1001 + + flow1 = MagicMock() + flow2 = MagicMock() + flow3 = MagicMock() + + task1 = DelayedSyncFlowTask(flow1, False, 1000, 0) + task2 = DelayedSyncFlowTask(flow2, False, 1000, 0) + task3 = DelayedSyncFlowTask(flow3, False, 1000, 0) + + result1 = SyncFlowResult(flow1, [flow3]) + + future1 = MagicMock() + future2 = MagicMock() + future3 = MagicMock() + + exception1 = MagicMock(spec=Exception) + sync_flow_exception = MagicMock(spec=SyncFlowException) + sync_flow_exception.sync_flow = flow2 + sync_flow_exception.exception = exception1 + + future1.done.side_effect = [False, False, True] + future1.exception.return_value = None + future1.result.return_value = result1 + + future2.done.side_effect = [False, False, False, True] + future2.exception.return_value = sync_flow_exception + + future3.done.side_effect = [False, False, False, True] + future3.exception.return_value = None + + self.thread_pool_executor.submit = MagicMock() + self.thread_pool_executor.submit.side_effect = [future1, future2, future3] + + self.executor._flow_queue.put(task1) + self.executor._flow_queue.put(task2) + + self.executor.add_sync_flow = MagicMock() + self.executor.add_sync_flow.side_effect = lambda x: self.executor._flow_queue.put(task3) + + self.executor.should_stop = MagicMock() + self.executor.should_stop.side_effect = [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + ] + + self.executor.execute(exception_handler=exception_handler_mock) + + self.thread_pool_executor.submit.assert_has_calls( + [ + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow1), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow2), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow3), + ] + ) + self.executor.add_sync_flow.assert_called_once_with(flow3) + + exception_handler_mock.assert_called_once_with(sync_flow_exception) + self.assertEqual(len(sleep_mock.mock_calls), 10) diff --git a/tests/unit/lib/sync/test_exceptions.py b/tests/unit/lib/sync/test_exceptions.py new file mode 100644 index 0000000000..363e5d9b25 --- /dev/null +++ b/tests/unit/lib/sync/test_exceptions.py @@ -0,0 +1,34 @@ +from unittest import TestCase +from unittest.mock import MagicMock +from samcli.lib.sync.exceptions import ( + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, +) + + +class TestSyncFlowException(TestCase): + def test_exception(self): + sync_flow_mock = MagicMock() + exception_mock = MagicMock() + exception = SyncFlowException(sync_flow_mock, exception_mock) + self.assertEqual(exception.sync_flow, sync_flow_mock) + self.assertEqual(exception.exception, exception_mock) + + +class TestMissingPhysicalResourceError(TestCase): + def test_exception(self): + exception = MissingPhysicalResourceError("A") + self.assertEqual(exception.resource_identifier, "A") + + def test_exception_with_mapping(self): + physical_mapping = MagicMock() + exception = MissingPhysicalResourceError("A", physical_mapping) + self.assertEqual(exception.resource_identifier, "A") + self.assertEqual(exception.physical_resource_mapping, physical_mapping) + + +class TestNoLayerVersionsFoundError(TestCase): + def test_exception(self): + exception = NoLayerVersionsFoundError("layer_name_arn") + self.assertEqual(exception.layer_name_arn, "layer_name_arn") diff --git a/tests/unit/lib/sync/test_sync_flow.py b/tests/unit/lib/sync/test_sync_flow.py new file mode 100644 index 0000000000..caca763eb8 --- /dev/null +++ b/tests/unit/lib/sync/test_sync_flow.py @@ -0,0 +1,119 @@ +from samcli.lib.providers.provider import ResourceIdentifier +from unittest import TestCase +from unittest.mock import MagicMock, call, patch + +from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall +from samcli.lib.utils.lock_distributor import LockChain + + +class TestSyncFlow(TestCase): + def create_sync_flow(self): + sync_flow = SyncFlow( + build_context=MagicMock(), + deploy_context=MagicMock(), + physical_id_mapping={}, + log_name="log-name", + stacks=[MagicMock()], + ) + sync_flow.gather_resources = MagicMock() + sync_flow.compare_remote = MagicMock() + sync_flow.sync = MagicMock() + sync_flow.gather_dependencies = MagicMock() + sync_flow._get_resource_api_calls = MagicMock() + return sync_flow + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_execute_all_steps(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.compare_remote.return_value = False + sync_flow.gather_dependencies.return_value = ["A"] + result = sync_flow.execute() + + sync_flow.gather_resources.assert_called_once() + sync_flow.compare_remote.assert_called_once() + sync_flow.sync.assert_called_once() + sync_flow.gather_dependencies.assert_called_once() + self.assertEqual(result, ["A"]) + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_execute_skip_after_compare(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.compare_remote.return_value = True + sync_flow.gather_dependencies.return_value = ["A"] + result = sync_flow.execute() + + sync_flow.gather_resources.assert_called_once() + sync_flow.compare_remote.assert_called_once() + sync_flow.sync.assert_not_called() + sync_flow.gather_dependencies.assert_not_called() + self.assertEqual(result, []) + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_set_up(self, session_mock): + sync_flow = self.create_sync_flow() + sync_flow.set_up() + session_mock.assert_called_once() + self.assertIsNotNone(sync_flow._session) + + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_set_locks_with_distributor(self, session_mock): + sync_flow = self.create_sync_flow() + distributor = MagicMock() + locks = {"A": 1, "B": 2} + distributor.get_locks.return_value = locks + sync_flow.set_locks_with_distributor(distributor) + self.assertEqual(locks, sync_flow._locks) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_get_lock_keys(self): + sync_flow = self.create_sync_flow() + sync_flow._get_resource_api_calls.return_value = [ResourceAPICall("A", "1"), ResourceAPICall("B", "2")] + result = sync_flow.get_lock_keys() + self.assertEqual(result, ["A_1", "B_2"]) + + @patch("samcli.lib.sync.sync_flow.LockChain") + @patch("samcli.lib.sync.sync_flow.Session") + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_get_lock_chain(self, session_mock, lock_chain_mock): + sync_flow = self.create_sync_flow() + locks = {"A": 1, "B": 2} + sync_flow._locks = locks + result = sync_flow._get_lock_chain() + lock_chain_mock.assert_called_once_with(locks) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_log_prefix(self): + sync_flow = self.create_sync_flow() + sync_flow._log_name = "A" + self.assertEqual(sync_flow.log_prefix, "SyncFlow [A]: ") + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_eq_true(self): + sync_flow_1 = self.create_sync_flow() + sync_flow_1._equality_keys = MagicMock() + sync_flow_1._equality_keys.return_value = "A" + sync_flow_2 = self.create_sync_flow() + sync_flow_2._equality_keys = MagicMock() + sync_flow_2._equality_keys.return_value = "A" + self.assertTrue(sync_flow_1 == sync_flow_2) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_eq_false(self): + sync_flow_1 = self.create_sync_flow() + sync_flow_1._equality_keys = MagicMock() + sync_flow_1._equality_keys.return_value = "A" + sync_flow_2 = self.create_sync_flow() + sync_flow_2._equality_keys = MagicMock() + sync_flow_2._equality_keys.return_value = "B" + self.assertFalse(sync_flow_1 == sync_flow_2) + + @patch.multiple(SyncFlow, __abstractmethods__=set()) + def test_hash(self): + sync_flow = self.create_sync_flow() + sync_flow._equality_keys = MagicMock() + sync_flow._equality_keys.return_value = "A" + self.assertEqual(hash(sync_flow), hash((type(sync_flow), "A"))) diff --git a/tests/unit/lib/sync/test_sync_flow_executor.py b/tests/unit/lib/sync/test_sync_flow_executor.py new file mode 100644 index 0000000000..2b40595e42 --- /dev/null +++ b/tests/unit/lib/sync/test_sync_flow_executor.py @@ -0,0 +1,204 @@ +from multiprocessing.managers import ValueProxy +from queue import Queue +from samcli.lib.sync.sync_flow import SyncFlow + +from botocore.exceptions import ClientError +from samcli.lib.sync.exceptions import ( + MissingPhysicalResourceError, + NoLayerVersionsFoundError, + SyncFlowException, +) +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + +from samcli.lib.sync.sync_flow_executor import ( + SyncFlowExecutor, + SyncFlowResult, + SyncFlowTask, + default_exception_handler, + HELP_TEXT_FOR_SYNC_INFRA, +) + + +class TestSyncFlowExecutor(TestCase): + def setUp(self): + self.thread_pool_executor_patch = patch("samcli.lib.sync.sync_flow_executor.ThreadPoolExecutor") + self.thread_pool_executor_mock = self.thread_pool_executor_patch.start() + self.thread_pool_executor = self.thread_pool_executor_mock.return_value + self.thread_pool_executor.__enter__.return_value = self.thread_pool_executor + self.lock_distributor_patch = patch("samcli.lib.sync.sync_flow_executor.LockDistributor") + self.lock_distributor_mock = self.lock_distributor_patch.start() + self.lock_distributor = self.lock_distributor_mock.return_value + self.executor = SyncFlowExecutor() + + def tearDown(self) -> None: + self.thread_pool_executor_patch.stop() + self.lock_distributor_patch.stop() + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_missing_physical_resource_error(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=MissingPhysicalResourceError) + exception.resource_identifier = "Resource1" + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_called_once_with( + "Cannot find resource %s in remote.%s", "Resource1", HELP_TEXT_FOR_SYNC_INFRA + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_client_error_valid(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=ClientError) + exception.resource_identifier = "Resource1" + exception.response = {"Error": {"Code": "ResourceNotFoundException", "Message": "MessageContent"}} + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_has_calls( + [call("Cannot find resource in remote.%s", HELP_TEXT_FOR_SYNC_INFRA), call("MessageContent")] + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_no_layer_versions_found(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=NoLayerVersionsFoundError) + exception.layer_name_arn = "layer_name" + sync_flow_exception.exception = exception + + default_exception_handler(sync_flow_exception) + log_mock.error.assert_has_calls( + [ + call( + "Cannot find any versions for layer %s.%s", + exception.layer_name_arn, + HELP_TEXT_FOR_SYNC_INFRA, + ) + ] + ) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_client_error_invalid_code(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = ClientError({"Error": {"Code": "RandomException", "Message": "MessageContent"}}, "") + exception.resource_identifier = "Resource1" + sync_flow_exception.exception = exception + with self.assertRaises(ClientError): + default_exception_handler(sync_flow_exception) + + @patch("samcli.lib.sync.sync_flow_executor.LOG") + def test_default_exception_handler_client_error_invalid_exception(self, log_mock): + sync_flow_exception = MagicMock(spec=SyncFlowException) + + class RandomException(Exception): + pass + + exception = RandomException() + exception.resource_identifier = "Resource1" + sync_flow_exception.exception = exception + with self.assertRaises(RandomException): + default_exception_handler(sync_flow_exception) + + @patch("samcli.lib.sync.sync_flow_executor.time.time") + @patch("samcli.lib.sync.sync_flow_executor.SyncFlowTask") + def test_add_sync_flow(self, task_mock, time_mock): + add_sync_flow_task_mock = MagicMock() + task = MagicMock() + task_mock.return_value = task + time_mock.return_value = 1000 + self.executor._add_sync_flow_task = add_sync_flow_task_mock + sync_flow = MagicMock() + + self.executor.add_sync_flow(sync_flow, False) + + task_mock.assert_called_once_with(sync_flow, False) + add_sync_flow_task_mock.assert_called_once_with(task) + + def test_add_sync_flow_task(self): + sync_flow = MagicMock() + task = SyncFlowTask(sync_flow, False) + + self.executor._add_sync_flow_task(task) + + sync_flow.set_locks_with_distributor.assert_called_once_with(self.executor._lock_distributor) + + queue_task = self.executor._flow_queue.get() + self.assertEqual(sync_flow, queue_task.sync_flow) + + def test_add_sync_flow_task_dedup(self): + sync_flow = MagicMock() + + task1 = SyncFlowTask(sync_flow, True) + task2 = SyncFlowTask(sync_flow, True) + + self.executor._add_sync_flow_task(task1) + self.executor._add_sync_flow_task(task2) + + sync_flow.set_locks_with_distributor.assert_called_once_with(self.executor._lock_distributor) + + queue_task = self.executor._flow_queue.get() + self.assertEqual(sync_flow, queue_task.sync_flow) + self.assertTrue(self.executor._flow_queue.empty()) + + def test_is_running_without_manager(self): + self.executor._running_flag = True + self.assertTrue(self.executor.is_running()) + + @patch("samcli.lib.sync.sync_flow_executor.time.time") + @patch("samcli.lib.sync.sync_flow_executor.time.sleep") + def test_execute_high_level_logic(self, sleep_mock, time_mock): + exception_handler_mock = MagicMock() + time_mock.return_value = 1001 + + flow1 = MagicMock() + flow2 = MagicMock() + flow3 = MagicMock() + + task1 = SyncFlowTask(flow1, False) + task2 = SyncFlowTask(flow2, False) + task3 = SyncFlowTask(flow3, False) + + result1 = SyncFlowResult(flow1, [flow3]) + + future1 = MagicMock() + future2 = MagicMock() + future3 = MagicMock() + + exception1 = MagicMock(spec=Exception) + sync_flow_exception = MagicMock(spec=SyncFlowException) + sync_flow_exception.sync_flow = flow2 + sync_flow_exception.exception = exception1 + + future1.done.side_effect = [False, False, True] + future1.exception.return_value = None + future1.result.return_value = result1 + + future2.done.side_effect = [False, False, False, True] + future2.exception.return_value = sync_flow_exception + + future3.done.side_effect = [False, False, False, True] + future3.exception.return_value = None + + self.thread_pool_executor.submit = MagicMock() + self.thread_pool_executor.submit.side_effect = [future1, future2, future3] + + self.executor._flow_queue.put(task1) + self.executor._flow_queue.put(task2) + + self.executor.add_sync_flow = MagicMock() + self.executor.add_sync_flow.side_effect = lambda x: self.executor._flow_queue.put(task3) + + self.executor.execute(exception_handler=exception_handler_mock) + + self.thread_pool_executor.submit.assert_has_calls( + [ + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow1), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow2), + call(SyncFlowExecutor._sync_flow_execute_wrapper, flow3), + ] + ) + self.executor.add_sync_flow.assert_called_once_with(flow3) + + exception_handler_mock.assert_called_once_with(sync_flow_exception) + self.assertEqual(len(sleep_mock.mock_calls), 6) diff --git a/tests/unit/lib/sync/test_sync_flow_factory.py b/tests/unit/lib/sync/test_sync_flow_factory.py new file mode 100644 index 0000000000..db6254ab2e --- /dev/null +++ b/tests/unit/lib/sync/test_sync_flow_factory.py @@ -0,0 +1,93 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from samcli.lib.sync.sync_flow_factory import SyncFlowFactory + + +class TestSyncFlowFactory(TestCase): + def create_factory(self): + factory = SyncFlowFactory( + build_context=MagicMock(), deploy_context=MagicMock(), stacks=[MagicMock(), MagicMock()] + ) + return factory + + @patch("samcli.lib.sync.sync_flow_factory.get_physical_id_mapping") + @patch("samcli.lib.sync.sync_flow_factory.get_boto_resource_provider_with_config") + def test_load_physical_id_mapping(self, get_boto_resource_provider_mock, get_physical_id_mapping_mock): + get_physical_id_mapping_mock.return_value = {"Resource1": "PhysicalResource1", "Resource2": "PhysicalResource2"} + + factory = self.create_factory() + factory.load_physical_id_mapping() + + self.assertEqual(len(factory._physical_id_mapping), 2) + self.assertEqual( + factory._physical_id_mapping, {"Resource1": "PhysicalResource1", "Resource2": "PhysicalResource2"} + ) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + def test_create_lambda_flow_zip(self, zip_function_mock, image_function_mock): + factory = self.create_factory() + resource = {"Properties": {"PackageType": "Zip"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, zip_function_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + def test_create_lambda_flow_image(self, zip_function_mock, image_function_mock): + factory = self.create_factory() + resource = {"Properties": {"PackageType": "Image"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, image_function_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.LayerSyncFlow") + def test_create_layer_flow(self, layer_sync_mock): + factory = self.create_factory() + result = factory._create_layer_flow("Layer1", {}) + self.assertEqual(result, layer_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.ImageFunctionSyncFlow") + @patch("samcli.lib.sync.sync_flow_factory.ZipFunctionSyncFlow") + def test_create_lambda_flow_other(self, zip_function_mock, image_function_mock): + factory = self.create_factory() + resource = {"Properties": {"PackageType": "Other"}} + result = factory._create_lambda_flow("Function1", resource) + self.assertEqual(result, None) + + @patch("samcli.lib.sync.sync_flow_factory.RestApiSyncFlow") + def test_create_rest_api_flow(self, rest_api_sync_mock): + factory = self.create_factory() + result = factory._create_rest_api_flow("API1", {}) + self.assertEqual(result, rest_api_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.HttpApiSyncFlow") + def test_create_api_flow(self, http_api_sync_mock): + factory = self.create_factory() + result = factory._create_api_flow("API1", {}) + self.assertEqual(result, http_api_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.StepFunctionsSyncFlow") + def test_create_stepfunctions_flow(self, stepfunctions_sync_mock): + factory = self.create_factory() + result = factory._create_stepfunctions_flow("StateMachine1", {}) + self.assertEqual(result, stepfunctions_sync_mock.return_value) + + @patch("samcli.lib.sync.sync_flow_factory.get_resource_by_id") + def test_create_sync_flow(self, get_resource_by_id_mock): + factory = self.create_factory() + + sync_flow = MagicMock() + resource_identifier = MagicMock() + get_resource_by_id = MagicMock() + get_resource_by_id_mock.return_value = get_resource_by_id + generator_mock = MagicMock() + generator_mock.return_value = sync_flow + + get_generator_function_mock = MagicMock() + get_generator_function_mock.return_value = generator_mock + factory._get_generator_function = get_generator_function_mock + + result = factory.create_sync_flow(resource_identifier) + + self.assertEqual(result, sync_flow) + generator_mock.assert_called_once_with(factory, resource_identifier, get_resource_by_id) diff --git a/tests/unit/lib/sync/test_watch_manager.py b/tests/unit/lib/sync/test_watch_manager.py new file mode 100644 index 0000000000..3bac6f9d75 --- /dev/null +++ b/tests/unit/lib/sync/test_watch_manager.py @@ -0,0 +1,237 @@ +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.sync.watch_manager import WatchManager +from samcli.lib.providers.exceptions import MissingCodeUri, MissingLocalDefinition +from samcli.lib.sync.exceptions import MissingPhysicalResourceError, SyncFlowException + + +class TestWatchManager(TestCase): + def setUp(self) -> None: + self.template = "template.yaml" + self.path_observer_patch = patch("samcli.lib.sync.watch_manager.HandlerObserver") + self.path_observer_mock = self.path_observer_patch.start() + self.path_observer = self.path_observer_mock.return_value + self.executor_patch = patch("samcli.lib.sync.watch_manager.ContinuousSyncFlowExecutor") + self.executor_mock = self.executor_patch.start() + self.executor = self.executor_mock.return_value + self.colored_patch = patch("samcli.lib.sync.watch_manager.Colored") + self.colored_mock = self.colored_patch.start() + self.colored = self.colored_mock.return_value + self.build_context = MagicMock() + self.package_context = MagicMock() + self.deploy_context = MagicMock() + self.watch_manager = WatchManager(self.template, self.build_context, self.package_context, self.deploy_context) + + def tearDown(self) -> None: + self.path_observer_patch.stop() + self.executor_patch.stop() + self.colored_patch.stop() + + def test_queue_infra_sync(self): + self.assertFalse(self.watch_manager._waiting_infra_sync) + self.watch_manager.queue_infra_sync() + self.assertTrue(self.watch_manager._waiting_infra_sync) + + @patch("samcli.lib.sync.watch_manager.SamLocalStackProvider.get_stacks") + @patch("samcli.lib.sync.watch_manager.SyncFlowFactory") + @patch("samcli.lib.sync.watch_manager.CodeTriggerFactory") + def test_update_stacks( + self, trigger_factory_mock: MagicMock, sync_flow_factory_mock: MagicMock, get_stacks_mock: MagicMock + ): + stacks = [MagicMock()] + get_stacks_mock.return_value = [ + stacks, + ] + self.watch_manager._update_stacks() + get_stacks_mock.assert_called_once_with(self.template) + sync_flow_factory_mock.assert_called_once_with(self.build_context, self.deploy_context, stacks) + sync_flow_factory_mock.return_value.load_physical_id_mapping.assert_called_once_with() + trigger_factory_mock.assert_called_once_with(stacks) + + @patch("samcli.lib.sync.watch_manager.get_all_resource_ids") + def test_add_code_triggers(self, get_all_resource_ids_mock): + resource_ids = [MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock()] + get_all_resource_ids_mock.return_value = resource_ids + + trigger_1 = MagicMock() + trigger_2 = MagicMock() + + trigger_factory = MagicMock() + trigger_factory.create_trigger.side_effect = [ + trigger_1, + None, + MissingCodeUri(), + trigger_2, + MissingLocalDefinition(MagicMock(), MagicMock()), + ] + self.watch_manager._stacks = [MagicMock()] + self.watch_manager._trigger_factory = trigger_factory + + on_code_change_wrapper_mock = MagicMock() + self.watch_manager._on_code_change_wrapper = on_code_change_wrapper_mock + + self.watch_manager._add_code_triggers() + + trigger_factory.create_trigger.assert_any_call(resource_ids[0], on_code_change_wrapper_mock.return_value) + trigger_factory.create_trigger.assert_any_call(resource_ids[1], on_code_change_wrapper_mock.return_value) + + on_code_change_wrapper_mock.assert_any_call(resource_ids[0]) + on_code_change_wrapper_mock.assert_any_call(resource_ids[1]) + + self.path_observer.schedule_handlers.assert_any_call(trigger_1.get_path_handlers.return_value) + self.path_observer.schedule_handlers.assert_any_call(trigger_2.get_path_handlers.return_value) + self.assertEqual(self.path_observer.schedule_handlers.call_count, 2) + + @patch("samcli.lib.sync.watch_manager.TemplateTrigger") + def test_add_template_trigger(self, template_trigger_mock): + trigger = template_trigger_mock.return_value + + self.watch_manager._add_template_trigger() + + template_trigger_mock.assert_called_once_with(self.template, ANY) + self.path_observer.schedule_handlers.assert_any_call(trigger.get_path_handlers.return_value) + + def test_execute_infra_sync(self): + self.watch_manager._execute_infra_context() + self.build_context.set_up.assert_called_once_with() + self.build_context.run.assert_called_once_with() + self.package_context.run.assert_called_once_with() + self.deploy_context.run.assert_called_once_with() + + @patch("samcli.lib.sync.watch_manager.threading.Thread") + def test_start_code_sync(self, thread_mock): + self.watch_manager._start_code_sync() + thread = thread_mock.return_value + + self.assertEqual(self.watch_manager._executor_thread, thread) + thread.start.assert_called_once_with() + + def test_stop_code_sync(self): + thread = MagicMock() + thread.is_alive.return_value = True + self.watch_manager._executor_thread = thread + + self.watch_manager._stop_code_sync() + + self.executor.stop.assert_called_once_with() + thread.join.assert_called_once_with() + + def test_start(self): + queue_infra_sync_mock = MagicMock() + _start_mock = MagicMock() + stop_code_sync_mock = MagicMock() + + self.watch_manager.queue_infra_sync = queue_infra_sync_mock + self.watch_manager._start = _start_mock + self.watch_manager._stop_code_sync = stop_code_sync_mock + + _start_mock.side_effect = KeyboardInterrupt() + + self.watch_manager.start() + + self.path_observer.stop.assert_called_once_with() + stop_code_sync_mock.assert_called_once_with() + + @patch("samcli.lib.sync.watch_manager.time.sleep") + def test__start(self, sleep_mock): + sleep_mock.side_effect = KeyboardInterrupt() + + stop_code_sync_mock = MagicMock() + execute_infra_sync_mock = MagicMock() + + update_stacks_mock = MagicMock() + add_template_trigger_mock = MagicMock() + add_code_trigger_mock = MagicMock() + start_code_sync_mock = MagicMock() + + self.watch_manager._stop_code_sync = stop_code_sync_mock + self.watch_manager._execute_infra_context = execute_infra_sync_mock + self.watch_manager._update_stacks = update_stacks_mock + self.watch_manager._add_template_trigger = add_template_trigger_mock + self.watch_manager._add_code_triggers = add_code_trigger_mock + self.watch_manager._start_code_sync = start_code_sync_mock + + self.watch_manager._waiting_infra_sync = True + with self.assertRaises(KeyboardInterrupt): + self.watch_manager._start() + + self.path_observer.start.assert_called_once_with() + self.assertFalse(self.watch_manager._waiting_infra_sync) + + stop_code_sync_mock.assert_called_once_with() + execute_infra_sync_mock.assert_called_once_with() + update_stacks_mock.assert_called_once_with() + add_template_trigger_mock.assert_called_once_with() + add_code_trigger_mock.assert_called_once_with() + start_code_sync_mock.assert_called_once_with() + + self.path_observer.unschedule_all.assert_called_once_with() + + self.path_observer.start.assert_called_once_with() + + @patch("samcli.lib.sync.watch_manager.time.sleep") + def test__start_infra_exception(self, sleep_mock): + sleep_mock.side_effect = KeyboardInterrupt() + + stop_code_sync_mock = MagicMock() + execute_infra_sync_mock = MagicMock() + execute_infra_sync_mock.side_effect = Exception() + + update_stacks_mock = MagicMock() + add_template_trigger_mock = MagicMock() + add_code_trigger_mock = MagicMock() + start_code_sync_mock = MagicMock() + + self.watch_manager._stop_code_sync = stop_code_sync_mock + self.watch_manager._execute_infra_context = execute_infra_sync_mock + self.watch_manager._update_stacks = update_stacks_mock + self.watch_manager._add_template_trigger = add_template_trigger_mock + self.watch_manager._add_code_triggers = add_code_trigger_mock + self.watch_manager._start_code_sync = start_code_sync_mock + + self.watch_manager._waiting_infra_sync = True + with self.assertRaises(KeyboardInterrupt): + self.watch_manager._start() + + self.path_observer.start.assert_called_once_with() + self.assertFalse(self.watch_manager._waiting_infra_sync) + + stop_code_sync_mock.assert_called_once_with() + execute_infra_sync_mock.assert_called_once_with() + add_template_trigger_mock.assert_called_once_with() + + update_stacks_mock.assert_not_called() + add_code_trigger_mock.assert_not_called() + start_code_sync_mock.assert_not_called() + + self.path_observer.unschedule_all.assert_called_once_with() + + self.path_observer.start.assert_called_once_with() + + def test_on_code_change_wrapper(self): + flow1 = MagicMock() + resource_id_mock = MagicMock() + factory_mock = MagicMock() + + self.watch_manager._sync_flow_factory = factory_mock + factory_mock.create_sync_flow.return_value = flow1 + + callback = self.watch_manager._on_code_change_wrapper(resource_id_mock) + + callback() + + self.executor.add_delayed_sync_flow.assert_any_call(flow1, dedup=True, wait_time=ANY) + + def test_watch_sync_flow_exception_handler_missing_physical(self): + sync_flow = MagicMock() + sync_flow_exception = MagicMock(spec=SyncFlowException) + exception = MagicMock(spec=MissingPhysicalResourceError) + sync_flow_exception.exception = exception + sync_flow_exception.sync_flow = sync_flow + + queue_infra_sync_mock = MagicMock() + self.watch_manager.queue_infra_sync = queue_infra_sync_mock + + self.watch_manager._watch_sync_flow_exception_handler(sync_flow_exception) + + queue_infra_sync_mock.assert_called_once_with() diff --git a/tests/unit/lib/utils/test_boto_utils.py b/tests/unit/lib/utils/test_boto_utils.py new file mode 100644 index 0000000000..35d45b77f6 --- /dev/null +++ b/tests/unit/lib/utils/test_boto_utils.py @@ -0,0 +1,75 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +from parameterized import parameterized + +from samcli.lib.utils.boto_utils import ( + get_boto_config_with_user_agent, + get_boto_client_provider_with_config, + get_boto_resource_provider_with_config, +) + +TEST_VERSION = "1.0.0" + + +class TestBotoUtils(TestCase): + @parameterized.expand([(True,), (False,)]) + @patch("samcli.lib.utils.boto_utils.GlobalConfig") + @patch("samcli.lib.utils.boto_utils.__version__", TEST_VERSION) + def test_get_boto_config_with_user_agent( + self, + telemetry_enabled, + patched_global_config, + ): + given_global_config_instance = Mock() + patched_global_config.return_value = given_global_config_instance + + given_global_config_instance.telemetry_enabled = telemetry_enabled + given_region_name = "us-west-2" + + config = get_boto_config_with_user_agent(region_name=given_region_name) + + self.assertEqual(given_region_name, config.region_name) + + if telemetry_enabled: + self.assertEqual( + config.user_agent_extra, f"aws-sam-cli/{TEST_VERSION}/{given_global_config_instance.installation_id}" + ) + else: + self.assertEqual(config.user_agent_extra, f"aws-sam-cli/{TEST_VERSION}") + + @patch("samcli.lib.utils.boto_utils.get_boto_config_with_user_agent") + @patch("samcli.lib.utils.boto_utils.boto3") + def test_get_boto_client_provider_with_config(self, patched_boto3, patched_get_config): + given_config = Mock() + patched_get_config.return_value = given_config + + given_config_param = Mock() + client_generator = get_boto_client_provider_with_config(param=given_config_param) + + given_service_client = Mock() + patched_boto3.session.Session().client.return_value = given_service_client + + client = client_generator("service") + + self.assertEqual(client, given_service_client) + patched_get_config.assert_called_with(param=given_config_param) + patched_boto3.session.Session().client.assert_called_with("service", config=given_config) + + @patch("samcli.lib.utils.boto_utils.get_boto_config_with_user_agent") + @patch("samcli.lib.utils.boto_utils.boto3") + def test_get_boto_resource_provider_with_config(self, patched_boto3, patched_get_config): + given_config = Mock() + patched_get_config.return_value = given_config + + given_config_param = Mock() + client_generator = get_boto_resource_provider_with_config(param=given_config_param) + + given_service_client = Mock() + patched_boto3.session.Session().resource.return_value = given_service_client + + client = client_generator("service") + + self.assertEqual(client, given_service_client) + patched_get_config.assert_called_with(param=given_config_param) + patched_boto3.session.Session().resource.assert_called_with("service", config=given_config) diff --git a/tests/unit/lib/utils/test_cloudformation.py b/tests/unit/lib/utils/test_cloudformation.py new file mode 100644 index 0000000000..f925e295ba --- /dev/null +++ b/tests/unit/lib/utils/test_cloudformation.py @@ -0,0 +1,119 @@ +from unittest import TestCase +from unittest.mock import patch, Mock, ANY + +from botocore.exceptions import ClientError + +from samcli.lib.utils.cloudformation import ( + CloudFormationResourceSummary, + get_physical_id_mapping, + get_resource_summaries, + get_resource_summary, +) + + +class TestCloudFormationResourceSummary(TestCase): + def test_cfn_resource_summary(self): + given_type = "type" + given_logical_id = "logical_id" + given_physical_id = "physical_id" + + resource_summary = CloudFormationResourceSummary(given_type, given_logical_id, given_physical_id) + + self.assertEqual(given_type, resource_summary.resource_type) + self.assertEqual(given_logical_id, resource_summary.logical_resource_id) + self.assertEqual(given_physical_id, resource_summary.physical_resource_id) + + +class TestCloudformationUtils(TestCase): + @patch("samcli.lib.utils.cloudformation.get_resource_summaries") + def test_get_physical_id_mapping(self, patched_get_resource_summaries): + patched_get_resource_summaries.return_value = [ + CloudFormationResourceSummary("", "Logical1", "Physical1"), + CloudFormationResourceSummary("", "Logical2", "Physical2"), + CloudFormationResourceSummary("", "Logical3", "Physical3"), + ] + + given_resource_provider = Mock() + given_resource_types = Mock() + given_stack_name = "stack_name" + physical_id_mapping = get_physical_id_mapping(given_resource_provider, given_stack_name, given_resource_types) + + self.assertEqual( + physical_id_mapping, + { + "Logical1": "Physical1", + "Logical2": "Physical2", + "Logical3": "Physical3", + }, + ) + + patched_get_resource_summaries.assert_called_with( + given_resource_provider, given_stack_name, given_resource_types + ) + + def test_get_resource_summaries(self): + resource_provider_mock = Mock() + given_stack_name = "stack_name" + given_resource_types = {"ResourceType0"} + + given_stack_resource_array = [ + Mock( + physical_resource_id="physical_id_1", logical_resource_id="logical_id_1", resource_type="ResourceType0" + ), + Mock( + physical_resource_id="physical_id_2", logical_resource_id="logical_id_2", resource_type="ResourceType0" + ), + Mock( + physical_resource_id="physical_id_3", logical_resource_id="logical_id_3", resource_type="ResourceType1" + ), + ] + + resource_provider_mock(ANY).Stack(ANY).resource_summaries.all.return_value = given_stack_resource_array + + resource_summaries = get_resource_summaries(resource_provider_mock, given_stack_name, given_resource_types) + + self.assertEqual(len(resource_summaries), 2) + self.assertEqual( + resource_summaries, + [ + CloudFormationResourceSummary("ResourceType0", "logical_id_1", "physical_id_1"), + CloudFormationResourceSummary("ResourceType0", "logical_id_2", "physical_id_2"), + ], + ) + + resource_provider_mock.assert_called_with("cloudformation") + resource_provider_mock(ANY).Stack.assert_called_with(given_stack_name) + resource_provider_mock(ANY).Stack(ANY).resource_summaries.all.assert_called_once() + + def test_get_resource_summary(self): + resource_provider_mock = Mock() + given_stack_name = "stack_name" + given_resource_logical_id = "logical_id_1" + + given_resource_type = "ResourceType0" + given_physical_id = "physical_id_1" + resource_provider_mock(ANY).StackResource.return_value = Mock( + physical_resource_id=given_physical_id, + logical_resource_id=given_resource_logical_id, + resource_type=given_resource_type, + ) + + resource_summary = get_resource_summary(resource_provider_mock, given_stack_name, given_resource_logical_id) + + self.assertEqual(resource_summary.resource_type, given_resource_type) + self.assertEqual(resource_summary.logical_resource_id, given_resource_logical_id) + self.assertEqual(resource_summary.physical_resource_id, given_physical_id) + + resource_provider_mock.assert_called_with("cloudformation") + resource_provider_mock(ANY).StackResource.assert_called_with(given_stack_name, given_resource_logical_id) + + def test_get_resource_summary_fail(self): + resource_provider_mock = Mock() + given_stack_name = "stack_name" + given_resource_logical_id = "logical_id_1" + + resource_provider_mock(ANY).StackResource.side_effect = ClientError({}, "operation") + + resource_summary = get_resource_summary(resource_provider_mock, given_stack_name, given_resource_logical_id) + + self.assertIsNone(resource_summary) diff --git a/tests/unit/lib/utils/test_code_trigger_factory.py b/tests/unit/lib/utils/test_code_trigger_factory.py new file mode 100644 index 0000000000..fc250aae85 --- /dev/null +++ b/tests/unit/lib/utils/test_code_trigger_factory.py @@ -0,0 +1,72 @@ +from parameterized import parameterized +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.code_trigger_factory import CodeTriggerFactory +from samcli.lib.providers.provider import ResourceIdentifier + + +class TestCodeTriggerFactory(TestCase): + def setUp(self): + self.stacks = [MagicMock(), MagicMock()] + self.factory = CodeTriggerFactory(self.stacks) + + @patch("samcli.lib.utils.code_trigger_factory.LambdaZipCodeTrigger") + def test_create_zip_function_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("Function1") + resource = {"Properties": {"PackageType": "Zip"}} + result = self.factory._create_lambda_trigger(resource_identifier, "Type", resource, on_code_change_mock) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.LambdaImageCodeTrigger") + def test_create_image_function_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("Function1") + resource = {"Properties": {"PackageType": "Image"}} + result = self.factory._create_lambda_trigger(resource_identifier, "Type", resource, on_code_change_mock) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.LambdaLayerCodeTrigger") + def test_create_layer_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("Layer1") + result = self.factory._create_layer_trigger(resource_identifier, "Type", {}, on_code_change_mock) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.DefinitionCodeTrigger") + def test_create_definition_trigger(self, trigger_mock): + on_code_change_mock = MagicMock() + resource_identifier = ResourceIdentifier("API1") + resource_type = "AWS::Serverless::Api" + result = self.factory._create_definition_code_trigger( + resource_identifier, resource_type, {}, on_code_change_mock + ) + self.assertEqual(result, trigger_mock.return_value) + trigger_mock.assert_called_once_with(resource_identifier, resource_type, self.stacks, on_code_change_mock) + + @patch("samcli.lib.utils.code_trigger_factory.get_resource_by_id") + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_create_trigger(self, get_resource_by_id_mock, parent_get_resource_by_id_mock): + code_trigger = MagicMock() + resource_identifier = MagicMock() + get_resource_by_id = {"Type": "AWS::Serverless::Api"} + get_resource_by_id_mock.return_value = get_resource_by_id + parent_get_resource_by_id_mock.return_value = get_resource_by_id + generator_mock = MagicMock() + generator_mock.return_value = code_trigger + + on_code_change_mock = MagicMock() + + get_generator_function_mock = MagicMock() + get_generator_function_mock.return_value = generator_mock + self.factory._get_generator_function = get_generator_function_mock + + result = self.factory.create_trigger(resource_identifier, on_code_change_mock) + + self.assertEqual(result, code_trigger) + generator_mock.assert_called_once_with( + self.factory, resource_identifier, "AWS::Serverless::Api", get_resource_by_id, on_code_change_mock + ) diff --git a/tests/unit/lib/utils/test_definition_validator.py b/tests/unit/lib/utils/test_definition_validator.py new file mode 100644 index 0000000000..726c33307b --- /dev/null +++ b/tests/unit/lib/utils/test_definition_validator.py @@ -0,0 +1,61 @@ +from parameterized import parameterized +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.definition_validator import DefinitionValidator + + +class TestDefinitionValidator(TestCase): + def setUp(self) -> None: + self.path = MagicMock() + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_invalid_path(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}] + self.path.exists.return_value = False + + validator = DefinitionValidator(self.path, detect_change=False, initialize_data=False) + self.assertFalse(validator.validate()) + self.assertFalse(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_no_detect_change_valid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}] + + validator = DefinitionValidator(self.path, detect_change=False, initialize_data=False) + self.assertTrue(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_no_detect_change_invalid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [ValueError(), {"A": 1}] + + validator = DefinitionValidator(self.path, detect_change=False, initialize_data=False) + self.assertFalse(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_detect_change_valid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"B": 1}] + + validator = DefinitionValidator(self.path, detect_change=True, initialize_data=False) + self.assertTrue(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_detect_change_invalid(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}, ValueError(), {"B": 1}] + + validator = DefinitionValidator(self.path, detect_change=True, initialize_data=False) + self.assertTrue(validator.validate()) + self.assertFalse(validator.validate()) + self.assertFalse(validator.validate()) + self.assertTrue(validator.validate()) + + @patch("samcli.lib.utils.definition_validator.parse_yaml_file") + def test_detect_change_initialize(self, parse_yaml_file_mock): + parse_yaml_file_mock.side_effect = [{"A": 1}, {"A": 1}, ValueError(), {"B": 1}] + + validator = DefinitionValidator(self.path, detect_change=True, initialize_data=True) + self.assertFalse(validator.validate()) + self.assertFalse(validator.validate()) + self.assertTrue(validator.validate()) diff --git a/tests/unit/lib/utils/test_handler_observer.py b/tests/unit/lib/utils/test_handler_observer.py new file mode 100644 index 0000000000..5d7041b6db --- /dev/null +++ b/tests/unit/lib/utils/test_handler_observer.py @@ -0,0 +1,145 @@ +import re +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.path_observer import HandlerObserver, PathHandler, StaticFolderWrapper + + +class TestPathHandler(TestCase): + def test_init(self): + handler_mock = MagicMock() + path_mock = MagicMock() + create_mock = MagicMock() + delete_mock = MagicMock() + bundle = PathHandler(handler_mock, path_mock, True, True, create_mock, delete_mock) + + self.assertEqual(bundle.event_handler, handler_mock) + self.assertEqual(bundle.path, path_mock) + self.assertEqual(bundle.self_create, create_mock) + self.assertEqual(bundle.self_delete, delete_mock) + self.assertTrue(bundle.recursive) + self.assertTrue(bundle.static_folder) + + +class TestStaticFolderWrapper(TestCase): + def setUp(self) -> None: + self.observer = MagicMock() + self.path_handler = MagicMock() + self.initial_watch = MagicMock() + self.wrapper = StaticFolderWrapper(self.observer, self.initial_watch, self.path_handler) + + def test_on_parent_change_on_delete(self): + watch_mock = MagicMock() + self.wrapper._watch = watch_mock + self.wrapper._path_handler.path.exists.return_value = False + + self.wrapper._on_parent_change(MagicMock()) + + self.path_handler.self_delete.assert_called_once_with() + self.observer.unschedule.assert_called_once_with(watch_mock) + self.assertIsNone(self.wrapper._watch) + + def test_on_parent_change_on_create(self): + watch_mock = MagicMock() + self.observer.schedule_handler.return_value = watch_mock + + self.wrapper._watch = None + self.wrapper._path_handler.path.exists.return_value = True + + self.wrapper._on_parent_change(MagicMock()) + + self.path_handler.self_create.assert_called_once_with() + self.observer.schedule_handler.assert_called_once_with(self.wrapper._path_handler) + self.assertEqual(self.wrapper._watch, watch_mock) + + @patch("samcli.lib.utils.path_observer.RegexMatchingEventHandler") + @patch("samcli.lib.utils.path_observer.PathHandler") + def test_get_dir_parent_path_handler(self, path_handler_mock, event_handler_mock): + path_mock = MagicMock() + path_mock.resolve.return_value.parent = "/parent/" + path_mock.resolve.return_value.__str__.return_value = "/parent/dir/" + self.path_handler.path = path_mock + + event_handler = MagicMock() + event_handler_mock.return_value = event_handler + path_handler = MagicMock() + path_handler_mock.return_value = path_handler + result = self.wrapper.get_dir_parent_path_handler() + + self.assertEqual(result, path_handler) + path_handler_mock.assert_called_once_with(path="/parent/", event_handler=event_handler) + escaped_path = re.escape("/parent/dir/") + event_handler_mock.assert_called_once_with( + regexes=[f"^{escaped_path}$"], ignore_regexes=[], ignore_directories=False, case_sensitive=True + ) + + +class TestHandlerObserver(TestCase): + def setUp(self) -> None: + self.observer = HandlerObserver() + + def test_schedule_handlers(self): + bundle_1 = MagicMock() + bundle_2 = MagicMock() + watch_1 = MagicMock() + watch_2 = MagicMock() + + schedule_handler_mock = MagicMock() + schedule_handler_mock.side_effect = [watch_1, watch_2] + self.observer.schedule_handler = schedule_handler_mock + result = self.observer.schedule_handlers([bundle_1, bundle_2]) + self.assertEqual(result, [watch_1, watch_2]) + schedule_handler_mock.assert_any_call(bundle_1) + schedule_handler_mock.assert_any_call(bundle_2) + + @patch("samcli.lib.utils.path_observer.StaticFolderWrapper") + def test_schedule_handler_not_static(self, wrapper_mock: MagicMock): + bundle = MagicMock() + event_handler = MagicMock() + bundle.event_handler = event_handler + bundle.path = "dir" + bundle.recursive = True + bundle.static_folder = False + watch = MagicMock() + + schedule_mock = MagicMock() + schedule_mock.return_value = watch + self.observer.schedule = schedule_mock + + result = self.observer.schedule_handler(bundle) + + self.assertEqual(result, watch) + schedule_mock.assert_any_call(bundle.event_handler, "dir", True) + wrapper_mock.assert_not_called() + + @patch("samcli.lib.utils.path_observer.StaticFolderWrapper") + def test_schedule_handler_static(self, wrapper_mock: MagicMock): + bundle = MagicMock() + event_handler = MagicMock() + bundle.event_handler = event_handler + bundle.path = "dir" + bundle.recursive = True + bundle.static_folder = True + watch = MagicMock() + + parent_bundle = MagicMock() + event_handler = MagicMock() + parent_bundle.event_handler = event_handler + parent_bundle.path = "parent" + parent_bundle.recursive = False + parent_bundle.static_folder = False + parent_watch = MagicMock() + + schedule_mock = MagicMock() + schedule_mock.side_effect = [watch, parent_watch] + self.observer.schedule = schedule_mock + + wrapper = MagicMock() + wrapper_mock.return_value = wrapper + wrapper.get_dir_parent_path_handler.return_value = parent_bundle + + result = self.observer.schedule_handler(bundle) + + self.assertEqual(result, parent_watch) + schedule_mock.assert_any_call(bundle.event_handler, "dir", True) + schedule_mock.assert_any_call(parent_bundle.event_handler, "parent", False) + wrapper_mock.assert_called_once_with(self.observer, watch, bundle) diff --git a/tests/unit/lib/utils/test_lock_distributor.py b/tests/unit/lib/utils/test_lock_distributor.py new file mode 100644 index 0000000000..f57ba4e1ed --- /dev/null +++ b/tests/unit/lib/utils/test_lock_distributor.py @@ -0,0 +1,103 @@ +from unittest import TestCase +from unittest.mock import MagicMock, call, patch +from samcli.lib.utils.lock_distributor import LockChain, LockDistributor, LockDistributorType + + +class TestLockChain(TestCase): + def test_aquire_order(self): + locks = {"A": MagicMock(), "B": MagicMock(), "C": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.acquire() + call_mock.assert_has_calls([call.a.acquire(), call.b.acquire(), call.c.acquire()]) + + def test_aquire_order_shuffled(self): + locks = {"A": MagicMock(), "C": MagicMock(), "B": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.acquire() + call_mock.assert_has_calls([call.a.acquire(), call.b.acquire(), call.c.acquire()]) + + def test_release_order(self): + locks = {"A": MagicMock(), "B": MagicMock(), "C": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.release() + call_mock.assert_has_calls([call.a.release(), call.b.release(), call.c.release()]) + + def test_release_order_shuffled(self): + locks = {"A": MagicMock(), "C": MagicMock(), "B": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + lock_chain = LockChain(locks) + lock_chain.release() + call_mock.assert_has_calls([call.a.release(), call.b.release(), call.c.release()]) + + def test_with(self): + locks = {"A": MagicMock(), "C": MagicMock(), "B": MagicMock()} + call_mock = MagicMock() + call_mock.a = locks["A"] + call_mock.b = locks["B"] + call_mock.c = locks["C"] + with LockChain(locks) as _: + call_mock.assert_has_calls([call.a.acquire(), call.b.acquire(), call.c.acquire()]) + call_mock.assert_has_calls( + [call.a.acquire(), call.b.acquire(), call.c.acquire(), call.a.release(), call.b.release(), call.c.release()] + ) + + +class TestLockDistributor(TestCase): + @patch("samcli.lib.utils.lock_distributor.threading.Lock") + @patch("samcli.lib.utils.lock_distributor.multiprocessing.Lock") + def test_thread_get_locks(self, process_lock_mock, thread_lock_mock): + locks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + thread_lock_mock.side_effect = locks + distributor = LockDistributor(LockDistributorType.THREAD, None) + keys = ["A", "B", "C"] + result = distributor.get_locks(keys) + + self.assertEqual(result["A"], locks[1]) + self.assertEqual(result["B"], locks[2]) + self.assertEqual(result["C"], locks[3]) + self.assertEqual(distributor.get_locks(keys)["A"], locks[1]) + + @patch("samcli.lib.utils.lock_distributor.threading.Lock") + @patch("samcli.lib.utils.lock_distributor.multiprocessing.Lock") + def test_process_get_locks(self, process_lock_mock, thread_lock_mock): + locks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + process_lock_mock.side_effect = locks + distributor = LockDistributor(LockDistributorType.PROCESS, None) + keys = ["A", "B", "C"] + result = distributor.get_locks(keys) + + self.assertEqual(result["A"], locks[1]) + self.assertEqual(result["B"], locks[2]) + self.assertEqual(result["C"], locks[3]) + self.assertEqual(distributor.get_locks(keys)["A"], locks[1]) + + @patch("samcli.lib.utils.lock_distributor.threading.Lock") + @patch("samcli.lib.utils.lock_distributor.multiprocessing.Lock") + def test_process_manager_get_locks(self, process_lock_mock, thread_lock_mock): + manager_mock = MagicMock() + locks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + manager_mock.dict.return_value = dict() + manager_mock.Lock.side_effect = locks + distributor = LockDistributor(LockDistributorType.PROCESS, manager_mock) + keys = ["A", "B", "C"] + result = distributor.get_locks(keys) + + self.assertEqual(result["A"], locks[1]) + self.assertEqual(result["B"], locks[2]) + self.assertEqual(result["C"], locks[3]) + self.assertEqual(distributor.get_locks(keys)["A"], locks[1]) diff --git a/tests/unit/lib/utils/test_resource_trigger.py b/tests/unit/lib/utils/test_resource_trigger.py new file mode 100644 index 0000000000..8feff30b71 --- /dev/null +++ b/tests/unit/lib/utils/test_resource_trigger.py @@ -0,0 +1,258 @@ +import re +from parameterized import parameterized +from unittest.case import TestCase +from unittest.mock import MagicMock, patch, ANY +from samcli.lib.utils.resource_trigger import ( + CodeResourceTrigger, + DefinitionCodeTrigger, + LambdaFunctionCodeTrigger, + LambdaImageCodeTrigger, + LambdaLayerCodeTrigger, + LambdaZipCodeTrigger, + ResourceTrigger, + TemplateTrigger, +) +from samcli.local.lambdafn.exceptions import FunctionNotFound, ResourceNotFound +from samcli.lib.providers.exceptions import MissingLocalDefinition +from samcli.lib.providers.provider import ResourceIdentifier + + +class TestResourceTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.PathHandler") + @patch("samcli.lib.utils.resource_trigger.RegexMatchingEventHandler") + @patch("samcli.lib.utils.resource_trigger.Path") + def test_single_file_path_handler(self, path_mock, handler_mock, bundle_mock): + path = MagicMock() + path_mock.return_value = path + file_path = MagicMock() + file_path.__str__.return_value = "/parent/file" + + parent_path = MagicMock() + parent_path.__str__.return_value = "/parent/" + + file_path.parent = parent_path + + path.resolve.return_value = file_path + + ResourceTrigger.get_single_file_path_handler("/parent/file") + + path_mock.assert_called_once_with("/parent/file") + escaped_path = re.escape("/parent/file") + handler_mock.assert_called_once_with( + regexes=[f"^{escaped_path}$"], ignore_regexes=[], ignore_directories=True, case_sensitive=True + ) + bundle_mock.assert_called_once_with(path=parent_path, event_handler=handler_mock.return_value, recursive=False) + + @patch("samcli.lib.utils.resource_trigger.PathHandler") + @patch("samcli.lib.utils.resource_trigger.PatternMatchingEventHandler") + @patch("samcli.lib.utils.resource_trigger.Path") + def test_dir_path_handler(self, path_mock, handler_mock, bundle_mock): + path = MagicMock() + path_mock.return_value = path + folder_path = MagicMock() + + path.resolve.return_value = folder_path + + ResourceTrigger.get_dir_path_handler("/parent/folder/") + + path_mock.assert_called_once_with("/parent/folder/") + handler_mock.assert_called_once_with( + patterns=["*"], ignore_patterns=[], ignore_directories=False, case_sensitive=True + ) + bundle_mock.assert_called_once_with( + path=folder_path, event_handler=handler_mock.return_value, recursive=True, static_folder=True + ) + + +class TestTemplateTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_single_file_path_handler") + def test_get_path_handler(self, single_file_handler_mock, path_mock, validator_mock): + trigger = TemplateTrigger("template.yaml", MagicMock()) + result = trigger.get_path_handlers() + self.assertEqual(result, [single_file_handler_mock.return_value]) + self.assertEqual(single_file_handler_mock.return_value.event_handler.on_any_event, trigger._validator_wrapper) + + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + def test_validator_wrapper(self, path_mock, validator_mock): + on_template_change_mock = MagicMock() + event_mock = MagicMock() + validator_mock.return_value.validate.return_value = True + trigger = TemplateTrigger("template.yaml", on_template_change_mock) + trigger._validator_wrapper(event_mock) + on_template_change_mock.assert_called_once_with(event_mock) + + +class TestCodeResourceTrigger(TestCase): + @patch.multiple(CodeResourceTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init(self, get_resource_by_id_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + trigger = CodeResourceTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + self.assertEqual(trigger._resource, get_resource_by_id_mock.return_value) + self.assertEqual(trigger._on_code_change, on_code_change_mock) + + @patch.multiple(CodeResourceTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init_invalid(self, get_resource_by_id_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + get_resource_by_id_mock.return_value = None + + with self.assertRaises(ResourceNotFound): + CodeResourceTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + + +class TestLambdaFunctionCodeTrigger(TestCase): + @patch.multiple(LambdaFunctionCodeTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + + code_uri_mock = MagicMock() + LambdaFunctionCodeTrigger._get_code_uri = code_uri_mock + + trigger = LambdaFunctionCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + self.assertEqual(trigger._function, function_mock) + self.assertEqual(trigger._code_uri, code_uri_mock.return_value) + + @patch.multiple(LambdaFunctionCodeTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init_invalid(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_provider_mock.return_value.get.return_value = None + + code_uri_mock = MagicMock() + LambdaFunctionCodeTrigger._get_code_uri = code_uri_mock + + with self.assertRaises(FunctionNotFound): + LambdaFunctionCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + + @patch.multiple(LambdaFunctionCodeTrigger, __abstractmethods__=set()) + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_dir_path_handler") + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handlers(self, get_resource_by_id_mock, function_provider_mock, get_dir_path_handler_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + + code_uri_mock = MagicMock() + LambdaFunctionCodeTrigger._get_code_uri = code_uri_mock + + bundle = MagicMock() + get_dir_path_handler_mock.return_value = bundle + + trigger = LambdaFunctionCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger.get_path_handlers() + + self.assertEqual(result, [bundle]) + self.assertEqual(bundle.self_create, on_code_change_mock) + self.assertEqual(bundle.self_delete, on_code_change_mock) + self.assertEqual(bundle.event_handler.on_any_event, on_code_change_mock) + + +class TestLambdaZipCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_code_uri(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + trigger = LambdaZipCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger._get_code_uri() + self.assertEqual(result, function_mock.codeuri) + + +class TestLambdaImageCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.SamFunctionProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_code_uri(self, get_resource_by_id_mock, function_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + function_mock = function_provider_mock.return_value.get.return_value + trigger = LambdaImageCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger._get_code_uri() + self.assertEqual(result, function_mock.metadata.get.return_value) + + +class TestLambdaLayerCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.SamLayerProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_init(self, get_resource_by_id_mock, layer_provider_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + layer_mock = layer_provider_mock.return_value.get.return_value + + trigger = LambdaLayerCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + self.assertEqual(trigger._layer, layer_mock) + self.assertEqual(trigger._code_uri, layer_mock.codeuri) + + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_dir_path_handler") + @patch("samcli.lib.utils.resource_trigger.SamLayerProvider") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handlers(self, get_resource_by_id_mock, layer_provider_mock, get_dir_path_handler_mock): + stacks = [MagicMock(), MagicMock()] + on_code_change_mock = MagicMock() + layer_mock = layer_provider_mock.return_value.get.return_value + + bundle = MagicMock() + get_dir_path_handler_mock.return_value = bundle + + trigger = LambdaLayerCodeTrigger(ResourceIdentifier("A"), stacks, on_code_change_mock) + result = trigger.get_path_handlers() + + self.assertEqual(result, [bundle]) + self.assertEqual(bundle.self_create, on_code_change_mock) + self.assertEqual(bundle.self_delete, on_code_change_mock) + self.assertEqual(bundle.event_handler.on_any_event, on_code_change_mock) + + +class TestDefinitionCodeTrigger(TestCase): + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_single_file_path_handler") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handler(self, get_resource_by_id_mock, single_file_handler_mock, path_mock, validator_mock): + stacks = [MagicMock(), MagicMock()] + resource = {"Properties": {"DefinitionUri": "abc"}} + get_resource_by_id_mock.return_value = resource + trigger = DefinitionCodeTrigger("TestApi", "AWS::Serverless::Api", stacks, MagicMock()) + result = trigger.get_path_handlers() + self.assertEqual(result, [single_file_handler_mock.return_value]) + self.assertEqual(single_file_handler_mock.return_value.event_handler.on_any_event, trigger._validator_wrapper) + + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.ResourceTrigger.get_single_file_path_handler") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_get_path_handler_missing_definition( + self, get_resource_by_id_mock, single_file_handler_mock, path_mock, validator_mock + ): + stacks = [MagicMock(), MagicMock()] + resource = {"Properties": {"Field": "abc"}} + get_resource_by_id_mock.return_value = resource + with self.assertRaises(MissingLocalDefinition): + trigger = DefinitionCodeTrigger("TestApi", "AWS::Serverless::Api", stacks, MagicMock()) + + @patch("samcli.lib.utils.resource_trigger.DefinitionValidator") + @patch("samcli.lib.utils.resource_trigger.Path") + @patch("samcli.lib.utils.resource_trigger.get_resource_by_id") + def test_validator_wrapper(self, get_resource_by_id_mock, path_mock, validator_mock): + stacks = [MagicMock(), MagicMock()] + on_definition_change_mock = MagicMock() + event_mock = MagicMock() + validator_mock.return_value.validate.return_value = True + resource = {"Properties": {"DefinitionUri": "abc"}} + get_resource_by_id_mock.return_value = resource + trigger = DefinitionCodeTrigger("TestApi", "AWS::Serverless::Api", stacks, on_definition_change_mock) + trigger._validator_wrapper(event_mock) + on_definition_change_mock.assert_called_once_with(event_mock) diff --git a/tests/unit/lib/utils/test_resource_type_based_factory.py b/tests/unit/lib/utils/test_resource_type_based_factory.py new file mode 100644 index 0000000000..302e91f4d6 --- /dev/null +++ b/tests/unit/lib/utils/test_resource_type_based_factory.py @@ -0,0 +1,48 @@ +from samcli.lib.providers.provider import ResourceIdentifier +from samcli.lib.utils.resource_type_based_factory import ResourceTypeBasedFactory +from unittest import TestCase +from unittest.mock import ANY, MagicMock, call, patch + + +class TestResourceTypeBasedFactory(TestCase): + def setUp(self): + self.abstract_method_patch = patch.multiple(ResourceTypeBasedFactory, __abstractmethods__=set()) + self.abstract_method_patch.start() + self.stacks = [MagicMock(), MagicMock()] + self.factory = ResourceTypeBasedFactory(self.stacks) + self.function_generator_mock = MagicMock() + self.layer_generator_mock = MagicMock() + self.factory._get_generator_mapping = MagicMock() + self.factory._get_generator_mapping.return_value = { + "AWS::Lambda::Function": self.function_generator_mock, + "AWS::Lambda::LayerVersion": self.layer_generator_mock, + } + + def tearDown(self): + self.abstract_method_patch.stop() + + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_get_generator_function_valid(self, get_resource_by_id_mock): + resource = {"Type": "AWS::Lambda::Function"} + get_resource_by_id_mock.return_value = resource + + generator = self.factory._get_generator_function(ResourceIdentifier("Resource1")) + self.assertEqual(generator, self.function_generator_mock) + + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_get_generator_function_unknown_type(self, get_resource_by_id_mock): + resource = {"Type": "AWS::Unknown::Type"} + get_resource_by_id_mock.return_value = resource + + generator = self.factory._get_generator_function(ResourceIdentifier("Resource1")) + + self.assertEqual(None, generator) + + @patch("samcli.lib.utils.resource_type_based_factory.get_resource_by_id") + def test_get_generator_function_no_type(self, get_resource_by_id_mock): + resource = {"Properties": {}} + get_resource_by_id_mock.return_value = resource + + generator = self.factory._get_generator_function(ResourceIdentifier("Resource1")) + + self.assertEqual(None, generator)