From cc090d47d04d4eef64c8dae237b51527bb5dc027 Mon Sep 17 00:00:00 2001 From: Balaji Sankar <115105204+balajisankar15@users.noreply.github.com> Date: Tue, 21 Feb 2023 10:15:48 -0700 Subject: [PATCH 01/40] feature: Added Config parser for SageMaker Python SDK (#840) Co-authored-by: Balaji Sankar --- setup.py | 3 + src/sagemaker/config/__init__.py | 15 + src/sagemaker/config/config.py | 268 ++++++++++ src/sagemaker/config/config_schema.py | 488 ++++++++++++++++++ tests/data/config/config.yaml | 123 +++++ .../expected_output_config_after_merge.yaml | 14 + tests/data/config/invalid_config_file.yaml | 11 + .../sample_additional_config_for_merge.yaml | 7 + .../data/config/sample_config_for_merge.yaml | 9 + tests/unit/sagemaker/config/__init__.py | 13 + tests/unit/sagemaker/config/conftest.py | 206 ++++++++ tests/unit/sagemaker/config/test_config.py | 234 +++++++++ .../sagemaker/config/test_config_schema.py | 139 +++++ 13 files changed, 1530 insertions(+) create mode 100644 src/sagemaker/config/__init__.py create mode 100644 src/sagemaker/config/config.py create mode 100644 src/sagemaker/config/config_schema.py create mode 100644 tests/data/config/config.yaml create mode 100644 tests/data/config/expected_output_config_after_merge.yaml create mode 100644 tests/data/config/invalid_config_file.yaml create mode 100644 tests/data/config/sample_additional_config_for_merge.yaml create mode 100644 tests/data/config/sample_config_for_merge.yaml create mode 100644 tests/unit/sagemaker/config/__init__.py create mode 100644 tests/unit/sagemaker/config/conftest.py create mode 100644 tests/unit/sagemaker/config/test_config.py create mode 100644 tests/unit/sagemaker/config/test_config_schema.py diff --git a/setup.py b/setup.py index e2adb6b433..98a63c9d32 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,9 @@ def read_requirements(filename): "pandas", "pathos", "schema", + "PyYAML==5.4.1", + "jsonschema", + "platformdirs", ] # Specific use case dependencies diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py new file mode 100644 index 0000000000..8ee1bd2962 --- /dev/null +++ b/src/sagemaker/config/__init__.py @@ -0,0 +1,15 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying athis file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module configures the default values for SageMaker Python SDK.""" + +from __future__ import absolute_import diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py new file mode 100644 index 0000000000..7451de158d --- /dev/null +++ b/src/sagemaker/config/config.py @@ -0,0 +1,268 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module configures the default values for SageMaker Python SDK. + +It supports loading Config files from local file system/S3. +The schema of the Config file is dictated in config_schema.py in the same module. + +""" +from __future__ import absolute_import + +import pathlib +import logging +import os +from typing import List +import boto3 +import yaml +from jsonschema import validate +from platformdirs import site_config_dir, user_config_dir +from botocore.utils import merge_dicts +from six.moves.urllib.parse import urlparse +from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA + +logger = logging.getLogger("sagemaker") + +_APP_NAME = "sagemaker" +_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml") +_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml") + +ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE = "SAGEMAKER_DEFAULT_CONFIG_OVERRIDE" +ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE" + +_config_paths = [_DEFAULT_ADMIN_CONFIG_FILE_PATH, _DEFAULT_USER_CONFIG_FILE_PATH] +_BOTO_SESSION = boto3.DEFAULT_SESSION or boto3.Session() +_DEFAULT_S3_RESOURCE = _BOTO_SESSION.resource("s3") + + +class SageMakerConfig(object): + """SageMakerConfig class encapsulates the Config for SageMaker Python SDK. + + Usages: + This class will be integrated with sagemaker.session.Session. Users of SageMaker Python SDK + will have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If + SageMakerConfig object is not provided by the user, then sagemaker.session.Session will + create its own SageMakerConfig object. + + Note: Once sagemaker.session.Session is initialized, it will operate with the configuration + values at that instant. If the users wish to alter configuration files/file paths after + sagemaker.session.Session is initialized, then that will not be reflected in + sagemaker.session.Session. They would have to re-initialize sagemaker.session.Session to + pick the latest changes. + + """ + + def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE): + """Constructor for SageMakerConfig. + + By default, it will first look for Config files in paths that are dictated by + _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH. + + Users can override the _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH + by using environment variables - SAGEMAKER_DEFAULT_CONFIG_OVERRIDE and + SAGEMAKER_USER_CONFIG_OVERRIDE + + Additional Configuration file paths can also be provided as a constructor parameter. + + This constructor will then + * Load each config file. + * It will validate the schema of the config files. + * It will perform the merge operation in the same order. + + This constructor will throw exceptions for the following cases: + * Schema validation fails for one/more config files. + * When the config file is not a proper YAML file. + * Any S3 related issues that arises while fetching config file from S3. This includes + permission issues, S3 Object is not found in the specified S3 URI. + * File doesn't exist in a path that was specified by the user as part of environment + variable/ additional_config_paths. This doesn't include + _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH + + + Args: + additional_config_paths: List of Config file paths. + These paths can be one of the following: + * Local file path + * Local directory path (in this case, we will look for config.yaml in that + directory) + * S3 URI of the config file + * S3 URI of the directory containing the config file (in this case, we will look for + config.yaml in that directory) + Note: S3 URI follows the format s3:/// + s3_resource: Corresponds to boto3 S3 resource. This will be used to fetch Config + files from S3. If it is not provided, we will create a default s3 resource + See :py:meth:`boto3.session.Session.resource`. This argument is not needed if the + config files are present in the local file system + + """ + default_config_path = os.getenv( + ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH + ) + user_config_path = os.getenv( + ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH + ) + self._config_paths = [default_config_path, user_config_path] + if additional_config_paths: + self._config_paths += additional_config_paths + self._s3_resource = s3_resource + config = {} + for file_path in self._config_paths: + if file_path.startswith("s3://"): + config_from_file = _load_config_from_s3(file_path, self._s3_resource) + else: + config_from_file = _load_config_from_file(file_path) + if config_from_file: + validate(config_from_file, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + merge_dicts(config, config_from_file) + self._config = config + + @property + def config_paths(self) -> List[str]: + """Getter for Config paths. + + Returns: + List[str]: This corresponds to the list of config file paths. + """ + return self._config_paths + + @property + def config(self) -> dict: + """Getter for the configuration object. + + Returns: + dict: A dictionary representing the configurations that were loaded from the config + file(s). + """ + return self._config + + +def _load_config_from_file(file_path: str) -> dict: + """This method loads the config file from the path that was specified as parameter. + + If the path that was provided, corresponds to a directory then this method will try to search + for 'config.yaml' in that directory. Note: We will not be doing any recursive search. + + Args: + file_path(str): The file path from which the Config file needs to be loaded. + + Returns: + dict: A dictionary representing the configurations that were loaded from the config file. + + This method will throw Exceptions for the following cases: + * When the config file is not a proper YAML file. + * File doesn't exist in a path that was specified by the consumer. This doesn't include + _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH + """ + config = {} + if file_path: + inferred_file_path = file_path + if os.path.isdir(file_path): + inferred_file_path = os.path.join(file_path, "config.yaml") + if not os.path.exists(inferred_file_path): + if inferred_file_path not in ( + _DEFAULT_ADMIN_CONFIG_FILE_PATH, + _DEFAULT_USER_CONFIG_FILE_PATH, + ): + # Customer provided file path is invalid. + raise ValueError( + f"Unable to load config file from the location: {file_path} Please" + f" provide a valid file path" + ) + else: + logger.debug("Fetching configuration file from the path: %s", file_path) + config = yaml.safe_load(open(inferred_file_path, "r")) + return config + + +def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: + """This method loads the config file from the S3 URI that was specified as parameter. + + If the S3 URI that was provided, corresponds to a directory then this method will try to + search for 'config.yaml' in that directory. Note: We will not be doing any recursive search. + + Args: + s3_uri(str): The S3 URI of the config file. + Note: S3 URI follows the format s3:/// + s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config + files from S3. See :py:meth:`boto3.session.Session.resource`. + + Returns: + dict: A dictionary representing the configurations that were loaded from the config file. + + This method will throw Exceptions for the following cases: + * If Boto3 S3 resource is not provided. + * When the config file is not a proper YAML file. + * If the method is unable to retrieve the list of all the S3 files with the same prefix + * If there are no S3 files with that prefix. + * If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml, + then we will throw an Exception. + """ + if not s3_resource_for_config: + raise RuntimeError("Please provide a S3 client for loading the config") + logger.debug("Fetching configuration file from the S3 URI: %s", s3_uri) + inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config) + parsed_url = urlparse(inferred_s3_uri) + bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") + s3_object = s3_resource_for_config.Object(bucket, key_prefix) + s3_file_content = s3_object.get()["Body"].read() + return yaml.safe_load(s3_file_content.decode("utf-8")) + + +def _get_inferred_s3_uri(s3_uri, s3_resource_for_config): + """Verifies whether the given S3 URI exists and returns the URI. + + If there are multiple S3 objects with the same key prefix, + then this method will verify whether S3 URI + /config.yaml exists. + s3://example-bucket/somekeyprefix/config.yaml + + Args: + s3_uri (str) : An S3 uri that refers to a location in which config file is present. + s3_uri must start with 's3://'. + An example s3_uri: 's3://example-bucket/config.yaml'. + s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config + files from S3. + See :py:meth:`boto3.session.Session.resource` + + Returns: + str: Valid S3 URI of the Config file. None if it doesn't exist. + + This method will throw Exceptions for the following cases: + * If the method is unable to retrieve the list of all the S3 files with the same prefix + * If there are no S3 files with that prefix. + * If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml, + then we will throw an Exception. + """ + parsed_url = urlparse(s3_uri) + bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") + try: + s3_bucket = s3_resource_for_config.Bucket(name=bucket) + s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() + s3_files_with_same_prefix = [ + "s3://{}/{}".format(bucket, s3_object.key) for s3_object in s3_objects + ] + except Exception as e: # pylint: disable=W0703 + # if customers didn't provide us with a valid S3 File/insufficient read permission, + # We will fail hard. + raise RuntimeError(f"Unable to read from S3 with URI: {s3_uri} due to {e}") + if len(s3_files_with_same_prefix) == 0: + # Customer provided us with an incorrect s3 path. + raise ValueError("Please provide a valid s3 path instead of {}".format(s3_uri)) + if len(s3_files_with_same_prefix) > 1: + # Customer has provided us with a S3 URI which points to a directory + # search for s3:///directory-key-prefix/config.yaml + inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, "config.yaml")).replace("s3:/", "s3://") + if inferred_s3_uri not in s3_files_with_same_prefix: + # We don't know which file we should be operating with. + raise ValueError("Please provide a S3 URI which has config.yaml in the directory") + # Customer has a config.yaml present in the directory that was provided as the S3 URI + return inferred_s3_uri + return s3_uri diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py new file mode 100644 index 0000000000..5a80ccf10f --- /dev/null +++ b/src/sagemaker/config/config_schema.py @@ -0,0 +1,488 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains/maintains the schema of the Config f`i`le.""" +from __future__ import absolute_import, print_function + +SECURITY_GROUP_IDS = "SecurityGroupIds" +SUBNETS = "Subnets" +ENABLE_NETWORK_ISOLATION = "EnableNetworkIsolation" +VOLUME_KMS_KEY_ID = "VolumeKmsKeyId" +KMS_KEY_ID = "KmsKeyId" +ROLE_ARN = "RoleArn" +EXECUTION_ROLE_ARN = "ExecutionRoleArn" +CLUSTER_ROLE_ARN = "ClusterRoleArn" +VPC_CONFIG = "VpcConfig" +OUTPUT_DATA_CONFIG = "OutputDataConfig" +AUTO_ML_JOB_CONFIG = "AutoMLJobConfig" +ASYNC_INFERENCE_CONFIG = "AsyncInferenceConfig" +OUTPUT_CONFIG = "OutputConfig" +PROCESSING_OUTPUT_CONFIG = "ProcessingOutputConfig" +CLUSTER_CONFIG = "ClusterConfig" +NETWORK_CONFIG = "NetworkConfig" +CORE_DUMP_CONFIG = "CoreDumpConfig" +DATA_CAPTURE_CONFIG = "DataCaptureConfig" +MONITORING_OUTPUT_CONFIG = "MonitoringOutputConfig" +RESOURCE_CONFIG = "ResourceConfig" +SCHEMA_VERSION = "SchemaVersion" +DATASET_DEFINITION = "DatasetDefinition" +ATHENA_DATASET_DEFINITION = "AthenaDatasetDefinition" +REDSHIFT_DATASET_DEFINITION = "RedshiftDatasetDefinition" +MONITORING_JOB_DEFINITION = "MonitoringJobDefinition" +SAGEMAKER = "SageMaker" +PYTHON_SDK = "PythonSDK" +MODULES = "Modules" +OFFLINE_STORE_CONFIG = "OfflineStoreConfig" +ONLINE_STORE_CONFIG = "OnlineStoreConfig" +S3_STORAGE_CONFIG = "S3StorageConfig" +SECURITY_CONFIG = "SecurityConfig" +TRANSFORM_JOB_DEFINITION = "TransformJobDefinition" +MONITORING_SCHEDULE_CONFIG = "MonitoringScheduleConfig" +MONITORING_RESOURCES = "MonitoringResources" +PROCESSING_RESOURCES = "ProcessingResources" +PRODUCTION_VARIANTS = "ProductionVariants" +SHADOW_PRODUCTION_VARIANTS = "ShadowProductionVariants" +TRANSFORM_OUTPUT = "TransformOutput" +TRANSFORM_RESOURCES = "TransformResources" +VALIDATION_ROLE = "ValidationRole" +VALIDATION_SPECIFICATION = "ValidationSpecification" +VALIDATION_PROFILES = "ValidationProfiles" +PROCESSING_INPUTS = "ProcessingInputs" +FEATURE_GROUP = "FeatureGroup" +EDGE_PACKAGING_JOB = "EdgePackagingJob" +TRAINING_JOB = "TrainingJob" +PROCESSING_JOB = "ProcessingJob" +MODEL_PACKAGE = "ModelPackage" +MODEL = "Model" +MONITORING_SCHEDULE = "MonitoringSchedule" +ENDPOINT_CONFIG = "EndpointConfig" +AUTO_ML = "AutoML" +COMPILATION_JOB = "CompilationJob" +PIPELINE = "Pipeline" +TRANSFORM_JOB = "TransformJob" +PROPERTIES = "properties" +TYPE = "type" +OBJECT = "object" +ADDITIONAL_PROPERTIES = "additionalProperties" +SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + TYPE: OBJECT, + "required": [SCHEMA_VERSION], + ADDITIONAL_PROPERTIES: False, + "definitions": { + "roleArn": { + # Schema for IAM Role. This includes a Regex validator. + TYPE: "string", + "pattern": r"^arn:aws[a-z\-]*:iam::\d{12}:role/?[a-zA-Z_0-9+=,.@\-_/]+$", + }, + "securityGroupId": {TYPE: "string", "pattern": r"[-0-9a-zA-Z]+"}, + "subnet": {TYPE: "string", "pattern": r"[-0-9a-zA-Z]+"}, + "vpcConfig": { + # Schema for VPC Configs. + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference + # /API_VpcConfig.html + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + SECURITY_GROUP_IDS: { + TYPE: "array", + "items": {"$ref": "#/definitions/securityGroupId"}, + }, + SUBNETS: {TYPE: "array", "items": {"$ref": "#/definitions/subnet"}}, + }, + }, + "productionVariant": { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + CORE_DUMP_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + } + }, + }, + "validationProfile": { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + TRANSFORM_JOB_DEFINITION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + TRANSFORM_OUTPUT: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + TRANSFORM_RESOURCES: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + }, + }, + } + }, + }, + "processingInput": { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + DATASET_DEFINITION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + ATHENA_DATASET_DEFINITION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + KMS_KEY_ID: { + TYPE: "string", + } + }, + }, + REDSHIFT_DATASET_DEFINITION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + KMS_KEY_ID: { + TYPE: "string", + }, + CLUSTER_ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + }, + }, + }, + } + }, + }, + }, + PROPERTIES: { + SCHEMA_VERSION: { + TYPE: "string", + # Currently we support only one schema version (1.0). + # In the future this might change if we introduce any breaking changes. + # So adding an enum as a validator. + "enum": ["1.0"], + "description": "The schema version of the document.", + }, + SAGEMAKER: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + PYTHON_SDK: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + MODULES: { + # Any SageMaker Python SDK specific configuration will be added here. + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + } + }, + }, + # Feature Group + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateFeatureGroup + # .html + FEATURE_GROUP: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + OFFLINE_STORE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + S3_STORAGE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + } + }, + }, + ONLINE_STORE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + SECURITY_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + } + }, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + }, + }, + # Monitoring Schedule + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateMonitoringSchedule.html + MONITORING_SCHEDULE: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + MONITORING_SCHEDULE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + MONITORING_JOB_DEFINITION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + MONITORING_OUTPUT_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + MONITORING_RESOURCES: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + CLUSTER_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + VOLUME_KMS_KEY_ID: {TYPE: "string"} + }, + } + }, + }, + NETWORK_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + }, + } + }, + } + }, + }, + # Endpoint Config + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html + # Note: there is a separate API for creating Endpoints. + # That will be added later to schema once we start + # supporting other parameters such as Tags + ENDPOINT_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + ASYNC_INFERENCE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + OUTPUT_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + } + }, + }, + DATA_CAPTURE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + KMS_KEY_ID: {TYPE: "string"}, + PRODUCTION_VARIANTS: { + TYPE: "array", + "items": {"$ref": "#/definitions/productionVariant"}, + }, + SHADOW_PRODUCTION_VARIANTS: { + TYPE: "array", + "items": {"$ref": "#/definitions/productionVariant"}, + }, + }, + }, + # Auto ML + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateAutoMLJob.html + AUTO_ML: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + AUTO_ML_JOB_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + SECURITY_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + VOLUME_KMS_KEY_ID: { + TYPE: "string", + }, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + } + }, + }, + OUTPUT_DATA_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + }, + }, + # Transform Job + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html + TRANSFORM_JOB: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + DATA_CAPTURE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + TRANSFORM_OUTPUT: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + TRANSFORM_RESOURCES: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + }, + }, + }, + # Compilation Job + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateCompilationJob + # .html + COMPILATION_JOB: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + OUTPUT_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + }, + # Pipeline + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreatePipeline.html + PIPELINE: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {ROLE_ARN: {"$ref": "#/definitions/roleArn"}}, + }, + # Model + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html + MODEL: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, + EXECUTION_ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + }, + # Model Package + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModelPackage.html + MODEL_PACKAGE: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + VALIDATION_SPECIFICATION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + VALIDATION_PROFILES: { + TYPE: "array", + "items": {"$ref": "#/definitions/validationProfile"}, + }, + VALIDATION_ROLE: {"$ref": "#/definitions/roleArn"}, + }, + } + }, + }, + # Processing Job + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html + PROCESSING_JOB: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + NETWORK_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + }, + PROCESSING_INPUTS: { + TYPE: "array", + "items": {"$ref": "#/definitions/processingInput"}, + }, + PROCESSING_OUTPUT_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + PROCESSING_RESOURCES: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + CLUSTER_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + } + }, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + }, + }, + # Training Job + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html + TRAINING_JOB: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, + OUTPUT_DATA_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + RESOURCE_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + }, + # Edge Packaging Job + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEdgePackagingJob.html + EDGE_PACKAGING_JOB: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + OUTPUT_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + }, + }, + }, + }, + }, +} diff --git a/tests/data/config/config.yaml b/tests/data/config/config.yaml new file mode 100644 index 0000000000..726f73f07a --- /dev/null +++ b/tests/data/config/config.yaml @@ -0,0 +1,123 @@ +SchemaVersion: '1.0' +SageMaker: + FeatureGroup: + OnlineStoreConfig: + SecurityConfig: + KmsKeyId: 'someotherkmskeyid' + OfflineStoreConfig: + S3StorageConfig: + KmsKeyId: 'somekmskeyid' + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + MonitoringSchedule: + MonitoringScheduleConfig: + MonitoringJobDefinition: + MonitoringOutputConfig: + KmsKeyId: 'somekmskey' + MonitoringResources: + ClusterConfig: + VolumeKmsKeyId: 'somevolumekmskey' + NetworkConfig: + EnableNetworkIsolation: true + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + EndpointConfig: + AsyncInferenceConfig: + OutputConfig: + KmsKeyId: 'somekmskey' + DataCaptureConfig: + KmsKeyId: 'somekmskey2' + KmsKeyId: 'somekmskey3' + ProductionVariants: + - CoreDumpConfig: + KmsKeyId: 'somekmskey4' + ShadowProductionVariants: + - CoreDumpConfig: + KmsKeyId: 'somekmskey5' + AutoML: + AutoMLJobConfig: + SecurityConfig: + VolumeKmsKeyId: 'somevolumekmskey' + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' + OutputDataConfig: + KmsKeyId: 'somekmskey' + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + TransformJob: + DataCaptureConfig: + KmsKeyId: 'somekmskey' + TransformOutput: + KmsKeyId: 'somekmskey2' + TransformResources: + VolumeKmsKeyId: 'somevolumekmskey' + CompilationJob: + OutputConfig: + KmsKeyId: 'somekmskey' + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' + Pipeline: + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + Model: + EnableNetworkIsolation: true + ExecutionRoleArn: 'arn:aws:iam::366666666666:role/IMRole' + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' + ModelPackage: + ValidationSpecification: + ValidationProfiles: + - TransformJobDefinition: + TransformOutput: + KmsKeyId: 'somerandomkmskeyid' + TransformResources: + VolumeKmsKeyId: 'somerandomkmskeyid' + ValidationRole: 'arn:aws:iam::366666666666:role/IMRole' + ProcessingJob: + NetworkConfig: + EnableNetworkIsolation: true + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' + ProcessingInputs: + - DatasetDefinition: + AthenaDatasetDefinition: + KmsKeyId: 'somekmskeyid' + RedshiftDatasetDefinition: + KmsKeyId: 'someotherkmskeyid' + ClusterRoleArn: 'arn:aws:iam::366666666666:role/IMRole' + ProcessingOutputConfig: + KmsKeyId: 'somerandomkmskeyid' + ProcessingResources: + ClusterConfig: + VolumeKmsKeyId: 'somerandomkmskeyid' + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + TrainingJob: + EnableNetworkIsolation: true + OutputDataConfig: + KmsKeyId: 'somekmskey' + ResourceConfig: + VolumeKmsKeyId: 'somevolumekmskey' + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' + EdgePackagingJob: + OutputConfig: + KmsKeyId: 'somekeyid' + RoleArn: 'arn:aws:iam::366666666666:role/IMRole' diff --git a/tests/data/config/expected_output_config_after_merge.yaml b/tests/data/config/expected_output_config_after_merge.yaml new file mode 100644 index 0000000000..1cd7d815c6 --- /dev/null +++ b/tests/data/config/expected_output_config_after_merge.yaml @@ -0,0 +1,14 @@ +SchemaVersion: '1.0' +SageMaker: + FeatureGroup: + OnlineStoreConfig: + SecurityConfig: + # Present in the additional override as well as default config. + # Pick the additional config value. + KmsKeyId: 'additionalConfigKmsKeyId' + OfflineStoreConfig: + S3StorageConfig: + # Present only in the default config + KmsKeyId: 'somekmskeyid' + # Present only in the additional config + RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' \ No newline at end of file diff --git a/tests/data/config/invalid_config_file.yaml b/tests/data/config/invalid_config_file.yaml new file mode 100644 index 0000000000..81cba67e98 --- /dev/null +++ b/tests/data/config/invalid_config_file.yaml @@ -0,0 +1,11 @@ +SchemaVersion: '1.0' +SageMaker: + FeatureGroup: + OfflineStoreConfig: + S3StorageConfig: + KmsKeyId: 'mykmskeyid' + OnlineStoreConfig: + SecurityConfig: + KmsKeyId: 'myotherKmsKeyId' + # Role ARN is invalid + RoleArn: 'arn:aws123:iam::366:role/IMRole' diff --git a/tests/data/config/sample_additional_config_for_merge.yaml b/tests/data/config/sample_additional_config_for_merge.yaml new file mode 100644 index 0000000000..94aba11176 --- /dev/null +++ b/tests/data/config/sample_additional_config_for_merge.yaml @@ -0,0 +1,7 @@ +SchemaVersion: '1.0' +SageMaker: + FeatureGroup: + OnlineStoreConfig: + SecurityConfig: + KmsKeyId: 'additionalConfigKmsKeyId' + RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' \ No newline at end of file diff --git a/tests/data/config/sample_config_for_merge.yaml b/tests/data/config/sample_config_for_merge.yaml new file mode 100644 index 0000000000..f933f46afe --- /dev/null +++ b/tests/data/config/sample_config_for_merge.yaml @@ -0,0 +1,9 @@ +SchemaVersion: '1.0' +SageMaker: + FeatureGroup: + OnlineStoreConfig: + SecurityConfig: + KmsKeyId: 'someotherkmskeyid' + OfflineStoreConfig: + S3StorageConfig: + KmsKeyId: 'somekmskeyid' \ No newline at end of file diff --git a/tests/unit/sagemaker/config/__init__.py b/tests/unit/sagemaker/config/__init__.py new file mode 100644 index 0000000000..a6987bc6a6 --- /dev/null +++ b/tests/unit/sagemaker/config/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/tests/unit/sagemaker/config/conftest.py b/tests/unit/sagemaker/config/conftest.py new file mode 100644 index 0000000000..31a2616298 --- /dev/null +++ b/tests/unit/sagemaker/config/conftest.py @@ -0,0 +1,206 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +from mock import MagicMock + + +@pytest.fixture(scope="module") +def base_config_with_schema(): + return {"SchemaVersion": "1.0"} + + +@pytest.fixture(scope="module") +def valid_vpc_config(): + return {"SecurityGroupIds": ["sg123"], "Subnets": ["subnet-1234"]} + + +@pytest.fixture(scope="module") +def valid_iam_role_arn(): + return "arn:aws:iam::366666666666:role/IMRole" + + +@pytest.fixture(scope="module") +def valid_feature_group_config(valid_iam_role_arn): + s3_storage_config = {"KmsKeyId": "somekmskeyid"} + security_storage_config = {"KmsKeyId": "someotherkmskeyid"} + online_store_config = {"SecurityConfig": security_storage_config} + offline_store_config = {"S3StorageConfig": s3_storage_config} + return { + "OnlineStoreConfig": online_store_config, + "OfflineStoreConfig": offline_store_config, + "RoleArn": valid_iam_role_arn, + } + + +@pytest.fixture(scope="module") +def valid_edge_packaging_config(valid_iam_role_arn): + return { + "OutputConfig": {"KmsKeyId": "somekeyid"}, + "RoleArn": valid_iam_role_arn, + } + + +@pytest.fixture(scope="module") +def valid_model_config(valid_iam_role_arn, valid_vpc_config): + return { + "EnableNetworkIsolation": True, + "ExecutionRoleArn": valid_iam_role_arn, + "VpcConfig": valid_vpc_config, + } + + +@pytest.fixture(scope="module") +def valid_model_package_config(valid_iam_role_arn): + transform_job_definition = { + "TransformOutput": {"KmsKeyId": "somerandomkmskeyid"}, + "TransformResources": {"VolumeKmsKeyId": "somerandomkmskeyid"}, + } + validation_specification = { + "ValidationProfiles": [{"TransformJobDefinition": transform_job_definition}], + "ValidationRole": valid_iam_role_arn, + } + return {"ValidationSpecification": validation_specification} + + +@pytest.fixture(scope="module") +def valid_processing_job_config(valid_iam_role_arn, valid_vpc_config): + network_config = {"EnableNetworkIsolation": True, "VpcConfig": valid_vpc_config} + dataset_definition = { + "AthenaDatasetDefinition": {"KmsKeyId": "somekmskeyid"}, + "RedshiftDatasetDefinition": { + "KmsKeyId": "someotherkmskeyid", + "ClusterRoleArn": valid_iam_role_arn, + }, + } + return { + "NetworkConfig": network_config, + "ProcessingInputs": [{"DatasetDefinition": dataset_definition}], + "ProcessingOutputConfig": {"KmsKeyId": "somerandomkmskeyid"}, + "ProcessingResources": {"ClusterConfig": {"VolumeKmsKeyId": "somerandomkmskeyid"}}, + "RoleArn": valid_iam_role_arn, + } + + +@pytest.fixture(scope="module") +def valid_training_job_config(valid_iam_role_arn, valid_vpc_config): + return { + "EnableNetworkIsolation": True, + "OutputDataConfig": {"KmsKeyId": "somekmskey"}, + "ResourceConfig": {"VolumeKmsKeyId": "somevolumekmskey"}, + "RoleArn": valid_iam_role_arn, + "VpcConfig": valid_vpc_config, + } + + +@pytest.fixture(scope="module") +def valid_pipeline_config(valid_iam_role_arn): + return {"RoleArn": valid_iam_role_arn} + + +@pytest.fixture(scope="module") +def valid_compilation_job_config(valid_iam_role_arn, valid_vpc_config): + return { + "OutputConfig": {"KmsKeyId": "somekmskey"}, + "RoleArn": valid_iam_role_arn, + "VpcConfig": valid_vpc_config, + } + + +@pytest.fixture(scope="module") +def valid_transform_job_config(): + return { + "DataCaptureConfig": {"KmsKeyId": "somekmskey"}, + "TransformOutput": {"KmsKeyId": "somekmskey2"}, + "TransformResources": {"VolumeKmsKeyId": "somevolumekmskey"}, + } + + +@pytest.fixture(scope="module") +def valid_automl_config(valid_iam_role_arn, valid_vpc_config): + return { + "AutoMLJobConfig": { + "SecurityConfig": {"VolumeKmsKeyId": "somevolumekmskey", "VpcConfig": valid_vpc_config} + }, + "OutputDataConfig": {"KmsKeyId": "somekmskey"}, + "RoleArn": valid_iam_role_arn, + } + + +@pytest.fixture(scope="module") +def valid_endpointconfig_config(): + return { + "AsyncInferenceConfig": {"OutputConfig": {"KmsKeyId": "somekmskey"}}, + "DataCaptureConfig": {"KmsKeyId": "somekmskey2"}, + "KmsKeyId": "somekmskey3", + "ProductionVariants": [{"CoreDumpConfig": {"KmsKeyId": "somekmskey4"}}], + "ShadowProductionVariants": [{"CoreDumpConfig": {"KmsKeyId": "somekmskey5"}}], + } + + +@pytest.fixture(scope="module") +def valid_monitoring_schedule_config(valid_iam_role_arn, valid_vpc_config): + network_config = {"EnableNetworkIsolation": True, "VpcConfig": valid_vpc_config} + return { + "MonitoringScheduleConfig": { + "MonitoringJobDefinition": { + "MonitoringOutputConfig": {"KmsKeyId": "somekmskey"}, + "MonitoringResources": {"ClusterConfig": {"VolumeKmsKeyId": "somevolumekmskey"}}, + "NetworkConfig": network_config, + "RoleArn": valid_iam_role_arn, + } + } + } + + +@pytest.fixture(scope="module") +def valid_config_with_all_the_scopes( + valid_feature_group_config, + valid_monitoring_schedule_config, + valid_endpointconfig_config, + valid_automl_config, + valid_transform_job_config, + valid_compilation_job_config, + valid_pipeline_config, + valid_model_config, + valid_model_package_config, + valid_processing_job_config, + valid_training_job_config, + valid_edge_packaging_config, +): + return { + "FeatureGroup": valid_feature_group_config, + "MonitoringSchedule": valid_monitoring_schedule_config, + "EndpointConfig": valid_endpointconfig_config, + "AutoML": valid_automl_config, + "TransformJob": valid_transform_job_config, + "CompilationJob": valid_compilation_job_config, + "Pipeline": valid_pipeline_config, + "Model": valid_model_config, + "ModelPackage": valid_model_package_config, + "ProcessingJob": valid_processing_job_config, + "TrainingJob": valid_training_job_config, + "EdgePackagingJob": valid_edge_packaging_config, + } + + +@pytest.fixture(scope="module") +def s3_resource_mock(): + return MagicMock(name="s3") + + +@pytest.fixture(scope="module") +def get_data_dir(): + return os.path.join(os.path.dirname(__file__), "..", "..", "..", "data", "config") diff --git a/tests/unit/sagemaker/config/test_config.py b/tests/unit/sagemaker/config/test_config.py new file mode 100644 index 0000000000..ac3719479b --- /dev/null +++ b/tests/unit/sagemaker/config/test_config.py @@ -0,0 +1,234 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import yaml +from mock import Mock, MagicMock + +from sagemaker.config.config import SageMakerConfig +from jsonschema import exceptions + + +@pytest.fixture() +def config_file_as_yaml(get_data_dir): + config_file_path = os.path.join(get_data_dir, "config.yaml") + return open(config_file_path, "r").read() + + +@pytest.fixture() +def expected_merged_config(get_data_dir): + expected_merged_config_file_path = os.path.join( + get_data_dir, "expected_output_config_after_merge.yaml" + ) + return yaml.safe_load(open(expected_merged_config_file_path, "r").read()) + + +def test_config_when_default_config_file_and_user_config_file_is_not_found(): + assert SageMakerConfig().config == {} + + +def test_config_when_overriden_default_config_file_is_not_found(get_data_dir): + fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") + os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = fake_config_file_path + with pytest.raises(ValueError): + SageMakerConfig() + del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + + +def test_config_when_additional_config_file_path_is_not_found(get_data_dir): + fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") + with pytest.raises(ValueError): + SageMakerConfig(additional_config_paths=[fake_config_file_path]) + + +def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir): + fake_additional_override_config_file_path = os.path.join( + get_data_dir, "additional-config-not-found.yaml" + ) + os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = fake_additional_override_config_file_path + with pytest.raises(ValueError): + SageMakerConfig() + del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] + + +def test_default_config_file_with_invalid_schema(get_data_dir): + config_file_path = os.path.join(get_data_dir, "invalid_config_file.yaml") + os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = config_file_path + with pytest.raises(exceptions.ValidationError): + SageMakerConfig() + del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + + +def test_default_config_file_when_directory_is_provided_as_the_path( + get_data_dir, valid_config_with_all_the_scopes, base_config_with_schema +): + # This will try to load config.yaml file from that directory if present. + expected_config = base_config_with_schema + expected_config["SageMaker"] = valid_config_with_all_the_scopes + os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = get_data_dir + assert expected_config == SageMakerConfig().config + del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + + +def test_additional_config_paths_when_directory_is_provided( + get_data_dir, valid_config_with_all_the_scopes, base_config_with_schema +): + # This will try to load config.yaml file from that directory if present. + expected_config = base_config_with_schema + expected_config["SageMaker"] = valid_config_with_all_the_scopes + assert expected_config == SageMakerConfig(additional_config_paths=[get_data_dir]).config + + +def test_default_config_file_when_path_is_provided_as_environment_variable( + get_data_dir, valid_config_with_all_the_scopes, base_config_with_schema +): + os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = get_data_dir + # This will try to load config.yaml file from that directory if present. + expected_config = base_config_with_schema + expected_config["SageMaker"] = valid_config_with_all_the_scopes + assert expected_config == SageMakerConfig().config + del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + + +def test_merge_behavior_when_additional_config_file_path_is_not_found( + get_data_dir, valid_config_with_all_the_scopes, base_config_with_schema +): + valid_config_file_path = os.path.join(get_data_dir, "config.yaml") + fake_additional_override_config_file_path = os.path.join( + get_data_dir, "additional-config-not-found.yaml" + ) + os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = valid_config_file_path + with pytest.raises(ValueError): + SageMakerConfig(additional_config_paths=[fake_additional_override_config_file_path]) + del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + + +def test_merge_behavior(get_data_dir, expected_merged_config): + valid_config_file_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml") + additional_override_config_file_path = os.path.join( + get_data_dir, "sample_additional_config_for_merge.yaml" + ) + os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = valid_config_file_path + assert ( + expected_merged_config + == SageMakerConfig(additional_config_paths=[additional_override_config_file_path]).config + ) + os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = additional_override_config_file_path + assert expected_merged_config == SageMakerConfig().config + del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] + + +def test_s3_config_file( + config_file_as_yaml, valid_config_with_all_the_scopes, base_config_with_schema, s3_resource_mock +): + config_file_bucket = "config-file-bucket" + config_file_s3_prefix = "config/config.yaml" + list_file_entry_mock = Mock() + list_file_entry_mock.key = config_file_s3_prefix + s3_resource_mock.Bucket(name=config_file_bucket).objects.filter( + Prefix=config_file_s3_prefix + ).all.return_value = [list_file_entry_mock] + response_body_mock = MagicMock() + response_body_mock.read.return_value = config_file_as_yaml.encode("utf-8") + s3_resource_mock.Object(config_file_bucket, config_file_s3_prefix).get.return_value = { + "Body": response_body_mock + } + config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) + expected_config = base_config_with_schema + expected_config["SageMaker"] = valid_config_with_all_the_scopes + assert ( + expected_config + == SageMakerConfig( + additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock + ).config + ) + + +def test_config_factory_when_default_s3_config_file_is_not_found(s3_resource_mock): + config_file_bucket = "config-file-bucket" + config_file_s3_prefix = "config/config.yaml" + # Return empty list during list operation + s3_resource_mock.Bucket(name=config_file_bucket).objects.filter( + Prefix=config_file_s3_prefix + ).all.return_value = [] + config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) + with pytest.raises(ValueError): + SageMakerConfig(additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock) + + +def test_s3_config_file_when_uri_provided_corresponds_to_a_path( + config_file_as_yaml, + valid_config_with_all_the_scopes, + base_config_with_schema, + s3_resource_mock, +): + config_file_bucket = "config-file-bucket" + config_file_s3_prefix = "config" + list_of_files = ["/config.yaml", "/something.txt", "/README.MD"] + list_s3_files_mock = [] + for file in list_of_files: + entry_mock = Mock() + entry_mock.key = config_file_s3_prefix + file + list_s3_files_mock.append(entry_mock) + s3_resource_mock.Bucket(name=config_file_bucket).objects.filter( + Prefix=config_file_s3_prefix + ).all.return_value = list_s3_files_mock + response_body_mock = MagicMock() + response_body_mock.read.return_value = config_file_as_yaml.encode("utf-8") + s3_resource_mock.Object(config_file_bucket, config_file_s3_prefix).get.return_value = { + "Body": response_body_mock + } + config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) + expected_config = base_config_with_schema + expected_config["SageMaker"] = valid_config_with_all_the_scopes + assert ( + expected_config + == SageMakerConfig( + additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock + ).config + ) + + +def test_merge_of_s3_default_config_file_and_regular_config_file( + get_data_dir, expected_merged_config, s3_resource_mock +): + config_file_content_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml") + config_file_as_yaml = open(config_file_content_path, "r").read() + config_file_bucket = "config-file-bucket" + config_file_s3_prefix = "config/config.yaml" + config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) + list_file_entry_mock = Mock() + list_file_entry_mock.key = config_file_s3_prefix + s3_resource_mock.Bucket(name=config_file_bucket).objects.filter( + Prefix=config_file_s3_prefix + ).all.return_value = [list_file_entry_mock] + response_body_mock = MagicMock() + response_body_mock.read.return_value = config_file_as_yaml.encode("utf-8") + s3_resource_mock.Object(config_file_bucket, config_file_s3_prefix).get.return_value = { + "Body": response_body_mock + } + additional_override_config_file_path = os.path.join( + get_data_dir, "sample_additional_config_for_merge.yaml" + ) + os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = config_file_s3_uri + assert ( + expected_merged_config + == SageMakerConfig( + additional_config_paths=[additional_override_config_file_path], + s3_resource=s3_resource_mock, + ).config + ) + del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] diff --git a/tests/unit/sagemaker/config/test_config_schema.py b/tests/unit/sagemaker/config/test_config_schema.py new file mode 100644 index 0000000000..6ea0296bdf --- /dev/null +++ b/tests/unit/sagemaker/config/test_config_schema.py @@ -0,0 +1,139 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from jsonschema import validate, exceptions +import pytest +from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA + + +def _validate_config(base_config_with_schema, sagemaker_config): + config = base_config_with_schema + config["SageMaker"] = sagemaker_config + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_valid_schema_version(base_config_with_schema): + validate(base_config_with_schema, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_invalid_schema_version(): + config = {"SchemaVersion": "99.0"} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + # Also test missing schema version. + config = {} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_valid_config_with_all_the_features( + base_config_with_schema, valid_config_with_all_the_scopes +): + _validate_config(base_config_with_schema, valid_config_with_all_the_scopes) + + +def test_feature_group_schema(base_config_with_schema, valid_feature_group_config): + _validate_config(base_config_with_schema, {"FeatureGroup": valid_feature_group_config}) + + +def test_valid_edge_packaging_job_schema(base_config_with_schema, valid_edge_packaging_config): + _validate_config(base_config_with_schema, {"EdgePackagingJob": valid_edge_packaging_config}) + + +def test_valid_training_job_schema(base_config_with_schema, valid_training_job_config): + _validate_config(base_config_with_schema, {"TrainingJob": valid_training_job_config}) + + +def test_valid_processing_job_schema(base_config_with_schema, valid_processing_job_config): + _validate_config(base_config_with_schema, {"ProcessingJob": valid_processing_job_config}) + + +def test_valid_model_package_schema(base_config_with_schema, valid_model_package_config): + _validate_config(base_config_with_schema, {"ModelPackage": valid_model_package_config}) + + +def test_valid_model_schema(base_config_with_schema, valid_model_config): + _validate_config(base_config_with_schema, {"Model": valid_model_config}) + + +def test_valid_pipeline_schema(base_config_with_schema, valid_pipeline_config): + _validate_config(base_config_with_schema, {"Pipeline": valid_pipeline_config}) + + +def test_valid_compilation_job_schema(base_config_with_schema, valid_compilation_job_config): + _validate_config(base_config_with_schema, {"CompilationJob": valid_compilation_job_config}) + + +def test_valid_transform_job_schema(base_config_with_schema, valid_transform_job_config): + _validate_config(base_config_with_schema, {"TransformJob": valid_transform_job_config}) + + +def test_valid_automl_schema(base_config_with_schema, valid_automl_config): + _validate_config(base_config_with_schema, {"AutoML": valid_automl_config}) + + +def test_valid_endpoint_config_schema(base_config_with_schema, valid_endpointconfig_config): + _validate_config(base_config_with_schema, {"EndpointConfig": valid_endpointconfig_config}) + + +def test_valid_monitoring_schedule_schema( + base_config_with_schema, valid_monitoring_schedule_config +): + _validate_config( + base_config_with_schema, {"MonitoringSchedule": valid_monitoring_schedule_config} + ) + + +def test_invalid_training_job_schema(base_config_with_schema, valid_iam_role_arn, valid_vpc_config): + # Changing key names + training_job_config = { + "EnableNetworkIsolation1": True, + "OutputDataConfig1": {"KmsKeyId": "somekmskey"}, + "ResourceConfig1": {"VolumeKmsKeyId": "somevolumekmskey"}, + "RoleArn1": valid_iam_role_arn, + "VpcConfig1": valid_vpc_config, + } + config = base_config_with_schema + config["SageMaker"] = {"TrainingJob": training_job_config} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_invalid_edge_packaging_job_schema(base_config_with_schema, valid_iam_role_arn): + # Using invalid keys + edge_packaging_job_config = { + "OutputConfig1": {"KmsKeyId": "somekeyid"}, + "RoleArn1": valid_iam_role_arn, + } + config = base_config_with_schema + config["SageMaker"] = {"EdgePackagingJob": edge_packaging_job_config} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_invalid_feature_group_schema(base_config_with_schema): + s3_storage_config = {"KmsKeyId": "somekmskeyid"} + security_storage_config = {"KmsKeyId": "someotherkmskeyid"} + # Online store doesn't have S3StorageConfig and similarly + # Offline store doesn't have SecurityConfig + online_store_config = {"S3StorageConfig": security_storage_config} + offline_store_config = {"SecurityConfig": s3_storage_config} + feature_group_config = { + "OnlineStoreConfig": online_store_config, + "OfflineStoreConfig": offline_store_config, + } + config = base_config_with_schema + config["SageMaker"] = {"FeatureGroup": feature_group_config} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) From 6c9bc6ade693f94220d85a0e4fae1a9f1c4ac3ef Mon Sep 17 00:00:00 2001 From: rubanh Date: Tue, 21 Feb 2023 14:58:48 -0500 Subject: [PATCH 02/40] intelligent defaults - tags and encryption (#842) * feature: sagemaker config - support tags for all APIs * feature: sagemaker config - support EnableInterContainerTrafficEncryption for relevant APIs --------- Co-authored-by: Ruban Hussain --- src/sagemaker/automl/automl.py | 14 +- src/sagemaker/automl/candidate_estimator.py | 18 +- src/sagemaker/config/config_schema.py | 85 +++- src/sagemaker/estimator.py | 12 +- src/sagemaker/local/local_session.py | 26 +- .../model_monitor/model_monitoring.py | 51 ++- src/sagemaker/processing.py | 9 +- src/sagemaker/session.py | 327 ++++++++++++++- src/sagemaker/utils.py | 62 +++ src/sagemaker/workflow/pipeline.py | 8 + tests/unit/sagemaker/automl/test_auto_ml.py | 9 + .../sagemaker/huggingface/test_estimator.py | 8 + .../sagemaker/huggingface/test_processing.py | 6 + .../monitor/test_clarify_model_monitor.py | 21 + .../monitor/test_model_monitoring.py | 20 + .../sagemaker/tensorflow/test_estimator.py | 17 + .../test_huggingface_pytorch_compiler.py | 8 + .../test_huggingface_tensorflow_compiler.py | 8 + .../test_pytorch_compiler.py | 8 + .../test_tensorflow_compiler.py | 8 + tests/unit/sagemaker/workflow/test_airflow.py | 12 + .../sagemaker/wrangler/test_processing.py | 6 + tests/unit/test_algorithm.py | 7 + tests/unit/test_chainer.py | 9 + tests/unit/test_estimator.py | 9 + tests/unit/test_mxnet.py | 9 + tests/unit/test_processing.py | 23 +- tests/unit/test_pytorch.py | 9 + tests/unit/test_rl.py | 9 + tests/unit/test_session.py | 381 ++++++++++++++++++ tests/unit/test_sklearn.py | 9 + tests/unit/test_tuner.py | 8 + tests/unit/test_utils.py | 132 +++++- tests/unit/test_xgboost.py | 8 + 34 files changed, 1328 insertions(+), 28 deletions(-) diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 7701eb7fa8..65e3f1b346 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -19,6 +19,9 @@ from sagemaker import Model, PipelineModel from sagemaker.automl.candidate_estimator import CandidateEstimator +from sagemaker.config.config_schema import ( + PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, +) from sagemaker.job import _Job from sagemaker.session import Session from sagemaker.utils import name_from_base @@ -106,7 +109,7 @@ def __init__( compression_type: Optional[str] = None, sagemaker_session: Optional[Session] = None, volume_kms_key: Optional[str] = None, - encrypt_inter_container_traffic: Optional[bool] = False, + encrypt_inter_container_traffic: Optional[bool] = None, vpc_config: Optional[Dict[str, List]] = None, problem_type: Optional[str] = None, max_candidates: Optional[int] = None, @@ -182,7 +185,6 @@ def __init__( self.base_job_name = base_job_name self.compression_type = compression_type self.volume_kms_key = volume_kms_key - self.encrypt_inter_container_traffic = encrypt_inter_container_traffic self.vpc_config = vpc_config self.problem_type = problem_type self.max_candidate = max_candidates @@ -205,6 +207,12 @@ def __init__( self._best_candidate = None self.sagemaker_session = sagemaker_session or Session() + self.encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config( + direct_input=encrypt_inter_container_traffic, + config_path=PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, + default_value=False, + ) + self._check_problem_type_and_job_objective(self.problem_type, self.job_objective) @runnable_by_pipeline @@ -276,6 +284,8 @@ def attach(cls, auto_ml_job_name, sagemaker_session=None): volume_kms_key=auto_ml_job_desc.get("AutoMLJobConfig", {}) .get("SecurityConfig", {}) .get("VolumeKmsKeyId"), + # Do not override encrypt_inter_container_traffic from config because this info + # is pulled from an existing automl job encrypt_inter_container_traffic=auto_ml_job_desc.get("AutoMLJobConfig", {}) .get("SecurityConfig", {}) .get("EnableInterContainerTrafficEncryption", False), diff --git a/src/sagemaker/automl/candidate_estimator.py b/src/sagemaker/automl/candidate_estimator.py index a104132fe9..4231c64e21 100644 --- a/src/sagemaker/automl/candidate_estimator.py +++ b/src/sagemaker/automl/candidate_estimator.py @@ -16,6 +16,7 @@ from six import string_types from sagemaker import Session +from sagemaker.config.config_schema import PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION from sagemaker.job import _Job from sagemaker.utils import name_from_base @@ -72,7 +73,8 @@ def fit( inputs, candidate_name=None, volume_kms_key=None, - encrypt_inter_container_traffic=False, + # default of False for training job, checked inside function + encrypt_inter_container_traffic=None, vpc_config=None, wait=True, logs=True, @@ -87,7 +89,8 @@ def fit( volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to the ML compute instance(s). encrypt_inter_container_traffic (bool): To encrypt all communications between ML compute - instances in distributed training. Default: False. + instances in distributed training. If not passed, will be fetched from + sagemaker_config. Default: False. vpc_config (dict): Specifies a VPC that jobs and hosted models have access to. Control access to and from training and model containers by configuring the VPC wait (bool): Whether the call should wait until all jobs completes (default: True). @@ -131,12 +134,21 @@ def fit( base_name = "sagemaker-automl-training-rerun" step_name = name_from_base(base_name) step["name"] = step_name + + # Check training_job config not auto_ml_job config because this function calls + # training job API + _encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config( + direct_input=encrypt_inter_container_traffic, + config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + default_value=False, + ) + train_args = self._get_train_args( desc, channels, step_name, volume_kms_key, - encrypt_inter_container_traffic, + _encrypt_inter_container_traffic, vpc_config, ) self.sagemaker_session.train(**train_args) diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 5a80ccf10f..27d0bed94d 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -19,6 +19,9 @@ VOLUME_KMS_KEY_ID = "VolumeKmsKeyId" KMS_KEY_ID = "KmsKeyId" ROLE_ARN = "RoleArn" +TAGS = "Tags" +KEY = "Key" +VALUE = "Value" EXECUTION_ROLE_ARN = "ExecutionRoleArn" CLUSTER_ROLE_ARN = "ClusterRoleArn" VPC_CONFIG = "VpcConfig" @@ -73,6 +76,35 @@ TYPE = "type" OBJECT = "object" ADDITIONAL_PROPERTIES = "additionalProperties" +ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption" + + +def _simple_path(*args: str): + """Appends an arbitrary number of strings to use as path constants""" + return ".".join(args) + + +# Paths for reference elsewhere in the SDK. +# Names include the schema version since the paths could change with other schema versions +PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, + MONITORING_SCHEDULE, + MONITORING_SCHEDULE_CONFIG, + MONITORING_JOB_DEFINITION, + NETWORK_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, +) +PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, AUTO_ML, SECURITY_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) +PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) +PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) + + SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = { "$schema": "https://json-schema.org/draft/2020-12/schema", TYPE: OBJECT, @@ -164,6 +196,31 @@ } }, }, + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_Tag.html + "tags": { + TYPE: "array", + "items": { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + KEY: { + TYPE: "string", + "pattern": r"^[\w\s\d_.:/=+\-@]*$", + "minLength": 1, + "maxLength": 128, + }, + VALUE: { + TYPE: "string", + "pattern": r"^[\w\s\d_.:/=+\-@]*$", + "minLength": 0, + "maxLength": 256, + }, + }, + }, + "minItems": 0, + "maxItems": 50, + }, + SUBNETS: {TYPE: "array", "items": {"$ref": "#/definitions/subnet"}}, }, PROPERTIES: { SCHEMA_VERSION: { @@ -219,6 +276,7 @@ }, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Monitoring Schedule @@ -257,6 +315,9 @@ TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: { + TYPE: "boolean" + }, ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, }, @@ -265,7 +326,8 @@ }, } }, - } + }, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Endpoint Config @@ -302,6 +364,7 @@ TYPE: "array", "items": {"$ref": "#/definitions/productionVariant"}, }, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Auto ML @@ -318,6 +381,9 @@ TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: { + TYPE: "boolean" + }, VOLUME_KMS_KEY_ID: { TYPE: "string", }, @@ -332,6 +398,7 @@ PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Transform Job @@ -355,6 +422,7 @@ ADDITIONAL_PROPERTIES: False, PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, }, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Compilation Job @@ -371,6 +439,7 @@ }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Pipeline @@ -378,7 +447,10 @@ PIPELINE: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {ROLE_ARN: {"$ref": "#/definitions/roleArn"}}, + PROPERTIES: { + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + TAGS: {"$ref": "#/definitions/tags"}, + }, }, # Model # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html @@ -389,6 +461,7 @@ ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, EXECUTION_ROLE_ARN: {"$ref": "#/definitions/roleArn"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Model Package @@ -407,7 +480,8 @@ }, VALIDATION_ROLE: {"$ref": "#/definitions/roleArn"}, }, - } + }, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Processing Job @@ -420,6 +494,7 @@ TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {TYPE: "boolean"}, ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, }, @@ -445,6 +520,7 @@ }, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Training Job @@ -453,6 +529,7 @@ TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {TYPE: "boolean"}, ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"}, OUTPUT_DATA_CONFIG: { TYPE: OBJECT, @@ -466,6 +543,7 @@ }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + TAGS: {"$ref": "#/definitions/tags"}, }, }, # Edge Packaging Job @@ -480,6 +558,7 @@ PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + TAGS: {"$ref": "#/definitions/tags"}, }, }, }, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index e012be566a..511f96c027 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -29,6 +29,7 @@ import sagemaker from sagemaker import git_utils, image_uris, vpc_utils from sagemaker.analytics import TrainingJobAnalytics +from sagemaker.config.config_schema import PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION from sagemaker.debugger import ( # noqa: F401 # pylint: disable=unused-import DEBUGGER_FLAG, DebuggerHookConfig, @@ -133,7 +134,7 @@ def __init__( model_uri: Optional[str] = None, model_channel_name: Union[str, PipelineVariable] = "model", metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, - encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None, use_spot_instances: Union[bool, PipelineVariable] = False, max_wait: Optional[Union[int, PipelineVariable]] = None, checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, @@ -598,7 +599,12 @@ def __init__( training_repository_credentials_provider_arn ) - self.encrypt_inter_container_traffic = encrypt_inter_container_traffic + self.encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config( + direct_input=encrypt_inter_container_traffic, + config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + default_value=False, + ) + self.use_spot_instances = use_spot_instances self.max_wait = max_wait self.checkpoint_s3_uri = checkpoint_s3_uri @@ -2168,6 +2174,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): # encrypt_inter_container_traffic may be a pipeline variable place holder object # which is parsed in execution time + # This does not check config because the EstimatorBase constuctor already did that check if estimator.encrypt_inter_container_traffic: train_args[ "encrypt_inter_container_traffic" @@ -2745,6 +2752,7 @@ def __init__( model_uri=model_uri, model_channel_name=model_channel_name, metric_definitions=metric_definitions, + # Does not check sagemaker config because EstimatorBase will do that check encrypt_inter_container_traffic=encrypt_inter_container_traffic, use_spot_instances=use_spot_instances, max_wait=max_wait, diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index e791e03b4e..5247fe1a45 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -21,6 +21,7 @@ import boto3 from botocore.exceptions import ClientError +from sagemaker.config.config import SageMakerConfig from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import get_docker_host from sagemaker.local.entities import ( @@ -599,7 +600,12 @@ class LocalSession(Session): """ def __init__( - self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False + self, + boto_session=None, + default_bucket=None, + s3_endpoint_url=None, + disable_local_code=False, + sagemaker_config: SageMakerConfig = None, ): """Create a Local SageMaker Session. @@ -619,13 +625,22 @@ def __init__( # discourage external use: self._disable_local_code = disable_local_code - super(LocalSession, self).__init__(boto_session=boto_session, default_bucket=default_bucket) + super(LocalSession, self).__init__( + boto_session=boto_session, + default_bucket=default_bucket, + sagemaker_config=sagemaker_config, + ) if platform.system() == "Windows": logger.warning("Windows Support for Local Mode is Experimental") def _initialize( - self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs + self, + boto_session, + sagemaker_client, + sagemaker_runtime_client, + sagemaker_config: SageMakerConfig = None, + **kwargs ): # pylint: disable=unused-argument """Initialize this Local SageMaker Session. @@ -671,6 +686,11 @@ def _initialize( if self._disable_local_code and "local" in self.config: self.config["local"]["local_code"] = False + if sagemaker_config: + self.sagemaker_config = sagemaker_config + else: + self.sagemaker_config = SageMakerConfig(s3_resource=self.boto_session) + def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"): """A no-op method meant to override the sagemaker client. diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index e831673471..4f7b73ab04 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -31,6 +31,12 @@ from botocore.exceptions import ClientError from sagemaker import image_uris, s3 +from sagemaker.config.config_schema import ( + SAGEMAKER, + MONITORING_SCHEDULE, + TAGS, + PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, +) from sagemaker.exceptions import UnexpectedStatusException from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations, Statistics from sagemaker.model_monitor.monitoring_alert import ( @@ -159,7 +165,6 @@ def __init__( self.sagemaker_session = sagemaker_session or Session() self.env = env self.tags = tags - self.network_config = network_config self.baselining_jobs = [] self.latest_baselining_job = None @@ -168,6 +173,13 @@ def __init__( self.monitoring_schedule_name = None self.job_definition_name = None + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + network_config, + "encrypt_inter_container_traffic", + PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, + ) + def run_baseline( self, baseline_inputs, output, arguments=None, wait=True, logs=True, job_name=None ): @@ -482,6 +494,8 @@ def update_monitoring_schedule( network_config_dict = None if self.network_config is not None: network_config_dict = self.network_config._to_request_dict() + # Do not need to check config because that check is done inside + # self.sagemaker_session.update_monitoring_schedule self.sagemaker_session.update_monitoring_schedule( monitoring_schedule_name=self.monitoring_schedule_name, @@ -1381,10 +1395,27 @@ def _create_monitoring_schedule_from_job_definition( monitoring_schedule_config["ScheduleConfig"] = { "ScheduleExpression": schedule_cron_expression } + all_tags = self.sagemaker_session._append_sagemaker_config_tags( + self.tags, "{}.{}.{}".format(SAGEMAKER, MONITORING_SCHEDULE, TAGS) + ) + + _enable_inter_container_traffic_encryption_from_config = ( + self.sagemaker_session.resolve_value_from_config( + config_path=PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION + ) + ) + if _enable_inter_container_traffic_encryption_from_config is not None: + # Not checking 'self.network_config' for 'enable_network_isolation' because that + # wasnt used here before this config value was set. Unclear whether there was a + # specific reason for that omission. + monitoring_schedule_config["MonitoringJobDefinition"]["NetworkConfig"][ + "EnableInterContainerTrafficEncryption" + ] = _enable_inter_container_traffic_encryption_from_config + self.sagemaker_session.sagemaker_client.create_monitoring_schedule( MonitoringScheduleName=monitor_schedule_name, MonitoringScheduleConfig=monitoring_schedule_config, - Tags=self.tags or [], + Tags=all_tags or [], ) def _upload_and_convert_to_processing_input(self, source, destination, name): @@ -1446,6 +1477,20 @@ def _update_monitoring_schedule(self, job_definition_name, schedule_cron_express monitoring_schedule_config["ScheduleConfig"] = { "ScheduleExpression": schedule_cron_expression } + + _enable_inter_container_traffic_encryption_from_config = ( + self.sagemaker_session.resolve_value_from_config( + config_path=PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION + ) + ) + if _enable_inter_container_traffic_encryption_from_config is not None: + # Not checking 'self.network_config' for 'enable_network_isolation' because that + # wasnt used here before this config value was checked. Unclear whether there was a + # specific reason for that omission. + monitoring_schedule_config["MonitoringJobDefinition"]["NetworkConfig"][ + "EnableInterContainerTrafficEncryption" + ] = _enable_inter_container_traffic_encryption_from_config + self.sagemaker_session.sagemaker_client.update_monitoring_schedule( MonitoringScheduleName=self.monitoring_schedule_name, MonitoringScheduleConfig=monitoring_schedule_config, @@ -1958,6 +2003,8 @@ def update_monitoring_schedule( network_config_dict = None if self.network_config is not None: network_config_dict = self.network_config._to_request_dict() + # Do not need to check config because that check is done inside + # self.sagemaker_session.update_monitoring_schedule if role is not None: self.role = role diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 65fa3c7dbc..7bc21bec38 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -30,6 +30,7 @@ from six.moves.urllib.parse import urlparse from six.moves.urllib.request import url2pathname from sagemaker import s3 +from sagemaker.config.config_schema import PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig @@ -130,7 +131,6 @@ def __init__( self.base_job_name = base_job_name self.env = env self.tags = tags - self.network_config = network_config self.jobs = [] self.latest_job = None @@ -144,6 +144,13 @@ def __init__( self.sagemaker_session = sagemaker_session or Session() + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + network_config, + "encrypt_inter_container_traffic", + PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, + ) + @runnable_by_pipeline def run( self, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5e3c788739..bc4b8d187b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -13,6 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import, print_function +import inspect import json import logging import os @@ -32,7 +33,27 @@ import sagemaker.logs from sagemaker import vpc_utils - +from sagemaker.config.config import SageMakerConfig +from sagemaker.config.config_schema import ( + AUTO_ML, + COMPILATION_JOB, + EDGE_PACKAGING_JOB, + ENDPOINT_CONFIG, + FEATURE_GROUP, + KEY, + HYPER_PARAMETER_TUNING_JOB, + SAGEMAKER, + MODEL, + MONITORING_SCHEDULE, + PROCESSING_JOB, + TAGS, + TRAINING_JOB, + TRANSFORM_JOB, + PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, + PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, + PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, +) from sagemaker._studio import _append_project_tags from sagemaker.deprecations import deprecated_class from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig @@ -43,6 +64,9 @@ secondary_training_status_message, sts_regional_endpoint, retries, + get_config_value, + get_nested_value, + set_nested_value, ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings @@ -50,7 +74,6 @@ LOGGER = logging.getLogger("sagemaker") NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json" - _STATUS_CODE_TABLE = { "COMPLETED": "Completed", "INPROGRESS": "InProgress", @@ -93,6 +116,7 @@ def __init__( default_bucket=None, settings=SessionSettings(), sagemaker_metrics_client=None, + sagemaker_config: SageMakerConfig = None, ): """Initialize a SageMaker ``Session``. @@ -139,6 +163,7 @@ def __init__( sagemaker_runtime_client=sagemaker_runtime_client, sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client, sagemaker_metrics_client=sagemaker_metrics_client, + sagemaker_config=sagemaker_config, ) def _initialize( @@ -148,6 +173,7 @@ def _initialize( sagemaker_runtime_client, sagemaker_featurestore_runtime_client, sagemaker_metrics_client, + sagemaker_config: SageMakerConfig = None, ): """Initialize this SageMaker Session. @@ -190,6 +216,11 @@ def _initialize( self.local_mode = False + if sagemaker_config: + self.sagemaker_config = sagemaker_config + else: + self.sagemaker_config = SageMakerConfig(s3_resource=self.boto_session) + @property def boto_region_name(self): """Placeholder docstring""" @@ -473,6 +504,216 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): else: raise + def _get_sagemaker_config_value(self, config_path: str): + """returns the value of the config at the path provided""" + return get_config_value(config_path, self.sagemaker_config.config) + + def _print_message_sagemaker_config_used(self, config_value, config_path): + """Informs the SDK user that a config value was substituted in automatically""" + print( + "[Sagemaker Config] config value {} at config path {}".format( + config_value, config_path + ), + "was automatically applied" + ) + + def _print_message_sagemaker_config_present_but_not_used( + self, direct_input, config_value, config_path + ): + """Informs the SDK user that a config value was not substituted in automatically despite + existing""" + print( + "[Sagemaker Config] value {} was specified,".format(direct_input), + "so config value {} at config path {} was not applied".format( + config_value, config_path + ) + ) + + def resolve_value_from_config( + self, direct_input=None, config_path: str = None, default_value=None + ): + """Makes a decision of which value is the right value for the caller to use while + incorporating info from the sagemaker config. + + Uses this order of prioritization: + (1) direct_input, (2) config value, (3) default_value, (4) None + + Args: + direct_input: the value that the caller of this method started with. Usually this is an + input to the caller's class or method + config_path (str): a string denoting the path to use to lookup the config value in the + sagemaker config + default_value: the value to use if not present elsewhere + + Returns: + The value that should be used by the caller + """ + config_value = self._get_sagemaker_config_value(config_path) + + if direct_input is not None: + if config_value is not None: + self._print_message_sagemaker_config_present_but_not_used( + direct_input, config_value, config_path + ) + # No print statement if there was nothing in the config, because nothing is + # being overridden + return direct_input + + if config_value is not None: + self._print_message_sagemaker_config_used(config_value, config_path) + return config_value + + return default_value + + def resolve_class_attribute_from_config( + self, clazz, instance, attribute: str, config_path: str, default_value=None + ): + """Takes an instance of a class and, if not already set, sets the instance's attribute to a + value fetched from the sagemaker_config or the default_value. + + Uses this order of prioritization to determine what the value of the attribute should be: + (1) current value of attribute, (2) config value, (3) default_value, (4) does not set it + + Args: + clazz: Class of 'instance'. Used to generate a new instance if the instance is None. + It is advised for the constructor of a given Class to set default values to + None; otherwise the constructor's non-None default value will be used instead + of any config value + instance (str): instance of the Class 'clazz' that has an attribute + of 'attribute' to set + attribute: attribute of the instance to set if not already set + config_path (str): a string denoting the path to use to lookup the config value in the + sagemaker config + default_value: the value to use if not present elsewhere + + Returns: + The updated class instance that should be used by the caller instead of the + 'instance' parameter that was passed in. + """ + config_value = self._get_sagemaker_config_value(config_path) + + if config_value is None and default_value is None: + # return instance unmodified. Could be None or populated + return instance + + if instance is None: + if clazz is None or not inspect.isclass(clazz): + return instance + # construct a new instance if the instance does not exist + instance = clazz() + + if not hasattr(instance, attribute): + raise TypeError( + "Unexpected structure of object.", + "Expected attribute {} to be present inside instance {} of class {}".format( + attribute, instance, clazz + ), + ) + + current_value = getattr(instance, attribute) + if current_value is None: + # only set value if object does not already have a value set + if config_value is not None: + setattr(instance, attribute, config_value) + self._print_message_sagemaker_config_used(config_value, config_path) + elif default_value is not None: + setattr(instance, attribute, default_value) + elif current_value is not None and config_value is not None: + self._print_message_sagemaker_config_present_but_not_used( + current_value, config_value, config_path + ) + + return instance + + def resolve_nested_dict_value_from_config( + self, + dictionary: dict, + nested_keys: list[str], + config_path: str, + default_value: object = None, + ): + """Takes a dictionary and, if not already set, sets the value for the provided list of + nested keys to the value fetched from the sagemaker_config or the default_value. + + Uses this order of prioritization to determine what the value of the attribute should be: + (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it + + Args: + dictionary: dict to update + nested_keys: path of keys at which the value should be checked (and set if needed) + config_path (str): a string denoting the path to use to lookup the config value in the + sagemaker config + default_value: the value to use if not present elsewhere + + Returns: + The updated dictionary that should be used by the caller instead of the + 'dictionary' parameter that was passed in. + """ + config_value = self._get_sagemaker_config_value(config_path) + + if config_value is None and default_value is None: + # if there is nothing to set, return early. And there is no need to traverse through + # the dictionary or add nested dicts to it + return dictionary + + try: + current_nested_value = get_nested_value(dictionary, nested_keys) + except ValueError as e: + logging.error("Failed to check dictionary for applying sagemaker config: %s", e) + return dictionary + + if current_nested_value is None: + # only set value if not already set + if config_value is not None: + dictionary = set_nested_value(dictionary, nested_keys, config_value) + self._print_message_sagemaker_config_used(config_value, config_path) + elif default_value is not None: + dictionary = set_nested_value(dictionary, nested_keys, default_value) + elif current_nested_value is not None and config_value is not None: + self._print_message_sagemaker_config_present_but_not_used( + current_nested_value, config_value, config_path + ) + + return dictionary + + def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): + """Appends tags specified in the sagemaker_config to the given list of tags. + + To minimize the chance of duplicate tags being applied, this is intended + to be used right before calls to sagemaker_client (rather + than during initialization of classes like EstimatorBase) + + Args: + tags: the list of tags to append to. + config_path_to_tags: the path to look up in the config + + Returns: + A potentially extended list of tags. + """ + config_tags = self._get_sagemaker_config_value(config_path_to_tags) + + if config_tags is None or len(config_tags) == 0: + return tags + + all_tags = tags or [] + for config_tag in config_tags: + config_tag_key = config_tag[KEY] + if not any(tag.get("Key", None) == config_tag_key for tag in all_tags): + # This check prevents new tags with duplicate keys from being added + # (to prevent API failure and/or overwriting of tags). If there is a conflict, + # the user-provided tag should take precedence over the config-provided tag. + # Note: this does not check user-provided tags for conflicts with other + # user-provided tags. + all_tags.append(config_tag) + + print( + "Appended tags from sagemaker_config to input.\n\texisting tags: {},".format(tags) + + "\n\ttags provided via sagemaker_config: {},".format(config_tags) + + "\n\tcombined tags: {}".format(all_tags) + ) + + return all_tags + def train( # noqa: C901 self, input_mode, @@ -490,7 +731,7 @@ def train( # noqa: C901 image_uri=None, training_image_config=None, algorithm_arn=None, - encrypt_inter_container_traffic=False, + encrypt_inter_container_traffic=None, use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, @@ -615,6 +856,16 @@ def train( # noqa: C901 str: ARN of the training job, if it is created. """ tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, TAGS) + ) + + _encrypt_inter_container_traffic = self.resolve_value_from_config( + direct_input=encrypt_inter_container_traffic, + config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + default_value=False, + ) + train_request = self._get_train_request( input_mode=input_mode, input_config=input_config, @@ -631,7 +882,7 @@ def train( # noqa: C901 image_uri=image_uri, training_image_config=training_image_config, algorithm_arn=algorithm_arn, - encrypt_inter_container_traffic=encrypt_inter_container_traffic, + encrypt_inter_container_traffic=_encrypt_inter_container_traffic, use_spot_instances=use_spot_instances, checkpoint_s3_uri=checkpoint_s3_uri, checkpoint_local_path=checkpoint_local_path, @@ -1002,6 +1253,16 @@ def process( * `TrialComponentDisplayName` is used for display in Studio. """ tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, PROCESSING_JOB, TAGS) + ) + + network_config = self.resolve_nested_dict_value_from_config( + network_config, + ["EnableInterContainerTrafficEncryption"], + PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, + ) + process_request = self._get_process_request( inputs=inputs, output_config=output_config, @@ -1248,12 +1509,20 @@ def create_monitoring_schedule( "Environment" ] = environment + network_config = self.resolve_nested_dict_value_from_config( + network_config, + ["EnableInterContainerTrafficEncryption"], + PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, + ) if network_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "NetworkConfig" ] = network_config tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, MONITORING_SCHEDULE, TAGS) + ) if tags is not None: monitoring_schedule_request["Tags"] = tags @@ -1528,10 +1797,17 @@ def update_monitoring_schedule( existing_network_config = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ].get("NetworkConfig") - if network_config is not None or existing_network_config is not None: + + _network_config = network_config or existing_network_config + _network_config = self.resolve_nested_dict_value_from_config( + _network_config, + ["EnableInterContainerTrafficEncryption"], + PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, + ) + if _network_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "NetworkConfig" - ] = (network_config or existing_network_config) + ] = _network_config LOGGER.info("Updating monitoring schedule with name: %s .", monitoring_schedule_name) LOGGER.debug( @@ -1848,6 +2124,13 @@ def auto_ml( for an automatic one-click Autopilot model deployment. Contains "AutoGenerateEndpointName" and "EndpointName" """ + + auto_ml_job_config = self.resolve_nested_dict_value_from_config( + auto_ml_job_config, + ["SecurityConfig", "EnableInterContainerTrafficEncryption"], + PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, + ) + auto_ml_job_request = self._get_auto_ml_request( input_config=input_config, output_config=output_config, @@ -1927,6 +2210,7 @@ def _get_auto_ml_request( auto_ml_job_request["ProblemType"] = problem_type tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML, TAGS)) if tags is not None: auto_ml_job_request["Tags"] = tags @@ -2122,6 +2406,9 @@ def compile_model( } tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, COMPILATION_JOB, TAGS) + ) if tags is not None: compilation_job_request["Tags"] = tags @@ -2162,6 +2449,9 @@ def package_model_for_edge( "CompilationJobName": compilation_job_name, } + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, EDGE_PACKAGING_JOB, TAGS) + ) if tags is not None: edge_packaging_job_request["Tags"] = tags if resource_key is not None: @@ -2352,6 +2642,9 @@ def tune( # noqa: C901 tune_request["WarmStartConfig"] = warm_start_config tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, HYPER_PARAMETER_TUNING_JOB, TAGS) + ) if tags is not None: tune_request["Tags"] = tags @@ -2454,6 +2747,9 @@ def _get_tuning_request( tune_request["WarmStartConfig"] = warm_start_config tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, HYPER_PARAMETER_TUNING_JOB, TAGS) + ) if tags is not None: tune_request["Tags"] = tags @@ -2903,6 +3199,10 @@ def transform( specifies the configurations related to the batch data capture for the transform job """ tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, TRANSFORM_JOB, TAGS) + ) + transform_request = self._get_transform_request( job_name=job_name, model_name=model_name, @@ -3028,6 +3328,8 @@ def create_model( str: Name of the Amazon SageMaker ``Model`` created. """ tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS)) + create_model_request = self._create_model_request( name=name, role=role, @@ -3359,6 +3661,9 @@ def create_endpoint_config( } tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) + ) if tags is not None: request["Tags"] = tags @@ -3428,6 +3733,9 @@ def create_endpoint_config_from_existing( existing_endpoint_config_desc["EndpointConfigArn"] ) request_tags = _append_project_tags(request_tags) + request_tags = self._append_sagemaker_config_tags( + request_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) + ) if request_tags: request["Tags"] = request_tags @@ -3984,6 +4292,9 @@ def endpoint_from_production_variants( """ config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) + ) if tags: config_options["Tags"] = tags if kms_key: @@ -4426,6 +4737,10 @@ def create_feature_group( Response dict from service. """ tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) + ) + kwargs = dict( FeatureGroupName=feature_group_name, RecordIdentifierFeatureName=record_identifier_name, diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 7da9ced131..9977de0add 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -176,6 +176,68 @@ def get_config_value(key_path, config): return current_section +def get_nested_value(dictionary: dict, nested_keys: list[str]): + """Returns a nested value from the given dictionary, and None if none present. + + Raises + ValueError if the dictionary structure does not match the nested_keys + """ + if ( + dictionary is not None + and isinstance(dictionary, dict) + and nested_keys is not None + and len(nested_keys) > 0 + ): + + current_section = dictionary + + for key in nested_keys[:-1]: + current_section = current_section.get(key, None) + if current_section is None: + # means the full path of nested_keys doesnt exist in the dictionary + # or the value was set to None + return None + if not isinstance(current_section, dict): + raise ValueError( + "Unexpected structure of dictionary.", + "Expected value of type dict at key '{}' but got '{}' for dict '{}'".format( + key, current_section, dictionary + ), + ) + return current_section.get(nested_keys[-1], None) + + return None + + +def set_nested_value(dictionary: dict, nested_keys: list[str], value_to_set: object): + """Sets a nested value inside the given dictionary and returns the new dictionary. Note: if + provided an unintended list of nested keys, this can overwrite an unexpected part of the dict. + Recommended to use after a check with get_nested_value first + """ + + if dictionary is None: + dictionary = {} + + if ( + dictionary is not None + and isinstance(dictionary, dict) + and nested_keys is not None + and len(nested_keys) > 0 + ): + current_section = dictionary + for key in nested_keys[:-1]: + if ( + key not in current_section + or current_section[key] is None + or not isinstance(current_section[key], dict) + ): + current_section[key] = {} + current_section = current_section[key] + + current_section[nested_keys[-1]] = value_to_set + return dictionary + + def get_short_version(framework_version): """Return short version in the format of x.x diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 95d0702ec8..82d89e8cc0 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -25,6 +25,11 @@ from sagemaker import s3 from sagemaker._studio import _append_project_tags +from sagemaker.config.config_schema import ( + SAGEMAKER, + PIPELINE, + TAGS, +) from sagemaker.session import Session from sagemaker.utils import retry_with_backoff from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep @@ -132,6 +137,9 @@ def create( logger.warning("Pipeline parallelism config is not supported in the local mode.") return self.sagemaker_session.sagemaker_client.create_pipeline(self, description) tags = _append_project_tags(tags) + tags = self.sagemaker_session._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, PIPELINE, TAGS) + ) kwargs = self._create_args(role_arn, description, parallelism_config) update_args( kwargs, diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index e68a019ce4..0e6901a9db 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -271,6 +271,15 @@ def sagemaker_session(): ) sms.list_candidates = Mock(name="list_candidates", return_value={"Candidates": []}) sms.sagemaker_client.list_tags = Mock(name="list_tags", return_value=LIST_TAGS_RESULT) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return sms diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 2d7261cdc6..03316250f4 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -75,6 +75,14 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index e7887cd794..869139b573 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -50,6 +50,12 @@ def sagemaker_session(): session_mock.upload_data = Mock(name="upload_data", return_value=MOCKED_S3_URI) session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = ROLE + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session_mock.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) return session_mock diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 1ca310a30a..7409d14430 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -413,6 +413,27 @@ def sagemaker_session(sagemaker_client): ) session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = ROLE_ARN + + session_mock._append_sagemaker_config_tags = Mock( + name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags + ) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session_mock.resolve_nested_dict_value_from_config = Mock( + name="resolve_nested_dict_value_from_config", + side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, + ) + session_mock.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return session_mock diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index 9c3f563d57..b627996b8e 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -461,6 +461,26 @@ def sagemaker_session(): }, ) session_mock.expand_role.return_value = ROLE + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session_mock.resolve_nested_dict_value_from_config = Mock( + name="resolve_nested_dict_value_from_config", + side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, + ) + session_mock.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + session_mock._append_sagemaker_config_tags = Mock( + name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags + ) + return session_mock diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index aaadbc98d5..634f2d102f 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -82,6 +82,23 @@ def sagemaker_session(): session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_nested_dict_value_from_config = Mock( + name="resolve_nested_dict_value_from_config", + side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, + ) + session.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 357ebdf1db..40e3241333 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -81,6 +81,14 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 4f048aa536..3e3aeb7e86 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -79,6 +79,14 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 8bbed0bbec..643dc6337c 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -80,6 +80,14 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index df4fdcc5a5..5f67a4c46b 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -85,6 +85,14 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index fa4b4d2e55..e53efbb30b 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -41,6 +41,18 @@ def sagemaker_session(): ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session._default_bucket = BUCKET_NAME + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session diff --git a/tests/unit/sagemaker/wrangler/test_processing.py b/tests/unit/sagemaker/wrangler/test_processing.py index f83ae74dfa..cce220ff21 100644 --- a/tests/unit/sagemaker/wrangler/test_processing.py +++ b/tests/unit/sagemaker/wrangler/test_processing.py @@ -40,6 +40,12 @@ def sagemaker_session(): settings=SessionSettings(), ) session_mock.expand_role.return_value = ROLE + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session_mock.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) return session_mock diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index 73970ebd25..d71268a7c7 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -909,6 +909,13 @@ def test_algorithm_enable_network_isolation_with_product_id(session): @patch("sagemaker.Session") def test_algorithm_encrypt_inter_container_traffic(session): + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["encrypt_inter_container_traffic"] = True session.sagemaker_client.describe_algorithm = Mock(return_value=response) diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index dbcedc1d99..00340d1c8b 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -73,6 +73,15 @@ def sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return session diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 341f5b48ae..f5912b17de 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -235,6 +235,15 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) sms.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) sms.upload_data = Mock(return_value=OUTPUT_PATH) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return sms diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 2395856acd..19dfbbfd28 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -100,6 +100,15 @@ def sagemaker_session(): session.wait_for_compilation_job = Mock(return_value=describe_compilation) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return session diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 34c530747d..0f156d8a0f 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -16,6 +16,7 @@ from mock import Mock, patch, MagicMock from packaging import version +from sagemaker import LocalSession from sagemaker.dataset_definition.inputs import ( S3Input, DatasetDefinition, @@ -79,6 +80,13 @@ def sagemaker_session(): session_mock.describe_processing_job = MagicMock( name="describe_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session_mock.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) + return session_mock @@ -102,6 +110,13 @@ def pipeline_session(): name="describe_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) session_mock.__class__ = PipelineSession + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session_mock.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) + return session_mock @@ -188,9 +203,8 @@ def test_sklearn_with_all_parameters( sagemaker_session.process.assert_called_with(**expected_args) -@patch("sagemaker.local.LocalSession.__init__", return_value=None) -def test_local_mode_disables_local_code_by_default(localsession_mock): - Processor( +def test_local_mode_disables_local_code_by_default(): + processor = Processor( image_uri="", role=ROLE, instance_count=1, @@ -199,7 +213,8 @@ def test_local_mode_disables_local_code_by_default(localsession_mock): # Most tests use a fixture for sagemaker_session for consistent behaviour, so this unit test # checks that the default initialization disables unsupported 'local_code' mode: - localsession_mock.assert_called_with(disable_local_code=True) + assert processor.sagemaker_session._disable_local_code + assert isinstance(processor.sagemaker_session, LocalSession) @patch("sagemaker.utils._botocore_resolver") diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 30b1251219..fd085d6590 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -82,6 +82,15 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return session diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index fea49a7548..50de3a47c4 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -75,6 +75,15 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return session diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e25d27fcd0..fe32ed5486 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import copy import datetime import io import logging @@ -820,6 +821,27 @@ def sagemaker_session(): } ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) ims.expand_role = Mock(return_value=EXPANDED_ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + ims.resolve_nested_dict_value_from_config = Mock( + name="resolve_nested_dict_value_from_config", + side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, + ) + ims.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) + return ims + + +@pytest.fixture() +def sagemaker_session_without_mocked_sagemaker_config(): + boto_mock = MagicMock(name="boto_session") + boto_mock.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = { + "Account": "123" + } + ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) + ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims @@ -3671,3 +3693,362 @@ def test_wait_for_inference_recommendations_job_invalid_log_level(sagemaker_sess ) assert "log_level must be either Quiet or Verbose" in str(error) + + +def test_append_sagemaker_config_tags(sagemaker_session): + tags_base = [ + {"Key": "tagkey4", "Value": "000"}, + {"Key": "tagkey5", "Value": "000"}, + ] + tags_duplicate = [ + {"Key": "tagkey1", "Value": "000"}, + {"Key": "tagkey2", "Value": "000"}, + ] + tags_none = None + tags_empty = [] + + # Helper to sort the lists so that the test is not dependent on order + def sort(tags): + return tags.sort(key=lambda tag: tag["Key"]) + + sagemaker_session._get_sagemaker_config_value = MagicMock( + return_value=[ + {"Key": "tagkey1", "Value": "tagvalue1"}, + {"Key": "tagkey2", "Value": "tagvalue2"}, + {"Key": "tagkey3", "Value": "tagvalue3"}, + ] + ) + + base_case = sagemaker_session._append_sagemaker_config_tags(tags_base, "DUMMY.CONFIG.PATH") + assert sort(base_case) == sort( + [ + {"Key": "tagkey1", "Value": "tagvalue1"}, + {"Key": "tagkey2", "Value": "tagvalue2"}, + {"Key": "tagkey3", "Value": "tagvalue3"}, + {"Key": "tagkey4", "Value": "000"}, + {"Key": "tagkey5", "Value": "000"}, + ] + ) + + duplicate_case = sagemaker_session._append_sagemaker_config_tags( + tags_duplicate, "DUMMY.CONFIG.PATH" + ) + assert sort(duplicate_case) == sort( + [ + {"Key": "tagkey1", "Value": "000"}, + {"Key": "tagkey2", "Value": "000"}, + {"Key": "tagkey3", "Value": "tagvalue3"}, + ] + ) + + none_case = sagemaker_session._append_sagemaker_config_tags(tags_none, "DUMMY.CONFIG.PATH") + assert sort(none_case) == sort( + [ + {"Key": "tagkey1", "Value": "tagvalue1"}, + {"Key": "tagkey2", "Value": "tagvalue2"}, + {"Key": "tagkey3", "Value": "tagvalue3"}, + ] + ) + + empty_case = sagemaker_session._append_sagemaker_config_tags(tags_empty, "DUMMY.CONFIG.PATH") + assert sort(empty_case) == sort( + [ + {"Key": "tagkey1", "Value": "tagvalue1"}, + {"Key": "tagkey2", "Value": "tagvalue2"}, + {"Key": "tagkey3", "Value": "tagvalue3"}, + ] + ) + + sagemaker_session._get_sagemaker_config_value = MagicMock(return_value=tags_none) + config_tags_none = sagemaker_session._append_sagemaker_config_tags( + tags_base, "DUMMY.CONFIG.PATH" + ) + assert sort(config_tags_none) == sort( + [ + {"Key": "tagkey4", "Value": "000"}, + {"Key": "tagkey5", "Value": "000"}, + ] + ) + + sagemaker_session._get_sagemaker_config_value = MagicMock(return_value=tags_empty) + config_tags_empty = sagemaker_session._append_sagemaker_config_tags( + tags_base, "DUMMY.CONFIG.PATH" + ) + assert sort(config_tags_empty) == sort( + [ + {"Key": "tagkey4", "Value": "000"}, + {"Key": "tagkey5", "Value": "000"}, + ] + ) + + +def test_resolve_value_from_config(sagemaker_session_without_mocked_sagemaker_config): + # using a shorter name for inside the test + ss = sagemaker_session_without_mocked_sagemaker_config + + # direct_input should be respected + ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "INPUT" + + ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" + + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" + + # Config or default values should be returned if no direct_input + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ss.resolve_value_from_config(None, None, "DEFAULT_VALUE") == "DEFAULT_VALUE" + + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ( + ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "DEFAULT_VALUE" + ) + + ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + assert ( + ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "CONFIG_VALUE" + ) + + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ss.resolve_value_from_config(None, None, None) is None + + # Different falsy direct_inputs + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ss.resolve_value_from_config("", "DUMMY.CONFIG.PATH", None) == "" + + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ss.resolve_value_from_config([], "DUMMY.CONFIG.PATH", None) == [] + + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ss.resolve_value_from_config(False, "DUMMY.CONFIG.PATH", None) is False + + ss._get_sagemaker_config_value = MagicMock(return_value=None) + assert ss.resolve_value_from_config({}, "DUMMY.CONFIG.PATH", None) == {} + + # Different falsy config_values + ss._get_sagemaker_config_value = MagicMock(return_value="") + assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == "" + + ss._get_sagemaker_config_value = MagicMock(return_value=[]) + assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == [] + + ss._get_sagemaker_config_value = MagicMock(return_value=False) + assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) is False + + ss._get_sagemaker_config_value = MagicMock(return_value={}) + assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == {} + + +@pytest.mark.parametrize( + "existing_value, config_value, default_value", + [ + ("EXISTING_VALUE", "CONFIG_VALUE", "DEFAULT_VALUE"), + (False, True, False), + (False, False, True), + (0, 1, 2), + ], +) +def test_resolve_class_attribute_from_config( + sagemaker_session_without_mocked_sagemaker_config, existing_value, config_value, default_value +): + # using a shorter name for inside the test + ss = sagemaker_session_without_mocked_sagemaker_config + + class TestClass(object): + def __init__(self, test_attribute=None, extra=None): + self.test_attribute = test_attribute + # the presence of an extra value that is set to None by default helps make sure a brand new + # TestClass object is being created only in the right scenarios + self.extra_attribute = extra + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + else: + return False + + dummy_config_path = ["DUMMY", "CONFIG", "PATH"] + + # with an existing config value + ss._get_sagemaker_config_value = MagicMock(return_value=config_value) + + # instance exists and has value; config has value + test_instance = TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") + assert ss.resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path + ) == TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") + + # instance exists but doesnt have value; config has value + test_instance = TestClass(extra="EXTRA_VALUE") + assert ss.resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path + ) == TestClass(test_attribute=config_value, extra="EXTRA_VALUE") + + # instance doesnt exist; config has value + test_instance = None + assert ss.resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path + ) == TestClass(test_attribute=config_value, extra=None) + + # wrong attribute used + test_instance = TestClass() + with pytest.raises(TypeError): + ss.resolve_class_attribute_from_config( + TestClass, test_instance, "other_attribute", dummy_config_path + ) + + # instance doesnt exist; clazz doesnt exist + test_instance = None + assert ( + ss.resolve_class_attribute_from_config( + None, test_instance, "test_attribute", dummy_config_path + ) + is None + ) + + # instance doesnt exist; clazz isnt a class + test_instance = None + assert ( + ss.resolve_class_attribute_from_config( + "CLASS", test_instance, "test_attribute", dummy_config_path + ) + is None + ) + + # without an existing config value + ss._get_sagemaker_config_value = MagicMock(return_value=None) + + # instance exists but doesnt have value; config doesnt have value + test_instance = TestClass(extra="EXTRA_VALUE") + assert ss.resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path + ) == TestClass(test_attribute=None, extra="EXTRA_VALUE") + + # instance exists but doesnt have value; config doesnt have value; default_value passed in + test_instance = TestClass(extra="EXTRA_VALUE") + assert ss.resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path, default_value=default_value + ) == TestClass(test_attribute=default_value, extra="EXTRA_VALUE") + + # instance doesnt exist; config doesnt have value + test_instance = None + assert ( + ss.resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path + ) + is None + ) + + # instance doesnt exist; config doesnt have value; default_value passed in + test_instance = None + assert ss.resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path, default_value=default_value + ) == TestClass(test_attribute=default_value, extra=None) + + +def test_resolve_nested_dict_value_from_config(sagemaker_session_without_mocked_sagemaker_config): + # using a shorter name for inside the test + ss = sagemaker_session_without_mocked_sagemaker_config + + dummy_config_path = ["DUMMY", "CONFIG", "PATH"] + + # with an existing config value + ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + + # happy cases: return existing dict with existing values + assert ss.resolve_nested_dict_value_from_config( + {"local": {"region_name": "us-west-2", "port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value="DEFAULT_VALUE", + ) == {"local": {"region_name": "us-west-2", "port": "123"}} + assert ss.resolve_nested_dict_value_from_config( + {"local": {"region_name": "us-west-2", "port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value=None, + ) == {"local": {"region_name": "us-west-2", "port": "123"}} + + # happy case: return dict with config_value when it wasnt set in dict or was None + assert ss.resolve_nested_dict_value_from_config( + {"local": {"port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value="DEFAULT_VALUE", + ) == {"local": {"region_name": "CONFIG_VALUE", "port": "123"}} + assert ss.resolve_nested_dict_value_from_config( + {}, ["local", "region_name"], dummy_config_path, default_value=None + ) == {"local": {"region_name": "CONFIG_VALUE"}} + assert ss.resolve_nested_dict_value_from_config( + None, ["local", "region_name"], dummy_config_path, default_value=None + ) == {"local": {"region_name": "CONFIG_VALUE"}} + assert ss.resolve_nested_dict_value_from_config( + { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": {"nest2": {"nest3": {"nest4a": "value", "nest4b": None}}}, + }, + ["nest1", "nest2", "nest3", "nest4b", "does_not", "exist"], + dummy_config_path, + default_value="DEFAULT_VALUE", + ) == { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": { + "nest2": { + "nest3": {"nest4a": "value", "nest4b": {"does_not": {"exist": "CONFIG_VALUE"}}} + } + }, + } + + # edge case: doesnt overwrite non-None and non-dict values + dictionary = { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": {"nest2": {"nest3": {"nest4a": "value", "nest4b": None}}}, + } + dictionary_copy = copy.deepcopy(dictionary) + assert ( + ss.resolve_nested_dict_value_from_config( + dictionary, + ["nest1", "nest2", "nest3", "nest4a", "does_not", "exist"], + dummy_config_path, + default_value="DEFAULT_VALUE", + ) + == dictionary_copy + ) + assert ( + ss.resolve_nested_dict_value_from_config( + dictionary, ["other", "key"], dummy_config_path, default_value="DEFAULT_VALUE" + ) + == dictionary_copy + ) + + # without an existing config value + ss._get_sagemaker_config_value = MagicMock(return_value=None) + + # happy case: return dict with default_value when it wasnt set in dict and in config + assert ss.resolve_nested_dict_value_from_config( + {"local": {"port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value="DEFAULT_VALUE", + ) == {"local": {"region_name": "DEFAULT_VALUE", "port": "123"}} + + # happy case: return dict as-is when value wasnt set in dict, in config, and as default + assert ss.resolve_nested_dict_value_from_config( + {"local": {"port": "123"}}, ["local", "region_name"], dummy_config_path, default_value=None + ) == {"local": {"port": "123"}} + assert ( + ss.resolve_nested_dict_value_from_config( + {}, ["local", "region_name"], dummy_config_path, default_value=None + ) + == {} + ) + assert ( + ss.resolve_nested_dict_value_from_config( + None, ["local", "region_name"], dummy_config_path, default_value=None + ) + is None + ) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index d16e887d18..0d3dff2159 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -77,6 +77,15 @@ def sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return session diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 77c94188d5..b530de1512 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -66,6 +66,14 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + # For the purposes of unit tests, no values should be fetched from sagemaker config + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return sms diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index be15f0f932..2e57328339 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -14,6 +14,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import copy import shutil import tarfile from datetime import datetime @@ -55,6 +56,133 @@ def test_get_config_value(): assert sagemaker.utils.get_config_value("other.key", None) is None +def test_get_nested_value(): + dictionary = { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": {"nest2": {"nest3": {"nest4": {"nest5a": "value", "nest5b": None}}}}, + } + + # happy cases: keys and values exist + assert sagemaker.utils.get_nested_value(dictionary, ["local", "region_name"]) == "us-west-2" + assert sagemaker.utils.get_nested_value(dictionary, ["local"]) == { + "region_name": "us-west-2", + "port": "123", + } + assert ( + sagemaker.utils.get_nested_value(dictionary, ["nest1", "nest2", "nest3", "nest4", "nest5a"]) + == "value" + ) + + # edge cases: non-existing keys + assert sagemaker.utils.get_nested_value(dictionary, ["local", "new_depth_1_key"]) is None + assert sagemaker.utils.get_nested_value(dictionary, ["new_depth_0_key"]) is None + assert ( + sagemaker.utils.get_nested_value(dictionary, ["new_depth_0_key", "new_depth_1_key"]) is None + ) + assert ( + sagemaker.utils.get_nested_value( + dictionary, ["nest1", "nest2", "nest3", "nest4", "nest5b", "does_not", "exist"] + ) + is None + ) + + # edge case: specified nested_keys contradict structure of dict + with pytest.raises(ValueError): + sagemaker.utils.get_nested_value( + dictionary, ["nest1", "nest2", "nest3", "nest4", "nest5a", "does_not", "exist"] + ) + + # edge cases: non-actionable inputs + assert sagemaker.utils.get_nested_value(None, ["other", "key"]) is None + assert sagemaker.utils.get_nested_value("not_a_dict", ["other", "key"]) is None + assert sagemaker.utils.get_nested_value(dictionary, None) is None + assert sagemaker.utils.get_nested_value(dictionary, []) is None + + +def test_set_nested_value(): + dictionary = { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": {"nest2": {"nest3": {"nest4": {"nest5a": "value", "nest5b": None}}}}, + "existing_depth_0_key": None, + } + dictionary_copy = copy.deepcopy(dictionary) + + # happy cases: change existing values + dictionary_copy["local"]["region_name"] = "region1" + assert ( + sagemaker.utils.set_nested_value(dictionary, ["local", "region_name"], "region1") + == dictionary_copy + ) + + dictionary_copy["existing_depth_0_key"] = {"new_key": "new_value"} + assert ( + sagemaker.utils.set_nested_value( + dictionary, ["existing_depth_0_key"], {"new_key": "new_value"} + ) + == dictionary_copy + ) + + dictionary_copy["nest1"]["nest2"]["nest3"]["nest4"]["nest5a"] = "value2" + assert ( + sagemaker.utils.set_nested_value( + dictionary, ["nest1", "nest2", "nest3", "nest4", "nest5a"], "value2" + ) + == dictionary_copy + ) + + # happy cases: add new keys and values + dictionary_copy["local"]["new_depth_1_key"] = "value" + assert ( + sagemaker.utils.set_nested_value(dictionary, ["local", "new_depth_1_key"], "value") + == dictionary_copy + ) + + dictionary_copy["new_depth_0_key"] = "value" + assert ( + sagemaker.utils.set_nested_value(dictionary, ["new_depth_0_key"], "value") + == dictionary_copy + ) + + dictionary_copy["new_depth_0_key_2"] = {"new_depth_1_key_2": "value"} + assert ( + sagemaker.utils.set_nested_value( + dictionary, ["new_depth_0_key_2", "new_depth_1_key_2"], "value" + ) + == dictionary_copy + ) + + dictionary_copy["nest1"]["nest2"]["nest3"]["nest4"]["nest5b"] = {"does_not": {"exist": "value"}} + assert ( + sagemaker.utils.set_nested_value( + dictionary, ["nest1", "nest2", "nest3", "nest4", "nest5b", "does_not", "exist"], "value" + ) + == dictionary_copy + ) + + # edge case: overwrite non-dict value + dictionary["nest1"]["nest2"]["nest3"]["nest4"]["nest5a"] = "value2" + dictionary_copy["nest1"]["nest2"]["nest3"]["nest4"]["nest5a"] = {"does_not": {"exist": "value"}} + assert ( + sagemaker.utils.set_nested_value( + dictionary, ["nest1", "nest2", "nest3", "nest4", "nest5a", "does_not", "exist"], "value" + ) + == dictionary_copy + ) + + # edge case: dict does not exist + assert sagemaker.utils.set_nested_value(None, ["other", "key"], "value") == { + "other": {"key": "value"} + } + + # edge cases: non-actionable inputs + dictionary_copy_2 = copy.deepcopy(dictionary) + assert sagemaker.utils.set_nested_value("not_a_dict", ["other", "key"], "value") == "not_a_dict" + assert sagemaker.utils.set_nested_value(dictionary, None, "value") == dictionary_copy_2 + assert sagemaker.utils.set_nested_value(dictionary, [], "value") == dictionary_copy_2 + + def test_get_short_version(): assert sagemaker.utils.get_short_version("1.13.1") == "1.13" assert sagemaker.utils.get_short_version("1.13") == "1.13" @@ -130,13 +258,13 @@ def test_base_name_from_image_with_pipeline_param(inputs): @patch("sagemaker.utils.sagemaker_timestamp") def test_name_from_base(sagemaker_timestamp): sagemaker.utils.name_from_base(NAME, short=False) - assert sagemaker_timestamp.called_once + sagemaker_timestamp.assert_called_once @patch("sagemaker.utils.sagemaker_short_timestamp") def test_name_from_base_short(sagemaker_short_timestamp): sagemaker.utils.name_from_base(NAME, short=True) - assert sagemaker_short_timestamp.called_once + sagemaker_short_timestamp.assert_called_once def test_unique_name_from_base(): diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 8fe5a0bc78..2c35ad8584 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -80,6 +80,14 @@ def sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session From 67e8d941a3f810e50c552e7d9fda1e59289f69b6 Mon Sep 17 00:00:00 2001 From: rubanh Date: Tue, 21 Feb 2023 18:54:41 -0500 Subject: [PATCH 03/40] intelligent defaults - custom parameters and small fixes (#845) * fix: sagemaker-config - S3 session, tuning tags, config schema test side-effects * feature: sagemaker-config - support for custom parameters in config schema --------- Co-authored-by: Ruban Hussain --- src/sagemaker/config/config_schema.py | 10 +++++- src/sagemaker/session.py | 9 +---- tests/unit/sagemaker/config/conftest.py | 36 +++++++++---------- .../sagemaker/config/test_config_schema.py | 30 +++++++++++++++- 4 files changed, 57 insertions(+), 28 deletions(-) diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 27d0bed94d..2df1650dde 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -70,9 +70,11 @@ ENDPOINT_CONFIG = "EndpointConfig" AUTO_ML = "AutoML" COMPILATION_JOB = "CompilationJob" +CUSTOM_PARAMETERS = "CustomParameters" PIPELINE = "Pipeline" TRANSFORM_JOB = "TransformJob" PROPERTIES = "properties" +PATTERN_PROPERTIES = "patternProperties" TYPE = "type" OBJECT = "object" ADDITIONAL_PROPERTIES = "additionalProperties" @@ -104,7 +106,6 @@ def _simple_path(*args: str): SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION ) - SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = { "$schema": "https://json-schema.org/draft/2020-12/schema", TYPE: OBJECT, @@ -231,6 +232,13 @@ def _simple_path(*args: str): "enum": ["1.0"], "description": "The schema version of the document.", }, + CUSTOM_PARAMETERS: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PATTERN_PROPERTIES: { + "^[\w\s\d_.:/=+\-@]+$": {TYPE: "string"}, + } + }, SAGEMAKER: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index bc4b8d187b..2422b31827 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -41,7 +41,6 @@ ENDPOINT_CONFIG, FEATURE_GROUP, KEY, - HYPER_PARAMETER_TUNING_JOB, SAGEMAKER, MODEL, MONITORING_SCHEDULE, @@ -219,7 +218,7 @@ def _initialize( if sagemaker_config: self.sagemaker_config = sagemaker_config else: - self.sagemaker_config = SageMakerConfig(s3_resource=self.boto_session) + self.sagemaker_config = SageMakerConfig(s3_resource=self.boto_session.resource("s3")) @property def boto_region_name(self): @@ -2642,9 +2641,6 @@ def tune( # noqa: C901 tune_request["WarmStartConfig"] = warm_start_config tags = _append_project_tags(tags) - tags = self._append_sagemaker_config_tags( - tags, "{}.{}.{}".format(SAGEMAKER, HYPER_PARAMETER_TUNING_JOB, TAGS) - ) if tags is not None: tune_request["Tags"] = tags @@ -2747,9 +2743,6 @@ def _get_tuning_request( tune_request["WarmStartConfig"] = warm_start_config tags = _append_project_tags(tags) - tags = self._append_sagemaker_config_tags( - tags, "{}.{}.{}".format(SAGEMAKER, HYPER_PARAMETER_TUNING_JOB, TAGS) - ) if tags is not None: tune_request["Tags"] = tags diff --git a/tests/unit/sagemaker/config/conftest.py b/tests/unit/sagemaker/config/conftest.py index 31a2616298..ca12a1e14b 100644 --- a/tests/unit/sagemaker/config/conftest.py +++ b/tests/unit/sagemaker/config/conftest.py @@ -17,22 +17,22 @@ from mock import MagicMock -@pytest.fixture(scope="module") +@pytest.fixture() def base_config_with_schema(): return {"SchemaVersion": "1.0"} -@pytest.fixture(scope="module") +@pytest.fixture() def valid_vpc_config(): return {"SecurityGroupIds": ["sg123"], "Subnets": ["subnet-1234"]} -@pytest.fixture(scope="module") +@pytest.fixture() def valid_iam_role_arn(): return "arn:aws:iam::366666666666:role/IMRole" -@pytest.fixture(scope="module") +@pytest.fixture() def valid_feature_group_config(valid_iam_role_arn): s3_storage_config = {"KmsKeyId": "somekmskeyid"} security_storage_config = {"KmsKeyId": "someotherkmskeyid"} @@ -45,7 +45,7 @@ def valid_feature_group_config(valid_iam_role_arn): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_edge_packaging_config(valid_iam_role_arn): return { "OutputConfig": {"KmsKeyId": "somekeyid"}, @@ -53,7 +53,7 @@ def valid_edge_packaging_config(valid_iam_role_arn): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_model_config(valid_iam_role_arn, valid_vpc_config): return { "EnableNetworkIsolation": True, @@ -62,7 +62,7 @@ def valid_model_config(valid_iam_role_arn, valid_vpc_config): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_model_package_config(valid_iam_role_arn): transform_job_definition = { "TransformOutput": {"KmsKeyId": "somerandomkmskeyid"}, @@ -75,7 +75,7 @@ def valid_model_package_config(valid_iam_role_arn): return {"ValidationSpecification": validation_specification} -@pytest.fixture(scope="module") +@pytest.fixture() def valid_processing_job_config(valid_iam_role_arn, valid_vpc_config): network_config = {"EnableNetworkIsolation": True, "VpcConfig": valid_vpc_config} dataset_definition = { @@ -94,7 +94,7 @@ def valid_processing_job_config(valid_iam_role_arn, valid_vpc_config): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_training_job_config(valid_iam_role_arn, valid_vpc_config): return { "EnableNetworkIsolation": True, @@ -105,12 +105,12 @@ def valid_training_job_config(valid_iam_role_arn, valid_vpc_config): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_pipeline_config(valid_iam_role_arn): return {"RoleArn": valid_iam_role_arn} -@pytest.fixture(scope="module") +@pytest.fixture() def valid_compilation_job_config(valid_iam_role_arn, valid_vpc_config): return { "OutputConfig": {"KmsKeyId": "somekmskey"}, @@ -119,7 +119,7 @@ def valid_compilation_job_config(valid_iam_role_arn, valid_vpc_config): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_transform_job_config(): return { "DataCaptureConfig": {"KmsKeyId": "somekmskey"}, @@ -128,7 +128,7 @@ def valid_transform_job_config(): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_automl_config(valid_iam_role_arn, valid_vpc_config): return { "AutoMLJobConfig": { @@ -139,7 +139,7 @@ def valid_automl_config(valid_iam_role_arn, valid_vpc_config): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_endpointconfig_config(): return { "AsyncInferenceConfig": {"OutputConfig": {"KmsKeyId": "somekmskey"}}, @@ -150,7 +150,7 @@ def valid_endpointconfig_config(): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_monitoring_schedule_config(valid_iam_role_arn, valid_vpc_config): network_config = {"EnableNetworkIsolation": True, "VpcConfig": valid_vpc_config} return { @@ -165,7 +165,7 @@ def valid_monitoring_schedule_config(valid_iam_role_arn, valid_vpc_config): } -@pytest.fixture(scope="module") +@pytest.fixture() def valid_config_with_all_the_scopes( valid_feature_group_config, valid_monitoring_schedule_config, @@ -196,11 +196,11 @@ def valid_config_with_all_the_scopes( } -@pytest.fixture(scope="module") +@pytest.fixture() def s3_resource_mock(): return MagicMock(name="s3") -@pytest.fixture(scope="module") +@pytest.fixture() def get_data_dir(): return os.path.join(os.path.dirname(__file__), "..", "..", "..", "data", "config") diff --git a/tests/unit/sagemaker/config/test_config_schema.py b/tests/unit/sagemaker/config/test_config_schema.py index 6ea0296bdf..745ac770d4 100644 --- a/tests/unit/sagemaker/config/test_config_schema.py +++ b/tests/unit/sagemaker/config/test_config_schema.py @@ -38,7 +38,7 @@ def test_invalid_schema_version(): def test_valid_config_with_all_the_features( - base_config_with_schema, valid_config_with_all_the_scopes + base_config_with_schema, valid_config_with_all_the_scopes ): _validate_config(base_config_with_schema, valid_config_with_all_the_scopes) @@ -137,3 +137,31 @@ def test_invalid_feature_group_schema(base_config_with_schema): config["SageMaker"] = {"FeatureGroup": feature_group_config} with pytest.raises(exceptions.ValidationError): validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_valid_custom_parameters_schema(base_config_with_schema): + config = base_config_with_schema + config["CustomParameters"] = { + "custom_key": "custom_value", + "CustomKey": "CustomValue", + "custom key": "custom value", + "custom-key": "custom-value", + "custom0123 key0123": "custom0123 value0123", + } + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_invalid_custom_parameters_schema(base_config_with_schema): + config = base_config_with_schema + + config["CustomParameters"] = {"^&": "custom_value"} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + config["CustomParameters"] = {"custom_key": 476} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + config["CustomParameters"] = {"custom_key": {"custom_key": "custom_value"}} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) From 67fc282aa88fe0e7f6954ceb5131c82feec263db Mon Sep 17 00:00:00 2001 From: Balaji Sankar <115105204+balajisankar15@users.noreply.github.com> Date: Thu, 23 Feb 2023 10:30:50 -0700 Subject: [PATCH 04/40] feature: Added support for VPC Config, EnableNetworkIsolation, KMS Key ID, Volume KMS Key ID, IAM role to be fetched from Config (#846) Co-authored-by: Balaji Sankar --- src/sagemaker/algorithm.py | 2 +- src/sagemaker/amazon/amazon_estimator.py | 2 +- .../amazon/factorization_machines.py | 4 +- src/sagemaker/amazon/ipinsights.py | 4 +- src/sagemaker/amazon/kmeans.py | 4 +- src/sagemaker/amazon/knn.py | 4 +- src/sagemaker/amazon/lda.py | 4 +- src/sagemaker/amazon/linear_learner.py | 4 +- src/sagemaker/amazon/ntm.py | 4 +- src/sagemaker/amazon/object2vec.py | 4 +- src/sagemaker/amazon/pca.py | 4 +- src/sagemaker/amazon/randomcutforest.py | 4 +- .../async_inference/async_inference_config.py | 23 +- src/sagemaker/automl/automl.py | 37 +- src/sagemaker/automl/candidate_estimator.py | 15 +- src/sagemaker/chainer/model.py | 4 +- src/sagemaker/clarify.py | 6 +- src/sagemaker/config/__init__.py | 63 ++ src/sagemaker/config/config_schema.py | 29 +- src/sagemaker/dataset_definition/inputs.py | 80 ++ src/sagemaker/estimator.py | 49 +- src/sagemaker/feature_store/feature_group.py | 24 +- src/sagemaker/huggingface/model.py | 2 +- src/sagemaker/huggingface/processing.py | 6 +- src/sagemaker/inputs.py | 21 +- src/sagemaker/local/local_session.py | 18 +- src/sagemaker/model.py | 101 ++- .../model_monitor/data_capture_config.py | 8 +- .../model_monitor/model_monitoring.py | 71 +- src/sagemaker/mxnet/model.py | 4 +- src/sagemaker/mxnet/processing.py | 6 +- src/sagemaker/network.py | 76 +- src/sagemaker/pipeline.py | 35 +- src/sagemaker/processing.py | 103 ++- src/sagemaker/pytorch/model.py | 4 +- src/sagemaker/pytorch/processing.py | 6 +- src/sagemaker/session.py | 779 +++++++++++++--- src/sagemaker/sklearn/model.py | 4 +- src/sagemaker/sklearn/processing.py | 6 +- src/sagemaker/spark/processing.py | 6 +- src/sagemaker/tensorflow/model.py | 2 +- src/sagemaker/tensorflow/processing.py | 6 +- src/sagemaker/transformer.py | 41 +- src/sagemaker/utils.py | 10 +- src/sagemaker/workflow/pipeline.py | 45 +- src/sagemaker/wrangler/processing.py | 8 +- src/sagemaker/xgboost/model.py | 6 +- src/sagemaker/xgboost/processing.py | 6 +- tests/integ/test_local_mode.py | 4 + tests/unit/conftest.py | 7 +- tests/unit/sagemaker/automl/test_auto_ml.py | 101 ++- .../sagemaker/config/test_config_schema.py | 2 +- .../feature_store/test_feature_group.py | 69 +- .../sagemaker/huggingface/test_estimator.py | 6 +- .../sagemaker/huggingface/test_processing.py | 5 +- .../sagemaker/local/test_local_pipeline.py | 14 +- tests/unit/sagemaker/model/test_deploy.py | 15 +- tests/unit/sagemaker/model/test_edge.py | 7 +- .../sagemaker/model/test_framework_model.py | 4 + tests/unit/sagemaker/model/test_model.py | 5 +- .../sagemaker/model/test_model_package.py | 5 +- tests/unit/sagemaker/model/test_neo.py | 7 +- .../monitor/test_clarify_model_monitor.py | 6 +- .../monitor/test_data_capture_config.py | 4 + .../monitor/test_model_monitoring.py | 5 +- .../sagemaker/tensorflow/test_estimator.py | 5 +- .../tensorflow/test_estimator_attach.py | 4 + .../tensorflow/test_estimator_init.py | 7 +- tests/unit/sagemaker/tensorflow/test_tfs.py | 4 + .../test_huggingface_pytorch_compiler.py | 4 + .../test_huggingface_tensorflow_compiler.py | 5 +- .../test_pytorch_compiler.py | 4 + .../test_tensorflow_compiler.py | 4 + tests/unit/sagemaker/workflow/conftest.py | 7 +- tests/unit/sagemaker/workflow/test_airflow.py | 4 + .../unit/sagemaker/workflow/test_pipeline.py | 51 +- .../sagemaker/wrangler/test_processing.py | 4 + tests/unit/test_algorithm.py | 4 + tests/unit/test_amazon_estimator.py | 4 + tests/unit/test_chainer.py | 4 + tests/unit/test_estimator.py | 79 ++ tests/unit/test_fm.py | 4 + tests/unit/test_ipinsights.py | 5 +- tests/unit/test_job.py | 5 +- tests/unit/test_kmeans.py | 5 +- tests/unit/test_knn.py | 5 +- tests/unit/test_lda.py | 5 +- tests/unit/test_linear_learner.py | 5 +- tests/unit/test_multidatamodel.py | 4 + tests/unit/test_mxnet.py | 4 + tests/unit/test_ntm.py | 5 +- tests/unit/test_object2vec.py | 5 +- tests/unit/test_pca.py | 5 +- tests/unit/test_pipeline_model.py | 65 ++ tests/unit/test_predictor.py | 4 + tests/unit/test_predictor_async.py | 5 +- tests/unit/test_processing.py | 105 +++ tests/unit/test_pytorch.py | 4 + tests/unit/test_randomcutforest.py | 4 + tests/unit/test_rl.py | 4 + tests/unit/test_session.py | 833 +++++++++++++++++- tests/unit/test_sklearn.py | 4 + tests/unit/test_sparkml_serving.py | 4 + tests/unit/test_timeout.py | 4 + tests/unit/test_transformer.py | 87 +- tests/unit/test_tuner.py | 4 + tests/unit/test_xgboost.py | 4 + 107 files changed, 3019 insertions(+), 383 deletions(-) diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index 95772c5229..f4124fff2a 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -46,7 +46,7 @@ class AlgorithmEstimator(EstimatorBase): def __init__( self, algorithm_arn: str, - role: str, + role: str = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, volume_size: Union[int, PipelineVariable] = 30, diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 1abea5e48c..6a6177fc81 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -50,7 +50,7 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, data_location: Optional[str] = None, diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 4d01897dbe..b0e628f935 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -87,7 +87,7 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, num_factors: Optional[int] = None, @@ -326,7 +326,7 @@ class FactorizationMachinesModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index a3562f8434..bddc17433f 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -63,7 +63,7 @@ class IPInsights(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, num_entity_vectors: Optional[int] = None, @@ -229,7 +229,7 @@ class IPInsightsModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index a6a9a918a7..f6ed27fdad 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -62,7 +62,7 @@ class KMeans(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, k: Optional[int] = None, @@ -255,7 +255,7 @@ class KMeansModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 3ea63f1587..08bdd4a429 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -73,7 +73,7 @@ class KNN(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, k: Optional[int] = None, @@ -246,7 +246,7 @@ class KNNModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index cb65d1f82e..1349e70d48 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -52,7 +52,7 @@ class LDA(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, num_topics: Optional[int] = None, alpha0: Optional[float] = None, @@ -230,7 +230,7 @@ class LDAModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index 53663e9fec..c8ef6790cc 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -145,7 +145,7 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, predictor_type: Optional[str] = None, @@ -499,7 +499,7 @@ class LinearLearnerModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index f43980eac8..6272181c68 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -74,7 +74,7 @@ class NTM(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, num_topics: Optional[int] = None, @@ -263,7 +263,7 @@ class NTMModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/object2vec.py b/src/sagemaker/amazon/object2vec.py index efc1105fd7..8a967484ec 100644 --- a/src/sagemaker/amazon/object2vec.py +++ b/src/sagemaker/amazon/object2vec.py @@ -153,7 +153,7 @@ class Object2Vec(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, epochs: Optional[int] = None, @@ -361,7 +361,7 @@ class Object2VecModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 7c0fec94de..767a854ffd 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -60,7 +60,7 @@ class PCA(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, num_components: Optional[int] = None, @@ -243,7 +243,7 @@ class PCAModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index 5fb708b91b..87f786eda3 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -54,7 +54,7 @@ class RandomCutForest(AmazonAlgorithmEstimatorBase): def __init__( self, - role: str, + role: Optional[Union[str, PipelineVariable]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, num_samples_per_tree: Optional[int] = None, @@ -216,7 +216,7 @@ class RandomCutForestModel(Model): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: Optional[str] = None, sagemaker_session: Optional[Session] = None, **kwargs ): diff --git a/src/sagemaker/async_inference/async_inference_config.py b/src/sagemaker/async_inference/async_inference_config.py index f5e2cb8f57..6f1bd56f6a 100644 --- a/src/sagemaker/async_inference/async_inference_config.py +++ b/src/sagemaker/async_inference/async_inference_config.py @@ -57,9 +57,30 @@ def __init__( """ self.output_path = output_path self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance - self.kms_key_id = kms_key_id + self._kms_key_id = kms_key_id self.notification_config = notification_config + @property + def kms_key_id(self): + """Getter for kms_key_id + + Returns: + str: The KMS Key ID. + """ + return self._kms_key_id + + @kms_key_id.setter + def kms_key_id(self, kms_key_id: str): + """Setter for kms_key_id + + Args: + kms_key_id: The new kms_key_id to replace the existing one. + + Returns: + + """ + self._kms_key_id = kms_key_id + def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" request_dict = { diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 65e3f1b346..67379558b5 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -19,11 +19,15 @@ from sagemaker import Model, PipelineModel from sagemaker.automl.candidate_estimator import CandidateEstimator -from sagemaker.config.config_schema import ( +from sagemaker.job import _Job +from sagemaker.session import ( + Session, + AUTO_ML_KMS_KEY_ID_PATH, + AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_VPC_CONFIG_PATH, + AUTO_ML_VOLUME_KMS_KEY_ID_PATH, PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, ) -from sagemaker.job import _Job -from sagemaker.session import Session from sagemaker.utils import name_from_base from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -101,8 +105,8 @@ class AutoML(object): def __init__( self, - role: str, - target_attribute_name: str, + role: Optional[str] = None, + target_attribute_name: str = None, output_kms_key: Optional[str] = None, output_path: Optional[str] = None, base_job_name: Optional[str] = None, @@ -179,13 +183,10 @@ def __init__( Returns: AutoML object. """ - self.role = role - self.output_kms_key = output_kms_key self.output_path = output_path self.base_job_name = base_job_name self.compression_type = compression_type - self.volume_kms_key = volume_kms_key - self.vpc_config = vpc_config + self.encrypt_inter_container_traffic = encrypt_inter_container_traffic self.problem_type = problem_type self.max_candidate = max_candidates self.max_runtime_per_training_job_in_seconds = max_runtime_per_training_job_in_seconds @@ -206,6 +207,24 @@ def __init__( self._auto_ml_job_desc = None self._best_candidate = None self.sagemaker_session = sagemaker_session or Session() + self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( + AUTO_ML_VPC_CONFIG_PATH, default_value=vpc_config + ) + self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( + AUTO_ML_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + ) + self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( + AUTO_ML_KMS_KEY_ID_PATH, default_value=output_kms_key + ) + self.role = self.sagemaker_session.get_sagemaker_config_override( + AUTO_ML_ROLE_ARN_PATH, default_value=role + ) + if not self.role: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating AutoML jobs.") self.encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config( direct_input=encrypt_inter_container_traffic, diff --git a/src/sagemaker/automl/candidate_estimator.py b/src/sagemaker/automl/candidate_estimator.py index 4231c64e21..eccf6c499a 100644 --- a/src/sagemaker/automl/candidate_estimator.py +++ b/src/sagemaker/automl/candidate_estimator.py @@ -15,8 +15,12 @@ from six import string_types -from sagemaker import Session -from sagemaker.config.config_schema import PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION +from sagemaker.session import ( + Session, + TRAINING_JOB_VPC_CONFIG_PATH, + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, +) from sagemaker.job import _Job from sagemaker.utils import name_from_base @@ -102,7 +106,12 @@ def fit( """Logs can only be shown if wait is set to True. Please either set wait to True or set logs to False.""" ) - + vpc_config = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_VPC_CONFIG_PATH, default_value=vpc_config + ) + volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + ) self.name = candidate_name or self.name running_jobs = {} diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index d723b8f4d5..3d5c645c73 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -82,8 +82,8 @@ class ChainerModel(FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, - entry_point: str, + role: Optional[str] = None, + entry_point: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 9df055df84..00ec6dc965 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1205,9 +1205,9 @@ class SageMakerClarifyProcessor(Processor): def __init__( self, - role: str, - instance_count: int, - instance_type: str, + role: Optional[str] = None, + instance_count: int = None, + instance_type: str = None, volume_size_in_gb: int = 30, volume_kms_key: Optional[str] = None, output_kms_key: Optional[str] = None, diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index 8ee1bd2962..94f04b0f90 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -13,3 +13,66 @@ """This module configures the default values for SageMaker Python SDK.""" from __future__ import absolute_import +from sagemaker.config.config import SageMakerConfig # noqa: F401 +from sagemaker.config.config_schema import ( # noqa: F401 + SECURITY_GROUP_IDS, + SUBNETS, + ENABLE_NETWORK_ISOLATION, + VOLUME_KMS_KEY_ID, + KMS_KEY_ID, + ROLE_ARN, + EXECUTION_ROLE_ARN, + CLUSTER_ROLE_ARN, + VPC_CONFIG, + OUTPUT_DATA_CONFIG, + AUTO_ML_JOB_CONFIG, + ASYNC_INFERENCE_CONFIG, + OUTPUT_CONFIG, + PROCESSING_OUTPUT_CONFIG, + CLUSTER_CONFIG, + NETWORK_CONFIG, + CORE_DUMP_CONFIG, + DATA_CAPTURE_CONFIG, + MONITORING_OUTPUT_CONFIG, + RESOURCE_CONFIG, + SCHEMA_VERSION, + DATASET_DEFINITION, + ATHENA_DATASET_DEFINITION, + REDSHIFT_DATASET_DEFINITION, + MONITORING_JOB_DEFINITION, + SAGEMAKER, + PYTHON_SDK, + MODULES, + OFFLINE_STORE_CONFIG, + ONLINE_STORE_CONFIG, + S3_STORAGE_CONFIG, + SECURITY_CONFIG, + TRANSFORM_JOB_DEFINITION, + MONITORING_SCHEDULE_CONFIG, + MONITORING_RESOURCES, + PROCESSING_RESOURCES, + PRODUCTION_VARIANTS, + SHADOW_PRODUCTION_VARIANTS, + TRANSFORM_OUTPUT, + TRANSFORM_RESOURCES, + VALIDATION_ROLE, + VALIDATION_SPECIFICATION, + VALIDATION_PROFILES, + PROCESSING_INPUTS, + FEATURE_GROUP, + EDGE_PACKAGING_JOB, + TRAINING_JOB, + PROCESSING_JOB, + MODEL_PACKAGE, + MODEL, + MONITORING_SCHEDULE, + ENDPOINT_CONFIG, + AUTO_ML, + COMPILATION_JOB, + PIPELINE, + TRANSFORM_JOB, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, + TAGS, + KEY, + VALUE, +) diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 2df1650dde..a72c41f68c 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -81,31 +81,6 @@ ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption" -def _simple_path(*args: str): - """Appends an arbitrary number of strings to use as path constants""" - return ".".join(args) - - -# Paths for reference elsewhere in the SDK. -# Names include the schema version since the paths could change with other schema versions -PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, - MONITORING_SCHEDULE, - MONITORING_SCHEDULE_CONFIG, - MONITORING_JOB_DEFINITION, - NETWORK_CONFIG, - ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, -) -PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, AUTO_ML, SECURITY_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION -) -PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION -) -PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION -) - SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = { "$schema": "https://json-schema.org/draft/2020-12/schema", TYPE: OBJECT, @@ -236,8 +211,8 @@ def _simple_path(*args: str): TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PATTERN_PROPERTIES: { - "^[\w\s\d_.:/=+\-@]+$": {TYPE: "string"}, - } + r"^[\w\s\d_.:/=+\-@]+$": {TYPE: "string"}, + }, }, SAGEMAKER: { TYPE: OBJECT, diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py index 468be22ac3..e037e6aee2 100644 --- a/src/sagemaker/dataset_definition/inputs.py +++ b/src/sagemaker/dataset_definition/inputs.py @@ -71,6 +71,46 @@ def __init__( output_compression=output_compression, ) + @property + def kms_key_id(self): + """Getter for KMSKeyId. + + Returns: + str: The KMS Key ID. + + """ + return self.__dict__["kms_key_id"] + + @kms_key_id.setter + def kms_key_id(self, kms_key_id: str): + """Setter for KMSKeyId. + + Args: + kms_key_id: The KMSKeyId to be used. + + """ + self.__dict__["kms_key_id"] = kms_key_id + + @property + def cluster_role_arn(self): + """Getter for Cluster Role ARN. + + Returns: + str: The cluster Role ARN. + + """ + return self.__dict__["cluster_role_arn"] + + @cluster_role_arn.setter + def cluster_role_arn(self, cluster_role_arn: str): + """Setter for Cluster Role ARN. + + Args: + cluster_role_arn: The ClusterRoleArn to be used. + + """ + self.__dict__["cluster_role_arn"] = cluster_role_arn + class AthenaDatasetDefinition(ApiObject): """DatasetDefinition for Athena. @@ -119,6 +159,26 @@ def __init__( output_compression=output_compression, ) + @property + def kms_key_id(self): + """Getter for KMSKeyId. + + Returns: + str: The KMS Key ID. + + """ + return self.__dict__["kms_key_id"] + + @kms_key_id.setter + def kms_key_id(self, kms_key_id: str): + """Setter for KMSKeyId. + + Args: + kms_key_id: The KMSKeyId to be used. + + """ + self.__dict__["kms_key_id"] = kms_key_id + class DatasetDefinition(ApiObject): """DatasetDefinition input.""" @@ -170,6 +230,26 @@ def __init__( athena_dataset_definition=athena_dataset_definition, ) + @property + def redshift_dataset_definition(self): + """Getter for RedshiftDatasetDefinition + + Returns: + RedshiftDatasetDefinition: RedshiftDatasetDefinition object. + + """ + return self.__dict__["redshift_dataset_definition"] + + @property + def athena_dataset_definition(self): + """Getter for AthenaDatasetDefinition + + Returns: + AthenaDatasetDefinition: AthenaDatasetDefinition object. + + """ + return self.__dict__["athena_dataset_definition"] + class S3Input(ApiObject): """Metadata of data objects stored in S3. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 511f96c027..2289d4f77a 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -29,7 +29,6 @@ import sagemaker from sagemaker import git_utils, image_uris, vpc_utils from sagemaker.analytics import TrainingJobAnalytics -from sagemaker.config.config_schema import PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION from sagemaker.debugger import ( # noqa: F401 # pylint: disable=unused-import DEBUGGER_FLAG, DebuggerHookConfig, @@ -73,7 +72,16 @@ ) from sagemaker.predictor import Predictor from sagemaker.s3 import S3Uploader, parse_s3_url -from sagemaker.session import Session +from sagemaker.session import ( + Session, + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + TRAINING_JOB_SUBNETS_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_KMS_KEY_ID_PATH, + PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, +) from sagemaker.transformer import Transformer from sagemaker.utils import ( base_from_name, @@ -116,7 +124,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man def __init__( self, - role: str, + role: str = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None, @@ -143,7 +151,7 @@ def __init__( debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, - enable_network_isolation: Union[bool, PipelineVariable] = False, + enable_network_isolation: Union[bool, PipelineVariable] = None, profiler_config: Optional[ProfilerConfig] = None, disable_profiler: bool = False, environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, @@ -535,13 +543,11 @@ def __init__( enable_network_isolation=enable_network_isolation, ) - self.role = role self.instance_count = instance_count self.instance_type = instance_type self.keep_alive_period_in_seconds = keep_alive_period_in_seconds self.instance_groups = instance_groups self.volume_size = volume_size - self.volume_kms_key = volume_kms_key self.max_run = max_run self.input_mode = input_mode self.metric_definitions = metric_definitions @@ -582,16 +588,34 @@ def __init__( ): raise RuntimeError("file:// output paths are only supported in Local Mode") self.output_path = output_path - self.output_kms_key = output_kms_key self.latest_training_job = None self.jobs = [] self.deploy_instance_type = None self._compiled_models = {} + self.role = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_ROLE_ARN_PATH, role + ) + if not self.role: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating estimators.") + self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_KMS_KEY_ID_PATH, default_value=output_kms_key + ) + self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + ) # VPC configurations - self.subnets = subnets - self.security_group_ids = security_group_ids + self.subnets = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_SUBNETS_PATH, default_value=subnets + ) + self.security_group_ids = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, default_value=security_group_ids + ) # training image configs self.training_repository_access_mode = training_repository_access_mode @@ -618,7 +642,10 @@ def __init__( self.collection_configs = None self.enable_sagemaker_metrics = enable_sagemaker_metrics - self._enable_network_isolation = enable_network_isolation + self._enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False if enable_network_isolation is None else enable_network_isolation, + ) self.profiler_config = profiler_config self.disable_profiler = disable_profiler @@ -2336,7 +2363,7 @@ class Estimator(EstimatorBase): def __init__( self, image_uri: Union[str, PipelineVariable], - role: str, + role: str = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None, diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index ee2163aa60..ca89b91385 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -41,7 +41,12 @@ from botocore.config import Config from pathos.multiprocessing import ProcessingPool -from sagemaker import Session +from sagemaker.session import ( + Session, + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, + FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, +) from sagemaker.feature_store.feature_definition import ( FeatureDefinition, FeatureTypeEnum, @@ -513,7 +518,7 @@ def create( s3_uri: Union[str, bool], record_identifier_name: str, event_time_feature_name: str, - role_arn: str, + role_arn: str = None, online_store_kms_key_id: str = None, enable_online_store: bool = False, offline_store_kms_key_id: str = None, @@ -552,6 +557,21 @@ def create( Returns: Response dict from service. """ + role_arn = self.sagemaker_session.get_sagemaker_config_override( + FEATURE_GROUP_ROLE_ARN_PATH, default_value=role_arn + ) + offline_store_kms_key_id = self.sagemaker_session.get_sagemaker_config_override( + FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, default_value=offline_store_kms_key_id + ) + online_store_kms_key_id = self.sagemaker_session.get_sagemaker_config_override( + FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, default_value=online_store_kms_key_id + ) + if not role_arn: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig, + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating Feature Groups.") create_feature_store_args = dict( feature_group_name=self.name, record_identifier_name=record_identifier_name, diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 5be10a00aa..437b749a90 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -106,7 +106,7 @@ class HuggingFaceModel(FrameworkModel): def __init__( self, - role: str, + role: Optional[str] = None, model_data: Optional[Union[str, PipelineVariable]] = None, entry_point: Optional[str] = None, transformers_version: Optional[str] = None, diff --git a/src/sagemaker/huggingface/processing.py b/src/sagemaker/huggingface/processing.py index 63810b0eb9..1daeac9c1f 100644 --- a/src/sagemaker/huggingface/processing.py +++ b/src/sagemaker/huggingface/processing.py @@ -34,9 +34,9 @@ class HuggingFaceProcessor(FrameworkProcessor): def __init__( self, - role: str, - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: Optional[Union[str, PipelineVariable]] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, transformers_version: Optional[str] = None, tensorflow_version: Optional[str] = None, pytorch_version: Optional[str] = None, diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index f0c678c623..bdedb5dea5 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -260,9 +260,28 @@ def __init__( (default: None) """ self.destination_s3_uri = destination_s3_uri - self.kms_key_id = kms_key_id + self._kms_key_id = kms_key_id self.generate_inference_id = generate_inference_id + @property + def kms_key_id(self): + """Getter for KmsKeyId + + Returns: + str: The KMS Key ID. + """ + return self._kms_key_id + + @kms_key_id.setter + def kms_key_id(self, kms_key_id: str): + """Setter for KmsKeyId + + Args: + kms_key_id: The KMS Key ID to set. + + """ + self._kms_key_id = kms_key_id + def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" batch_data_capture_config = { diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 5247fe1a45..d088dcf0c8 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -21,7 +21,7 @@ import boto3 from botocore.exceptions import ClientError -from sagemaker.config.config import SageMakerConfig +from sagemaker.config import SageMakerConfig from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import get_docker_host from sagemaker.local.entities import ( @@ -673,6 +673,17 @@ def _initialize( if self.s3_endpoint_url is not None: self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url) self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url) + self.sagemaker_config = sagemaker_config or ( + SageMakerConfig(s3_resource=self.s3_resource) + if "sagemaker_config" not in kwargs + else kwargs.get("sagemaker_config") + ) + else: + self.sagemaker_config = sagemaker_config or ( + SageMakerConfig() + if "sagemaker_config" not in kwargs + else kwargs.get("sagemaker_config") + ) sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") if os.path.exists(sagemaker_config_file): @@ -686,11 +697,6 @@ def _initialize( if self._disable_local_code and "local" in self.config: self.config["local"]["local_code"] = False - if sagemaker_config: - self.sagemaker_config = sagemaker_config - else: - self.sagemaker_config = SageMakerConfig(s3_resource=self.boto_session) - def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"): """A no-op method meant to override the sagemaker client. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 05a5214e1a..ac3266da3a 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -29,7 +29,16 @@ utils, git_utils, ) -from sagemaker.session import Session +from sagemaker.session import ( + Session, + COMPILATION_JOB_ROLE_ARN_PATH, + EDGE_PACKAGING_KMS_KEY_ID_PATH, + EDGE_PACKAGING_ROLE_ARN_PATH, + MODEL_VPC_CONFIG_PATH, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + MODEL_EXECUTION_ROLE_ARN_PATH, + ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, +) from sagemaker.model_metrics import ModelMetrics from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -97,7 +106,7 @@ def __init__( name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, - enable_network_isolation: Union[bool, PipelineVariable] = False, + enable_network_isolation: Union[bool, PipelineVariable] = None, model_kms_key: Optional[str] = None, image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, source_dir: Optional[str] = None, @@ -269,20 +278,40 @@ def __init__( """ self.model_data = model_data self.image_uri = image_uri - self.role = role self.predictor_cls = predictor_cls self.env = env or {} self.name = name self._base_name = None - self.vpc_config = vpc_config self.sagemaker_session = sagemaker_session + self.role = ( + self.sagemaker_session.get_sagemaker_config_override( + MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role + ) + if sagemaker_session + else role + ) + self.vpc_config = ( + self.sagemaker_session.get_sagemaker_config_override( + MODEL_VPC_CONFIG_PATH, default_value=vpc_config + ) + if sagemaker_session + else vpc_config + ) self.endpoint_name = None self._is_compiled_model = False self._compilation_job_name = None self._is_edge_packaged_model = False self.inference_recommender_job_results = None self.inference_recommendations = None - self._enable_network_isolation = enable_network_isolation + self._enable_network_isolation = ( + self.sagemaker_session.get_sagemaker_config_override( + MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=enable_network_isolation + ) + if sagemaker_session + else enable_network_isolation + ) + if self._enable_network_isolation is None: + self._enable_network_isolation = False self.model_kms_key = model_kms_key self.image_config = image_config self.entry_point = entry_point @@ -633,7 +662,7 @@ def enable_network_isolation(self): Returns: bool: If network isolation should be enabled or not. """ - return self._enable_network_isolation + return False if not self._enable_network_isolation else self._enable_network_isolation def _create_sagemaker_model( self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None @@ -674,15 +703,23 @@ def _create_sagemaker_model( ) self._set_model_name_if_needed() - enable_network_isolation = self.enable_network_isolation() - self._init_sagemaker_session_if_does_not_exist(instance_type) + # Depending on the instance type, a local session (or) a session is initialized. + self.role = self.sagemaker_session.get_sagemaker_config_override( + MODEL_EXECUTION_ROLE_ARN_PATH, default_value=self.role + ) + self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( + MODEL_VPC_CONFIG_PATH, default_value=self.vpc_config + ) + self._enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( + MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=self._enable_network_isolation + ) create_model_args = dict( name=self.name, role=self.role, container_defs=container_def, vpc_config=self.vpc_config, - enable_network_isolation=enable_network_isolation, + enable_network_isolation=self._enable_network_isolation, tags=tags, ) self.sagemaker_session.create_model(**create_model_args) @@ -872,10 +909,15 @@ def package_for_edge( raise ValueError("You must first compile this model") if job_name is None: job_name = f"packaging{self._compilation_job_name[11:]}" - if role is None: - role = self.sagemaker_session.expand_role(role) - self._init_sagemaker_session_if_does_not_exist(None) + s3_kms_key = self.sagemaker_session.get_sagemaker_config_override( + EDGE_PACKAGING_KMS_KEY_ID_PATH, default_value=s3_kms_key + ) + role = self.sagemaker_session.get_sagemaker_config_override( + EDGE_PACKAGING_ROLE_ARN_PATH, default_value=role + ) + if role is not None: + role = self.sagemaker_session.expand_role(role) config = self._edge_packaging_job_config( output_path, role, @@ -899,7 +941,7 @@ def compile( target_instance_family, input_shape, output_path, - role, + role=None, tags=None, job_name=None, compile_max_run=15 * 60, @@ -978,6 +1020,15 @@ def compile( framework_version = framework_version or self._get_framework_version() self._init_sagemaker_session_if_does_not_exist(target_instance_family) + role = self.sagemaker_session.get_sagemaker_config_override( + COMPILATION_JOB_ROLE_ARN_PATH, default_value=role + ) + if not role: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating compilation jobs.") config = self._compilation_job_config( target_instance_family, input_shape, @@ -1130,6 +1181,16 @@ def deploy( removed_kwargs("update_endpoint", kwargs) self._init_sagemaker_session_if_does_not_exist(instance_type) + # Depending on the instance type, a local session (or) a session is initialized. + self.role = self.sagemaker_session.get_sagemaker_config_override( + MODEL_EXECUTION_ROLE_ARN_PATH, default_value=self.role + ) + self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( + MODEL_VPC_CONFIG_PATH, default_value=self.vpc_config + ) + self._enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( + MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=self._enable_network_isolation + ) tags = add_jumpstart_tags( tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir @@ -1216,6 +1277,12 @@ def deploy( async_inference_config = self._build_default_async_inference_config( async_inference_config ) + async_inference_config.kms_key_id = ( + self.sagemaker_session.get_sagemaker_config_override( + ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, + default_value=async_inference_config.kms_key_id, + ) + ) async_inference_config_dict = async_inference_config._to_request_dict() self.sagemaker_session.endpoint_from_production_variants( @@ -1341,8 +1408,8 @@ def __init__( self, model_data: Union[str, PipelineVariable], image_uri: Union[str, PipelineVariable], - role: str, - entry_point: str, + role: Optional[str] = None, + entry_point: Optional[str] = None, source_dir: Optional[str] = None, predictor_cls: Optional[callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, @@ -1522,7 +1589,9 @@ def __init__( class ModelPackage(Model): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" - def __init__(self, role, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs): + def __init__( + self, role=None, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs + ): """Initialize a SageMaker ModelPackage. Args: diff --git a/src/sagemaker/model_monitor/data_capture_config.py b/src/sagemaker/model_monitor/data_capture_config.py index bc0862abf7..b9febada4b 100644 --- a/src/sagemaker/model_monitor/data_capture_config.py +++ b/src/sagemaker/model_monitor/data_capture_config.py @@ -18,7 +18,7 @@ from __future__ import print_function, absolute_import from sagemaker import s3 -from sagemaker.session import Session +from sagemaker.session import Session, ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH _MODEL_MONITOR_S3_PATH = "model-monitor" _DATA_CAPTURE_S3_PATH = "data-capture" @@ -66,8 +66,8 @@ def __init__( self.enable_capture = enable_capture self.sampling_percentage = sampling_percentage self.destination_s3_uri = destination_s3_uri + sagemaker_session = sagemaker_session or Session() if self.destination_s3_uri is None: - sagemaker_session = sagemaker_session or Session() self.destination_s3_uri = s3.s3_path_join( "s3://", sagemaker_session.default_bucket(), @@ -75,7 +75,9 @@ def __init__( _DATA_CAPTURE_S3_PATH, ) - self.kms_key_id = kms_key_id + self.kms_key_id = sagemaker_session.get_sagemaker_config_override( + ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH, kms_key_id + ) self.capture_options = capture_options or ["REQUEST", "RESPONSE"] self.csv_content_types = csv_content_types or ["text/csv"] self.json_content_types = json_content_types or ["application/json"] diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 4f7b73ab04..407a1a5cb9 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -35,7 +35,6 @@ SAGEMAKER, MONITORING_SCHEDULE, TAGS, - PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, ) from sagemaker.exceptions import UnexpectedStatusException from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations, Statistics @@ -48,7 +47,16 @@ from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat from sagemaker.network import NetworkConfig from sagemaker.processing import Processor, ProcessingInput, ProcessingJob, ProcessingOutput -from sagemaker.session import Session +from sagemaker.session import ( + Session, + MONITORING_JOB_SUBNETS_PATH, + MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, + MONITORING_JOB_SECURITY_GROUP_IDS_PATH, + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, + MONITORING_JOB_ROLE_ARN_PATH, +) from sagemaker.utils import name_from_base, retries DEFAULT_REPOSITORY_NAME = "sagemaker-model-monitor-analyzer" @@ -102,8 +110,8 @@ class ModelMonitor(object): def __init__( self, - role, - image_uri, + role=None, + image_uri=None, instance_count=1, instance_type="ml.m5.xlarge", entrypoint=None, @@ -152,14 +160,11 @@ def __init__( inter-container traffic, security group IDs, and subnets. """ - self.role = role self.image_uri = image_uri self.instance_count = instance_count self.instance_type = instance_type self.entrypoint = entrypoint self.volume_size_in_gb = volume_size_in_gb - self.volume_kms_key = volume_kms_key - self.output_kms_key = output_kms_key self.max_runtime_in_seconds = max_runtime_in_seconds self.base_job_name = base_job_name self.sagemaker_session = sagemaker_session or Session() @@ -172,10 +177,60 @@ def __init__( self.latest_baselining_job_name = None self.monitoring_schedule_name = None self.job_definition_name = None + self.role = self.sagemaker_session.get_sagemaker_config_override( + MONITORING_JOB_ROLE_ARN_PATH, default_value=role + ) + if not self.role: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating Monitoring Schedule.") + self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + ) + self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, default_value=output_kms_key + ) + _enable_network_isolation_from_config = ( + self.sagemaker_session.get_sagemaker_config_override( + MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH + ) + ) + + _subnets_from_config = self.sagemaker_session.get_sagemaker_config_override( + MONITORING_JOB_SUBNETS_PATH + ) + _security_group_ids_from_config = self.sagemaker_session.get_sagemaker_config_override( + MONITORING_JOB_SECURITY_GROUP_IDS_PATH + ) + if network_config: + if not network_config.subnets: + network_config.subnets = _subnets_from_config + if network_config.enable_network_isolation is None: + network_config.enable_network_isolation = ( + _enable_network_isolation_from_config or False + ) + if not network_config.security_group_ids: + network_config.security_group_ids = _security_group_ids_from_config + self.network_config = network_config + else: + if ( + _enable_network_isolation_from_config is not None + or _subnets_from_config + or _security_group_ids_from_config + ): + self.network_config = NetworkConfig( + enable_network_isolation=_enable_network_isolation_from_config or False, + security_group_ids=_security_group_ids_from_config, + subnets=_subnets_from_config, + ) + else: + self.network_config = None self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( NetworkConfig, - network_config, + self.network_config, "encrypt_inter_container_traffic", PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, ) diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 32f6a096f5..6983fa64b9 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -84,8 +84,8 @@ class MXNetModel(FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, - entry_point: str, + role: Optional[str] = None, + entry_point: Optional[str] = None, framework_version: str = _LOWEST_MMS_VERSION, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/mxnet/processing.py b/src/sagemaker/mxnet/processing.py index 71bce7cdff..d85ab5b526 100644 --- a/src/sagemaker/mxnet/processing.py +++ b/src/sagemaker/mxnet/processing.py @@ -34,9 +34,9 @@ class MXNetProcessor(FrameworkProcessor): def __init__( self, framework_version: str, # New arg - role: str, - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: Optional[Union[str, PipelineVariable]] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, py_version: str = "py3", # New kwarg image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, diff --git a/src/sagemaker/network.py b/src/sagemaker/network.py index cc874091f6..2f278414c3 100644 --- a/src/sagemaker/network.py +++ b/src/sagemaker/network.py @@ -29,7 +29,7 @@ class NetworkConfig(object): def __init__( self, - enable_network_isolation: Union[bool, PipelineVariable] = False, + enable_network_isolation: Union[bool, PipelineVariable] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, @@ -48,11 +48,79 @@ def __init__( encrypt_inter_container_traffic (bool or PipelineVariable): Boolean that determines whether to encrypt inter-container traffic. Default value is None. """ - self.enable_network_isolation = enable_network_isolation - self.security_group_ids = security_group_ids - self.subnets = subnets + self._enable_network_isolation = enable_network_isolation + self._security_group_ids = security_group_ids + self._subnets = subnets self.encrypt_inter_container_traffic = encrypt_inter_container_traffic + @property + def security_group_ids(self): + """Getter for Security Groups + + Returns: + list[str]: List of Security Groups + + """ + return self._security_group_ids + + @property + def subnets(self): + """Getter for Subnets + + Returns: + list[str]: List of Subnets + + """ + return self._subnets + + @property + def enable_network_isolation(self): + """Getter for Enable Network Isolation + + Returns: + bool: Value of Enable Network Isolation + + """ + return self._enable_network_isolation + + @security_group_ids.setter + def security_group_ids( + self, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + ): + """Setter for security groups. + + Args: + security_group_ids: List of Security Group Ids. + + """ + self._security_group_ids = security_group_ids + + @subnets.setter + def subnets( + self, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + ): + """Setter for subnets. + + Args: + subnets: List of Subnets. + + """ + self._subnets = subnets + + @enable_network_isolation.setter + def enable_network_isolation( + self, enable_network_isolation: Union[bool, PipelineVariable] = False + ): + """Setter for enable network isolation. + + Args: + enable_network_isolation: Value for enable network isolation + + """ + self._enable_network_isolation = enable_network_isolation + def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" network_config_request = {"EnableNetworkIsolation": self.enable_network_isolation} diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index ad5ed1291c..5d40a59b39 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -19,7 +19,14 @@ from sagemaker import ModelMetrics, Model from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties -from sagemaker.session import Session +from sagemaker.session import ( + Session, + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, + MODEL_VPC_CONFIG_PATH, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + MODEL_EXECUTION_ROLE_ARN_PATH, +) + from sagemaker.utils import ( name_from_image, update_container_with_inference_params, @@ -38,12 +45,12 @@ class PipelineModel(object): def __init__( self, models: List[Model], - role: str, + role: str = None, predictor_cls: Optional[callable] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, - enable_network_isolation: Union[bool, PipelineVariable] = False, + enable_network_isolation: Union[bool, PipelineVariable] = None, ): """Initialize a SageMaker `Model` instance. @@ -80,13 +87,26 @@ def __init__( or from the model container.Boolean """ self.models = models - self.role = role self.predictor_cls = predictor_cls self.name = name - self.vpc_config = vpc_config self.sagemaker_session = sagemaker_session - self.enable_network_isolation = enable_network_isolation self.endpoint_name = None + self.role = self.sagemaker_session.get_sagemaker_config_override( + MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role + ) + self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( + MODEL_VPC_CONFIG_PATH, default_value=vpc_config + ) + self.enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False if enable_network_isolation is None else enable_network_isolation, + ) + if not self.role: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating Pipeline Model.") def pipeline_container_def(self, instance_type=None): """The pipeline definition for deploying this model. @@ -212,6 +232,9 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, ) self.endpoint_name = endpoint_name or self.name + kms_key = self.sagemaker_session.get_sagemaker_config_override( + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, default_value=kms_key + ) data_capture_config_dict = None if data_capture_config is not None: diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 7bc21bec38..0dbff0ca59 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -30,7 +30,6 @@ from six.moves.urllib.parse import urlparse from six.moves.urllib.request import url2pathname from sagemaker import s3 -from sagemaker.config.config_schema import PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig @@ -40,7 +39,16 @@ name_from_base, check_and_get_run_experiment_config, ) -from sagemaker.session import Session +from sagemaker.session import ( + Session, + PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, + PROCESSING_JOB_SUBNETS_PATH, + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, + PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, +) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.functions import Join from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -60,10 +68,10 @@ class Processor(object): def __init__( self, - role: Union[str, PipelineVariable], - image_uri: Union[str, PipelineVariable], - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: str = None, + image_uri: Union[str, PipelineVariable] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, entrypoint: Optional[List[Union[str, PipelineVariable]]] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, @@ -119,14 +127,11 @@ def __init__( object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. """ - self.role = role self.image_uri = image_uri self.instance_count = instance_count self.instance_type = instance_type self.entrypoint = entrypoint self.volume_size_in_gb = volume_size_in_gb - self.volume_kms_key = volume_kms_key - self.output_kms_key = output_kms_key self.max_runtime_in_seconds = max_runtime_in_seconds self.base_job_name = base_job_name self.env = env @@ -143,10 +148,60 @@ def __init__( sagemaker_session = LocalSession(disable_local_code=True) self.sagemaker_session = sagemaker_session or Session() + self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( + PROCESSING_JOB_KMS_KEY_ID_PATH, default_value=output_kms_key + ) + self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( + PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + ) + _enable_network_isolation_from_config = ( + self.sagemaker_session.get_sagemaker_config_override( + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH + ) + ) + + _subnets_from_config = self.sagemaker_session.get_sagemaker_config_override( + PROCESSING_JOB_SUBNETS_PATH + ) + _security_group_ids_from_config = self.sagemaker_session.get_sagemaker_config_override( + PROCESSING_JOB_SECURITY_GROUP_IDS_PATH + ) + if network_config: + if not network_config.subnets: + network_config.subnets = _subnets_from_config + if network_config.enable_network_isolation is None: + network_config.enable_network_isolation = ( + _enable_network_isolation_from_config or False + ) + if not network_config.security_group_ids: + network_config.security_group_ids = _security_group_ids_from_config + self.network_config = network_config + else: + if ( + _enable_network_isolation_from_config is not None + or _subnets_from_config + or _security_group_ids_from_config + ): + self.network_config = NetworkConfig( + enable_network_isolation=_enable_network_isolation_from_config or False, + security_group_ids=_security_group_ids_from_config, + subnets=_subnets_from_config, + ) + else: + self.network_config = None + self.role = self.sagemaker_session.get_sagemaker_config_override( + PROCESSING_JOB_ROLE_ARN_PATH, default_value=role + ) + if not self.role: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating Processing jobs.") self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( NetworkConfig, - network_config, + self.network_config, "encrypt_inter_container_traffic", PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, ) @@ -445,11 +500,11 @@ class ScriptProcessor(Processor): def __init__( self, - role: Union[str, PipelineVariable], - image_uri: Union[str, PipelineVariable], - command: List[str], - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: Optional[Union[str, PipelineVariable]] = None, + image_uri: Union[str, PipelineVariable] = None, + command: List[str] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, @@ -1201,10 +1256,20 @@ def __init__( self.s3_data_distribution_type = s3_data_distribution_type self.s3_compression_type = s3_compression_type self.s3_input = s3_input - self.dataset_definition = dataset_definition + self._dataset_definition = dataset_definition self.app_managed = app_managed self._create_s3_input() + @property + def dataset_definition(self): + """Getter for DataSetDefinition + + Returns: + DatasetDefinition: The DatasetDefinition Object. + + """ + return self._dataset_definition + def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" @@ -1368,9 +1433,9 @@ def __init__( self, estimator_cls: type, framework_version: str, - role: Union[str, PipelineVariable], - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: Optional[Union[str, PipelineVariable]] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 73ecdd7ad7..6da71c8bde 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -85,8 +85,8 @@ class PyTorchModel(FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, - entry_point: str, + role: Optional[str] = None, + entry_point: Optional[str] = None, framework_version: str = "1.3", py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index 73551243e3..70fc96497e 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -34,9 +34,9 @@ class PyTorchProcessor(FrameworkProcessor): def __init__( self, framework_version: str, # New arg - role: str, - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: Optional[Union[str, PipelineVariable]] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, py_version: str = "py3", # New kwarg image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2422b31827..ba2db269e9 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -28,32 +28,69 @@ import boto3 import botocore import botocore.config +from botocore.utils import merge_dicts from botocore.exceptions import ClientError import six import sagemaker.logs from sagemaker import vpc_utils -from sagemaker.config.config import SageMakerConfig -from sagemaker.config.config_schema import ( - AUTO_ML, - COMPILATION_JOB, +from sagemaker._studio import _append_project_tags +from sagemaker.config import ( # noqa: F401 + SageMakerConfig, + SAGEMAKER, + TRAINING_JOB, + ENABLE_NETWORK_ISOLATION, + KMS_KEY_ID, + RESOURCE_CONFIG, + VOLUME_KMS_KEY_ID, + ROLE_ARN, + VPC_CONFIG, + SECURITY_GROUP_IDS, + SUBNETS, EDGE_PACKAGING_JOB, - ENDPOINT_CONFIG, + OUTPUT_CONFIG, FEATURE_GROUP, - KEY, - SAGEMAKER, - MODEL, + OFFLINE_STORE_CONFIG, + ONLINE_STORE_CONFIG, + AUTO_ML, + AUTO_ML_JOB_CONFIG, + SECURITY_CONFIG, + OUTPUT_DATA_CONFIG, MONITORING_SCHEDULE, + MONITORING_SCHEDULE_CONFIG, + MONITORING_JOB_DEFINITION, + MONITORING_OUTPUT_CONFIG, + MONITORING_RESOURCES, + CLUSTER_CONFIG, + NETWORK_CONFIG, + TRANSFORM_JOB, + TRANSFORM_OUTPUT, + TRANSFORM_RESOURCES, + DATA_CAPTURE_CONFIG, + MODEL, + EXECUTION_ROLE_ARN, + S3_STORAGE_CONFIG, + ENDPOINT_CONFIG, + PIPELINE, + COMPILATION_JOB, PROCESSING_JOB, + PROCESSING_INPUTS, + DATASET_DEFINITION, + REDSHIFT_DATASET_DEFINITION, + ATHENA_DATASET_DEFINITION, + CLUSTER_ROLE_ARN, + PROCESSING_OUTPUT_CONFIG, + PROCESSING_RESOURCES, + ASYNC_INFERENCE_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, TAGS, - TRAINING_JOB, - TRANSFORM_JOB, - PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, - PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, - PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, - PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, + KEY, + PRODUCTION_VARIANTS, + VALIDATION_ROLE, + VALIDATION_PROFILES, + MODEL_PACKAGE, + VALIDATION_SPECIFICATION, ) -from sagemaker._studio import _append_project_tags from sagemaker.deprecations import deprecated_class from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig from sagemaker.user_agent import prepend_user_agent @@ -85,6 +122,167 @@ } +def _simple_path(*args: str): + """Appends an arbitrary number of strings to use as path constants""" + return ".".join(args) + + +COMPILATION_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, VPC_CONFIG) +COMPILATION_JOB_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG, KMS_KEY_ID +) +COMPILATION_JOB_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG) +COMPILATION_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, ROLE_ARN) +TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( + SAGEMAKER, TRAINING_JOB, ENABLE_NETWORK_ISOLATION +) +TRAINING_JOB_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, OUTPUT_DATA_CONFIG, KMS_KEY_ID) +TRAINING_JOB_RESOURCE_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG) +TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, OUTPUT_DATA_CONFIG) +TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG, VOLUME_KMS_KEY_ID +) +TRAINING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, ROLE_ARN) +TRAINING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, VPC_CONFIG) +TRAINING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( + TRAINING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS +) +TRAINING_JOB_SUBNETS_PATH = _simple_path(TRAINING_JOB_VPC_CONFIG_PATH, SUBNETS) +EDGE_PACKAGING_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG, KMS_KEY_ID +) +EDGE_PACKAGING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG) +EDGE_PACKAGING_ROLE_ARN_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, ROLE_ARN) +ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG, KMS_KEY_ID +) +ENDPOINT_CONFIG_DATA_CAPTURE_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG) +ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG +) +ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, PRODUCTION_VARIANTS +) +ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG, OUTPUT_CONFIG, KMS_KEY_ID +) +ENDPOINT_CONFIG_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, KMS_KEY_ID) +FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ONLINE_STORE_CONFIG) +FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH = _simple_path( + SAGEMAKER, FEATURE_GROUP, OFFLINE_STORE_CONFIG +) +FEATURE_GROUP_ROLE_ARN_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ROLE_ARN) +FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH = _simple_path( + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, S3_STORAGE_CONFIG, KMS_KEY_ID +) +FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH = _simple_path( + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, SECURITY_CONFIG, KMS_KEY_ID +) +AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG) +AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG, KMS_KEY_ID) +AUTO_ML_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID +) +AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML, ROLE_ARN) +AUTO_ML_VPC_CONFIG_PATH = _simple_path( + SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG +) +AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG) +MONITORING_JOB_DEFINITION_PREFIX = _simple_path( + SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION +) +MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, MONITORING_OUTPUT_CONFIG, KMS_KEY_ID +) +MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, MONITORING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID +) +MONITORING_JOB_NETWORK_CONFIG_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG) +MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG, ENABLE_NETWORK_ISOLATION +) +MONITORING_JOB_VPC_CONFIG_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG, VPC_CONFIG +) +MONITORING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( + MONITORING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS +) +MONITORING_JOB_SUBNETS_PATH = _simple_path(MONITORING_JOB_VPC_CONFIG_PATH, SUBNETS) +MONITORING_JOB_ROLE_ARN_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, ROLE_ARN) +PIPELINE_ROLE_ARN_PATH = _simple_path(SAGEMAKER, PIPELINE, ROLE_ARN) +PIPELINE_TAGS_PATH = _simple_path(SAGEMAKER, PIPELINE, TAGS) +TRANSFORM_OUTPUT_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, TRANSFORM_OUTPUT, KMS_KEY_ID +) +TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, TRANSFORM_RESOURCES, VOLUME_KMS_KEY_ID +) +TRANSFORM_JOB_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, DATA_CAPTURE_CONFIG, KMS_KEY_ID +) +MODEL_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, MODEL, VPC_CONFIG) +MODEL_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(SAGEMAKER, MODEL, ENABLE_NETWORK_ISOLATION) +MODEL_EXECUTION_ROLE_ARN_PATH = _simple_path(SAGEMAKER, MODEL, EXECUTION_ROLE_ARN) +PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_NETWORK_ISOLATION +) +PROCESSING_JOB_INPUTS_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, PROCESSING_INPUTS) +REDSHIFT_DATASET_DEFINITION_KMS_KEY_ID_PATH = _simple_path( + DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION, KMS_KEY_ID +) +ATHENA_DATASET_DEFINITION_KMS_KEY_ID_PATH = _simple_path( + DATASET_DEFINITION, ATHENA_DATASET_DEFINITION, KMS_KEY_ID +) +REDSHIFT_DATASET_DEFINITION_CLUSTER_ROLE_ARN_PATH = _simple_path( + DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION, CLUSTER_ROLE_ARN +) +PROCESSING_JOB_NETWORK_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG) +PROCESSING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, VPC_CONFIG) +PROCESSING_JOB_SUBNETS_PATH = _simple_path(PROCESSING_JOB_VPC_CONFIG_PATH, SUBNETS) +PROCESSING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( + PROCESSING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS +) +PROCESSING_OUTPUT_CONFIG_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG, KMS_KEY_ID +) +PROCESSING_JOB_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG, KMS_KEY_ID +) +PROCESSING_JOB_PROCESSING_RESOURCES_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, PROCESSING_RESOURCES +) +PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, PROCESSING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID +) +PROCESSING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, ROLE_ARN) +MODEL_PACKAGE_VALIDATION_ROLE_PATH = _simple_path( + SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_ROLE +) +MODEL_PACKAGE_VALIDATION_PROFILES_PATH = _simple_path( + SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES +) + +# Paths for reference elsewhere in the SDK. +# Names include the schema version since the paths could change with other schema versions +PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, + MONITORING_SCHEDULE, + MONITORING_SCHEDULE_CONFIG, + MONITORING_JOB_DEFINITION, + NETWORK_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, +) +PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, AUTO_ML, SECURITY_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) +PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) +PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( + SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) + + class LogState(object): """Placeholder docstring""" @@ -147,6 +345,9 @@ def __init__( Client which makes SageMaker Metrics related calls to Amazon SageMaker (default: None). If not provided, one will be created using this instance's ``boto_session``. + sagemaker_config (sagemaker.config.SageMakerConfig): The SageMakerConfig object which + holds the default values for the SageMaker Python SDK. (default: None). If not + provided, This class will create its own SageMakerConfig object. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -214,11 +415,14 @@ def _initialize( prepend_user_agent(self.sagemaker_metrics_client) self.local_mode = False - if sagemaker_config: self.sagemaker_config = sagemaker_config else: - self.sagemaker_config = SageMakerConfig(s3_resource=self.boto_session.resource("s3")) + if self.s3_resource is None: + s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) + else: + s3 = self.s3_resource + self.sagemaker_config = SageMakerConfig(s3_resource=s3) @property def boto_region_name(self): @@ -435,6 +639,27 @@ def default_bucket(self): return self._default_bucket + def get_sagemaker_config_override(self, key, default_value=None): + """Util method that fetches a particular key path in the SageMakerConfig and returns it. + + If a default value is provided, then this method will return the default value. + + Args: + key: Key Path of the config file entry. + default_value: The existing value that was passed as method parameter. If this is not + None, then the method will return this value + + Returns: + object: The corresponding value in the Config file/ the default value. + + """ + if default_value is not None: + return default_value + config_value = get_config_value(key, self.sagemaker_config.config) + if config_value is not None: + self._print_message_sagemaker_config_used(config_value, key) + return config_value + def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): """Creates an S3 Bucket if it does not exist. @@ -503,36 +728,36 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): else: raise - def _get_sagemaker_config_value(self, config_path: str): - """returns the value of the config at the path provided""" - return get_config_value(config_path, self.sagemaker_config.config) - def _print_message_sagemaker_config_used(self, config_value, config_path): """Informs the SDK user that a config value was substituted in automatically""" print( "[Sagemaker Config] config value {} at config path {}".format( config_value, config_path ), - "was automatically applied" + "was automatically applied", ) def _print_message_sagemaker_config_present_but_not_used( self, direct_input, config_value, config_path ): - """Informs the SDK user that a config value was not substituted in automatically despite - existing""" + """Informs the SDK user that a config value was not substituted in automatically. + + This is because method parameter is already provided. + + """ print( "[Sagemaker Config] value {} was specified,".format(direct_input), "so config value {} at config path {} was not applied".format( config_value, config_path - ) + ), ) def resolve_value_from_config( self, direct_input=None, config_path: str = None, default_value=None ): - """Makes a decision of which value is the right value for the caller to use while - incorporating info from the sagemaker config. + """Makes a decision of which value is the right value for the caller to use. + + Note: This method also incorporates info from the sagemaker config. Uses this order of prioritization: (1) direct_input, (2) config value, (3) default_value, (4) None @@ -547,7 +772,7 @@ def resolve_value_from_config( Returns: The value that should be used by the caller """ - config_value = self._get_sagemaker_config_value(config_path) + config_value = self.get_sagemaker_config_override(config_path) if direct_input is not None: if config_value is not None: @@ -567,7 +792,9 @@ def resolve_value_from_config( def resolve_class_attribute_from_config( self, clazz, instance, attribute: str, config_path: str, default_value=None ): - """Takes an instance of a class and, if not already set, sets the instance's attribute to a + """Utility method that merges config values to data classes. + + Takes an instance of a class and, if not already set, sets the instance's attribute to a value fetched from the sagemaker_config or the default_value. Uses this order of prioritization to determine what the value of the attribute should be: @@ -589,7 +816,7 @@ def resolve_class_attribute_from_config( The updated class instance that should be used by the caller instead of the 'instance' parameter that was passed in. """ - config_value = self._get_sagemaker_config_value(config_path) + config_value = self.get_sagemaker_config_override(config_path) if config_value is None and default_value is None: # return instance unmodified. Could be None or populated @@ -627,12 +854,14 @@ def resolve_class_attribute_from_config( def resolve_nested_dict_value_from_config( self, dictionary: dict, - nested_keys: list[str], + nested_keys: List[str], config_path: str, default_value: object = None, ): - """Takes a dictionary and, if not already set, sets the value for the provided list of - nested keys to the value fetched from the sagemaker_config or the default_value. + """Utility method that sets the value of a key path in a nested dictionary . + + This method takes a dictionary and, if not already set, sets the value for the provided + list of nested keys to the value fetched from the sagemaker_config or the default_value. Uses this order of prioritization to determine what the value of the attribute should be: (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it @@ -648,7 +877,7 @@ def resolve_nested_dict_value_from_config( The updated dictionary that should be used by the caller instead of the 'dictionary' parameter that was passed in. """ - config_value = self._get_sagemaker_config_value(config_path) + config_value = self.get_sagemaker_config_override(config_path) if config_value is None and default_value is None: # if there is nothing to set, return early. And there is no need to traverse through @@ -689,7 +918,7 @@ def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): Returns: A potentially extended list of tags. """ - config_tags = self._get_sagemaker_config_value(config_path_to_tags) + config_tags = self.get_sagemaker_config_override(config_path_to_tags) if config_tags is None or len(config_tags) == 0: return tags @@ -717,16 +946,16 @@ def train( # noqa: C901 self, input_mode, input_config, - role, - job_name, - output_config, - resource_config, - vpc_config, - hyperparameters, - stop_condition, - tags, - metric_definitions, - enable_network_isolation=False, + role=None, + job_name=None, + output_config=None, + resource_config=None, + vpc_config=None, + hyperparameters=None, + stop_condition=None, + tags=None, + metric_definitions=None, + enable_network_isolation=None, image_uri=None, training_image_config=None, algorithm_arn=None, @@ -864,15 +1093,29 @@ def train( # noqa: C901 config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, default_value=False, ) - + role = self.get_sagemaker_config_override(TRAINING_JOB_ROLE_ARN_PATH, role) + enable_network_isolation = self.resolve_value_from_config( + direct_input=enable_network_isolation, + config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False, + ) + inferred_vpc_config = self._update_nested_dictionary_with_values_from_config( + vpc_config, TRAINING_JOB_VPC_CONFIG_PATH + ) + inferred_output_config = self._update_nested_dictionary_with_values_from_config( + output_config, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH + ) + inferred_resource_config = self._update_nested_dictionary_with_values_from_config( + resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH + ) train_request = self._get_train_request( input_mode=input_mode, input_config=input_config, role=role, job_name=job_name, - output_config=output_config, - resource_config=resource_config, - vpc_config=vpc_config, + output_config=inferred_output_config, + resource_config=inferred_resource_config, + vpc_config=inferred_vpc_config, hyperparameters=hyperparameters, stop_condition=stop_condition, tags=tags, @@ -1198,6 +1441,49 @@ def _get_update_training_job_request( return update_training_job_request + def update_processing_input_from_config(self, inputs): + """Updates Processor Inputs to fetch values from SageMakerConfig wherever applicable. + + Args: + inputs (list[dict]): A list of Processing Input objects. + + """ + processing_inputs_from_config = ( + self.get_sagemaker_config_override(PROCESSING_JOB_INPUTS_PATH) or [] + ) + for i in range(min(len(inputs), len(processing_inputs_from_config))): + processing_input_from_config = processing_inputs_from_config[i] + if "DatasetDefinition" in inputs[i]: + dataset_definition = inputs[i]["DatasetDefinition"] + if "AthenaDatasetDefinition" in dataset_definition: + athena_dataset_definition = dataset_definition["AthenaDatasetDefinition"] + if KMS_KEY_ID not in athena_dataset_definition: + athena_kms_key_id_from_config = get_config_value( + ATHENA_DATASET_DEFINITION_KMS_KEY_ID_PATH, processing_input_from_config + ) + if athena_kms_key_id_from_config: + athena_dataset_definition[KMS_KEY_ID] = athena_kms_key_id_from_config + if "RedshiftDatasetDefinition" in dataset_definition: + redshift_dataset_definition = dataset_definition["RedshiftDatasetDefinition"] + if CLUSTER_ROLE_ARN not in redshift_dataset_definition: + redshift_role_arn_from_config = get_config_value( + REDSHIFT_DATASET_DEFINITION_CLUSTER_ROLE_ARN_PATH, + processing_input_from_config, + ) + if redshift_role_arn_from_config: + redshift_dataset_definition[ + CLUSTER_ROLE_ARN + ] = redshift_role_arn_from_config + if not redshift_dataset_definition.kms_key_id: + redshift_kms_key_id_from_config = get_config_value( + REDSHIFT_DATASET_DEFINITION_KMS_KEY_ID_PATH, + processing_input_from_config, + ) + if redshift_kms_key_id_from_config: + redshift_dataset_definition[ + KMS_KEY_ID + ] = redshift_kms_key_id_from_config + def process( self, inputs, @@ -1207,9 +1493,9 @@ def process( stopping_condition, app_specification, environment, - network_config, - role_arn, - tags, + network_config=None, + role_arn=None, + tags=None, experiment_config=None, ): """Create an Amazon SageMaker processing job. @@ -1262,15 +1548,30 @@ def process( PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, ) + self.update_processing_input_from_config(inputs) + role_arn = self.get_sagemaker_config_override( + PROCESSING_JOB_ROLE_ARN_PATH, default_value=role_arn + ) + inferred_network_config_from_config = ( + self._update_nested_dictionary_with_values_from_config( + network_config, PROCESSING_JOB_NETWORK_CONFIG_PATH + ) + ) + inferred_output_config = self._update_nested_dictionary_with_values_from_config( + output_config, PROCESSING_OUTPUT_CONFIG_PATH + ) + inferred_resources_config = self._update_nested_dictionary_with_values_from_config( + resources, PROCESSING_JOB_PROCESSING_RESOURCES_PATH + ) process_request = self._get_process_request( inputs=inputs, - output_config=output_config, + output_config=inferred_output_config, job_name=job_name, - resources=resources, + resources=inferred_resources_config, stopping_condition=stopping_condition, app_specification=app_specification, environment=environment, - network_config=network_config, + network_config=inferred_network_config_from_config, role_arn=role_arn, tags=tags, experiment_config=experiment_config, @@ -1380,17 +1681,17 @@ def create_monitoring_schedule( instance_count, instance_type, volume_size_in_gb, - volume_kms_key, - image_uri, - entrypoint, - arguments, - record_preprocessor_source_uri, - post_analytics_processor_source_uri, - max_runtime_in_seconds, - environment, - network_config, - role_arn, - tags, + volume_kms_key=None, + image_uri=None, + entrypoint=None, + arguments=None, + record_preprocessor_source_uri=None, + post_analytics_processor_source_uri=None, + max_runtime_in_seconds=None, + environment=None, + network_config=None, + role_arn=None, + tags=None, ): """Create an Amazon SageMaker monitoring schedule. @@ -1430,6 +1731,17 @@ def create_monitoring_schedule( tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. """ + role_arn = self.get_sagemaker_config_override( + MONITORING_JOB_ROLE_ARN_PATH, default_value=role_arn + ) + volume_kms_key = self.get_sagemaker_config_override( + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + ) + inferred_network_config_from_config = ( + self._update_nested_dictionary_with_values_from_config( + network_config, MONITORING_JOB_NETWORK_CONFIG_PATH + ) + ) monitoring_schedule_request = { "MonitoringScheduleName": monitoring_schedule_name, "MonitoringScheduleConfig": { @@ -1454,6 +1766,11 @@ def create_monitoring_schedule( } if monitoring_output_config is not None: + kms_key_from_config = self.get_sagemaker_config_override( + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH + ) + if KMS_KEY_ID not in monitoring_output_config and kms_key_from_config: + monitoring_output_config[KMS_KEY_ID] = kms_key_from_config monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringOutputConfig" ] = monitoring_output_config @@ -1508,15 +1825,10 @@ def create_monitoring_schedule( "Environment" ] = environment - network_config = self.resolve_nested_dict_value_from_config( - network_config, - ["EnableInterContainerTrafficEncryption"], - PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, - ) - if network_config is not None: + if inferred_network_config_from_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "NetworkConfig" - ] = network_config + ] = inferred_network_config_from_config tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( @@ -2089,8 +2401,8 @@ def auto_ml( input_config, output_config, auto_ml_job_config, - role, - job_name, + role=None, + job_name=None, problem_type=None, job_objective=None, generate_candidate_definitions_only=False, @@ -2124,16 +2436,17 @@ def auto_ml( Contains "AutoGenerateEndpointName" and "EndpointName" """ - auto_ml_job_config = self.resolve_nested_dict_value_from_config( - auto_ml_job_config, - ["SecurityConfig", "EnableInterContainerTrafficEncryption"], - PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, + role = self.get_sagemaker_config_override(AUTO_ML_ROLE_ARN_PATH, default_value=role) + inferred_output_config = self._update_nested_dictionary_with_values_from_config( + output_config, AUTO_ML_OUTPUT_CONFIG_PATH + ) + inferred_automl_job_config = self._update_nested_dictionary_with_values_from_config( + auto_ml_job_config, AUTO_ML_JOB_CONFIG_PATH ) - auto_ml_job_request = self._get_auto_ml_request( input_config=input_config, - output_config=output_config, - auto_ml_job_config=auto_ml_job_config, + output_config=inferred_output_config, + auto_ml_job_config=inferred_automl_job_config, role=role, job_name=job_name, problem_type=problem_type, @@ -2375,7 +2688,13 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m print() def compile_model( - self, input_model_config, output_model_config, role, job_name, stop_condition, tags + self, + input_model_config, + output_model_config, + role=None, + job_name=None, + stop_condition=None, + tags=None, ): """Create an Amazon SageMaker Neo compilation job. @@ -2396,13 +2715,20 @@ def compile_model( Returns: str: ARN of the compile model job, if it is created. """ + role = self.get_sagemaker_config_override(COMPILATION_JOB_ROLE_ARN_PATH, default_value=role) + inferred_output_model_config = self._update_nested_dictionary_with_values_from_config( + output_model_config, COMPILATION_JOB_OUTPUT_CONFIG_PATH + ) + vpc_config = self.get_sagemaker_config_override(COMPILATION_JOB_VPC_CONFIG_PATH) compilation_job_request = { "InputConfig": input_model_config, - "OutputConfig": output_model_config, + "OutputConfig": inferred_output_model_config, "RoleArn": role, "StoppingCondition": stop_condition, "CompilationJobName": job_name, } + if vpc_config: + compilation_job_request["VpcConfig"] = vpc_config tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( @@ -2417,13 +2743,13 @@ def compile_model( def package_model_for_edge( self, output_model_config, - role, - job_name, - compilation_job_name, - model_name, - model_version, - resource_key, - tags, + role=None, + job_name=None, + compilation_job_name=None, + model_name=None, + model_version=None, + resource_key=None, + tags=None, ): """Create an Amazon SageMaker Edge packaging job. @@ -2439,15 +2765,19 @@ def package_model_for_edge( tags (list[dict]): List of tags for labeling a compile model job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ + role = self.get_sagemaker_config_override(EDGE_PACKAGING_ROLE_ARN_PATH, role) + inferred_output_model_config = self._update_nested_dictionary_with_values_from_config( + output_model_config, EDGE_PACKAGING_OUTPUT_CONFIG_PATH + ) edge_packaging_job_request = { - "OutputConfig": output_model_config, + "OutputConfig": inferred_output_model_config, "RoleArn": role, "ModelName": model_name, "ModelVersion": model_version, "EdgePackagingJobName": job_name, "CompilationJobName": compilation_job_name, } - + tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, EDGE_PACKAGING_JOB, TAGS) ) @@ -3149,8 +3479,8 @@ def transform( output_config, resource_config, experiment_config, - tags, - data_processing, + tags=None, + data_processing=None, model_client_config=None, batch_data_capture_config: BatchDataCaptureConfig = None, ): @@ -3195,6 +3525,23 @@ def transform( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, TRANSFORM_JOB, TAGS) ) + if batch_data_capture_config: + if not batch_data_capture_config.kms_key_id: + batch_data_capture_config.kms_key_id = self.get_sagemaker_config_override( + TRANSFORM_JOB_KMS_KEY_ID_PATH + ) + if output_config: + kms_key_from_config = self.get_sagemaker_config_override( + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH + ) + if KMS_KEY_ID not in output_config and kms_key_from_config: + output_config[KMS_KEY_ID] = kms_key_from_config + if resource_config: + volume_kms_key_from_config = self.get_sagemaker_config_override( + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH + ) + if VOLUME_KMS_KEY_ID not in resource_config and volume_kms_key_from_config: + resource_config[VOLUME_KMS_KEY_ID] = volume_kms_key_from_config transform_request = self._get_transform_request( job_name=job_name, @@ -3273,10 +3620,10 @@ def _create_model_request( def create_model( self, name, - role, - container_defs, + role=None, + container_defs=None, vpc_config=None, - enable_network_isolation=False, + enable_network_isolation=None, primary_container=None, tags=None, ): @@ -3322,7 +3669,15 @@ def create_model( """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS)) - + role = self.get_sagemaker_config_override(MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role) + vpc_config = self.get_sagemaker_config_override( + MODEL_VPC_CONFIG_PATH, default_value=vpc_config + ) + enable_network_isolation = self.resolve_value_from_config( + direct_input=enable_network_isolation, + config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False, + ) create_model_request = self._create_model_request( name=name, role=role, @@ -3360,7 +3715,7 @@ def create_model_from_job( image_uri=None, model_data_url=None, env=None, - enable_network_isolation=False, + enable_network_isolation=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, tags=None, ): @@ -3395,6 +3750,14 @@ def create_model_from_job( ) name = name or training_job_name role = role or training_job["RoleArn"] + role = self.get_sagemaker_config_override( + MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role or training_job["RoleArn"] + ) + enable_network_isolation = self.resolve_value_from_config( + direct_input=enable_network_isolation, + config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False, + ) env = env or {} primary_container = container_def( image_uri or training_job["AlgorithmSpecification"]["TrainingImage"], @@ -3402,6 +3765,9 @@ def create_model_from_job( env=env, ) vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override) + vpc_config = self.get_sagemaker_config_override( + MODEL_VPC_CONFIG_PATH, default_value=vpc_config + ) return self.create_model( name, role, @@ -3494,7 +3860,53 @@ def create_model_package_from_containers( "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ - + validation_role_from_config = self.get_sagemaker_config_override( + MODEL_PACKAGE_VALIDATION_ROLE_PATH + ) + validation_profiles_from_config = self.get_sagemaker_config_override( + MODEL_PACKAGE_VALIDATION_PROFILES_PATH + ) + if validation_role_from_config or validation_profiles_from_config: + if not validation_specification: + validation_specification = { + VALIDATION_ROLE: validation_role_from_config, + VALIDATION_PROFILES: validation_profiles_from_config, + } + else: + # ValidationSpecification is provided as method parameter + # Now we need to carefully merge + if VALIDATION_ROLE not in validation_specification: + # if Validation role is not provided as part of the dict, merge it + validation_specification[VALIDATION_ROLE] = validation_role_from_config + if VALIDATION_PROFILES not in validation_specification: + # if Validation profile is not provided as part of the dict, merge it + validation_specification[VALIDATION_PROFILES] = validation_profiles_from_config + elif validation_profiles_from_config: + # Validation profiles are provided in the config as well as parameter. + validation_profiles = validation_specification[VALIDATION_PROFILES] + for i in range( + min(len(validation_profiles), len(validation_profiles_from_config)) + ): + # Now we need to merge corresponding entries which are not provided in the + # dict , but are present in the config + validation_profile = validation_profiles[i] + validation_profile_from_config = validation_profiles_from_config[i] + original_config_dict_value = validation_profile_from_config.copy() + # Apply the default configurations on top of the config entries + merge_dicts(validation_profile_from_config, validation_profile) + if validation_profile != validation_profile_from_config: + print( + "Config value {} at config path {} was fetched first for " + "index {}.".format( + original_config_dict_value, + MODEL_PACKAGE_VALIDATION_PROFILES_PATH, + i, + ), + "It was then merged with the existing value {} to give {}".format( + validation_profile, validation_profile_from_config + ), + ) + validation_profile.update(validation_profile_from_config) model_pkg_request = get_create_model_package_request( model_package_name, model_package_group_name, @@ -3637,20 +4049,39 @@ def create_endpoint_config( LOGGER.info("Creating endpoint-config with name %s", name) tags = tags or [] + provided_production_variant = production_variant( + model_name, + instance_type, + initial_instance_count, + accelerator_type=accelerator_type, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, + ) + inferred_production_variants_from_config = ( + self.get_sagemaker_config_override(ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH) or [] + ) + if inferred_production_variants_from_config: + inferred_production_variant_from_config = ( + inferred_production_variants_from_config[0] or {} + ) + original_config_dict_value = inferred_production_variant_from_config.copy() + merge_dicts(inferred_production_variant_from_config, provided_production_variant) + if provided_production_variant != inferred_production_variant_from_config: + print( + "Config value {} at config path {} was fetched first for " + "index: 0.".format( + original_config_dict_value, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH + ), + "It was then merged with the existing value {} to give {}".format( + provided_production_variant, inferred_production_variant_from_config + ), + ) + provided_production_variant.update(inferred_production_variant_from_config) request = { "EndpointConfigName": name, - "ProductionVariants": [ - production_variant( - model_name, - instance_type, - initial_instance_count, - accelerator_type=accelerator_type, - volume_size=volume_size, - model_data_download_timeout=model_data_download_timeout, - container_startup_health_check_timeout=container_startup_health_check_timeout, - ) - ], + "ProductionVariants": [provided_production_variant], } tags = _append_project_tags(tags) @@ -3659,12 +4090,19 @@ def create_endpoint_config( ) if tags is not None: request["Tags"] = tags - + kms_key = self.get_sagemaker_config_override( + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, default_value=kms_key + ) if kms_key is not None: request["KmsKeyId"] = kms_key if data_capture_config_dict is not None: - request["DataCaptureConfig"] = data_capture_config_dict + inferred_data_capture_config_dict = ( + self._update_nested_dictionary_with_values_from_config( + data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH + ) + ) + request["DataCaptureConfig"] = inferred_data_capture_config_dict self.sagemaker_client.create_endpoint_config(**request) return name @@ -3718,6 +4156,27 @@ def create_endpoint_config_from_existing( "EndpointConfigName": new_config_name, } + if new_production_variants: + inferred_production_variants_from_config = ( + self.get_sagemaker_config_override(ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH) or [] + ) + for i in range( + min(len(new_production_variants), len(inferred_production_variants_from_config)) + ): + original_config_dict_value = inferred_production_variants_from_config[i].copy() + merge_dicts(inferred_production_variants_from_config[i], new_production_variants[i]) + if new_production_variants[i] != inferred_production_variants_from_config[i]: + print( + "Config value {} at config path {} was fetched first for " + "index: 0.".format( + original_config_dict_value, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH + ), + "It was then merged with the existing value {} to give {}".format( + new_production_variants[i], inferred_production_variants_from_config[i] + ), + ) + new_production_variants[i].update(inferred_production_variants_from_config[i]) + request["ProductionVariants"] = ( new_production_variants or existing_endpoint_config_desc["ProductionVariants"] ) @@ -3734,18 +4193,35 @@ def create_endpoint_config_from_existing( if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None: request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId") + if KMS_KEY_ID not in request: + kms_key_from_config = self.get_sagemaker_config_override( + ENDPOINT_CONFIG_KMS_KEY_ID_PATH + ) + if kms_key_from_config: + request[KMS_KEY_ID] = kms_key_from_config request_data_capture_config_dict = ( new_data_capture_config_dict or existing_endpoint_config_desc.get("DataCaptureConfig") ) if request_data_capture_config_dict is not None: - request["DataCaptureConfig"] = request_data_capture_config_dict + inferred_data_capture_config_dict = ( + self._update_nested_dictionary_with_values_from_config( + request_data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH + ) + ) + request["DataCaptureConfig"] = inferred_data_capture_config_dict if existing_endpoint_config_desc.get("AsyncInferenceConfig") is not None: - request["AsyncInferenceConfig"] = existing_endpoint_config_desc.get( + async_inference_config_dict = existing_endpoint_config_desc.get( "AsyncInferenceConfig", None ) + inferred_async_inference_config_dict = ( + self._update_nested_dictionary_with_values_from_config( + async_inference_config_dict, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH + ) + ) + request["AsyncInferenceConfig"] = inferred_async_inference_config_dict self.sagemaker_client.create_endpoint_config(**request) @@ -4284,6 +4760,9 @@ def endpoint_from_production_variants( str: The name of the created ``Endpoint``. """ config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} + kms_key = self.get_sagemaker_config_override( + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, default_value=kms_key + ) tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) @@ -4293,9 +4772,19 @@ def endpoint_from_production_variants( if kms_key: config_options["KmsKeyId"] = kms_key if data_capture_config_dict is not None: - config_options["DataCaptureConfig"] = data_capture_config_dict + inferred_data_capture_config_dict = ( + self._update_nested_dictionary_with_values_from_config( + data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH + ) + ) + config_options["DataCaptureConfig"] = inferred_data_capture_config_dict if async_inference_config_dict is not None: - config_options["AsyncInferenceConfig"] = async_inference_config_dict + inferred_async_inference_config_dict = ( + self._update_nested_dictionary_with_values_from_config( + async_inference_config_dict, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH + ) + ) + config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict LOGGER.info("Creating endpoint-config with name %s", name) self.sagemaker_client.create_endpoint_config(**config_options) @@ -4705,7 +5194,7 @@ def create_feature_group( record_identifier_name: str, event_time_feature_name: str, feature_definitions: Sequence[Dict[str, str]], - role_arn: str, + role_arn: str = None, online_store_config: Dict[str, str] = None, offline_store_config: Dict[str, str] = None, description: str = None, @@ -4733,7 +5222,19 @@ def create_feature_group( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) ) - + role_arn = self.get_sagemaker_config_override( + FEATURE_GROUP_ROLE_ARN_PATH, default_value=role_arn + ) + inferred_online_store_from_config = self._update_nested_dictionary_with_values_from_config( + online_store_config, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH + ) + if len(inferred_online_store_from_config) > 0: + # OnlineStore should be handled differently because if you set KmsKeyId, then you + # need to set EnableOnlineStore key as well + inferred_online_store_from_config["EnableOnlineStore"] = True + inferred_offline_store_from_config = self._update_nested_dictionary_with_values_from_config( + offline_store_config, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH + ) kwargs = dict( FeatureGroupName=feature_group_name, RecordIdentifierFeatureName=record_identifier_name, @@ -4743,8 +5244,8 @@ def create_feature_group( ) update_args( kwargs, - OnlineStoreConfig=online_store_config, - OfflineStoreConfig=offline_store_config, + OnlineStoreConfig=inferred_online_store_from_config, + OfflineStoreConfig=inferred_offline_store_from_config, Description=description, Tags=tags, ) @@ -5074,6 +5575,40 @@ def wait_for_athena_query(self, query_execution_id: str, poll: int = 5): else: LOGGER.error("Failed to execute query %s.", query_execution_id) + def _update_nested_dictionary_with_values_from_config( + self, source_dict, config_key_path + ) -> dict: + """Updates a given nested dictionary with missing values which are present in Config. + + Args: + source_dict: The input nested dictionary that was provided as method parameter. + config_key_path: The Key Path in the Config file which corresponds to this + source_dict parameter. + + Returns: + dict: The merged nested dictionary which includes missings values that are present + in the Config file. + + """ + inferred_config_dict = self.get_sagemaker_config_override(config_key_path) or {} + original_config_dict_value = inferred_config_dict.copy() + merge_dicts(inferred_config_dict, source_dict or {}) + if source_dict == inferred_config_dict: + # Corresponds to the case where we didn't use any values from Config. + self._print_message_sagemaker_config_present_but_not_used( + source_dict, original_config_dict_value, config_key_path + ) + else: + print( + "Config value {} at config path {} was fetched first.".format( + original_config_dict_value, config_key_path + ), + "It was then merged with the existing value {} to give {}".format( + source_dict, inferred_config_dict + ), + ) + return inferred_config_dict + def download_athena_query_result( self, bucket: str, diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 1aead6b51d..46425c0660 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -78,8 +78,8 @@ class SKLearnModel(FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, - entry_point: str, + role: Optional[str] = None, + entry_point: Optional[str] = None, framework_version: Optional[str] = None, py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/sklearn/processing.py b/src/sagemaker/sklearn/processing.py index 93896eaf50..86d0df9113 100644 --- a/src/sagemaker/sklearn/processing.py +++ b/src/sagemaker/sklearn/processing.py @@ -32,9 +32,9 @@ class SKLearnProcessor(ScriptProcessor): def __init__( self, framework_version: str, # New arg - role: str, - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: Optional[Union[str, PipelineVariable]] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, command: Optional[List[str]] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index afb3e04599..c35d97a588 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -94,9 +94,9 @@ class _SparkProcessorBase(ScriptProcessor): def __init__( self, - role, - instance_type, - instance_count, + role=None, + instance_type=None, + instance_count=None, framework_version=None, py_version=None, container_version=None, diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index dfcec76c15..b0eaf753fb 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -131,7 +131,7 @@ class TensorFlowModel(sagemaker.model.FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, + role: str = None, entry_point: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, diff --git a/src/sagemaker/tensorflow/processing.py b/src/sagemaker/tensorflow/processing.py index af5ac68b9d..e4495a39fd 100644 --- a/src/sagemaker/tensorflow/processing.py +++ b/src/sagemaker/tensorflow/processing.py @@ -34,9 +34,9 @@ class TensorFlowProcessor(FrameworkProcessor): def __init__( self, framework_version: str, # New arg - role: str, - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: Optional[Union[str, PipelineVariable]] = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, py_version: str = "py3", # New kwarg image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 40ed143ebc..41c158e935 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -20,7 +20,14 @@ from botocore import exceptions from sagemaker.job import _Job -from sagemaker.session import Session, get_execution_role +from sagemaker.session import ( + Session, + get_execution_role, + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, + TRANSFORM_JOB_KMS_KEY_ID_PATH, + TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, + PIPELINE_ROLE_ARN_PATH, +) from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.functions import Join @@ -103,13 +110,11 @@ def __init__( self.env = env self.output_path = output_path - self.output_kms_key = output_kms_key self.accept = accept self.assemble_with = assemble_with self.instance_count = instance_count self.instance_type = instance_type - self.volume_kms_key = volume_kms_key self.max_concurrent_transforms = max_concurrent_transforms self.max_payload = max_payload @@ -121,6 +126,24 @@ def __init__( self._reset_output_path = False self.sagemaker_session = sagemaker_session or Session() + self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( + TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + ) + self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, default_value=output_kms_key + ) + + def _update_batch_capture_config(self, batch_capture_config: BatchDataCaptureConfig): + """Utility method that updates BatchDataCaptureConfig with values from SageMakerConfig. + + Args: + batch_capture_config: The BatchDataCaptureConfig object. + + """ + if batch_capture_config: + batch_capture_config.kms_key_id = self.sagemaker_session.get_sagemaker_config_override( + TRANSFORM_JOB_KMS_KEY_ID_PATH, default_value=batch_capture_config.kms_key_id + ) @runnable_by_pipeline def transform( @@ -256,6 +279,7 @@ def transform( self._reset_output_path = True experiment_config = check_and_get_run_experiment_config(experiment_config) + self._update_batch_capture_config(batch_data_capture_config) self.latest_transform_job = _TransformJob.start_new( self, data, @@ -377,7 +401,7 @@ def transform_with_monitoring( transformer = copy.deepcopy(self) transformer.sagemaker_session = PipelineSession() self.sagemaker_session = sagemaker_session - + self._update_batch_capture_config(batch_data_capture_config) transform_step_args = transformer.transform( data=data, data_type=data_type, @@ -416,7 +440,14 @@ def transform_with_monitoring( steps=[monitoring_batch_step], sagemaker_session=transformer.sagemaker_session, ) - pipeline.upsert(role_arn=role if role else get_execution_role()) + pipeline_role_arn = ( + role + if role + else transformer.sagemaker_session.get_sagemaker_config_override( + PIPELINE_ROLE_ARN_PATH, default_value=get_execution_role() + ) + ) + pipeline.upsert(pipeline_role_arn) execution = pipeline.start() if wait: logging.info("Waiting for transform with monitoring to execute ...") diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 9977de0add..762fbf0d3e 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -27,7 +27,7 @@ import abc import uuid from datetime import datetime -from typing import Optional +from typing import List, Optional from importlib import import_module import botocore @@ -176,7 +176,7 @@ def get_config_value(key_path, config): return current_section -def get_nested_value(dictionary: dict, nested_keys: list[str]): +def get_nested_value(dictionary: dict, nested_keys: List[str]): """Returns a nested value from the given dictionary, and None if none present. Raises @@ -209,8 +209,10 @@ def get_nested_value(dictionary: dict, nested_keys: list[str]): return None -def set_nested_value(dictionary: dict, nested_keys: list[str], value_to_set: object): - """Sets a nested value inside the given dictionary and returns the new dictionary. Note: if +def set_nested_value(dictionary: dict, nested_keys: List[str], value_to_set: object): + """Sets a nested value in a dictionary. + + This sets a nested value inside the given dictionary and returns the new dictionary. Note: if provided an unintended list of nested keys, this can overwrite an unexpected part of the dict. Recommended to use after a check with get_nested_value first """ diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 82d89e8cc0..6cef6dbff2 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -25,12 +25,7 @@ from sagemaker import s3 from sagemaker._studio import _append_project_tags -from sagemaker.config.config_schema import ( - SAGEMAKER, - PIPELINE, - TAGS, -) -from sagemaker.session import Session +from sagemaker.session import Session, PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH from sagemaker.utils import retry_with_backoff from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep @@ -113,7 +108,7 @@ def to_request(self) -> RequestType: def create( self, - role_arn: str, + role_arn: str = None, description: str = None, tags: List[Dict[str, str]] = None, parallelism_config: ParallelismConfiguration = None, @@ -132,14 +127,21 @@ def create( Returns: A response dict from the service. """ + role_arn = self.sagemaker_session.get_sagemaker_config_override( + PIPELINE_ROLE_ARN_PATH, role_arn + ) + if not role_arn: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating Pipeline.") if self.sagemaker_session.local_mode: if parallelism_config: logger.warning("Pipeline parallelism config is not supported in the local mode.") return self.sagemaker_session.sagemaker_client.create_pipeline(self, description) tags = _append_project_tags(tags) - tags = self.sagemaker_session._append_sagemaker_config_tags( - tags, "{}.{}.{}".format(SAGEMAKER, PIPELINE, TAGS) - ) + tags = self.sagemaker_session._append_sagemaker_config_tags(tags, PIPELINE_TAGS_PATH) kwargs = self._create_args(role_arn, description, parallelism_config) update_args( kwargs, @@ -203,7 +205,7 @@ def describe(self) -> Dict[str, Any]: def update( self, - role_arn: str, + role_arn: str = None, description: str = None, parallelism_config: ParallelismConfiguration = None, ) -> Dict[str, Any]: @@ -219,6 +221,15 @@ def update( Returns: A response dict from the service. """ + role_arn = self.sagemaker_session.get_sagemaker_config_override( + PIPELINE_ROLE_ARN_PATH, role_arn + ) + if not role_arn: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for updating Pipeline.") if self.sagemaker_session.local_mode: if parallelism_config: logger.warning("Pipeline parallelism config is not supported in the local mode.") @@ -232,7 +243,7 @@ def update( def upsert( self, - role_arn: str, + role_arn: str = None, description: str = None, tags: List[Dict[str, str]] = None, parallelism_config: ParallelismConfiguration = None, @@ -250,6 +261,16 @@ def upsert( Returns: response dict from service """ + role_arn = self.sagemaker_session.get_sagemaker_config_override( + PIPELINE_ROLE_ARN_PATH, role_arn + ) + if not role_arn: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("IAM role should be provided for creating/updating Pipeline.") + exists = True try: response = self.create(role_arn, description, tags, parallelism_config) except ClientError as ce: diff --git a/src/sagemaker/wrangler/processing.py b/src/sagemaker/wrangler/processing.py index b30feb968d..fe38b670a0 100644 --- a/src/sagemaker/wrangler/processing.py +++ b/src/sagemaker/wrangler/processing.py @@ -30,10 +30,10 @@ class DataWranglerProcessor(Processor): def __init__( self, - role: str, - data_wrangler_flow_source: str, - instance_count: int, - instance_type: str, + role: str = None, + data_wrangler_flow_source: str = None, + instance_count: int = None, + instance_type: str = None, volume_size_in_gb: int = 30, volume_kms_key: str = None, output_kms_key: str = None, diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 1916dc0d77..e2be025b62 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -77,9 +77,9 @@ class XGBoostModel(FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: str, - entry_point: str, - framework_version: str, + role: str = None, + entry_point: str = None, + framework_version: str = None, image_uri: Optional[Union[str, PipelineVariable]] = None, py_version: str = "py3", predictor_cls: callable = XGBoostPredictor, diff --git a/src/sagemaker/xgboost/processing.py b/src/sagemaker/xgboost/processing.py index 41b557b731..d840bfd960 100644 --- a/src/sagemaker/xgboost/processing.py +++ b/src/sagemaker/xgboost/processing.py @@ -34,9 +34,9 @@ class XGBoostProcessor(FrameworkProcessor): def __init__( self, framework_version: str, # New arg - role: str, - instance_count: Union[int, PipelineVariable], - instance_type: Union[str, PipelineVariable], + role: str = None, + instance_count: Union[int, PipelineVariable] = None, + instance_type: Union[str, PipelineVariable] = None, py_version: str = "py3", # New kwarg image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index b0ab426ded..d1416299f9 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -67,6 +67,10 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.sagemaker_client = LocalSagemakerClient(self) self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True + self.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) class LocalPipelineNoS3Session(LocalPipelineSession): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 21fe49cc97..916a8645f2 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -57,10 +57,15 @@ def boto_session(client): def sagemaker_session(boto_session, client): # ideally this would mock Session instead of instantiating it # most unit tests do mock the session correctly - return sagemaker.session.Session( + session = sagemaker.session.Session( boto_session=boto_session, sagemaker_client=client, sagemaker_runtime_client=client, default_bucket=_DEFAULT_BUCKET, sagemaker_metrics_client=client, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return session diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index 0e6901a9db..22c9f88249 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -17,6 +17,18 @@ import pytest from mock import Mock, patch from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator, PipelineModel +from sagemaker.config import ( + RESOURCE_CONFIG, + AUTO_ML, + AUTO_ML_JOB_CONFIG, + SECURITY_CONFIG, + VOLUME_KMS_KEY_ID, + KMS_KEY_ID, + OUTPUT_DATA_CONFIG, + ROLE_ARN, + SAGEMAKER, +) + from sagemaker.predictor import Predictor from sagemaker.session_settings import SessionSettings from sagemaker.workflow.functions import Join @@ -271,7 +283,6 @@ def sagemaker_session(): ) sms.list_candidates = Mock(name="list_candidates", return_value={"Candidates": []}) sms.sagemaker_client.list_tags = Mock(name="list_tags", return_value=LIST_TAGS_RESULT) - # For the purposes of unit tests, no values should be fetched from sagemaker config sms.resolve_value_from_config = Mock( name="resolve_value_from_config", @@ -279,7 +290,10 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms @@ -294,6 +308,48 @@ def candidate_mock(sagemaker_session): return candidate +def test_auto_ml_without_role_parameter(sagemaker_session): + with pytest.raises(ValueError): + AutoML( + target_attribute_name=TARGET_ATTRIBUTE_NAME, + sagemaker_session=sagemaker_session, + ) + + +def _config_override_mock(key, default_value=None): + kms_key_id_path = "{}.{}.{}.{}".format(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG, KMS_KEY_ID) + volume_kms_key_id_path = "{}.{}.{}.{}.{}".format( + SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID + ) + role_arn_path = "{}.{}.{}".format(SAGEMAKER, AUTO_ML, ROLE_ARN) + vpc_config_path = "{}.{}.{}.{}.{}".format( + SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, "VpcConfig" + ) + if key == role_arn_path: + return "ConfigRoleArn" + elif key == kms_key_id_path: + return "ConfigKmsKeyId" + elif key == volume_kms_key_id_path: + return "ConfigVolumeKmsKeyId" + elif key == vpc_config_path: + return {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} + return default_value + + +def test_framework_initialization_with_defaults(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", side_effect=_config_override_mock + ) + auto_ml = AutoML( + target_attribute_name=TARGET_ATTRIBUTE_NAME, + sagemaker_session=sagemaker_session, + ) + assert auto_ml.role == "ConfigRoleArn" + assert auto_ml.output_kms_key == "ConfigKmsKeyId" + assert auto_ml.volume_kms_key == "ConfigVolumeKmsKeyId" + assert auto_ml.vpc_config == {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} + + def test_auto_ml_default_channel_name(sagemaker_session): auto_ml = AutoML( role=ROLE, @@ -800,6 +856,47 @@ def test_deploy_optional_args(candidate_estimator, sagemaker_session, candidate_ ) +def _config_override_mock_for_candidate_estimator(key, default_value=None): + from sagemaker.config import TRAINING_JOB + + vpc_config_path = "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, "VpcConfig") + volume_kms_key_id_path = "{}.{}.{}.{}".format( + SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG, VOLUME_KMS_KEY_ID + ) + if key == volume_kms_key_id_path: + return "ConfigVolumeKmsKeyId" + elif key == vpc_config_path: + return {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} + return default_value + + +def test_candidate_estimator_fit_initialization_with_defaults(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_config_override_mock_for_candidate_estimator, + ) + desc_training_job_response = TRAINING_JOB + del desc_training_job_response["VpcConfig"] + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=desc_training_job_response + ) + candidate_estimator = CandidateEstimator(CANDIDATE_DICT, sagemaker_session=sagemaker_session) + candidate_estimator._check_all_job_finished = Mock( + name="_check_all_job_finished", return_value=True + ) + inputs = DEFAULT_S3_INPUT_DATA + candidate_estimator.fit(inputs) + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert "train" in sagemaker_call_names + index_of_train = sagemaker_call_names.index("train") + actual_train_args = sagemaker_session.method_calls[index_of_train][2] + assert actual_train_args["vpc_config"] == { + "SecurityGroupIds": ["sg-config"], + "Subnets": ["subnet-config"], + } + assert actual_train_args["resource_config"]["VolumeKmsKeyId"] == "ConfigVolumeKmsKeyId" + + def test_candidate_estimator_get_steps(sagemaker_session): candidate_estimator = CandidateEstimator(CANDIDATE_DICT, sagemaker_session=sagemaker_session) steps = candidate_estimator.get_steps() diff --git a/tests/unit/sagemaker/config/test_config_schema.py b/tests/unit/sagemaker/config/test_config_schema.py index 745ac770d4..b5b0c3bd89 100644 --- a/tests/unit/sagemaker/config/test_config_schema.py +++ b/tests/unit/sagemaker/config/test_config_schema.py @@ -38,7 +38,7 @@ def test_invalid_schema_version(): def test_valid_config_with_all_the_features( - base_config_with_schema, valid_config_with_all_the_scopes + base_config_with_schema, valid_config_with_all_the_scopes ): _validate_config(base_config_with_schema, valid_config_with_all_the_scopes) diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py index f8ab2d2f8d..9ea083cd61 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -32,6 +32,11 @@ IngestionError, ) from sagemaker.feature_store.inputs import FeatureParameter +from sagemaker.session import ( + FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, + FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, + FEATURE_GROUP_ROLE_ARN_PATH, +) class PicklableMock(Mock): @@ -51,7 +56,12 @@ def s3_uri(): @pytest.fixture def sagemaker_session_mock(): - return Mock() + session_mock = Mock() + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return session_mock @pytest.fixture @@ -87,6 +97,63 @@ def create_table_ddl(): ) +def test_feature_group_create_without_role( + sagemaker_session_mock, feature_group_dummy_definitions, s3_uri +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + with pytest.raises(ValueError): + feature_group.create( + s3_uri=s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature2", + enable_online_store=True, + ) + + +def _config_override_mock(key, default_value=None): + if key == FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH: + return "OnlineConfigKmsKeyId" + elif key == FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH: + return "OfflineConfigKmsKeyId" + elif key == FEATURE_GROUP_ROLE_ARN_PATH: + return "ConfigRoleArn" + return default_value + + +def test_feature_store_create_with_config_injection( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri +): + sagemaker_session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", side_effect=_config_override_mock + ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature2", + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn="ConfigRoleArn", + description=None, + tags=None, + online_store_config={ + "EnableOnlineStore": True, + "SecurityConfig": {"KmsKeyId": "OnlineConfigKmsKeyId"}, + }, + offline_store_config={ + "DisableGlueTableCreation": False, + "S3StorageConfig": {"S3Uri": s3_uri, "KmsKeyId": "OfflineConfigKmsKeyId"}, + }, + ) + + def test_feature_store_create( sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri ): diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 03316250f4..867f79cd22 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -75,7 +75,6 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config session.resolve_value_from_config = Mock( name="resolve_value_from_config", @@ -83,6 +82,11 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) + + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index 869139b573..6d420214ea 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -50,12 +50,15 @@ def sagemaker_session(): session_mock.upload_data = Mock(name="upload_data", return_value=MOCKED_S3_URI) session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = ROLE - # For the purposes of unit tests, no values should be fetched from sagemaker config session_mock.resolve_class_attribute_from_config = Mock( name="resolve_class_attribute_from_config", side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, ) + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session_mock diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py index 5ff050d0db..f119e52b41 100644 --- a/tests/unit/sagemaker/local/test_local_pipeline.py +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -135,16 +135,26 @@ def boto_session(client): @pytest.fixture def pipeline_session(boto_session, client): - return PipelineSession( + pipeline_session_mock = PipelineSession( boto_session=boto_session, sagemaker_client=client, default_bucket=BUCKET, ) + pipeline_session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return pipeline_session_mock @pytest.fixture() def local_sagemaker_session(boto_session): - return LocalSession(boto_session=boto_session, default_bucket="my-bucket") + local_session_mock = LocalSession(boto_session=boto_session, default_bucket="my-bucket") + local_session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return local_session_mock @pytest.fixture diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 1fd9a1fad9..a31a431526 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -67,7 +67,12 @@ @pytest.fixture def sagemaker_session(): - return Mock() + session = Mock() + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return session @patch("sagemaker.production_variant") @@ -436,6 +441,14 @@ def test_deploy_wrong_serverless_config(sagemaker_session): @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_deploy_creates_correct_session(local_session, session): + local_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) # We expect a LocalSession when deploying to instance_type = 'local' model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) diff --git a/tests/unit/sagemaker/model/test_edge.py b/tests/unit/sagemaker/model/test_edge.py index 98bdfefacb..627cd28086 100644 --- a/tests/unit/sagemaker/model/test_edge.py +++ b/tests/unit/sagemaker/model/test_edge.py @@ -30,7 +30,12 @@ @pytest.fixture def sagemaker_session(): - return Mock(boto_region_name=REGION) + session = Mock(boto_region_name=REGION) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return session def _create_model(sagemaker_session=None): diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index 73ff09ef07..9d01d9fead 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -94,6 +94,10 @@ def sagemaker_session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 2274a5feb6..ff7408d6ca 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -109,7 +109,10 @@ def sagemaker_session(): s3_resource=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 161940874f..c0cdff6483 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -59,7 +59,10 @@ def sagemaker_session(): session.sagemaker_client.describe_model_package = Mock( return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE ) - + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index 82a7b40afd..b21957a205 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -34,7 +34,12 @@ @pytest.fixture def sagemaker_session(): - return Mock(boto_region_name=REGION) + session = Mock(boto_region_name=REGION) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return session def _create_model(sagemaker_session=None): diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 7409d14430..75a414b35c 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -413,7 +413,6 @@ def sagemaker_session(sagemaker_client): ) session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = ROLE_ARN - session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) @@ -433,7 +432,10 @@ def sagemaker_session(sagemaker_client): if direct_input is not None else default_value, ) - + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session_mock diff --git a/tests/unit/sagemaker/monitor/test_data_capture_config.py b/tests/unit/sagemaker/monitor/test_data_capture_config.py index 88fb0a1db2..5717aa4ad8 100644 --- a/tests/unit/sagemaker/monitor/test_data_capture_config.py +++ b/tests/unit/sagemaker/monitor/test_data_capture_config.py @@ -56,6 +56,10 @@ def test_init_when_non_defaults_provided(): def test_init_when_optionals_not_provided(): sagemaker_session = Mock() sagemaker_session.default_bucket.return_value = DEFAULT_BUCKET_NAME + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) data_capture_config = DataCaptureConfig( enable_capture=DEFAULT_ENABLE_CAPTURE, sagemaker_session=sagemaker_session diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index b627996b8e..383edb05c9 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -480,7 +480,10 @@ def sagemaker_session(): session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) - + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session_mock diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 634f2d102f..43b978ea37 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -98,7 +98,10 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py index 4676c85df9..ceab141c0c 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py @@ -42,6 +42,10 @@ def sagemaker_session(): describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_init.py b/tests/unit/sagemaker/tensorflow/test_estimator_init.py index 71764b0e93..dd7d974e0a 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_init.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_init.py @@ -25,7 +25,12 @@ @pytest.fixture() def sagemaker_session(): - return Mock(name="sagemaker_session", boto_region_name=REGION) + session_mock = Mock(name="sagemaker_session", boto_region_name=REGION) + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return session_mock def _build_tf(sagemaker_session, **kwargs): diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index 67b69efc44..2ce5e71681 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -67,6 +67,10 @@ def sagemaker_session(): session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 40e3241333..002489bb27 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -89,6 +89,10 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 3e3aeb7e86..a2a8708f2c 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -79,7 +79,6 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config session.resolve_value_from_config = Mock( name="resolve_value_from_config", @@ -87,6 +86,10 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 643dc6337c..663917596c 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -88,6 +88,10 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 5f67a4c46b..1719223652 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -93,6 +93,10 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py index 0a5cb0cf21..972c6e5d89 100644 --- a/tests/unit/sagemaker/workflow/conftest.py +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -59,8 +59,13 @@ def mock_boto_session(client): @pytest.fixture(scope="module") def pipeline_session(mock_boto_session, mock_client): - return PipelineSession( + pipeline_session = PipelineSession( boto_session=mock_boto_session, sagemaker_client=mock_client, default_bucket=BUCKET, ) + pipeline_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return pipeline_session diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index e53efbb30b..b8010ef0ad 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -53,6 +53,10 @@ def sagemaker_session(): if direct_input is not None else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 813a945cf4..50cda15e40 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -46,9 +46,58 @@ def sagemaker_session_mock(): session_mock = Mock() session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket") session_mock.local_mode = False + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + session_mock._append_sagemaker_config_tags = Mock( + name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags + ) return session_mock +def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock): + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[], + sagemaker_session=sagemaker_session_mock, + ) + with pytest.raises(ValueError): + pipeline.create() + with pytest.raises(ValueError): + pipeline.update() + with pytest.raises(ValueError): + pipeline.upsert() + + +def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock): + sagemaker_session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", side_effect=lambda a, b: "ConfigRoleArn" + ) + sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = { + "PipelineArn": "pipeline-arn" + } + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[], + sagemaker_session=sagemaker_session_mock, + ) + pipeline.create() + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with( + PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn" + ) + pipeline.update() + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( + PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn" + ) + pipeline.upsert() + assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn" + ) + + def test_pipeline_create(sagemaker_session_mock, role_arn): pipeline = Pipeline( name="MyPipeline", @@ -57,7 +106,7 @@ def test_pipeline_create(sagemaker_session_mock, role_arn): sagemaker_session=sagemaker_session_mock, ) pipeline.create(role_arn=role_arn) - assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) diff --git a/tests/unit/sagemaker/wrangler/test_processing.py b/tests/unit/sagemaker/wrangler/test_processing.py index cce220ff21..48b6e17f1d 100644 --- a/tests/unit/sagemaker/wrangler/test_processing.py +++ b/tests/unit/sagemaker/wrangler/test_processing.py @@ -46,6 +46,10 @@ def sagemaker_session(): name="resolve_class_attribute_from_config", side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, ) + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session_mock diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index d71268a7c7..3bc247a307 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -954,6 +954,10 @@ def test_algorithm_no_required_hyperparameters(session): def test_algorithm_attach_from_hyperparameter_tuning(): session = Mock() + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) job_name = "training-job-that-is-part-of-a-tuning-job" algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees" role_arn = "arn:aws:iam::123412341234:role/SageMakerRole" diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 18a576d44f..54e38fda3c 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -74,6 +74,10 @@ def sagemaker_session(): sms.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=returned_job_description ) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 00340d1c8b..98824ed5c3 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -82,6 +82,10 @@ def sagemaker_session(): else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index f5912b17de..95e52b18e3 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -244,6 +244,10 @@ def sagemaker_session(): else default_value, ) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms @@ -345,6 +349,77 @@ def test_framework_all_init_args(sagemaker_session): } +def test_framework_without_role_parameter(sagemaker_session): + with pytest.raises(ValueError): + DummyFramework( + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + + +def test_default_value_of_enable_network_isolation(sagemaker_session): + framework = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + assert framework.enable_network_isolation() is False + + +def _config_override_mock(key, default_value=None): + from sagemaker.session import ( + TRAINING_JOB_SUBNETS_PATH, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + TRAINING_JOB_KMS_KEY_ID_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + ) + + if key == TRAINING_JOB_ROLE_ARN_PATH: + return "ConfigRoleArn" + elif key == TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH: + return True + elif key == TRAINING_JOB_KMS_KEY_ID_PATH: + return "ConfigKmsKeyId" + elif key == TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH: + return "ConfigVolumeKmsKeyId" + elif key == TRAINING_JOB_SECURITY_GROUP_IDS_PATH: + return ["sg-config"] + elif key == TRAINING_JOB_SUBNETS_PATH: + return ["subnet-config"] + return default_value + + +def test_framework_initialization_with_defaults(sagemaker_session): + + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", side_effect=_config_override_mock + ) + framework = DummyFramework( + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + assert framework.role == "ConfigRoleArn" + assert framework.enable_network_isolation() + assert framework.output_kms_key == "ConfigKmsKeyId" + assert framework.volume_kms_key == "ConfigVolumeKmsKeyId" + assert framework.security_group_ids == ["sg-config"] + assert framework.subnets == ["subnet-config"] + + def test_framework_with_heterogeneous_cluster(sagemaker_session): f = DummyFramework( entry_point=SCRIPT_PATH, @@ -3886,6 +3961,10 @@ def test_estimator_local_mode_error(sagemaker_session): def test_estimator_local_mode_ok(sagemaker_local_session): + sagemaker_local_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) # When using instance local with a session which is not LocalSession we should error out Estimator( image_uri="some-image", diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index ceefeb9b3e..5e5e08bbe8 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -68,6 +68,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index 478d33fc06..e44cf68546 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -65,7 +65,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index deff0a3cdb..f3a451f22e 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -80,7 +80,10 @@ def sagemaker_session(): name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None ) mock_session.expand_role = Mock(name="expand_role", return_value=ROLE) - + mock_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return mock_session diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index 790ee73576..de0e517d63 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -62,7 +62,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 02dc073d9d..7880716edd 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -68,7 +68,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index f2574b30b5..1dc0001ec7 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -57,7 +57,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index bb2a140200..83126166ae 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -63,7 +63,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_multidatamodel.py b/tests/unit/test_multidatamodel.py index 3ea7aefa77..58164be338 100644 --- a/tests/unit/test_multidatamodel.py +++ b/tests/unit/test_multidatamodel.py @@ -80,6 +80,10 @@ def sagemaker_session(): name="upload_data", return_value=os.path.join(VALID_MULTI_MODEL_DATA_PREFIX, "mleap_model.tar.gz"), ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) s3_mock = Mock() boto_mock.client("s3").return_value = s3_mock diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 19dfbbfd28..6d4e8e1430 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -109,6 +109,10 @@ def sagemaker_session(): else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index cbe9f18e36..570ef16b56 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -62,7 +62,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index e6aaf770fa..98ec6f29e1 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -70,7 +70,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 222021caa3..dfba755f18 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -62,7 +62,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 913ffbc556..53f08ba893 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -18,6 +18,12 @@ from sagemaker.model import FrameworkModel from sagemaker.pipeline import PipelineModel from sagemaker.predictor import Predictor +from sagemaker.session import ( + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + MODEL_EXECUTION_ROLE_ARN_PATH, + MODEL_VPC_CONFIG_PATH, +) from sagemaker.session_settings import SessionSettings from sagemaker.sparkml import SparkMLModel @@ -69,6 +75,10 @@ def sagemaker_session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms @@ -289,6 +299,61 @@ def test_deploy_tags(tfo, time, sagemaker_session): ) +def test_pipeline_model_without_role(sagemaker_session): + with pytest.raises(ValueError): + PipelineModel([], sagemaker_session=sagemaker_session) + + +def _config_override_mock(key, default_value=None): + if key == ENDPOINT_CONFIG_KMS_KEY_ID_PATH: + return "ConfigKmsKeyId" + elif key == MODEL_ENABLE_NETWORK_ISOLATION_PATH: + return True + elif key == MODEL_EXECUTION_ROLE_ARN_PATH: + return "ConfigRoleArn" + elif key == MODEL_VPC_CONFIG_PATH: + return {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} + return default_value + + +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) +def test_pipeline_model_with_config_injection(tfo, time, sagemaker_session): + framework_model = DummyFrameworkModel(sagemaker_session) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session + ) + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", side_effect=_config_override_mock + ) + pipeline_model = PipelineModel( + [framework_model, sparkml_model], sagemaker_session=sagemaker_session + ) + assert pipeline_model.role == "ConfigRoleArn" + assert pipeline_model.vpc_config == { + "SecurityGroupIds": ["sg-config"], + "Subnets": ["subnet-config"], + } + assert pipeline_model.enable_network_isolation + pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name="mi-1-2017-10-10-14-14-15", + production_variants=[ + { + "InitialVariantWeight": 1, + "ModelName": "mi-1-2017-10-10-14-14-15", + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], + tags=None, + kms_key="ConfigKmsKeyId", + wait=True, + data_capture_config_dict=None, + ) + + def test_delete_model_without_deploy(sagemaker_session): pipeline_model = PipelineModel([], role=ROLE, sagemaker_session=sagemaker_session) diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index ee53628ef4..b86a5f0e67 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -51,6 +51,10 @@ def empty_sagemaker_session(): ims.sagemaker_runtime_client.invoke_endpoint = Mock( name="invoke_endpoint", return_value={"Body": response_body} ) + ims.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return ims diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index e4d2ee829f..45c45ef000 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -52,7 +52,10 @@ def empty_sagemaker_session(): "OutputLocation": ASYNC_OUTPUT_LOCATION, }, ) - + ims.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) response_body = Mock("body") response_body.read = Mock("read", return_value=RETURN_VALUE) response_body.close = Mock("close", return_value=None) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 0f156d8a0f..4758ebd3f9 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -30,6 +30,14 @@ ScriptProcessor, ProcessingJob, ) +from sagemaker.session import ( + PROCESSING_JOB_SUBNETS_PATH, + PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, +) from sagemaker.session_settings import SessionSettings from sagemaker.spark.processing import PySparkProcessor from sagemaker.sklearn.processing import SKLearnProcessor @@ -87,6 +95,10 @@ def sagemaker_session(): side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, ) + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session_mock @@ -109,6 +121,10 @@ def pipeline_session(): session_mock.describe_processing_job = MagicMock( name="describe_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) + session_mock.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) session_mock.__class__ = PipelineSession # For the purposes of unit tests, no values should be fetched from sagemaker config @@ -621,6 +637,95 @@ def test_script_processor_with_required_parameters(exists_mock, isfile_mock, sag sagemaker_session.process.assert_called_with(**expected_args) +def _config_override_mock(key, default_value=None): + if key == PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH: + return True + if key == PROCESSING_JOB_ROLE_ARN_PATH: + return "arn:aws:iam::012345678901:role/ConfigRoleArn" + elif key == PROCESSING_JOB_KMS_KEY_ID_PATH: + return "ConfigKmsKeyId" + elif key == PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH: + return "ConfigVolumeKmsKeyId" + elif key == PROCESSING_JOB_SECURITY_GROUP_IDS_PATH: + return ["sg-config"] + elif key == PROCESSING_JOB_SUBNETS_PATH: + return ["subnet-config"] + return default_value + + +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +def test_script_processor_without_role(exists_mock, isfile_mock, sagemaker_session): + with pytest.raises(ValueError): + ScriptProcessor( + image_uri=CUSTOM_IMAGE_URI, + command=["python3"], + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=100, + volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", + output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", + max_runtime_in_seconds=3600, + base_job_name="my_sklearn_processor", + env={"my_env_variable": "my_env_variable_value"}, + tags=[{"Key": "my-tag", "Value": "my-tag-value"}], + network_config=NetworkConfig( + subnets=["my_subnet_id"], + security_group_ids=["my_security_group_id"], + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + ), + sagemaker_session=sagemaker_session, + ) + + +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +def test_script_processor_with_some_parameters_from_config( + exists_mock, isfile_mock, sagemaker_session +): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", side_effect=_config_override_mock + ) + sagemaker_session.expand_role = Mock(name="expand_role", side_effect=lambda a: a) + processor = ScriptProcessor( + image_uri=CUSTOM_IMAGE_URI, + command=["python3"], + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=100, + max_runtime_in_seconds=3600, + base_job_name="my_sklearn_processor", + env={"my_env_variable": "my_env_variable_value"}, + tags=[{"Key": "my-tag", "Value": "my-tag-value"}], + network_config=NetworkConfig( + encrypt_inter_container_traffic=True, + ), + sagemaker_session=sagemaker_session, + ) + processor.run( + code="/local/path/to/processing_code.py", + inputs=_get_data_inputs_all_parameters(), + outputs=_get_data_outputs_all_parameters(), + arguments=["--drop-columns", "'SelfEmployed'"], + wait=True, + logs=False, + job_name="my_job_name", + experiment_config={"ExperimentName": "AnExperiment"}, + ) + expected_args = _get_expected_args_all_parameters(processor._current_job_name) + expected_args["resources"]["ClusterConfig"]["VolumeKmsKeyId"] = "ConfigVolumeKmsKeyId" + expected_args["output_config"]["KmsKeyId"] = "ConfigKmsKeyId" + expected_args["role_arn"] = "arn:aws:iam::012345678901:role/ConfigRoleArn" + expected_args["network_config"]["VpcConfig"] = { + "SecurityGroupIds": ["sg-config"], + "Subnets": ["subnet-config"], + } + + sagemaker_session.process.assert_called_with(**expected_args) + assert "my_job_name" in processor._current_job_name + + @patch("os.path.exists", return_value=True) @patch("os.path.isfile", return_value=True) def test_script_processor_with_all_parameters(exists_mock, isfile_mock, sagemaker_session): diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index fd085d6590..4fa9bc146a 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -91,6 +91,10 @@ def fixture_sagemaker_session(): else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index 3e3bca00dc..65ee77399a 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -62,6 +62,10 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 50de3a47c4..1b73714ed1 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -84,6 +84,10 @@ def fixture_sagemaker_session(): else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fe32ed5486..aaaff31343 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -249,6 +249,138 @@ def test_process(boto_session): session.sagemaker_client.create_processing_job.assert_called_with(**expected_request) +def _sagemaker_config_override_mock_for_process(key, default_value=None): + from sagemaker.session import ( + PROCESSING_JOB_ROLE_ARN_PATH, + PROCESSING_JOB_NETWORK_CONFIG_PATH, + PROCESSING_OUTPUT_CONFIG_PATH, + PROCESSING_JOB_PROCESSING_RESOURCES_PATH, + PROCESSING_JOB_INPUTS_PATH, + ) + + if key is PROCESSING_JOB_ROLE_ARN_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is PROCESSING_JOB_NETWORK_CONFIG_PATH: + return { + "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + "EnableNetworkIsolation": True, + } + elif key is PROCESSING_OUTPUT_CONFIG_PATH: + return {"KmsKeyId": "testKmsKeyId"} + elif key is PROCESSING_JOB_PROCESSING_RESOURCES_PATH: + return {"ClusterConfig": {"VolumeKmsKeyId": "testVolumeKmsKeyId"}} + elif key is PROCESSING_JOB_INPUTS_PATH: + return [ + { + "DatasetDefinition": { + "AthenaDatasetDefinition": {"KmsKeyId": "AthenaKmsKeyId"}, + "RedshiftDatasetDefinition": { + "KmsKeyId": "RedshiftKmsKeyId", + "ClusterRoleArn": "clusterrole", + }, + } + } + ] + + return default_value + + +def test_create_process_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_process, + ) + processing_inputs = [ + { + "InputName": "input-1", + "S3Input": { + "S3Uri": "mocked_s3_uri_from_upload_data", + "LocalPath": "/container/path/", + "S3DataType": "Archive", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, + } + ] + output_config = { + "Outputs": [ + { + "OutputName": "output-1", + "S3Output": { + "S3Uri": "s3://mybucket/current_job_name/output", + "LocalPath": "/data/output", + "S3UploadMode": "Continuous", + }, + }, + { + "OutputName": "my_output", + "S3Output": { + "S3Uri": "s3://uri/", + "LocalPath": "/container/path/", + "S3UploadMode": "Continuous", + }, + }, + ], + } + job_name = ("current_job_name",) + resource_config = { + "ClusterConfig": { + "InstanceType": "ml.m4.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 100, + } + } + app_specification = { + "ImageUri": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3", + "ContainerArguments": ["--drop-columns", "'SelfEmployed'"], + "ContainerEntrypoint": ["python3", "/code/source/sklearn_transformer.py"], + } + + process_request_args = { + "inputs": processing_inputs, + "output_config": output_config, + "job_name": job_name, + "resources": resource_config, + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, + "app_specification": app_specification, + "environment": {"my_env_variable": 20}, + "tags": [{"Name": "my-tag", "Value": "my-tag-value"}], + "experiment_config": {"ExperimentName": "AnExperiment"}, + } + sagemaker_session.process(**process_request_args) + + expected_request = { + "ProcessingJobName": job_name, + "ProcessingResources": resource_config, + "AppSpecification": app_specification, + "RoleArn": "arn:aws:iam::111111111111:role/ConfigRole", + "ProcessingInputs": processing_inputs, + "ProcessingOutputConfig": output_config, + "Environment": {"my_env_variable": 20}, + "NetworkConfig": { + "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + "EnableNetworkIsolation": True, + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, + "Tags": [{"Name": "my-tag", "Value": "my-tag-value"}], + "ExperimentConfig": {"ExperimentName": "AnExperiment"}, + } + expected_request["ProcessingInputs"][0]["DatasetDefinition"] = { + "AthenaDatasetDefinition": {"KmsKeyId": "AthenaKmsKeyId"}, + "RedshiftDatasetDefinition": { + "KmsKeyId": "RedshiftKmsKeyId", + "ClusterRoleArn": "clusterrole", + }, + } + expected_request["ProcessingOutputConfig"]["KmsKeyId"] = "testKmsKeyId" + expected_request["ProcessingResources"]["ClusterConfig"][ + "VolumeKmsKeyId" + ] = "testVolumeKmsKeyId" + + sagemaker_session.sagemaker_client.create_processing_job.assert_called_with(**expected_request) + + def mock_exists(filepath_to_mock, exists_result): unmocked_exists = os.path.exists @@ -1426,6 +1558,118 @@ def test_stop_tuning_job_client_error(sagemaker_session): ) +def _sagemaker_config_override_mock_for_train(key, default_value=None): + from sagemaker.session import ( + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_VPC_CONFIG_PATH, + TRAINING_JOB_RESOURCE_CONFIG_PATH, + TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + ) + + if key is TRAINING_JOB_ROLE_ARN_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is TRAINING_JOB_VPC_CONFIG_PATH: + return {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]} + elif key is TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH: + return {"KmsKeyId": "TestKms"} + elif key is TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH: + return True + elif key is TRAINING_JOB_RESOURCE_CONFIG_PATH: + return {"VolumeKmsKeyId": "volumekey"} + return default_value + + +def test_train_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_train, + ) + in_config = [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": S3_INPUT_URI, + } + }, + } + ] + + out_config = {"S3OutputPath": S3_OUTPUT} + + resource_config = { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": MAX_SIZE, + } + + stop_cond = {"MaxRuntimeInSeconds": MAX_TIME} + RETRY_STRATEGY = {"MaximumRetryAttempts": 2} + hyperparameters = {"foo": "bar"} + TRAINING_IMAGE_CONFIG = { + "TrainingRepositoryAccessMode": "Vpc", + "TrainingRepositoryAuthConfig": { + "TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567897:function:test" + }, + } + + sagemaker_session.train( + image_uri=IMAGE, + input_mode="File", + input_config=in_config, + job_name=JOB_NAME, + output_config=out_config, + resource_config=resource_config, + hyperparameters=hyperparameters, + stop_condition=stop_cond, + tags=TAGS, + metric_definitions=METRIC_DEFINITONS, + encrypt_inter_container_traffic=True, + use_spot_instances=True, + checkpoint_s3_uri="s3://mybucket/checkpoints/", + checkpoint_local_path="/tmp/checkpoints", + enable_sagemaker_metrics=True, + environment=ENV_INPUT, + retry_strategy=RETRY_STRATEGY, + training_image_config=TRAINING_IMAGE_CONFIG, + ) + + _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] + + assert actual_train_args["VpcConfig"] == { + "Subnets": ["subnets-123"], + "SecurityGroupIds": ["sg-123"], + } + assert actual_train_args["HyperParameters"] == hyperparameters + assert actual_train_args["Tags"] == TAGS + assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS + assert actual_train_args["AlgorithmSpecification"]["EnableSageMakerMetricsTimeSeries"] is True + assert actual_train_args["EnableInterContainerTrafficEncryption"] is True + assert actual_train_args["EnableNetworkIsolation"] is True + assert actual_train_args["EnableManagedSpotTraining"] is True + assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/" + assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints" + assert actual_train_args["Environment"] == ENV_INPUT + assert actual_train_args["RetryStrategy"] == RETRY_STRATEGY + assert ( + actual_train_args["AlgorithmSpecification"]["TrainingImageConfig"] == TRAINING_IMAGE_CONFIG + ) + assert actual_train_args["RoleArn"] == "arn:aws:iam::111111111111:role/ConfigRole" + assert actual_train_args["ResourceConfig"] == { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": MAX_SIZE, + "VolumeKmsKeyId": "volumekey", + } + assert actual_train_args["OutputDataConfig"] == { + "S3OutputPath": S3_OUTPUT, + "KmsKeyId": "TestKms", + } + + def test_train_pack_to_request_with_optional_params(sagemaker_session): in_config = [ { @@ -1499,6 +1743,77 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): ) +def _sagemaker_config_override_mock_for_transform(key, default_value=None): + from sagemaker.session import ( + TRANSFORM_JOB_KMS_KEY_ID_PATH, + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + ) + + if key is TRANSFORM_JOB_KMS_KEY_ID_PATH: + return "jobKmsKeyId" + elif key is TRANSFORM_OUTPUT_KMS_KEY_ID_PATH: + return "outputKmsKeyId" + elif key is TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH: + return "volumeKmsKeyId" + return default_value + + +def test_create_transform_job_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_transform, + ) + + model_name = "my-model" + + in_config = { + "CompressionType": "None", + "ContentType": "text/csv", + "SplitType": "None", + "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}}, + } + + out_config = {"S3OutputPath": S3_OUTPUT} + + resource_config = {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE} + + data_processing = {"OutputFilter": "$", "InputFilter": "$", "JoinSource": "Input"} + + data_capture_config = BatchDataCaptureConfig(destination_s3_uri="s3://test") + expected_args = { + "TransformJobName": JOB_NAME, + "ModelName": model_name, + "TransformInput": in_config, + "TransformOutput": out_config, + "TransformResources": resource_config, + "DataProcessing": data_processing, + "DataCaptureConfig": data_capture_config._to_request_dict(), + } + expected_args["DataCaptureConfig"]["KmsKeyId"] = "jobKmsKeyId" + expected_args["TransformOutput"]["KmsKeyId"] = "outputKmsKeyId" + expected_args["TransformResources"]["VolumeKmsKeyId"] = "volumeKmsKeyId" + sagemaker_session.transform( + job_name=JOB_NAME, + model_name=model_name, + strategy=None, + max_concurrent_transforms=None, + max_payload=None, + env=None, + input_config=in_config, + output_config=out_config, + resource_config=resource_config, + experiment_config=None, + model_client_config=None, + tags=None, + data_processing=data_processing, + batch_data_capture_config=data_capture_config, + ) + + _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] + assert actual_args == expected_args + + def test_transform_pack_to_request(sagemaker_session): model_name = "my-model" @@ -1831,6 +2146,45 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_ } +def _sagemaker_config_override_mock_for_model(key, default_value=None): + from sagemaker.session import ( + MODEL_EXECUTION_ROLE_ARN_PATH, + MODEL_VPC_CONFIG_PATH, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + ) + + if key is MODEL_EXECUTION_ROLE_ARN_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is MODEL_VPC_CONFIG_PATH: + return {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]} + elif key is MODEL_ENABLE_NETWORK_ISOLATION_PATH: + return True + return default_value + + +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) +def test_create_model_with_configs(expand_container_def, sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_model, + ) + sagemaker_session.expand_role = Mock( + name="expand_role", side_effect=lambda role_name: role_name + ) + model = sagemaker_session.create_model( + MODEL_NAME, + container_defs=PRIMARY_CONTAINER, + ) + assert model == MODEL_NAME + sagemaker_session.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn="arn:aws:iam::111111111111:role/ConfigRole", + ModelName=MODEL_NAME, + PrimaryContainer=PRIMARY_CONTAINER, + VpcConfig={"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + EnableNetworkIsolation=True, + ) + + @patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model(expand_container_def, sagemaker_session): model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) @@ -1999,6 +2353,155 @@ def test_create_model_from_job_with_tags(sagemaker_session): ) +def _sagemaker_config_override_mock_for_edge_packaging(key, default_value=None): + from sagemaker.session import ( + EDGE_PACKAGING_ROLE_ARN_PATH, + EDGE_PACKAGING_OUTPUT_CONFIG_PATH, + ) + + if key is EDGE_PACKAGING_ROLE_ARN_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is EDGE_PACKAGING_OUTPUT_CONFIG_PATH: + return {"KmsKeyId": "configKmsKeyId"} + return default_value + + +def test_create_edge_packaging_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_edge_packaging, + ) + + output_config = {"S3OutputLocation": S3_OUTPUT} + + sagemaker_session.package_model_for_edge( + output_config, + ) + sagemaker_session.sagemaker_client.create_edge_packaging_job.assert_called_with( + RoleArn="arn:aws:iam::111111111111:role/ConfigRole", # provided from config + OutputConfig={ + "S3OutputLocation": S3_OUTPUT, # provided as param + "KmsKeyId": "configKmsKeyId", # fetched from config + }, + ModelName=None, + ModelVersion=None, + EdgePackagingJobName=None, + CompilationJobName=None, + ) + + +def _sagemaker_config_override_mock_for_monitoring_schedule(key, default_value=None): + from sagemaker.session import ( + MONITORING_JOB_ROLE_ARN_PATH, + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, + MONITORING_JOB_NETWORK_CONFIG_PATH, + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, + ) + + if key is MONITORING_JOB_ROLE_ARN_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is MONITORING_JOB_NETWORK_CONFIG_PATH: + return { + "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + "EnableNetworkIsolation": True, + } + elif key is MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH: + return "configKmsKeyId" + elif key is MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH: + return "configVolumeKmsKeyId" + return default_value + + +def test_create_monitoring_schedule_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_monitoring_schedule, + ) + + monitoring_output_config = {"MonitoringOutputs": [{"S3Output": {"S3Uri": S3_OUTPUT}}]} + + sagemaker_session.create_monitoring_schedule( + JOB_NAME, + schedule_expression=None, + statistics_s3_uri=None, + constraints_s3_uri=None, + monitoring_inputs=[], + monitoring_output_config=monitoring_output_config, + instance_count=1, + instance_type="ml.m4.xlarge", + volume_size_in_gb=4, + image_uri="someimageuri", + network_config={"VpcConfig": {"SecurityGroupIds": ["sg-asparam"]}}, + ) + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + MonitoringScheduleName=JOB_NAME, + MonitoringScheduleConfig={ + "MonitoringJobDefinition": { + "MonitoringInputs": [], + "MonitoringResources": { + "ClusterConfig": { + "InstanceCount": 1, # provided as param + "InstanceType": "ml.m4.xlarge", # provided as param + "VolumeSizeInGB": 4, # provided as param + "VolumeKmsKeyId": "configVolumeKmsKeyId", # Fetched from config + } + }, + "MonitoringAppSpecification": {"ImageUri": "someimageuri"}, # provided as param + "RoleArn": "arn:aws:iam::111111111111:role/ConfigRole", # Fetched from config + "MonitoringOutputConfig": { + "MonitoringOutputs": [ # provided as param + {"S3Output": {"S3Uri": "s3://sagemaker-123/output/jobname"}} + ], + "KmsKeyId": "configKmsKeyId", # fetched from config + }, + "NetworkConfig": { + "VpcConfig": { + "Subnets": ["subnets-123"], # fetched from config + "SecurityGroupIds": ["sg-asparam"], # provided as param + }, + "EnableNetworkIsolation": True, # fetched from config + }, + } + }, + ) + + +def _sagemaker_config_override_mock_for_compile(key, default_value=None): + from sagemaker.session import ( + COMPILATION_JOB_ROLE_ARN_PATH, + COMPILATION_JOB_OUTPUT_CONFIG_PATH, + COMPILATION_JOB_VPC_CONFIG_PATH, + ) + + if key is COMPILATION_JOB_ROLE_ARN_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is COMPILATION_JOB_VPC_CONFIG_PATH: + return {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]} + elif key is COMPILATION_JOB_OUTPUT_CONFIG_PATH: + return {"KmsKeyId": "TestKms"} + return default_value + + +def test_compile_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_compile, + ) + sagemaker_session.compile_model( + input_model_config={}, + output_model_config={"S3OutputLocation": "s3://test"}, + job_name="TestJob", + ) + sagemaker_session.sagemaker_client.create_compilation_job.assert_called_with( + InputConfig={}, + OutputConfig={"S3OutputLocation": "s3://test", "KmsKeyId": "TestKms"}, + RoleArn="arn:aws:iam::111111111111:role/ConfigRole", + StoppingCondition=None, + CompilationJobName="TestJob", + VpcConfig={"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + ) + + def test_create_model_from_job_with_image(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT @@ -2061,6 +2564,58 @@ def test_endpoint_from_production_variants(sagemaker_session): ) +def _sagemaker_config_override_mock_for_endpoint_config(key, default_value=None): + from sagemaker.session import ( + ENDPOINT_CONFIG_DATA_CAPTURE_PATH, + ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, + ) + + if key is ENDPOINT_CONFIG_KMS_KEY_ID_PATH: + return "testKmsKeyId" + elif key is ENDPOINT_CONFIG_DATA_CAPTURE_PATH: + return {"KmsKeyId": "testDataCaptureKmsKeyId"} + elif key is ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH: + return [{"CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}}] + return default_value + + +def test_create_enpoint_config_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_endpoint_config, + ) + data_capture_config_dict = {"DestinationS3Uri": "s3://test"} + + tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] + + sagemaker_session.create_endpoint_config( + "endpoint-test", + "simple-model", + 1, + "local", + tags=tags, + data_capture_config_dict=data_capture_config_dict, + ) + + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName="endpoint-test", + ProductionVariants=[ + { + "CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}, + "ModelName": "simple-model", + "VariantName": "AllTraffic", + "InitialVariantWeight": 1, + "InitialInstanceCount": 1, + "InstanceType": "local", + } + ], + DataCaptureConfig={"DestinationS3Uri": "s3://test", "KmsKeyId": "testDataCaptureKmsKeyId"}, + KmsKeyId="testKmsKeyId", + Tags=tags, + ) + + def test_create_endpoint_config_with_tags(sagemaker_session): tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] @@ -2180,7 +2735,7 @@ def test_endpoint_from_production_variants_with_async_config(sagemaker_session): sagemaker_session.endpoint_from_production_variants( "some-endpoint", pvs, - async_inference_config_dict=AsyncInferenceConfig, + async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(), ) sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=[] @@ -2188,7 +2743,7 @@ def test_endpoint_from_production_variants_with_async_config(sagemaker_session): sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName="some-endpoint", ProductionVariants=pvs, - AsyncInferenceConfig=AsyncInferenceConfig, + AsyncInferenceConfig=AsyncInferenceConfig()._to_request_dict(), ) @@ -2522,6 +3077,27 @@ def test_wait_until_fail_access_denied_after_5_mins(patched_sleep): } +def _sagemaker_config_override_mock_for_auto_ml(key, default_value=None): + from sagemaker.session import ( + AUTO_ML_OUTPUT_CONFIG_PATH, + AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_JOB_CONFIG_PATH, + ) + + if key is AUTO_ML_ROLE_ARN_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is AUTO_ML_JOB_CONFIG_PATH: + return { + "SecurityConfig": { + "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + "VolumeKmsKeyId": "TestKmsKeyId", + } + } + elif key is AUTO_ML_OUTPUT_CONFIG_PATH: + return {"KmsKeyId": "configKmsKeyId"} + return default_value + + def test_auto_ml_pack_to_request(sagemaker_session): input_config = [ { @@ -2544,11 +3120,56 @@ def test_auto_ml_pack_to_request(sagemaker_session): role = EXPANDED_ROLE sagemaker_session.auto_ml(input_config, output_config, auto_ml_job_config, role, job_name) + sagemaker_session.sagemaker_client.create_auto_ml_job.assert_called_with( + AutoMLJobName=DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS["AutoMLJobName"], + InputDataConfig=DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS["InputDataConfig"], + OutputDataConfig=DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS["OutputDataConfig"], + AutoMLJobConfig=DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS["AutoMLJobConfig"], + RoleArn=DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS["RoleArn"], + GenerateCandidateDefinitionsOnly=False, + ) - assert sagemaker_session.sagemaker_client.method_calls[0] == ( - "create_auto_ml_job", - (), - DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS, + +def test_create_auto_ml_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_auto_ml, + ) + input_config = [ + { + "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}}, + "TargetAttributeName": "y", + } + ] + + output_config = {"S3OutputPath": S3_OUTPUT} + + auto_ml_job_config = { + "CompletionCriteria": { + "MaxCandidates": 10, + "MaxAutoMLJobRuntimeInSeconds": 36000, + "MaxRuntimePerTrainingJobInSeconds": 3600 * 2, + } + } + + job_name = JOB_NAME + sagemaker_session.auto_ml(input_config, output_config, auto_ml_job_config, job_name=job_name) + expected_call_args = DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS.copy() + expected_call_args["OutputDataConfig"]["KmsKeyId"] = "configKmsKeyId" + expected_call_args["RoleArn"] = "arn:aws:iam::111111111111:role/ConfigRole" + expected_call_args["AutoMLJobConfig"]["SecurityConfig"] = {} + expected_call_args["AutoMLJobConfig"]["SecurityConfig"]["VpcConfig"] = { + "Subnets": ["subnets-123"], + "SecurityGroupIds": ["sg-123"], + } + expected_call_args["AutoMLJobConfig"]["SecurityConfig"]["VolumeKmsKeyId"] = "TestKmsKeyId" + sagemaker_session.sagemaker_client.create_auto_ml_job.assert_called_with( + AutoMLJobName=expected_call_args["AutoMLJobName"], + InputDataConfig=expected_call_args["InputDataConfig"], + OutputDataConfig=expected_call_args["OutputDataConfig"], + AutoMLJobConfig=expected_call_args["AutoMLJobConfig"], + RoleArn=expected_call_args["RoleArn"], + GenerateCandidateDefinitionsOnly=False, ) @@ -2749,6 +3370,121 @@ def test_create_model_package_from_containers_without_model_package_group_name( ) +def _sagemaker_config_override_mock_for_model_package(key, default_value=None): + from sagemaker.session import ( + MODEL_PACKAGE_VALIDATION_ROLE_PATH, + MODEL_PACKAGE_VALIDATION_PROFILES_PATH, + ) + + if key is MODEL_PACKAGE_VALIDATION_ROLE_PATH: + return "arn:aws:iam::111111111111:role/ConfigRole" + elif key is MODEL_PACKAGE_VALIDATION_PROFILES_PATH: + return [ + { + "TransformJobDefinition": { + "TransformOutput": {"KmsKeyId": "testKmsKeyId"}, + "TransformResources": {"VolumeKmsKeyId": "testVolumeKmsKeyId"}, + } + } + ] + return default_value + + +def test_create_model_package_with_configs(sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_model_package, + ) + model_package_name = "sagemaker-model-package" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + model_metrics = { + "Bias": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + drift_check_baselines = { + "Bias": { + "ConfigFile": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + } + validation_profiles = [ + {"TransformJobDefinition": {"TransformOutput": {"S3OutputPath": "s3://test"}}} + ] + validation_specification = {"ValidationProfiles": validation_profiles} + + metadata_properties = { + "CommitId": "test-commit-id", + "Repository": "test-repository", + "GeneratedBy": "sagemaker-python-sdk", + "ProjectId": "unit-test", + } + marketplace_cert = (True,) + approval_status = ("Approved",) + description = "description" + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + sagemaker_session.create_model_package_from_containers( + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + model_package_name=model_package_name, + model_metrics=model_metrics, + metadata_properties=metadata_properties, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + description=description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + validation_specification=validation_specification, + ) + expected_args = { + "ModelPackageName": model_package_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + "SupportedTransformInstanceTypes": transform_instances, + }, + "ModelPackageDescription": description, + "ModelMetrics": model_metrics, + "MetadataProperties": metadata_properties, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, + "Domain": domain, + "SamplePayloadUrl": sample_payload_url, + "Task": task, + "ValidationSpecification": validation_specification, + } + expected_args["ValidationSpecification"][ + "ValidationRole" + ] = "arn:aws:iam::111111111111:role/ConfigRole" + expected_args["ValidationSpecification"]["ValidationProfiles"][0]["TransformJobDefinition"][ + "TransformResources" + ] = {"VolumeKmsKeyId": "testVolumeKmsKeyId"} + expected_args["ValidationSpecification"]["ValidationProfiles"][0]["TransformJobDefinition"][ + "TransformOutput" + ]["KmsKeyId"] = "testKmsKeyId" + sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) + + def test_create_model_package_from_containers_all_args(sagemaker_session): model_package_name = "sagemaker-model-package" containers = ["dummy-container"] @@ -2957,6 +3693,47 @@ def feature_group_dummy_definitions(): return [{"FeatureName": "feature1", "FeatureType": "String"}] +def _sagemaker_config_override_mock_for_feature_store(key, default_value=None): + from sagemaker.session import ( + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, + ) + + if key is FEATURE_GROUP_ROLE_ARN_PATH: + return "config_role" + elif key is FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH: + return {"S3StorageConfig": {"KmsKeyId": "testKmsId"}} + elif key is FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH: + return {"SecurityConfig": {"KmsKeyId": "testKmsId2"}} + return default_value + + +def test_feature_group_create_with_config_injections( + sagemaker_session, feature_group_dummy_definitions +): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=_sagemaker_config_override_mock_for_feature_store, + ) + sagemaker_session.create_feature_group( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=feature_group_dummy_definitions, + offline_store_config={"S3StorageConfig": {"S3Uri": "s3://test"}}, + ) + assert sagemaker_session.sagemaker_client.create_feature_group.called_with( + FeatureGroupName="MyFeatureGroup", + RecordIdentifierFeatureName="feature1", + EventTimeFeatureName="feature2", + FeatureDefinitions=feature_group_dummy_definitions, + RoleArn="config_role", + OnlineStoreConfig={"SecurityConfig": {"KmsKeyId": "testKmsId2"}, "EnableOnlineStore": True}, + OfflineStoreConfig={"S3StorageConfig": {"KmsKeyId": "testKmsId", "S3Uri": "s3://test"}}, + ) + + def test_feature_group_create(sagemaker_session, feature_group_dummy_definitions): sagemaker_session.create_feature_group( feature_group_name="MyFeatureGroup", @@ -3711,7 +4488,7 @@ def test_append_sagemaker_config_tags(sagemaker_session): def sort(tags): return tags.sort(key=lambda tag: tag["Key"]) - sagemaker_session._get_sagemaker_config_value = MagicMock( + sagemaker_session.get_sagemaker_config_override = MagicMock( return_value=[ {"Key": "tagkey1", "Value": "tagvalue1"}, {"Key": "tagkey2", "Value": "tagvalue2"}, @@ -3759,7 +4536,7 @@ def sort(tags): ] ) - sagemaker_session._get_sagemaker_config_value = MagicMock(return_value=tags_none) + sagemaker_session.get_sagemaker_config_override = MagicMock(return_value=tags_none) config_tags_none = sagemaker_session._append_sagemaker_config_tags( tags_base, "DUMMY.CONFIG.PATH" ) @@ -3770,7 +4547,7 @@ def sort(tags): ] ) - sagemaker_session._get_sagemaker_config_value = MagicMock(return_value=tags_empty) + sagemaker_session.get_sagemaker_config_override = MagicMock(return_value=tags_empty) config_tags_empty = sagemaker_session._append_sagemaker_config_tags( tags_base, "DUMMY.CONFIG.PATH" ) @@ -3787,56 +4564,56 @@ def test_resolve_value_from_config(sagemaker_session_without_mocked_sagemaker_co ss = sagemaker_session_without_mocked_sagemaker_config # direct_input should be respected - ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "INPUT" - ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" # Config or default values should be returned if no direct_input - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ss.resolve_value_from_config(None, None, "DEFAULT_VALUE") == "DEFAULT_VALUE" - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ( ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "DEFAULT_VALUE" ) - ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") assert ( ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "CONFIG_VALUE" ) - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ss.resolve_value_from_config(None, None, None) is None # Different falsy direct_inputs - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ss.resolve_value_from_config("", "DUMMY.CONFIG.PATH", None) == "" - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ss.resolve_value_from_config([], "DUMMY.CONFIG.PATH", None) == [] - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ss.resolve_value_from_config(False, "DUMMY.CONFIG.PATH", None) is False - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) assert ss.resolve_value_from_config({}, "DUMMY.CONFIG.PATH", None) == {} # Different falsy config_values - ss._get_sagemaker_config_value = MagicMock(return_value="") + ss.get_sagemaker_config_override = MagicMock(return_value="") assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == "" - ss._get_sagemaker_config_value = MagicMock(return_value=[]) + ss.get_sagemaker_config_override = MagicMock(return_value=[]) assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == [] - ss._get_sagemaker_config_value = MagicMock(return_value=False) + ss.get_sagemaker_config_override = MagicMock(return_value=False) assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) is False - ss._get_sagemaker_config_value = MagicMock(return_value={}) + ss.get_sagemaker_config_override = MagicMock(return_value={}) assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == {} @@ -3871,7 +4648,7 @@ def __eq__(self, other): dummy_config_path = ["DUMMY", "CONFIG", "PATH"] # with an existing config value - ss._get_sagemaker_config_value = MagicMock(return_value=config_value) + ss.get_sagemaker_config_override = MagicMock(return_value=config_value) # instance exists and has value; config has value test_instance = TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") @@ -3917,7 +4694,7 @@ def __eq__(self, other): ) # without an existing config value - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) # instance exists but doesnt have value; config doesnt have value test_instance = TestClass(extra="EXTRA_VALUE") @@ -3954,7 +4731,7 @@ def test_resolve_nested_dict_value_from_config(sagemaker_session_without_mocked_ dummy_config_path = ["DUMMY", "CONFIG", "PATH"] # with an existing config value - ss._get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") # happy cases: return existing dict with existing values assert ss.resolve_nested_dict_value_from_config( @@ -4026,7 +4803,7 @@ def test_resolve_nested_dict_value_from_config(sagemaker_session_without_mocked_ ) # without an existing config value - ss._get_sagemaker_config_value = MagicMock(return_value=None) + ss.get_sagemaker_config_override = MagicMock(return_value=None) # happy case: return dict with default_value when it wasnt set in dict and in config assert ss.resolve_nested_dict_value_from_config( diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 0d3dff2159..e9c79bf152 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -86,6 +86,10 @@ def sagemaker_session(): else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index 3fb21d62d2..e6415aada8 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -46,6 +46,10 @@ def sagemaker_session(): sms.boto_region_name = REGION sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py index bded6ce2cc..ee8bb5b138 100644 --- a/tests/unit/test_timeout.py +++ b/tests/unit/test_timeout.py @@ -60,6 +60,10 @@ def session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name=DEFAULT_BUCKET_NAME, return_value=BUCKET_NAME) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 10c9123fe9..133d358b32 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -18,6 +18,11 @@ from sagemaker.transformer import _TransformJob, Transformer from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.inputs import BatchDataCaptureConfig +from sagemaker.session import ( + TRANSFORM_JOB_KMS_KEY_ID_PATH, + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, + TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, +) from tests.integ import test_local_mode ROLE = "DummyRole" @@ -63,7 +68,12 @@ def mock_create_tar_file(): @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session") - return Mock(name="sagemaker_session", boto_session=boto_mock, local_mode=False) + session = Mock(name="sagemaker_session", boto_session=boto_mock, local_mode=False) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) + return session @pytest.fixture() @@ -97,6 +107,81 @@ def transformer(sagemaker_session): ) +def _config_override_mock(key, default_value=None): + if key == TRANSFORM_OUTPUT_KMS_KEY_ID_PATH: + return "ConfigKmsKeyId" + elif key == TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH: + return "ConfigVolumeKmsKeyId" + elif key == TRANSFORM_JOB_KMS_KEY_ID_PATH: + return "DataCaptureConfigKmsKeyId" + return default_value + + +@patch("sagemaker.transformer._TransformJob.start_new") +def test_transform_with_config_injection(start_new_job, sagemaker_session): + sagemaker_session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", side_effect=_config_override_mock + ) + transformer = Transformer( + MODEL_NAME, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) + assert transformer.volume_kms_key == "ConfigVolumeKmsKeyId" + assert transformer.output_kms_key == "ConfigKmsKeyId" + + content_type = "text/csv" + compression = "Gzip" + split = "Line" + input_filter = "$.feature" + output_filter = "$['sagemaker_output', 'id']" + join_source = "Input" + experiment_config = { + "ExperimentName": "exp", + "TrialName": "t", + "TrialComponentDisplayName": "tc", + } + model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2} + batch_data_capture_config = BatchDataCaptureConfig( + destination_s3_uri=OUTPUT_PATH, generate_inference_id=False + ) + transformer.transform( + DATA, + S3_DATA_TYPE, + content_type=content_type, + compression_type=compression, + split_type=split, + job_name=JOB_NAME, + input_filter=input_filter, + output_filter=output_filter, + join_source=join_source, + experiment_config=experiment_config, + model_client_config=model_client_config, + batch_data_capture_config=batch_data_capture_config, + ) + + assert transformer._current_job_name == JOB_NAME + assert transformer.output_path == OUTPUT_PATH + start_new_job.assert_called_once_with( + transformer, + DATA, + S3_DATA_TYPE, + content_type, + compression, + split, + input_filter, + output_filter, + join_source, + experiment_config, + model_client_config, + batch_data_capture_config, + ) + # KmsKeyId in BatchDataCapture will be inserted from the config + assert batch_data_capture_config.kms_key_id == "DataCaptureConfigKmsKeyId" + + def test_delete_model(sagemaker_session): transformer = Transformer( MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index b530de1512..8079015e4b 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -74,6 +74,10 @@ def sagemaker_session(): else default_value, ) + sms.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return sms diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 2c35ad8584..ebec275d8c 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -88,6 +88,10 @@ def sagemaker_session(): if direct_input is not None else default_value, ) + session.get_sagemaker_config_override = Mock( + name="get_sagemaker_config_override", + side_effect=lambda key, default_value=None: default_value, + ) return session From 669e5a663aae143509557f679dd6cddf2a3e61ca Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Thu, 23 Feb 2023 15:54:25 -0700 Subject: [PATCH 05/40] fix: Make Key, Value as required fields for each "Tags" entry in the config file. --- src/sagemaker/config/config_schema.py | 1 + .../sagemaker/config/test_config_schema.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index a72c41f68c..730df320c8 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -192,6 +192,7 @@ "maxLength": 256, }, }, + "required": [KEY, VALUE], }, "minItems": 0, "maxItems": 50, diff --git a/tests/unit/sagemaker/config/test_config_schema.py b/tests/unit/sagemaker/config/test_config_schema.py index b5b0c3bd89..e008d670e9 100644 --- a/tests/unit/sagemaker/config/test_config_schema.py +++ b/tests/unit/sagemaker/config/test_config_schema.py @@ -95,6 +95,26 @@ def test_valid_monitoring_schedule_schema( ) +def test_tags_with_invalid_schema(base_config_with_schema, valid_edge_packaging_config): + edge_packaging_config = valid_edge_packaging_config.copy() + edge_packaging_config["Tags"] = [{"Key": "somekey"}] + config = base_config_with_schema + config["SageMaker"] = {"EdgePackagingJob": edge_packaging_config} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + edge_packaging_config["Tags"] = [{"Value": "somekey"}] + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_tags_with_valid_schema(base_config_with_schema, valid_edge_packaging_config): + edge_packaging_config = valid_edge_packaging_config.copy() + edge_packaging_config["Tags"] = [{"Key": "somekey", "Value": "somevalue"}] + config = base_config_with_schema + config["SageMaker"] = {"EdgePackagingJob": edge_packaging_config} + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + def test_invalid_training_job_schema(base_config_with_schema, valid_iam_role_arn, valid_vpc_config): # Changing key names training_job_config = { From 29c972816e46c28cd6dbd97fa6c155f9ad0012da Mon Sep 17 00:00:00 2001 From: Balaji Sankar <115105204+balajisankar15@users.noreply.github.com> Date: Fri, 24 Feb 2023 12:15:42 -0700 Subject: [PATCH 06/40] fix: Make 'role' as Optional for ModelQualityMonitor and DefaultModelMonitor, and fixed PROCESSING_CONFIG_PATH (#849) Co-authored-by: Balaji Sankar --- src/sagemaker/model_monitor/clarify_model_monitoring.py | 2 +- src/sagemaker/model_monitor/model_monitoring.py | 4 ++-- src/sagemaker/session.py | 4 +--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 030de7c6db..82334e2a1d 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -40,7 +40,7 @@ class ClarifyModelMonitor(mm.ModelMonitor): def __init__( self, - role, + role=None, instance_count=1, instance_type="ml.m5.xlarge", volume_size_in_gb=30, diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 407a1a5cb9..7dacea23c5 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -1565,7 +1565,7 @@ class DefaultModelMonitor(ModelMonitor): def __init__( self, - role, + role=None, instance_count=1, instance_type="ml.m5.xlarge", volume_size_in_gb=30, @@ -2632,7 +2632,7 @@ class ModelQualityMonitor(ModelMonitor): def __init__( self, - role, + role=None, instance_count=1, instance_type="ml.m5.xlarge", volume_size_in_gb=30, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ba2db269e9..6b7958995b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -242,9 +242,7 @@ def _simple_path(*args: str): PROCESSING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( PROCESSING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS ) -PROCESSING_OUTPUT_CONFIG_PATH = _simple_path( - SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG, KMS_KEY_ID -) +PROCESSING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG) PROCESSING_JOB_KMS_KEY_ID_PATH = _simple_path( SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG, KMS_KEY_ID ) From 09cc5303a3bffbb54e00b7521e1f769678e7cbff Mon Sep 17 00:00:00 2001 From: Balaji Sankar <115105204+balajisankar15@users.noreply.github.com> Date: Fri, 24 Feb 2023 13:29:48 -0700 Subject: [PATCH 07/40] Fix: Certain unit tests aren't passing sagemaker_session. Modify the logic to accommodate that case (#850) Co-authored-by: Balaji Sankar --- src/sagemaker/pipeline.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 5d40a59b39..1e3b10327e 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -91,15 +91,29 @@ def __init__( self.name = name self.sagemaker_session = sagemaker_session self.endpoint_name = None - self.role = self.sagemaker_session.get_sagemaker_config_override( - MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role + self.role = ( + self.sagemaker_session.get_sagemaker_config_override( + MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role + ) + if sagemaker_session + else role + ) + self.vpc_config = ( + self.sagemaker_session.get_sagemaker_config_override( + MODEL_VPC_CONFIG_PATH, default_value=vpc_config + ) + if sagemaker_session + else vpc_config ) - self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( - MODEL_VPC_CONFIG_PATH, default_value=vpc_config + default_enable_network_isolation = ( + False if enable_network_isolation is None else enable_network_isolation ) - self.enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( - MODEL_ENABLE_NETWORK_ISOLATION_PATH, - default_value=False if enable_network_isolation is None else enable_network_isolation, + self.enable_network_isolation = ( + self.sagemaker_session.get_sagemaker_config_override( + MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=default_enable_network_isolation + ) + if sagemaker_session + else default_enable_network_isolation ) if not self.role: # Originally IAM role was a required parameter. From ce90edac1aa95b70eb59f11ae3504bbe51679efb Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Mon, 27 Feb 2023 11:31:45 -0800 Subject: [PATCH 08/40] fix: Sagemaker Config - KeyError: 'MonitoringJobDefinition' in model_monitoring --- .../model_monitor/model_monitoring.py | 30 ++++--------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 7dacea23c5..6dc474b1b8 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -1454,18 +1454,9 @@ def _create_monitoring_schedule_from_job_definition( self.tags, "{}.{}.{}".format(SAGEMAKER, MONITORING_SCHEDULE, TAGS) ) - _enable_inter_container_traffic_encryption_from_config = ( - self.sagemaker_session.resolve_value_from_config( - config_path=PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION - ) - ) - if _enable_inter_container_traffic_encryption_from_config is not None: - # Not checking 'self.network_config' for 'enable_network_isolation' because that - # wasnt used here before this config value was set. Unclear whether there was a - # specific reason for that omission. - monitoring_schedule_config["MonitoringJobDefinition"]["NetworkConfig"][ - "EnableInterContainerTrafficEncryption" - ] = _enable_inter_container_traffic_encryption_from_config + # Not using value from sagemaker + # config key PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION here + # because no MonitoringJobDefinition is set for this call self.sagemaker_session.sagemaker_client.create_monitoring_schedule( MonitoringScheduleName=monitor_schedule_name, @@ -1533,18 +1524,9 @@ def _update_monitoring_schedule(self, job_definition_name, schedule_cron_express "ScheduleExpression": schedule_cron_expression } - _enable_inter_container_traffic_encryption_from_config = ( - self.sagemaker_session.resolve_value_from_config( - config_path=PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION - ) - ) - if _enable_inter_container_traffic_encryption_from_config is not None: - # Not checking 'self.network_config' for 'enable_network_isolation' because that - # wasnt used here before this config value was checked. Unclear whether there was a - # specific reason for that omission. - monitoring_schedule_config["MonitoringJobDefinition"]["NetworkConfig"][ - "EnableInterContainerTrafficEncryption" - ] = _enable_inter_container_traffic_encryption_from_config + # Not using value from sagemaker + # config key PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION here + # because no MonitoringJobDefinition is set for this call self.sagemaker_session.sagemaker_client.update_monitoring_schedule( MonitoringScheduleName=self.monitoring_schedule_name, From 24b681f9ba426d241a81e138d2ca69e1f874e52a Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Mon, 27 Feb 2023 12:38:47 -0800 Subject: [PATCH 09/40] change: Sagemaker Config - improved readability of print statements and simplified its code --- src/sagemaker/session.py | 94 +++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 50 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 6b7958995b..5957462205 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -651,11 +651,12 @@ def get_sagemaker_config_override(self, key, default_value=None): object: The corresponding value in the Config file/ the default value. """ + config_value = get_config_value(key, self.sagemaker_config.config) + self._print_message_on_sagemaker_config_usage(default_value, config_value, key) + if default_value is not None: return default_value - config_value = get_config_value(key, self.sagemaker_config.config) - if config_value is not None: - self._print_message_sagemaker_config_used(config_value, key) + return config_value def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): @@ -726,29 +727,31 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): else: raise - def _print_message_sagemaker_config_used(self, config_value, config_path): - """Informs the SDK user that a config value was substituted in automatically""" - print( - "[Sagemaker Config] config value {} at config path {}".format( - config_value, config_path - ), - "was automatically applied", - ) + def _print_message_on_sagemaker_config_usage(self, direct_input, config_value, config_path): + """Informs the SDK user whether a config value was present and automatically substituted""" - def _print_message_sagemaker_config_present_but_not_used( - self, direct_input, config_value, config_path - ): - """Informs the SDK user that a config value was not substituted in automatically. + if config_value is not None: - This is because method parameter is already provided. + if direct_input is not None and config_value != direct_input: + # Sagemaker Config had a value defined that is NOT going to be used + # and the config value has not already been applied earlier + print( + "[Sagemaker Config - skipped value]\n", + "config key = {}\n".format(config_path), + "config value = {}\n".format(config_value), + "specified value that will be used = {}\n".format(direct_input), + ) - """ - print( - "[Sagemaker Config] value {} was specified,".format(direct_input), - "so config value {} at config path {} was not applied".format( - config_value, config_path - ), - ) + elif direct_input is None: + # Sagemaker Config value is going to be used + print( + "[Sagemaker Config - applied value]\n", + "config key = {}\n".format(config_path), + "config value that will be used = {}\n".format(config_value), + ) + + # There is no print statement needed if nothing was specified in the config and nothing is + # being automatically applied def resolve_value_from_config( self, direct_input=None, config_path: str = None, default_value=None @@ -771,18 +774,12 @@ def resolve_value_from_config( The value that should be used by the caller """ config_value = self.get_sagemaker_config_override(config_path) + self._print_message_on_sagemaker_config_usage(direct_input, config_value, config_path) if direct_input is not None: - if config_value is not None: - self._print_message_sagemaker_config_present_but_not_used( - direct_input, config_value, config_path - ) - # No print statement if there was nothing in the config, because nothing is - # being overridden return direct_input if config_value is not None: - self._print_message_sagemaker_config_used(config_value, config_path) return config_value return default_value @@ -839,13 +836,10 @@ def resolve_class_attribute_from_config( # only set value if object does not already have a value set if config_value is not None: setattr(instance, attribute, config_value) - self._print_message_sagemaker_config_used(config_value, config_path) elif default_value is not None: setattr(instance, attribute, default_value) - elif current_value is not None and config_value is not None: - self._print_message_sagemaker_config_present_but_not_used( - current_value, config_value, config_path - ) + + self._print_message_on_sagemaker_config_usage(current_value, config_value, config_path) return instance @@ -892,13 +886,12 @@ def resolve_nested_dict_value_from_config( # only set value if not already set if config_value is not None: dictionary = set_nested_value(dictionary, nested_keys, config_value) - self._print_message_sagemaker_config_used(config_value, config_path) elif default_value is not None: dictionary = set_nested_value(dictionary, nested_keys, default_value) - elif current_nested_value is not None and config_value is not None: - self._print_message_sagemaker_config_present_but_not_used( - current_nested_value, config_value, config_path - ) + + self._print_message_on_sagemaker_config_usage( + current_nested_value, config_value, config_path + ) return dictionary @@ -933,9 +926,11 @@ def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): all_tags.append(config_tag) print( - "Appended tags from sagemaker_config to input.\n\texisting tags: {},".format(tags) - + "\n\ttags provided via sagemaker_config: {},".format(config_tags) - + "\n\tcombined tags: {}".format(all_tags) + "[Sagemaker Config - applied value]\n", + "config key = {}\n".format(config_path_to_tags), + "config value = {}\n".format(config_tags), + "source value = {}\n".format(tags), + "combined value that will be used = {}\n".format(all_tags), ) return all_tags @@ -5593,17 +5588,16 @@ def _update_nested_dictionary_with_values_from_config( merge_dicts(inferred_config_dict, source_dict or {}) if source_dict == inferred_config_dict: # Corresponds to the case where we didn't use any values from Config. - self._print_message_sagemaker_config_present_but_not_used( + self._print_message_on_sagemaker_config_usage( source_dict, original_config_dict_value, config_key_path ) else: print( - "Config value {} at config path {} was fetched first.".format( - original_config_dict_value, config_key_path - ), - "It was then merged with the existing value {} to give {}".format( - source_dict, inferred_config_dict - ), + "[Sagemaker Config - applied value]\n", + "config key = {}\n".format(config_key_path), + "config value = {}\n".format(original_config_dict_value), + "source value = {}\n".format(source_dict), + "combined value that will be used = {}\n".format(inferred_config_dict), ) return inferred_config_dict From 6b6e44c1983d9c47e0746c9e1ac39d0bdf3ffdd5 Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Mon, 27 Feb 2023 21:44:53 -0800 Subject: [PATCH 10/40] fix: Sagemaker Config - Reduce duplicate and misleading config-related print statements --- src/sagemaker/session.py | 22 +++++++++++---- tests/unit/test_session.py | 58 +++++++++++++++++++++++--------------- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5957462205..6f5d13042b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -637,6 +637,18 @@ def default_bucket(self): return self._default_bucket + def get_sagemaker_config_value(self, key): + """Util method that fetches a particular key path in the SageMakerConfig and returns it. + + Args: + key: Key Path of the config file entry. + + Returns: + object: The corresponding value in the Config file/ the default value. + """ + config_value = get_config_value(key, self.sagemaker_config.config) + return config_value + def get_sagemaker_config_override(self, key, default_value=None): """Util method that fetches a particular key path in the SageMakerConfig and returns it. @@ -773,7 +785,7 @@ def resolve_value_from_config( Returns: The value that should be used by the caller """ - config_value = self.get_sagemaker_config_override(config_path) + config_value = self.get_sagemaker_config_value(config_path) self._print_message_on_sagemaker_config_usage(direct_input, config_value, config_path) if direct_input is not None: @@ -811,7 +823,7 @@ def resolve_class_attribute_from_config( The updated class instance that should be used by the caller instead of the 'instance' parameter that was passed in. """ - config_value = self.get_sagemaker_config_override(config_path) + config_value = self.get_sagemaker_config_value(config_path) if config_value is None and default_value is None: # return instance unmodified. Could be None or populated @@ -869,7 +881,7 @@ def resolve_nested_dict_value_from_config( The updated dictionary that should be used by the caller instead of the 'dictionary' parameter that was passed in. """ - config_value = self.get_sagemaker_config_override(config_path) + config_value = self.get_sagemaker_config_value(config_path) if config_value is None and default_value is None: # if there is nothing to set, return early. And there is no need to traverse through @@ -909,7 +921,7 @@ def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): Returns: A potentially extended list of tags. """ - config_tags = self.get_sagemaker_config_override(config_path_to_tags) + config_tags = self.get_sagemaker_config_value(config_path_to_tags) if config_tags is None or len(config_tags) == 0: return tags @@ -5583,7 +5595,7 @@ def _update_nested_dictionary_with_values_from_config( in the Config file. """ - inferred_config_dict = self.get_sagemaker_config_override(config_key_path) or {} + inferred_config_dict = self.get_sagemaker_config_value(config_key_path) or {} original_config_dict_value = inferred_config_dict.copy() merge_dicts(inferred_config_dict, source_dict or {}) if source_dict == inferred_config_dict: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index aaaff31343..71f926ca3f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -290,6 +290,8 @@ def test_create_process_with_configs(sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_process, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override + processing_inputs = [ { "InputName": "input-1", @@ -1585,6 +1587,8 @@ def test_train_with_configs(sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_train, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override + in_config = [ { "ChannelName": "training", @@ -2168,6 +2172,8 @@ def test_create_model_with_configs(expand_container_def, sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_model, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override + sagemaker_session.expand_role = Mock( name="expand_role", side_effect=lambda role_name: role_name ) @@ -2371,6 +2377,7 @@ def test_create_edge_packaging_with_configs(sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_edge_packaging, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override output_config = {"S3OutputLocation": S3_OUTPUT} @@ -2417,6 +2424,7 @@ def test_create_monitoring_schedule_with_configs(sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_monitoring_schedule, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override monitoring_output_config = {"MonitoringOutputs": [{"S3Output": {"S3Uri": S3_OUTPUT}}]} @@ -2487,6 +2495,8 @@ def test_compile_with_configs(sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_compile, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override + sagemaker_session.compile_model( input_model_config={}, output_model_config={"S3OutputLocation": "s3://test"}, @@ -2585,6 +2595,8 @@ def test_create_enpoint_config_with_configs(sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_endpoint_config, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override + data_capture_config_dict = {"DestinationS3Uri": "s3://test"} tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] @@ -3135,6 +3147,8 @@ def test_create_auto_ml_with_configs(sagemaker_session): name="get_sagemaker_config_override", side_effect=_sagemaker_config_override_mock_for_auto_ml, ) + sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override + input_config = [ { "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}}, @@ -4488,7 +4502,7 @@ def test_append_sagemaker_config_tags(sagemaker_session): def sort(tags): return tags.sort(key=lambda tag: tag["Key"]) - sagemaker_session.get_sagemaker_config_override = MagicMock( + sagemaker_session.get_sagemaker_config_value = MagicMock( return_value=[ {"Key": "tagkey1", "Value": "tagvalue1"}, {"Key": "tagkey2", "Value": "tagvalue2"}, @@ -4536,7 +4550,7 @@ def sort(tags): ] ) - sagemaker_session.get_sagemaker_config_override = MagicMock(return_value=tags_none) + sagemaker_session.get_sagemaker_config_value = MagicMock(return_value=tags_none) config_tags_none = sagemaker_session._append_sagemaker_config_tags( tags_base, "DUMMY.CONFIG.PATH" ) @@ -4547,7 +4561,7 @@ def sort(tags): ] ) - sagemaker_session.get_sagemaker_config_override = MagicMock(return_value=tags_empty) + sagemaker_session.get_sagemaker_config_value = MagicMock(return_value=tags_empty) config_tags_empty = sagemaker_session._append_sagemaker_config_tags( tags_base, "DUMMY.CONFIG.PATH" ) @@ -4564,56 +4578,56 @@ def test_resolve_value_from_config(sagemaker_session_without_mocked_sagemaker_co ss = sagemaker_session_without_mocked_sagemaker_config # direct_input should be respected - ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "INPUT" - ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" # Config or default values should be returned if no direct_input - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ss.resolve_value_from_config(None, None, "DEFAULT_VALUE") == "DEFAULT_VALUE" - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ( ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "DEFAULT_VALUE" ) - ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") assert ( ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "CONFIG_VALUE" ) - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ss.resolve_value_from_config(None, None, None) is None # Different falsy direct_inputs - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ss.resolve_value_from_config("", "DUMMY.CONFIG.PATH", None) == "" - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ss.resolve_value_from_config([], "DUMMY.CONFIG.PATH", None) == [] - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ss.resolve_value_from_config(False, "DUMMY.CONFIG.PATH", None) is False - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) assert ss.resolve_value_from_config({}, "DUMMY.CONFIG.PATH", None) == {} # Different falsy config_values - ss.get_sagemaker_config_override = MagicMock(return_value="") + ss.get_sagemaker_config_value = MagicMock(return_value="") assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == "" - ss.get_sagemaker_config_override = MagicMock(return_value=[]) + ss.get_sagemaker_config_value = MagicMock(return_value=[]) assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == [] - ss.get_sagemaker_config_override = MagicMock(return_value=False) + ss.get_sagemaker_config_value = MagicMock(return_value=False) assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) is False - ss.get_sagemaker_config_override = MagicMock(return_value={}) + ss.get_sagemaker_config_value = MagicMock(return_value={}) assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == {} @@ -4648,7 +4662,7 @@ def __eq__(self, other): dummy_config_path = ["DUMMY", "CONFIG", "PATH"] # with an existing config value - ss.get_sagemaker_config_override = MagicMock(return_value=config_value) + ss.get_sagemaker_config_value = MagicMock(return_value=config_value) # instance exists and has value; config has value test_instance = TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") @@ -4694,7 +4708,7 @@ def __eq__(self, other): ) # without an existing config value - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) # instance exists but doesnt have value; config doesnt have value test_instance = TestClass(extra="EXTRA_VALUE") @@ -4731,7 +4745,7 @@ def test_resolve_nested_dict_value_from_config(sagemaker_session_without_mocked_ dummy_config_path = ["DUMMY", "CONFIG", "PATH"] # with an existing config value - ss.get_sagemaker_config_override = MagicMock(return_value="CONFIG_VALUE") + ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") # happy cases: return existing dict with existing values assert ss.resolve_nested_dict_value_from_config( @@ -4803,7 +4817,7 @@ def test_resolve_nested_dict_value_from_config(sagemaker_session_without_mocked_ ) # without an existing config value - ss.get_sagemaker_config_override = MagicMock(return_value=None) + ss.get_sagemaker_config_value = MagicMock(return_value=None) # happy case: return dict with default_value when it wasnt set in dict and in config assert ss.resolve_nested_dict_value_from_config( From cb66d7b8f5f4dfced061232cda39069704910b08 Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Mon, 6 Mar 2023 12:10:22 -0800 Subject: [PATCH 11/40] fix: Sagemaker Config - add function description --- src/sagemaker/session.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 6f5d13042b..3d0c3a66b5 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -739,8 +739,23 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): else: raise - def _print_message_on_sagemaker_config_usage(self, direct_input, config_value, config_path): - """Informs the SDK user whether a config value was present and automatically substituted""" + def _print_message_on_sagemaker_config_usage( + self, direct_input, config_value, config_path: str + ): + """Informs the SDK user whether a config value was present and automatically substituted + + Args: + direct_input: the value that would be used if no sagemaker_config or default values + existed. Usually this will be user-provided input to a Class or to a + session.py method, or None if no input was provided. + config_value: the value fetched from sagemaker_config. This is usually the value that + will be used if direct_input is None. + config_path: a string denoting the path of keys that point to the config value in the + sagemaker_config + + Returns: + No output (just prints information) + """ if config_value is not None: From ceecae5220ab39256b0a47a731a51b884bb4e918 Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Thu, 2 Mar 2023 13:52:30 -0800 Subject: [PATCH 12/40] fix: Sagemaker Config - Fix failing Integ tests, fix backwards incompatible behavior, and improved some unit tests --- src/sagemaker/session.py | 72 ++++++----- src/sagemaker/transformer.py | 24 ++-- src/sagemaker/workflow/pipeline_context.py | 5 + tests/integ/test_local_mode.py | 33 ++++- tests/unit/conftest.py | 45 +++++++ tests/unit/test_session.py | 133 ++++++++++----------- tests/unit/test_transformer.py | 34 +++++- 7 files changed, 222 insertions(+), 124 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 3d0c3a66b5..e2dcc09d4f 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -220,6 +220,10 @@ def _simple_path(*args: str): TRANSFORM_JOB_KMS_KEY_ID_PATH = _simple_path( SAGEMAKER, TRANSFORM_JOB, DATA_CAPTURE_CONFIG, KMS_KEY_ID ) +TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, TRANSFORM_RESOURCES, VOLUME_KMS_KEY_ID +) + MODEL_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, MODEL, VPC_CONFIG) MODEL_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(SAGEMAKER, MODEL, ENABLE_NETWORK_ISOLATION) MODEL_EXECUTION_ROLE_ARN_PATH = _simple_path(SAGEMAKER, MODEL, EXECUTION_ROLE_ARN) @@ -812,7 +816,12 @@ def resolve_value_from_config( return default_value def resolve_class_attribute_from_config( - self, clazz, instance, attribute: str, config_path: str, default_value=None + self, + clazz: Optional[type], + instance: Optional[object], + attribute: str, + config_path: str, + default_value=None, ): """Utility method that merges config values to data classes. @@ -823,13 +832,14 @@ def resolve_class_attribute_from_config( (1) current value of attribute, (2) config value, (3) default_value, (4) does not set it Args: - clazz: Class of 'instance'. Used to generate a new instance if the instance is None. - It is advised for the constructor of a given Class to set default values to - None; otherwise the constructor's non-None default value will be used instead - of any config value - instance (str): instance of the Class 'clazz' that has an attribute - of 'attribute' to set - attribute: attribute of the instance to set if not already set + clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the + instance is None. If None is provided here, no new object will be created + if 'instance' doesnt exist. Note: if provided, the constructor should set default + values to None; Otherwise, the constructor's non-None default will be left + as-is even if a config value was defined. + instance (Optional[object]): instance of the Class 'clazz' that has an attribute + of 'attribute' to set + attribute (str): attribute of the instance to set if not already set config_path (str): a string denoting the path to use to lookup the config value in the sagemaker config default_value: the value to use if not present elsewhere @@ -3545,23 +3555,15 @@ def transform( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, TRANSFORM_JOB, TAGS) ) - if batch_data_capture_config: - if not batch_data_capture_config.kms_key_id: - batch_data_capture_config.kms_key_id = self.get_sagemaker_config_override( - TRANSFORM_JOB_KMS_KEY_ID_PATH - ) - if output_config: - kms_key_from_config = self.get_sagemaker_config_override( - TRANSFORM_OUTPUT_KMS_KEY_ID_PATH - ) - if KMS_KEY_ID not in output_config and kms_key_from_config: - output_config[KMS_KEY_ID] = kms_key_from_config - if resource_config: - volume_kms_key_from_config = self.get_sagemaker_config_override( - TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH - ) - if VOLUME_KMS_KEY_ID not in resource_config and volume_kms_key_from_config: - resource_config[VOLUME_KMS_KEY_ID] = volume_kms_key_from_config + batch_data_capture_config = self.resolve_class_attribute_from_config( + None, batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH + ) + output_config = self.resolve_nested_dict_value_from_config( + output_config, [KMS_KEY_ID], TRANSFORM_OUTPUT_KMS_KEY_ID_PATH + ) + resource_config = self.resolve_nested_dict_value_from_config( + resource_config, [VOLUME_KMS_KEY_ID], TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH + ) transform_request = self._get_transform_request( job_name=job_name, @@ -5248,7 +5250,7 @@ def create_feature_group( inferred_online_store_from_config = self._update_nested_dictionary_with_values_from_config( online_store_config, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH ) - if len(inferred_online_store_from_config) > 0: + if inferred_online_store_from_config is not None: # OnlineStore should be handled differently because if you set KmsKeyId, then you # need to set EnableOnlineStore key as well inferred_online_store_from_config["EnableOnlineStore"] = True @@ -5608,17 +5610,31 @@ def _update_nested_dictionary_with_values_from_config( Returns: dict: The merged nested dictionary which includes missings values that are present in the Config file. - """ inferred_config_dict = self.get_sagemaker_config_value(config_key_path) or {} original_config_dict_value = inferred_config_dict.copy() merge_dicts(inferred_config_dict, source_dict or {}) + + if original_config_dict_value == {}: + # The config value is empty. That means either + # (1) inferred_config_dict equals source_dict, or + # (2) if source_dict was None, inferred_config_dict equals {} + # We should return whatever source_dict was to be safe. Because if for example, + # a VpcConfig is set to {} instead of None, some boto calls will fail due to + # ParamValidationError (because a VpcConfig was specified but required parameters for + # the VpcConfig were missing.) + + # Dont need to print because no config value was used or defined + return source_dict + if source_dict == inferred_config_dict: - # Corresponds to the case where we didn't use any values from Config. + # We didnt use any values from the config, but we should print if any of the config + # values were defined self._print_message_on_sagemaker_config_usage( source_dict, original_config_dict_value, config_key_path ) else: + # Something from the config was merged in print( "[Sagemaker Config - applied value]\n", "config key = {}\n".format(config_key_path), diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 41c158e935..804947475a 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -133,18 +133,6 @@ def __init__( TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, default_value=output_kms_key ) - def _update_batch_capture_config(self, batch_capture_config: BatchDataCaptureConfig): - """Utility method that updates BatchDataCaptureConfig with values from SageMakerConfig. - - Args: - batch_capture_config: The BatchDataCaptureConfig object. - - """ - if batch_capture_config: - batch_capture_config.kms_key_id = self.sagemaker_session.get_sagemaker_config_override( - TRANSFORM_JOB_KMS_KEY_ID_PATH, default_value=batch_capture_config.kms_key_id - ) - @runnable_by_pipeline def transform( self, @@ -279,7 +267,11 @@ def transform( self._reset_output_path = True experiment_config = check_and_get_run_experiment_config(experiment_config) - self._update_batch_capture_config(batch_data_capture_config) + + batch_data_capture_config = self.sagemaker_session.resolve_class_attribute_from_config( + None, batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH + ) + self.latest_transform_job = _TransformJob.start_new( self, data, @@ -401,7 +393,11 @@ def transform_with_monitoring( transformer = copy.deepcopy(self) transformer.sagemaker_session = PipelineSession() self.sagemaker_session = sagemaker_session - self._update_batch_capture_config(batch_data_capture_config) + + batch_data_capture_config = self.sagemaker_session.resolve_class_attribute_from_config( + None, batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH + ) + transform_step_args = transformer.transform( data=data, data_type=data_type, diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index c196d1c15b..3dcb8905c4 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -18,6 +18,7 @@ from functools import wraps from typing import Dict, Optional, Callable +from sagemaker.config import SageMakerConfig from sagemaker.session import Session, SessionSettings from sagemaker.local import LocalSession @@ -112,6 +113,7 @@ def __init__( sagemaker_client=None, default_bucket=None, settings=SessionSettings(), + sagemaker_config: SageMakerConfig = None, ): """Initialize a ``PipelineSession``. @@ -131,12 +133,15 @@ def __init__( Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional parameters to apply to the session. + sagemaker_config (sagemaker.config.SageMakerConfig): The SageMakerConfig object which + holds the default values for the SageMaker Python SDK. (default: None). """ super().__init__( boto_session=boto_session, sagemaker_client=sagemaker_client, default_bucket=default_bucket, settings=settings, + sagemaker_config=sagemaker_config, ) self._context = None diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index d1416299f9..f1a6488ab4 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -23,6 +23,7 @@ import stopit import tests.integ.lock as lock +from sagemaker.config import SageMakerConfig from tests.integ import DATA_DIR from mock import Mock, ANY @@ -58,7 +59,14 @@ class LocalNoS3Session(LocalSession): def __init__(self): super(LocalSession, self).__init__() - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs): + def _initialize( + self, + boto_session, + sagemaker_client, + sagemaker_runtime_client, + sagemaker_config: SageMakerConfig = None, + **kwargs + ): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} @@ -67,9 +75,11 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.sagemaker_client = LocalSagemakerClient(self) self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True - self.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + + self.sagemaker_config = sagemaker_config or ( + SageMakerConfig() + if "sagemaker_config" not in kwargs + else kwargs.get("sagemaker_config") ) @@ -81,7 +91,14 @@ class LocalPipelineNoS3Session(LocalPipelineSession): def __init__(self): super(LocalPipelineSession, self).__init__() - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs): + def _initialize( + self, + boto_session, + sagemaker_client, + sagemaker_runtime_client, + sagemaker_config: SageMakerConfig = None, + **kwargs + ): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} @@ -91,6 +108,12 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True + self.sagemaker_config = sagemaker_config or ( + SageMakerConfig() + if "sagemaker_config" not in kwargs + else kwargs.get("sagemaker_config") + ) + @pytest.fixture(scope="module") def sagemaker_local_session_no_local_code(boto_session): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 916a8645f2..5d9d565582 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -13,10 +13,14 @@ from __future__ import absolute_import import pytest +from mock.mock import MagicMock + import sagemaker from mock import Mock, PropertyMock +from sagemaker.config import SageMakerConfig + _ROLE = "DummyRole" _REGION = "us-west-2" _DEFAULT_BUCKET = "my-bucket" @@ -69,3 +73,44 @@ def sagemaker_session(boto_session, client): side_effect=lambda key, default_value=None: default_value, ) return session + + +@pytest.fixture() +def sagemaker_config_session(): + """ + Returns: a sagemaker.Session to use for tests of injection of default parameters from the + sagemaker_config. + + This session has a custom SageMakerConfig that allows us to set the sagemaker_config.config + dict manually. This allows us to test in unit tests without tight coupling to the exact + sagemaker_config related helpers/utils/methods used. (And those helpers/utils/methods should + have their own separate and specific unit tests.) + + An alternative would be to mock each call to a sagemaker_config-related method, but that would + be harder to maintain/update over time, and be less readable. + """ + + class SageMakerConfigWithSetter(SageMakerConfig): + """ + Version of SageMakerConfig that allows the config to be set + """ + + def __init__(self): + self._config = {} + # no need to call super + + @property + def config(self) -> dict: + return self._config + + @config.setter + def config(self, new_config): + self._config = new_config + + boto_mock = MagicMock(name="boto_session") + session_with_custom_sagemaker_config = sagemaker.Session( + boto_session=boto_mock, + sagemaker_client=MagicMock(), + sagemaker_config=SageMakerConfigWithSetter(), + ) + return session_with_custom_sagemaker_config diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 71f926ca3f..b69d6881aa 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -27,6 +27,16 @@ import sagemaker from sagemaker import TrainingInput, Session, get_execution_role, exceptions from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.config import ( + SAGEMAKER, + TRANSFORM_JOB, + DATA_CAPTURE_CONFIG, + TRANSFORM_OUTPUT, + TRANSFORM_RESOURCES, + KMS_KEY_ID, + VOLUME_KMS_KEY_ID, + TAGS, +) from sagemaker.session import ( _tuning_job_status, _transform_job_status, @@ -852,7 +862,7 @@ def test_training_input_all_arguments(): MAX_SIZE = 30 MAX_TIME = 3 * 60 * 60 JOB_NAME = "jobname" -TAGS = [{"Name": "some-tag", "Value": "value-for-tag"}] +EXAMPLE_TAGS = [{"Key": "some-tag", "Value": "value-for-tag"}] VPC_CONFIG = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} METRIC_DEFINITONS = [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}] EXPERIMENT_CONFIG = { @@ -956,26 +966,6 @@ def sagemaker_session(): ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) ims.expand_role = Mock(return_value=EXPANDED_ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - ims.resolve_nested_dict_value_from_config = Mock( - name="resolve_nested_dict_value_from_config", - side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, - ) - ims.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - return ims - - -@pytest.fixture() -def sagemaker_session_without_mocked_sagemaker_config(): - boto_mock = MagicMock(name="boto_session") - boto_mock.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = { - "Account": "123" - } - ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) - ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims @@ -1629,7 +1619,7 @@ def test_train_with_configs(sagemaker_session): resource_config=resource_config, hyperparameters=hyperparameters, stop_condition=stop_cond, - tags=TAGS, + tags=EXAMPLE_TAGS, metric_definitions=METRIC_DEFINITONS, encrypt_inter_container_traffic=True, use_spot_instances=True, @@ -1648,7 +1638,7 @@ def test_train_with_configs(sagemaker_session): "SecurityGroupIds": ["sg-123"], } assert actual_train_args["HyperParameters"] == hyperparameters - assert actual_train_args["Tags"] == TAGS + assert actual_train_args["Tags"] == EXAMPLE_TAGS assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS assert actual_train_args["AlgorithmSpecification"]["EnableSageMakerMetricsTimeSeries"] is True assert actual_train_args["EnableInterContainerTrafficEncryption"] is True @@ -1717,7 +1707,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): vpc_config=VPC_CONFIG, hyperparameters=hyperparameters, stop_condition=stop_cond, - tags=TAGS, + tags=EXAMPLE_TAGS, metric_definitions=METRIC_DEFINITONS, encrypt_inter_container_traffic=True, use_spot_instances=True, @@ -1733,7 +1723,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): assert actual_train_args["VpcConfig"] == VPC_CONFIG assert actual_train_args["HyperParameters"] == hyperparameters - assert actual_train_args["Tags"] == TAGS + assert actual_train_args["Tags"] == EXAMPLE_TAGS assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS assert actual_train_args["AlgorithmSpecification"]["EnableSageMakerMetricsTimeSeries"] is True assert actual_train_args["EnableInterContainerTrafficEncryption"] is True @@ -1747,57 +1737,56 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): ) -def _sagemaker_config_override_mock_for_transform(key, default_value=None): - from sagemaker.session import ( - TRANSFORM_JOB_KMS_KEY_ID_PATH, - TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, - TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, - ) - - if key is TRANSFORM_JOB_KMS_KEY_ID_PATH: - return "jobKmsKeyId" - elif key is TRANSFORM_OUTPUT_KMS_KEY_ID_PATH: - return "outputKmsKeyId" - elif key is TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH: - return "volumeKmsKeyId" - return default_value - +def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_session): -def test_create_transform_job_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_transform, - ) + # Config to test injection for + sagemaker_config_session.sagemaker_config.config = { + SAGEMAKER: { + TRANSFORM_JOB: { + DATA_CAPTURE_CONFIG: {KMS_KEY_ID: "jobKmsKeyId"}, + TRANSFORM_OUTPUT: {KMS_KEY_ID: "outputKmsKeyId"}, + TRANSFORM_RESOURCES: {VOLUME_KMS_KEY_ID: "volumeKmsKeyId"}, + TAGS: EXAMPLE_TAGS, + } + } + } model_name = "my-model" - in_config = { "CompressionType": "None", "ContentType": "text/csv", "SplitType": "None", "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}}, } - out_config = {"S3OutputPath": S3_OUTPUT} - resource_config = {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE} - data_processing = {"OutputFilter": "$", "InputFilter": "$", "JoinSource": "Input"} - data_capture_config = BatchDataCaptureConfig(destination_s3_uri="s3://test") - expected_args = { - "TransformJobName": JOB_NAME, - "ModelName": model_name, - "TransformInput": in_config, - "TransformOutput": out_config, - "TransformResources": resource_config, - "DataProcessing": data_processing, - "DataCaptureConfig": data_capture_config._to_request_dict(), - } + + # important to deepcopy, otherwise the original dicts are modified when we add expected params + expected_args = copy.deepcopy( + { + "TransformJobName": JOB_NAME, + "ModelName": model_name, + "TransformInput": in_config, + "TransformOutput": out_config, + "TransformResources": resource_config, + "DataProcessing": data_processing, + "DataCaptureConfig": data_capture_config._to_request_dict(), + "Tags": EXAMPLE_TAGS, + } + ) expected_args["DataCaptureConfig"]["KmsKeyId"] = "jobKmsKeyId" expected_args["TransformOutput"]["KmsKeyId"] = "outputKmsKeyId" expected_args["TransformResources"]["VolumeKmsKeyId"] = "volumeKmsKeyId" - sagemaker_session.transform( + + # make sure the original dicts were not modified before config injection + assert "KmsKeyId" not in in_config + assert "KmsKeyId" not in out_config + assert "VolumeKmsKeyId" not in resource_config + + # injection should happen during this method + sagemaker_config_session.transform( job_name=JOB_NAME, model_name=model_name, strategy=None, @@ -1809,12 +1798,12 @@ def test_create_transform_job_with_configs(sagemaker_session): resource_config=resource_config, experiment_config=None, model_client_config=None, - tags=None, + tags=EXAMPLE_TAGS, data_processing=data_processing, batch_data_capture_config=data_capture_config, ) - _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] + _, _, actual_args = sagemaker_config_session.sagemaker_client.method_calls[0] assert actual_args == expected_args @@ -1888,7 +1877,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): resource_config={}, experiment_config=EXPERIMENT_CONFIG, model_client_config=MODEL_CLIENT_CONFIG, - tags=TAGS, + tags=EXAMPLE_TAGS, data_processing=None, batch_data_capture_config=batch_data_capture_config, ) @@ -1898,7 +1887,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): assert actual_args["MaxConcurrentTransforms"] == max_concurrent_transforms assert actual_args["MaxPayloadInMB"] == max_payload assert actual_args["Environment"] == env - assert actual_args["Tags"] == TAGS + assert actual_args["Tags"] == EXAMPLE_TAGS assert actual_args["ExperimentConfig"] == EXPERIMENT_CONFIG assert actual_args["ModelClientConfig"] == MODEL_CLIENT_CONFIG assert actual_args["DataCaptureConfig"] == batch_data_capture_config._to_request_dict() @@ -2345,7 +2334,7 @@ def test_create_model_from_job(sagemaker_session): def test_create_model_from_job_with_tags(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT - ims.create_model_from_job(JOB_NAME, tags=TAGS) + ims.create_model_from_job(JOB_NAME, tags=EXAMPLE_TAGS) assert ( call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list @@ -2355,7 +2344,7 @@ def test_create_model_from_job_with_tags(sagemaker_session): ModelName=JOB_NAME, PrimaryContainer=PRIMARY_CONTAINER, VpcConfig=VPC_CONFIG, - Tags=TAGS, + Tags=EXAMPLE_TAGS, ) @@ -4573,9 +4562,9 @@ def sort(tags): ) -def test_resolve_value_from_config(sagemaker_session_without_mocked_sagemaker_config): +def test_resolve_value_from_config(sagemaker_session): # using a shorter name for inside the test - ss = sagemaker_session_without_mocked_sagemaker_config + ss = sagemaker_session # direct_input should be respected ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") @@ -4641,10 +4630,10 @@ def test_resolve_value_from_config(sagemaker_session_without_mocked_sagemaker_co ], ) def test_resolve_class_attribute_from_config( - sagemaker_session_without_mocked_sagemaker_config, existing_value, config_value, default_value + sagemaker_session, existing_value, config_value, default_value ): # using a shorter name for inside the test - ss = sagemaker_session_without_mocked_sagemaker_config + ss = sagemaker_session class TestClass(object): def __init__(self, test_attribute=None, extra=None): @@ -4738,9 +4727,9 @@ def __eq__(self, other): ) == TestClass(test_attribute=default_value, extra=None) -def test_resolve_nested_dict_value_from_config(sagemaker_session_without_mocked_sagemaker_config): +def test_resolve_nested_dict_value_from_config(sagemaker_session): # using a shorter name for inside the test - ss = sagemaker_session_without_mocked_sagemaker_config + ss = sagemaker_session dummy_config_path = ["DUMMY", "CONFIG", "PATH"] diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 133d358b32..cf7d519823 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -15,6 +15,16 @@ import pytest from mock import MagicMock, Mock, patch, PropertyMock +import sagemaker +from sagemaker.config import ( + SAGEMAKER, + TRANSFORM_JOB, + DATA_CAPTURE_CONFIG, + TRANSFORM_OUTPUT, + TRANSFORM_RESOURCES, + VOLUME_KMS_KEY_ID, + TAGS, +) from sagemaker.transformer import _TransformJob, Transformer from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.inputs import BatchDataCaptureConfig @@ -69,6 +79,12 @@ def mock_create_tar_file(): def sagemaker_session(): boto_mock = Mock(name="boto_session") session = Mock(name="sagemaker_session", boto_session=boto_mock, local_mode=False) + + # For the purposes of unit tests, no values should be fetched from sagemaker config + session.resolve_class_attribute_from_config = Mock( + name="resolve_class_attribute_from_config", + side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, + ) session.get_sagemaker_config_override = Mock( name="get_sagemaker_config_override", side_effect=lambda key, default_value=None: default_value, @@ -118,16 +134,24 @@ def _config_override_mock(key, default_value=None): @patch("sagemaker.transformer._TransformJob.start_new") -def test_transform_with_config_injection(start_new_job, sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", side_effect=_config_override_mock - ) +def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = { + SAGEMAKER: { + TRANSFORM_JOB: { + DATA_CAPTURE_CONFIG: {sagemaker.config.KMS_KEY_ID: "DataCaptureConfigKmsKeyId"}, + TRANSFORM_OUTPUT: {sagemaker.config.KMS_KEY_ID: "ConfigKmsKeyId"}, + TRANSFORM_RESOURCES: {VOLUME_KMS_KEY_ID: "ConfigVolumeKmsKeyId"}, + TAGS: [], + } + } + } + transformer = Transformer( MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session, + sagemaker_session=sagemaker_config_session, ) assert transformer.volume_kms_key == "ConfigVolumeKmsKeyId" assert transformer.output_kms_key == "ConfigKmsKeyId" From 018d6826cfec1b8c0e22a0c6e257523a410f3314 Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Fri, 3 Mar 2023 22:11:01 -0800 Subject: [PATCH 13/40] change: new integ test for sagemaker_config --- tests/integ/test_sagemaker_config.py | 85 ++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/integ/test_sagemaker_config.py diff --git a/tests/integ/test_sagemaker_config.py b/tests/integ/test_sagemaker_config.py new file mode 100644 index 0000000000..28e2229ca3 --- /dev/null +++ b/tests/integ/test_sagemaker_config.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os + +import pytest +import yaml + +from sagemaker.config import SageMakerConfig +from sagemaker.s3 import S3Uploader +from tests.integ.kms_utils import get_or_create_kms_key + + +@pytest.fixture() +def get_data_dir(): + return os.path.join(os.path.dirname(__file__), "..", "data", "config") + + +@pytest.fixture(scope="module") +def s3_files_kms_key(sagemaker_session): + return get_or_create_kms_key(sagemaker_session=sagemaker_session) + + +@pytest.fixture() +def expected_merged_config(get_data_dir): + expected_merged_config_file_path = os.path.join( + get_data_dir, "expected_output_config_after_merge.yaml" + ) + return yaml.safe_load(open(expected_merged_config_file_path, "r").read()) + + +def test_config_download_from_s3_and_merge( + sagemaker_session, + s3_files_kms_key, + get_data_dir, + expected_merged_config, +): + + # Note: not using unique_name_from_base() here because the config contents are expected to + # change very rarely (if ever), so rather than writing new files and deleting them every time + # we can just use the same S3 paths + s3_uri_prefix = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "integ-test-sagemaker_config", + ) + + config_file_1_local_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml") + config_file_2_local_path = os.path.join(get_data_dir, "sample_additional_config_for_merge.yaml") + + config_file_1_as_yaml = open(config_file_1_local_path, "r").read() + config_file_2_as_yaml = open(config_file_2_local_path, "r").read() + + s3_uri_config_1 = os.path.join(s3_uri_prefix, "config_1.yaml") + s3_uri_config_2 = os.path.join(s3_uri_prefix, "config_2.yaml") + + # Upload S3 files in case they dont already exist + S3Uploader.upload_string_as_file_body( + body=config_file_1_as_yaml, + desired_s3_uri=s3_uri_config_1, + kms_key=s3_files_kms_key, + sagemaker_session=sagemaker_session, + ) + S3Uploader.upload_string_as_file_body( + body=config_file_2_as_yaml, + desired_s3_uri=s3_uri_config_2, + kms_key=s3_files_kms_key, + sagemaker_session=sagemaker_session, + ) + + # The thing being tested. + sagemaker_config = SageMakerConfig(additional_config_paths=[s3_uri_config_1, s3_uri_config_2]) + + assert sagemaker_config.config == expected_merged_config From 514618b4987755445ac70bea122480d74e9750f2 Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Wed, 8 Mar 2023 15:28:41 -0800 Subject: [PATCH 14/40] fix: Sagemaker Config - fleshed out unit tests and fixed bugs --- src/sagemaker/config/__init__.py | 1 - src/sagemaker/config/config_schema.py | 6 - src/sagemaker/estimator.py | 28 +- src/sagemaker/pipeline.py | 24 +- src/sagemaker/session.py | 101 +-- tests/data/config/config.yaml | 3 - tests/unit/__init__.py | 249 ++++++++ tests/unit/conftest.py | 11 +- tests/unit/sagemaker/automl/test_auto_ml.py | 99 +-- tests/unit/sagemaker/config/conftest.py | 1 - .../feature_store/test_feature_group.py | 32 +- .../sagemaker/huggingface/test_processing.py | 6 + .../monitor/test_model_monitoring.py | 42 ++ tests/unit/test_amazon_estimator.py | 6 + tests/unit/test_estimator.py | 69 ++- tests/unit/test_pipeline_model.py | 57 +- tests/unit/test_processing.py | 70 +-- tests/unit/test_session.py | 576 +++++++----------- tests/unit/test_transformer.py | 44 +- 19 files changed, 760 insertions(+), 665 deletions(-) diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index 94f04b0f90..f2938aaf8c 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -52,7 +52,6 @@ MONITORING_RESOURCES, PROCESSING_RESOURCES, PRODUCTION_VARIANTS, - SHADOW_PRODUCTION_VARIANTS, TRANSFORM_OUTPUT, TRANSFORM_RESOURCES, VALIDATION_ROLE, diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 730df320c8..7942f05b52 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -53,7 +53,6 @@ MONITORING_RESOURCES = "MonitoringResources" PROCESSING_RESOURCES = "ProcessingResources" PRODUCTION_VARIANTS = "ProductionVariants" -SHADOW_PRODUCTION_VARIANTS = "ShadowProductionVariants" TRANSFORM_OUTPUT = "TransformOutput" TRANSFORM_RESOURCES = "TransformResources" VALIDATION_ROLE = "ValidationRole" @@ -344,10 +343,6 @@ TYPE: "array", "items": {"$ref": "#/definitions/productionVariant"}, }, - SHADOW_PRODUCTION_VARIANTS: { - TYPE: "array", - "items": {"$ref": "#/definitions/productionVariant"}, - }, TAGS: {"$ref": "#/definitions/tags"}, }, }, @@ -465,7 +460,6 @@ VALIDATION_ROLE: {"$ref": "#/definitions/roleArn"}, }, }, - TAGS: {"$ref": "#/definitions/tags"}, }, }, # Processing Job diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 2289d4f77a..8d23fe036c 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -536,13 +536,6 @@ def __init__( "train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs ) - validate_source_code_input_against_pipeline_variables( - entry_point=entry_point, - source_dir=source_dir, - git_config=git_config, - enable_network_isolation=enable_network_isolation, - ) - self.instance_count = instance_count self.instance_type = instance_type self.keep_alive_period_in_seconds = keep_alive_period_in_seconds @@ -642,9 +635,11 @@ def __init__( self.collection_configs = None self.enable_sagemaker_metrics = enable_sagemaker_metrics - self._enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - default_value=False if enable_network_isolation is None else enable_network_isolation, + + self._enable_network_isolation = self.sagemaker_session.resolve_value_from_config( + direct_input=enable_network_isolation, + config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False, ) self.profiler_config = profiler_config @@ -663,6 +658,13 @@ def __init__( self.profiler_rules = None self.debugger_rules = None + validate_source_code_input_against_pipeline_variables( + entry_point=entry_point, + source_dir=source_dir, + git_config=git_config, + enable_network_isolation=self._enable_network_isolation, + ) + @abstractmethod def training_image_uri(self): """Return the Docker image to use for training. @@ -2382,12 +2384,12 @@ def __init__( model_uri: Optional[str] = None, model_channel_name: Union[str, PipelineVariable] = "model", metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, - encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None, use_spot_instances: Union[bool, PipelineVariable] = False, max_wait: Optional[Union[int, PipelineVariable]] = None, checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, - enable_network_isolation: Union[bool, PipelineVariable] = False, + enable_network_isolation: Union[bool, PipelineVariable] = None, rules: Optional[List[RuleBase]] = None, debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, @@ -2925,7 +2927,7 @@ def __init__( code_location: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, - enable_network_isolation: Union[bool, PipelineVariable] = False, + enable_network_isolation: Union[bool, PipelineVariable] = None, git_config: Optional[Dict[str, str]] = None, checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 1e3b10327e..ccb87aa966 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -95,26 +95,28 @@ def __init__( self.sagemaker_session.get_sagemaker_config_override( MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role ) - if sagemaker_session + if self.sagemaker_session else role ) self.vpc_config = ( self.sagemaker_session.get_sagemaker_config_override( MODEL_VPC_CONFIG_PATH, default_value=vpc_config ) - if sagemaker_session + if self.sagemaker_session else vpc_config ) - default_enable_network_isolation = ( - False if enable_network_isolation is None else enable_network_isolation - ) - self.enable_network_isolation = ( - self.sagemaker_session.get_sagemaker_config_override( - MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=default_enable_network_isolation + + if self.sagemaker_session is not None: + self.enable_network_isolation = self.sagemaker_session.resolve_value_from_config( + direct_input=enable_network_isolation, + config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False, ) - if sagemaker_session - else default_enable_network_isolation - ) + else: + self.enable_network_isolation = ( + False if enable_network_isolation is None else enable_network_isolation + ) + if not self.role: # Originally IAM role was a required parameter. # Now we marked that as Optional because we can fetch it from SageMakerConfig diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index e2dcc09d4f..03e1f88825 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -13,6 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import, print_function +import copy import inspect import json import logging @@ -275,7 +276,11 @@ def _simple_path(*args: str): ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, ) PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, AUTO_ML, SECURITY_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION + SAGEMAKER, + AUTO_ML, + AUTO_ML_JOB_CONFIG, + SECURITY_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, ) PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION @@ -651,7 +656,9 @@ def get_sagemaker_config_value(self, key): object: The corresponding value in the Config file/ the default value. """ config_value = get_config_value(key, self.sagemaker_config.config) - return config_value + + # Copy the value so any modifications to the output will not modify the source config + return copy.deepcopy(config_value) def get_sagemaker_config_override(self, key, default_value=None): """Util method that fetches a particular key path in the SageMakerConfig and returns it. @@ -667,7 +674,7 @@ def get_sagemaker_config_override(self, key, default_value=None): object: The corresponding value in the Config file/ the default value. """ - config_value = get_config_value(key, self.sagemaker_config.config) + config_value = self.get_sagemaker_config_value(key) self._print_message_on_sagemaker_config_usage(default_value, config_value, key) if default_value is not None: @@ -1471,6 +1478,7 @@ def _get_update_training_job_request( return update_training_job_request + # TODO: unit tests or make a more generic version def update_processing_input_from_config(self, inputs): """Updates Processor Inputs to fetch values from SageMakerConfig wherever applicable. @@ -1478,41 +1486,49 @@ def update_processing_input_from_config(self, inputs): inputs (list[dict]): A list of Processing Input objects. """ + inputs_copy = copy.deepcopy(inputs) processing_inputs_from_config = ( self.get_sagemaker_config_override(PROCESSING_JOB_INPUTS_PATH) or [] ) for i in range(min(len(inputs), len(processing_inputs_from_config))): - processing_input_from_config = processing_inputs_from_config[i] - if "DatasetDefinition" in inputs[i]: - dataset_definition = inputs[i]["DatasetDefinition"] - if "AthenaDatasetDefinition" in dataset_definition: - athena_dataset_definition = dataset_definition["AthenaDatasetDefinition"] - if KMS_KEY_ID not in athena_dataset_definition: - athena_kms_key_id_from_config = get_config_value( - ATHENA_DATASET_DEFINITION_KMS_KEY_ID_PATH, processing_input_from_config - ) - if athena_kms_key_id_from_config: - athena_dataset_definition[KMS_KEY_ID] = athena_kms_key_id_from_config - if "RedshiftDatasetDefinition" in dataset_definition: - redshift_dataset_definition = dataset_definition["RedshiftDatasetDefinition"] - if CLUSTER_ROLE_ARN not in redshift_dataset_definition: - redshift_role_arn_from_config = get_config_value( - REDSHIFT_DATASET_DEFINITION_CLUSTER_ROLE_ARN_PATH, - processing_input_from_config, - ) - if redshift_role_arn_from_config: - redshift_dataset_definition[ - CLUSTER_ROLE_ARN - ] = redshift_role_arn_from_config - if not redshift_dataset_definition.kms_key_id: - redshift_kms_key_id_from_config = get_config_value( - REDSHIFT_DATASET_DEFINITION_KMS_KEY_ID_PATH, - processing_input_from_config, - ) - if redshift_kms_key_id_from_config: - redshift_dataset_definition[ - KMS_KEY_ID - ] = redshift_kms_key_id_from_config + dict_from_inputs = inputs[i] + dict_from_config = processing_inputs_from_config[i] + + # The Dataset Definition input must specify exactly one of either + # AthenaDatasetDefinition or RedshiftDatasetDefinition types (source: API reference). + # So to prevent API failure because of sagemaker_config, we will only populate from the + # config for the ones already present in dict_from_inputs. + # If BOTH are present, we will still add to both and let the API call fail as it would + # have even without injection from sagemaker_config. + athena_path = [DATASET_DEFINITION, ATHENA_DATASET_DEFINITION] + redshift_path = [DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION] + + athena_value_from_inputs = get_nested_value(dict_from_inputs, athena_path) + athena_value_from_config = get_nested_value(dict_from_config, athena_path) + + redshift_value_from_inputs = get_nested_value(dict_from_inputs, redshift_path) + redshift_value_from_config = get_nested_value(dict_from_config, redshift_path) + + if athena_value_from_inputs is not None: + merge_dicts(athena_value_from_config, athena_value_from_inputs) + inputs[i] = set_nested_value( + dict_from_inputs, athena_path, athena_value_from_config + ) + + if redshift_value_from_inputs is not None: + merge_dicts(redshift_value_from_config, redshift_value_from_inputs) + inputs[i] = set_nested_value( + dict_from_inputs, redshift_path, redshift_value_from_config + ) + + if processing_inputs_from_config != []: + print( + "[Sagemaker Config - applied value]\n", + "config key = {}\n".format(PROCESSING_JOB_INPUTS_PATH), + "config value = {}\n".format(processing_inputs_from_config), + "source value = {}\n".format(inputs_copy), + "combined value that will be used = {}\n".format(inputs), + ) def process( self, @@ -4178,6 +4194,7 @@ def create_endpoint_config_from_existing( "EndpointConfigName": new_config_name, } + # TODO: should this merge from config even if new_production_variants is None? if new_production_variants: inferred_production_variants_from_config = ( self.get_sagemaker_config_override(ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH) or [] @@ -4234,10 +4251,10 @@ def create_endpoint_config_from_existing( ) request["DataCaptureConfig"] = inferred_data_capture_config_dict - if existing_endpoint_config_desc.get("AsyncInferenceConfig") is not None: - async_inference_config_dict = existing_endpoint_config_desc.get( - "AsyncInferenceConfig", None - ) + async_inference_config_dict = existing_endpoint_config_desc.get( + "AsyncInferenceConfig", None + ) + if async_inference_config_dict is not None: inferred_async_inference_config_dict = ( self._update_nested_dictionary_with_values_from_config( async_inference_config_dict, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH @@ -5612,7 +5629,7 @@ def _update_nested_dictionary_with_values_from_config( in the Config file. """ inferred_config_dict = self.get_sagemaker_config_value(config_key_path) or {} - original_config_dict_value = inferred_config_dict.copy() + original_config_dict_value = copy.deepcopy(inferred_config_dict) merge_dicts(inferred_config_dict, source_dict or {}) if original_config_dict_value == {}: @@ -5678,7 +5695,11 @@ def account_id(self) -> str: return sts_client.get_caller_identity()["Account"] def _intercept_create_request( - self, request: typing.Dict, create, func_name: str = None # pylint: disable=unused-argument + self, + request: typing.Dict, + create, + func_name: str = None + # pylint: disable=unused-argument ): """This function intercepts the create job request. diff --git a/tests/data/config/config.yaml b/tests/data/config/config.yaml index 726f73f07a..8e0e40019a 100644 --- a/tests/data/config/config.yaml +++ b/tests/data/config/config.yaml @@ -34,9 +34,6 @@ SageMaker: ProductionVariants: - CoreDumpConfig: KmsKeyId: 'somekmskey4' - ShadowProductionVariants: - - CoreDumpConfig: - KmsKeyId: 'somekmskey5' AutoML: AutoMLJobConfig: SecurityConfig: diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index a7f9c55fd5..e818f4743c 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -14,5 +14,254 @@ import os +from sagemaker.config import ( + SAGEMAKER, + MONITORING_SCHEDULE, + MONITORING_SCHEDULE_CONFIG, + MONITORING_JOB_DEFINITION, + MONITORING_OUTPUT_CONFIG, + KMS_KEY_ID, + MONITORING_RESOURCES, + CLUSTER_CONFIG, + VOLUME_KMS_KEY_ID, + NETWORK_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, + ENABLE_NETWORK_ISOLATION, + VPC_CONFIG, + SUBNETS, + SECURITY_GROUP_IDS, + ROLE_ARN, + TAGS, + KEY, + VALUE, + COMPILATION_JOB, + OUTPUT_CONFIG, + EDGE_PACKAGING_JOB, + ENDPOINT_CONFIG, + DATA_CAPTURE_CONFIG, + PRODUCTION_VARIANTS, + AUTO_ML, + AUTO_ML_JOB_CONFIG, + SECURITY_CONFIG, + OUTPUT_DATA_CONFIG, + MODEL_PACKAGE, + VALIDATION_SPECIFICATION, + VALIDATION_PROFILES, + TRANSFORM_JOB_DEFINITION, + TRANSFORM_OUTPUT, + TRANSFORM_RESOURCES, + VALIDATION_ROLE, + FEATURE_GROUP, + OFFLINE_STORE_CONFIG, + S3_STORAGE_CONFIG, + ONLINE_STORE_CONFIG, + PROCESSING_JOB, + PROCESSING_INPUTS, + DATASET_DEFINITION, + ATHENA_DATASET_DEFINITION, + REDSHIFT_DATASET_DEFINITION, + CLUSTER_ROLE_ARN, + PROCESSING_OUTPUT_CONFIG, + PROCESSING_RESOURCES, + TRAINING_JOB, + RESOURCE_CONFIG, + TRANSFORM_JOB, + EXECUTION_ROLE_ARN, + MODEL, + ASYNC_INFERENCE_CONFIG, +) + DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") PY_VERSION = "py3" + + +SAGEMAKER_CONFIG_MONITORING_SCHEDULE = { + SAGEMAKER: { + MONITORING_SCHEDULE: { + MONITORING_SCHEDULE_CONFIG: { + MONITORING_JOB_DEFINITION: { + MONITORING_OUTPUT_CONFIG: {KMS_KEY_ID: "configKmsKeyId"}, + MONITORING_RESOURCES: { + CLUSTER_CONFIG: {VOLUME_KMS_KEY_ID: "configVolumeKmsKeyId"}, + }, + NETWORK_CONFIG: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: False, + ENABLE_NETWORK_ISOLATION: True, + VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, + }, + ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + } + }, + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + } + } +} + +SAGEMAKER_CONFIG_COMPILATION_JOB = { + SAGEMAKER: { + COMPILATION_JOB: { + OUTPUT_CONFIG: {KMS_KEY_ID: "TestKms"}, + ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + } +} + +SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB = { + SAGEMAKER: { + EDGE_PACKAGING_JOB: { + OUTPUT_CONFIG: { + KMS_KEY_ID: "configKmsKeyId", + }, + ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + }, +} + +SAGEMAKER_CONFIG_ENDPOINT_CONFIG = { + SAGEMAKER: { + ENDPOINT_CONFIG: { + ASYNC_INFERENCE_CONFIG: { + OUTPUT_CONFIG: { + KMS_KEY_ID: "testOutputKmsKeyId", + } + }, + DATA_CAPTURE_CONFIG: { + KMS_KEY_ID: "testDataCaptureKmsKeyId", + }, + KMS_KEY_ID: "ConfigKmsKeyId", + PRODUCTION_VARIANTS: [ + {"CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}}, + {"CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId2"}}, + ], + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + } +} + +SAGEMAKER_CONFIG_AUTO_ML = { + SAGEMAKER: { + AUTO_ML: { + AUTO_ML_JOB_CONFIG: { + SECURITY_CONFIG: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: True, + VOLUME_KMS_KEY_ID: "TestKmsKeyId", + VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, + }, + }, + OUTPUT_DATA_CONFIG: {KMS_KEY_ID: "configKmsKeyId"}, + ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + } +} + +SAGEMAKER_CONFIG_MODEL_PACKAGE = { + SAGEMAKER: { + MODEL_PACKAGE: { + VALIDATION_SPECIFICATION: { + VALIDATION_PROFILES: [ + { + TRANSFORM_JOB_DEFINITION: { + TRANSFORM_OUTPUT: {KMS_KEY_ID: "testKmsKeyId"}, + TRANSFORM_RESOURCES: {VOLUME_KMS_KEY_ID: "testVolumeKmsKeyId"}, + } + } + ], + VALIDATION_ROLE: "arn:aws:iam::111111111111:role/ConfigRole", + }, + # TODO - does SDK not support tags for this API? + # TAGS: EXAMPLE_TAGS, + }, + } +} + +SAGEMAKER_CONFIG_FEATURE_GROUP = { + SAGEMAKER: { + FEATURE_GROUP: { + OFFLINE_STORE_CONFIG: { + S3_STORAGE_CONFIG: { + KMS_KEY_ID: "OfflineConfigKmsKeyId", + } + }, + ONLINE_STORE_CONFIG: { + SECURITY_CONFIG: { + KMS_KEY_ID: "OnlineConfigKmsKeyId", + } + }, + ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + } +} + +SAGEMAKER_CONFIG_PROCESSING_JOB = { + SAGEMAKER: { + PROCESSING_JOB: { + NETWORK_CONFIG: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: False, + ENABLE_NETWORK_ISOLATION: True, + VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, + }, + PROCESSING_INPUTS: [ + { + DATASET_DEFINITION: { + ATHENA_DATASET_DEFINITION: { + KMS_KEY_ID: "AthenaKmsKeyId", + }, + REDSHIFT_DATASET_DEFINITION: { + KMS_KEY_ID: "RedshiftKmsKeyId", + CLUSTER_ROLE_ARN: "arn:aws:iam::111111111111:role/ClusterRole", + }, + }, + }, + ], + PROCESSING_OUTPUT_CONFIG: {KMS_KEY_ID: "testKmsKeyId"}, + PROCESSING_RESOURCES: { + CLUSTER_CONFIG: { + VOLUME_KMS_KEY_ID: "testVolumeKmsKeyId", + }, + }, + ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + } +} + +SAGEMAKER_CONFIG_TRAINING_JOB = { + SAGEMAKER: { + TRAINING_JOB: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: True, + ENABLE_NETWORK_ISOLATION: True, + OUTPUT_DATA_CONFIG: {KMS_KEY_ID: "TestKms"}, + RESOURCE_CONFIG: {VOLUME_KMS_KEY_ID: "volumekey"}, + ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + }, +} + +SAGEMAKER_CONFIG_TRANSFORM_JOB = { + SAGEMAKER: { + TRANSFORM_JOB: { + DATA_CAPTURE_CONFIG: {KMS_KEY_ID: "jobKmsKeyId"}, + TRANSFORM_OUTPUT: {KMS_KEY_ID: "outputKmsKeyId"}, + TRANSFORM_RESOURCES: {VOLUME_KMS_KEY_ID: "volumeKmsKeyId"}, + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + } + } +} + +SAGEMAKER_CONFIG_MODEL = { + SAGEMAKER: { + MODEL: { + ENABLE_NETWORK_ISOLATION: True, + EXECUTION_ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", + VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, + TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], + }, + } +} diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5d9d565582..75d7730523 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,7 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import + import pytest +from jsonschema.validators import validate from mock.mock import MagicMock import sagemaker @@ -20,6 +22,7 @@ from mock import Mock, PropertyMock from sagemaker.config import SageMakerConfig +from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA, SCHEMA_VERSION _ROLE = "DummyRole" _REGION = "us-west-2" @@ -105,9 +108,15 @@ def config(self) -> dict: @config.setter def config(self, new_config): + """Validates and sets a new config.""" + # Add schema version if not already there since that is required + if SCHEMA_VERSION not in new_config: + new_config[SCHEMA_VERSION] = "1.0" + # Validate to make sure unit tests are not accidentally testing with a wrong config + validate(new_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) self._config = new_config - boto_mock = MagicMock(name="boto_session") + boto_mock = MagicMock(name="boto_session", region_name="us-west-2") session_with_custom_sagemaker_config = sagemaker.Session( boto_session=boto_mock, sagemaker_client=MagicMock(), diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index 22c9f88249..6c85773709 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -17,21 +17,11 @@ import pytest from mock import Mock, patch from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator, PipelineModel -from sagemaker.config import ( - RESOURCE_CONFIG, - AUTO_ML, - AUTO_ML_JOB_CONFIG, - SECURITY_CONFIG, - VOLUME_KMS_KEY_ID, - KMS_KEY_ID, - OUTPUT_DATA_CONFIG, - ROLE_ARN, - SAGEMAKER, -) from sagemaker.predictor import Predictor from sagemaker.session_settings import SessionSettings from sagemaker.workflow.functions import Join +from tests.unit import SAGEMAKER_CONFIG_AUTO_ML, SAGEMAKER_CONFIG_TRAINING_JOB MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -316,38 +306,18 @@ def test_auto_ml_without_role_parameter(sagemaker_session): ) -def _config_override_mock(key, default_value=None): - kms_key_id_path = "{}.{}.{}.{}".format(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG, KMS_KEY_ID) - volume_kms_key_id_path = "{}.{}.{}.{}.{}".format( - SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID - ) - role_arn_path = "{}.{}.{}".format(SAGEMAKER, AUTO_ML, ROLE_ARN) - vpc_config_path = "{}.{}.{}.{}.{}".format( - SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, "VpcConfig" - ) - if key == role_arn_path: - return "ConfigRoleArn" - elif key == kms_key_id_path: - return "ConfigKmsKeyId" - elif key == volume_kms_key_id_path: - return "ConfigVolumeKmsKeyId" - elif key == vpc_config_path: - return {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} - return default_value +def test_framework_initialization_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_AUTO_ML - -def test_framework_initialization_with_defaults(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", side_effect=_config_override_mock - ) auto_ml = AutoML( target_attribute_name=TARGET_ATTRIBUTE_NAME, - sagemaker_session=sagemaker_session, + sagemaker_session=sagemaker_config_session, ) - assert auto_ml.role == "ConfigRoleArn" - assert auto_ml.output_kms_key == "ConfigKmsKeyId" - assert auto_ml.volume_kms_key == "ConfigVolumeKmsKeyId" - assert auto_ml.vpc_config == {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} + assert auto_ml.role == "arn:aws:iam::111111111111:role/ConfigRole" + assert auto_ml.output_kms_key == "configKmsKeyId" + assert auto_ml.volume_kms_key == "TestKmsKeyId" + assert auto_ml.vpc_config == {"SecurityGroupIds": ["sg-123"], "Subnets": ["subnets-123"]} + assert auto_ml.encrypt_inter_container_traffic is True def test_auto_ml_default_channel_name(sagemaker_session): @@ -856,45 +826,38 @@ def test_deploy_optional_args(candidate_estimator, sagemaker_session, candidate_ ) -def _config_override_mock_for_candidate_estimator(key, default_value=None): - from sagemaker.config import TRAINING_JOB - - vpc_config_path = "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, "VpcConfig") - volume_kms_key_id_path = "{}.{}.{}.{}".format( - SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG, VOLUME_KMS_KEY_ID - ) - if key == volume_kms_key_id_path: - return "ConfigVolumeKmsKeyId" - elif key == vpc_config_path: - return {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} - return default_value +def test_candidate_estimator_fit_initialization_with_sagemaker_config_injection( + sagemaker_config_session, +): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB + sagemaker_config_session.train = Mock() + sagemaker_config_session.transform = Mock() -def test_candidate_estimator_fit_initialization_with_defaults(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_config_override_mock_for_candidate_estimator, - ) - desc_training_job_response = TRAINING_JOB + desc_training_job_response = copy.deepcopy(TRAINING_JOB) del desc_training_job_response["VpcConfig"] - sagemaker_session.sagemaker_client.describe_training_job = Mock( + del desc_training_job_response["OutputDataConfig"]["KmsKeyId"] + + sagemaker_config_session.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=desc_training_job_response ) - candidate_estimator = CandidateEstimator(CANDIDATE_DICT, sagemaker_session=sagemaker_session) + candidate_estimator = CandidateEstimator( + CANDIDATE_DICT, sagemaker_session=sagemaker_config_session + ) candidate_estimator._check_all_job_finished = Mock( name="_check_all_job_finished", return_value=True ) inputs = DEFAULT_S3_INPUT_DATA candidate_estimator.fit(inputs) - sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert "train" in sagemaker_call_names - index_of_train = sagemaker_call_names.index("train") - actual_train_args = sagemaker_session.method_calls[index_of_train][2] - assert actual_train_args["vpc_config"] == { - "SecurityGroupIds": ["sg-config"], - "Subnets": ["subnet-config"], - } - assert actual_train_args["resource_config"]["VolumeKmsKeyId"] == "ConfigVolumeKmsKeyId" + + for train_call in sagemaker_config_session.train.call_args_list: + train_args = train_call.kwargs + assert train_args["vpc_config"] == { + "SecurityGroupIds": ["sg-123"], + "Subnets": ["subnets-123"], + } + assert train_args["resource_config"]["VolumeKmsKeyId"] == "volumekey" + assert train_args["encrypt_inter_container_traffic"] is True def test_candidate_estimator_get_steps(sagemaker_session): diff --git a/tests/unit/sagemaker/config/conftest.py b/tests/unit/sagemaker/config/conftest.py index ca12a1e14b..fc517016c2 100644 --- a/tests/unit/sagemaker/config/conftest.py +++ b/tests/unit/sagemaker/config/conftest.py @@ -146,7 +146,6 @@ def valid_endpointconfig_config(): "DataCaptureConfig": {"KmsKeyId": "somekmskey2"}, "KmsKeyId": "somekmskey3", "ProductionVariants": [{"CoreDumpConfig": {"KmsKeyId": "somekmskey4"}}], - "ShadowProductionVariants": [{"CoreDumpConfig": {"KmsKeyId": "somekmskey5"}}], } diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py index 9ea083cd61..0803e96e08 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -32,11 +32,8 @@ IngestionError, ) from sagemaker.feature_store.inputs import FeatureParameter -from sagemaker.session import ( - FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, - FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, - FEATURE_GROUP_ROLE_ARN_PATH, -) + +from tests.unit import SAGEMAKER_CONFIG_FEATURE_GROUP class PicklableMock(Mock): @@ -111,23 +108,14 @@ def test_feature_group_create_without_role( ) -def _config_override_mock(key, default_value=None): - if key == FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH: - return "OnlineConfigKmsKeyId" - elif key == FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH: - return "OfflineConfigKmsKeyId" - elif key == FEATURE_GROUP_ROLE_ARN_PATH: - return "ConfigRoleArn" - return default_value - - def test_feature_store_create_with_config_injection( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri + sagemaker_config_session, role_arn, feature_group_dummy_definitions, s3_uri ): - sagemaker_session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", side_effect=_config_override_mock - ) - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_FEATURE_GROUP + sagemaker_config_session.create_feature_group = Mock() + + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_config_session) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -135,12 +123,12 @@ def test_feature_store_create_with_config_injection( event_time_feature_name="feature2", enable_online_store=True, ) - sagemaker_session_mock.create_feature_group.assert_called_with( + sagemaker_config_session.create_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", record_identifier_name="feature1", event_time_feature_name="feature2", feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn="ConfigRoleArn", + role_arn="arn:aws:iam::111111111111:role/ConfigRole", description=None, tags=None, online_store_config={ diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index 6d420214ea..def116ac53 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -59,6 +59,12 @@ def sagemaker_session(): name="get_sagemaker_config_override", side_effect=lambda key, default_value=None: default_value, ) + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session_mock diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index 383edb05c9..ceb2b5e997 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -40,6 +40,7 @@ from sagemaker.network import NetworkConfig from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat, DatasetFormat +from tests.unit import SAGEMAKER_CONFIG_MONITORING_SCHEDULE REGION = "us-west-2" BUCKET_NAME = "mybucket" @@ -889,6 +890,47 @@ def _test_data_quality_batch_transform_monitor_create_schedule( ) +def test_data_quality_batch_transform_monitor_create_schedule_with_sagemaker_config_injection( + data_quality_monitor, + sagemaker_config_session, +): + + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE + + sagemaker_config_session.sagemaker_client.create_monitoring_schedule = Mock() + data_quality_monitor.sagemaker_session = sagemaker_config_session + + # for batch transform input + data_quality_monitor.create_monitoring_schedule( + batch_transform_input=BatchTransformInput( + data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI, + destination=SCHEDULE_DESTINATION, + dataset_format=MonitoringDatasetFormat.csv(header=False), + ), + record_preprocessor_script=PREPROCESSOR_URI, + post_analytics_processor_script=POSTPROCESSOR_URI, + output_s3_uri=OUTPUT_S3_URI, + constraints=CONSTRAINTS, + statistics=STATISTICS, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_HOURLY, + ) + + sagemaker_config_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + MonitoringScheduleName=SCHEDULE_NAME, + MonitoringScheduleConfig={ + "MonitoringJobDefinitionName": data_quality_monitor.job_definition_name, + "MonitoringType": "DataQuality", + "ScheduleConfig": {"ScheduleExpression": CRON_HOURLY}, + }, + # new tags appended from config + Tags=[ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "some-tag", "Value": "value-for-tag"}, + ], + ) + + def _test_data_quality_monitor_update_schedule(data_quality_monitor, sagemaker_session): # update schedule sagemaker_session.describe_monitoring_schedule = MagicMock() diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 54e38fda3c..15e7c2ecc7 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -78,6 +78,12 @@ def sagemaker_session(): name="get_sagemaker_config_override", side_effect=lambda key, default_value=None: default_value, ) + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return sms diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 95e52b18e3..3e4b301393 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -18,6 +18,7 @@ import os import subprocess from time import sleep + from sagemaker.fw_utils import UploadedCode @@ -57,6 +58,7 @@ from sagemaker.workflow.parameters import ParameterString, ParameterBoolean from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.xgboost.estimator import XGBoost +from tests.unit import SAGEMAKER_CONFIG_TRAINING_JOB MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -374,50 +376,47 @@ def test_default_value_of_enable_network_isolation(sagemaker_session): assert framework.enable_network_isolation() is False -def _config_override_mock(key, default_value=None): - from sagemaker.session import ( - TRAINING_JOB_SUBNETS_PATH, - TRAINING_JOB_SECURITY_GROUP_IDS_PATH, - TRAINING_JOB_KMS_KEY_ID_PATH, - TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - TRAINING_JOB_ROLE_ARN_PATH, - TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, - ) - - if key == TRAINING_JOB_ROLE_ARN_PATH: - return "ConfigRoleArn" - elif key == TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH: - return True - elif key == TRAINING_JOB_KMS_KEY_ID_PATH: - return "ConfigKmsKeyId" - elif key == TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH: - return "ConfigVolumeKmsKeyId" - elif key == TRAINING_JOB_SECURITY_GROUP_IDS_PATH: - return ["sg-config"] - elif key == TRAINING_JOB_SUBNETS_PATH: - return ["subnet-config"] - return default_value +def test_framework_initialization_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB -def test_framework_initialization_with_defaults(sagemaker_session): - - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", side_effect=_config_override_mock - ) framework = DummyFramework( entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session, + sagemaker_session=sagemaker_config_session, instance_groups=[ InstanceGroup("group1", "ml.c4.xlarge", 1), InstanceGroup("group2", "ml.m4.xlarge", 2), ], ) - assert framework.role == "ConfigRoleArn" - assert framework.enable_network_isolation() - assert framework.output_kms_key == "ConfigKmsKeyId" - assert framework.volume_kms_key == "ConfigVolumeKmsKeyId" - assert framework.security_group_ids == ["sg-config"] - assert framework.subnets == ["subnet-config"] + assert framework.role == "arn:aws:iam::111111111111:role/ConfigRole" + assert framework.enable_network_isolation() is True + assert framework.encrypt_inter_container_traffic is True + assert framework.output_kms_key == "TestKms" + assert framework.volume_kms_key == "volumekey" + assert framework.security_group_ids == ["sg-123"] + assert framework.subnets == ["subnets-123"] + + +def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_config_session): + + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB + + estimator = Estimator( + image_uri="some-image", + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.p3.16xlarge", 2), + ], + sagemaker_session=sagemaker_config_session, + base_job_name="base_job_name", + ) + assert estimator.role == "arn:aws:iam::111111111111:role/ConfigRole" + assert estimator.enable_network_isolation() is True + assert estimator.encrypt_inter_container_traffic is True + assert estimator.output_kms_key == "TestKms" + assert estimator.volume_kms_key == "volumekey" + assert estimator.security_group_ids == ["sg-123"] + assert estimator.subnets == ["subnets-123"] def test_framework_with_heterogeneous_cluster(sagemaker_session): diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 53f08ba893..b11926df44 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -12,20 +12,18 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import copy + import pytest +from botocore.utils import merge_dicts from mock import Mock, patch from sagemaker.model import FrameworkModel from sagemaker.pipeline import PipelineModel from sagemaker.predictor import Predictor -from sagemaker.session import ( - ENDPOINT_CONFIG_KMS_KEY_ID_PATH, - MODEL_ENABLE_NETWORK_ISOLATION_PATH, - MODEL_EXECUTION_ROLE_ARN_PATH, - MODEL_VPC_CONFIG_PATH, -) from sagemaker.session_settings import SessionSettings from sagemaker.sparkml import SparkMLModel +from tests.unit import SAGEMAKER_CONFIG_MODEL, SAGEMAKER_CONFIG_ENDPOINT_CONFIG ENTRY_POINT = "blah.py" MODEL_DATA_1 = "s3://bucket/model_1.tar.gz" @@ -75,10 +73,18 @@ def sagemaker_session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.get_sagemaker_config_override = Mock( name="get_sagemaker_config_override", side_effect=lambda key, default_value=None: default_value, ) + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + return sms @@ -304,39 +310,30 @@ def test_pipeline_model_without_role(sagemaker_session): PipelineModel([], sagemaker_session=sagemaker_session) -def _config_override_mock(key, default_value=None): - if key == ENDPOINT_CONFIG_KMS_KEY_ID_PATH: - return "ConfigKmsKeyId" - elif key == MODEL_ENABLE_NETWORK_ISOLATION_PATH: - return True - elif key == MODEL_EXECUTION_ROLE_ARN_PATH: - return "ConfigRoleArn" - elif key == MODEL_VPC_CONFIG_PATH: - return {"SecurityGroupIds": ["sg-config"], "Subnets": ["subnet-config"]} - return default_value - - @patch("tarfile.open") @patch("time.strftime", return_value=TIMESTAMP) -def test_pipeline_model_with_config_injection(tfo, time, sagemaker_session): - framework_model = DummyFrameworkModel(sagemaker_session) +def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_session): + combined_config = copy.deepcopy(SAGEMAKER_CONFIG_MODEL) + endpoint_config = copy.deepcopy(SAGEMAKER_CONFIG_ENDPOINT_CONFIG) + merge_dicts(combined_config, endpoint_config) + sagemaker_config_session.sagemaker_config.config = combined_config + sagemaker_config_session.endpoint_from_production_variants = Mock() + + framework_model = DummyFrameworkModel(sagemaker_config_session) sparkml_model = SparkMLModel( - model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session - ) - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", side_effect=_config_override_mock + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_config_session ) pipeline_model = PipelineModel( - [framework_model, sparkml_model], sagemaker_session=sagemaker_session + [framework_model, sparkml_model], sagemaker_session=sagemaker_config_session ) - assert pipeline_model.role == "ConfigRoleArn" + assert pipeline_model.role == "arn:aws:iam::111111111111:role/ConfigRole" assert pipeline_model.vpc_config == { - "SecurityGroupIds": ["sg-config"], - "Subnets": ["subnet-config"], + "SecurityGroupIds": ["sg-123"], + "Subnets": ["subnets-123"], } - assert pipeline_model.enable_network_isolation + assert pipeline_model.enable_network_isolation is True pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) - sagemaker_session.endpoint_from_production_variants.assert_called_with( + sagemaker_config_session.endpoint_from_production_variants.assert_called_with( name="mi-1-2017-10-10-14-14-15", production_variants=[ { diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 4758ebd3f9..43dab585db 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -12,6 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import copy + import pytest from mock import Mock, patch, MagicMock from packaging import version @@ -30,14 +32,6 @@ ScriptProcessor, ProcessingJob, ) -from sagemaker.session import ( - PROCESSING_JOB_SUBNETS_PATH, - PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, - PROCESSING_JOB_ROLE_ARN_PATH, - PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - PROCESSING_JOB_KMS_KEY_ID_PATH, - PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, -) from sagemaker.session_settings import SessionSettings from sagemaker.spark.processing import PySparkProcessor from sagemaker.sklearn.processing import SKLearnProcessor @@ -51,6 +45,7 @@ from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.workflow.functions import Join from sagemaker.workflow.execution_variables import ExecutionVariables +from tests.unit import SAGEMAKER_CONFIG_PROCESSING_JOB BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -99,6 +94,12 @@ def sagemaker_session(): name="get_sagemaker_config_override", side_effect=lambda key, default_value=None: default_value, ) + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) return session_mock @@ -637,22 +638,6 @@ def test_script_processor_with_required_parameters(exists_mock, isfile_mock, sag sagemaker_session.process.assert_called_with(**expected_args) -def _config_override_mock(key, default_value=None): - if key == PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH: - return True - if key == PROCESSING_JOB_ROLE_ARN_PATH: - return "arn:aws:iam::012345678901:role/ConfigRoleArn" - elif key == PROCESSING_JOB_KMS_KEY_ID_PATH: - return "ConfigKmsKeyId" - elif key == PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH: - return "ConfigVolumeKmsKeyId" - elif key == PROCESSING_JOB_SECURITY_GROUP_IDS_PATH: - return ["sg-config"] - elif key == PROCESSING_JOB_SUBNETS_PATH: - return ["subnet-config"] - return default_value - - @patch("os.path.exists", return_value=True) @patch("os.path.isfile", return_value=True) def test_script_processor_without_role(exists_mock, isfile_mock, sagemaker_session): @@ -681,13 +666,19 @@ def test_script_processor_without_role(exists_mock, isfile_mock, sagemaker_sessi @patch("os.path.exists", return_value=True) @patch("os.path.isfile", return_value=True) -def test_script_processor_with_some_parameters_from_config( - exists_mock, isfile_mock, sagemaker_session +def test_script_processor_with_sagemaker_config_injection( + exists_mock, isfile_mock, sagemaker_config_session ): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", side_effect=_config_override_mock + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_PROCESSING_JOB + + sagemaker_config_session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sagemaker_config_session.upload_data = Mock(name="upload_data", return_value=MOCKED_S3_URI) + sagemaker_config_session.wait_for_processing_job = MagicMock( + name="wait_for_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) - sagemaker_session.expand_role = Mock(name="expand_role", side_effect=lambda a: a) + sagemaker_config_session.process = Mock() + sagemaker_config_session.expand_role = Mock(name="expand_role", side_effect=lambda a: a) + processor = ScriptProcessor( image_uri=CUSTOM_IMAGE_URI, command=["python3"], @@ -698,10 +689,7 @@ def test_script_processor_with_some_parameters_from_config( base_job_name="my_sklearn_processor", env={"my_env_variable": "my_env_variable_value"}, tags=[{"Key": "my-tag", "Value": "my-tag-value"}], - network_config=NetworkConfig( - encrypt_inter_container_traffic=True, - ), - sagemaker_session=sagemaker_session, + sagemaker_session=sagemaker_config_session, ) processor.run( code="/local/path/to/processing_code.py", @@ -713,16 +701,18 @@ def test_script_processor_with_some_parameters_from_config( job_name="my_job_name", experiment_config={"ExperimentName": "AnExperiment"}, ) - expected_args = _get_expected_args_all_parameters(processor._current_job_name) - expected_args["resources"]["ClusterConfig"]["VolumeKmsKeyId"] = "ConfigVolumeKmsKeyId" - expected_args["output_config"]["KmsKeyId"] = "ConfigKmsKeyId" - expected_args["role_arn"] = "arn:aws:iam::012345678901:role/ConfigRoleArn" + expected_args = copy.deepcopy(_get_expected_args_all_parameters(processor._current_job_name)) + expected_args["resources"]["ClusterConfig"]["VolumeKmsKeyId"] = "testVolumeKmsKeyId" + expected_args["output_config"]["KmsKeyId"] = "testKmsKeyId" + expected_args["role_arn"] = "arn:aws:iam::111111111111:role/ConfigRole" expected_args["network_config"]["VpcConfig"] = { - "SecurityGroupIds": ["sg-config"], - "Subnets": ["subnet-config"], + "SecurityGroupIds": ["sg-123"], + "Subnets": ["subnets-123"], } + expected_args["network_config"]["EnableNetworkIsolation"] = True + expected_args["network_config"]["EnableInterContainerTrafficEncryption"] = False - sagemaker_session.process.assert_called_with(**expected_args) + sagemaker_config_session.process.assert_called_with(**expected_args) assert "my_job_name" in processor._current_job_name diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b69d6881aa..cf6f2c0026 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -27,16 +27,6 @@ import sagemaker from sagemaker import TrainingInput, Session, get_execution_role, exceptions from sagemaker.async_inference import AsyncInferenceConfig -from sagemaker.config import ( - SAGEMAKER, - TRANSFORM_JOB, - DATA_CAPTURE_CONFIG, - TRANSFORM_OUTPUT, - TRANSFORM_RESOURCES, - KMS_KEY_ID, - VOLUME_KMS_KEY_ID, - TAGS, -) from sagemaker.session import ( _tuning_job_status, _transform_job_status, @@ -47,6 +37,19 @@ ) from sagemaker.tuner import WarmStartConfig, WarmStartTypes from sagemaker.inputs import BatchDataCaptureConfig +from tests.unit import ( + SAGEMAKER_CONFIG_MONITORING_SCHEDULE, + SAGEMAKER_CONFIG_COMPILATION_JOB, + SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB, + SAGEMAKER_CONFIG_ENDPOINT_CONFIG, + SAGEMAKER_CONFIG_AUTO_ML, + SAGEMAKER_CONFIG_MODEL_PACKAGE, + SAGEMAKER_CONFIG_FEATURE_GROUP, + SAGEMAKER_CONFIG_PROCESSING_JOB, + SAGEMAKER_CONFIG_TRAINING_JOB, + SAGEMAKER_CONFIG_TRANSFORM_JOB, + SAGEMAKER_CONFIG_MODEL, +) STATIC_HPs = {"feature_dim": "784"} @@ -259,48 +262,8 @@ def test_process(boto_session): session.sagemaker_client.create_processing_job.assert_called_with(**expected_request) -def _sagemaker_config_override_mock_for_process(key, default_value=None): - from sagemaker.session import ( - PROCESSING_JOB_ROLE_ARN_PATH, - PROCESSING_JOB_NETWORK_CONFIG_PATH, - PROCESSING_OUTPUT_CONFIG_PATH, - PROCESSING_JOB_PROCESSING_RESOURCES_PATH, - PROCESSING_JOB_INPUTS_PATH, - ) - - if key is PROCESSING_JOB_ROLE_ARN_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is PROCESSING_JOB_NETWORK_CONFIG_PATH: - return { - "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, - "EnableNetworkIsolation": True, - } - elif key is PROCESSING_OUTPUT_CONFIG_PATH: - return {"KmsKeyId": "testKmsKeyId"} - elif key is PROCESSING_JOB_PROCESSING_RESOURCES_PATH: - return {"ClusterConfig": {"VolumeKmsKeyId": "testVolumeKmsKeyId"}} - elif key is PROCESSING_JOB_INPUTS_PATH: - return [ - { - "DatasetDefinition": { - "AthenaDatasetDefinition": {"KmsKeyId": "AthenaKmsKeyId"}, - "RedshiftDatasetDefinition": { - "KmsKeyId": "RedshiftKmsKeyId", - "ClusterRoleArn": "clusterrole", - }, - } - } - ] - - return default_value - - -def test_create_process_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_process, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override +def test_create_process_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_PROCESSING_JOB processing_inputs = [ { @@ -313,6 +276,14 @@ def test_create_process_with_configs(sagemaker_session): "S3DataDistributionType": "FullyReplicated", "S3CompressionType": "None", }, + # DatasetDefinition and (AthenaDatasetDefinition or RedshiftDatasetDefinition) need + # to be present for config injection. Included both AthenaDatasetDefinition and + # RedshiftDatasetDefinition at the same time to test injection (even though this API + # call would fail normally) + "DatasetDefinition": { + "AthenaDatasetDefinition": {}, + "RedshiftDatasetDefinition": {}, + }, } ] output_config = { @@ -357,32 +328,34 @@ def test_create_process_with_configs(sagemaker_session): "stopping_condition": {"MaxRuntimeInSeconds": 3600}, "app_specification": app_specification, "environment": {"my_env_variable": 20}, - "tags": [{"Name": "my-tag", "Value": "my-tag-value"}], "experiment_config": {"ExperimentName": "AnExperiment"}, } - sagemaker_session.process(**process_request_args) + sagemaker_config_session.process(**process_request_args) - expected_request = { - "ProcessingJobName": job_name, - "ProcessingResources": resource_config, - "AppSpecification": app_specification, - "RoleArn": "arn:aws:iam::111111111111:role/ConfigRole", - "ProcessingInputs": processing_inputs, - "ProcessingOutputConfig": output_config, - "Environment": {"my_env_variable": 20}, - "NetworkConfig": { - "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, - "EnableNetworkIsolation": True, - }, - "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, - "Tags": [{"Name": "my-tag", "Value": "my-tag-value"}], - "ExperimentConfig": {"ExperimentName": "AnExperiment"}, - } + expected_request = copy.deepcopy( + { + "ProcessingJobName": job_name, + "ProcessingResources": resource_config, + "AppSpecification": app_specification, + "RoleArn": "arn:aws:iam::111111111111:role/ConfigRole", + "ProcessingInputs": processing_inputs, + "ProcessingOutputConfig": output_config, + "Environment": {"my_env_variable": 20}, + "NetworkConfig": { + "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + "EnableNetworkIsolation": True, + "EnableInterContainerTrafficEncryption": False, + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, + "Tags": TAGS, + "ExperimentConfig": {"ExperimentName": "AnExperiment"}, + } + ) expected_request["ProcessingInputs"][0]["DatasetDefinition"] = { "AthenaDatasetDefinition": {"KmsKeyId": "AthenaKmsKeyId"}, "RedshiftDatasetDefinition": { "KmsKeyId": "RedshiftKmsKeyId", - "ClusterRoleArn": "clusterrole", + "ClusterRoleArn": "arn:aws:iam::111111111111:role/ClusterRole", }, } expected_request["ProcessingOutputConfig"]["KmsKeyId"] = "testKmsKeyId" @@ -390,7 +363,9 @@ def test_create_process_with_configs(sagemaker_session): "VolumeKmsKeyId" ] = "testVolumeKmsKeyId" - sagemaker_session.sagemaker_client.create_processing_job.assert_called_with(**expected_request) + sagemaker_config_session.sagemaker_client.create_processing_job.assert_called_with( + **expected_request + ) def mock_exists(filepath_to_mock, exists_result): @@ -862,7 +837,7 @@ def test_training_input_all_arguments(): MAX_SIZE = 30 MAX_TIME = 3 * 60 * 60 JOB_NAME = "jobname" -EXAMPLE_TAGS = [{"Key": "some-tag", "Value": "value-for-tag"}] +TAGS = [{"Key": "some-tag", "Value": "value-for-tag"}] VPC_CONFIG = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} METRIC_DEFINITONS = [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}] EXPERIMENT_CONFIG = { @@ -1550,34 +1525,8 @@ def test_stop_tuning_job_client_error(sagemaker_session): ) -def _sagemaker_config_override_mock_for_train(key, default_value=None): - from sagemaker.session import ( - TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - TRAINING_JOB_VPC_CONFIG_PATH, - TRAINING_JOB_RESOURCE_CONFIG_PATH, - TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, - TRAINING_JOB_ROLE_ARN_PATH, - ) - - if key is TRAINING_JOB_ROLE_ARN_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is TRAINING_JOB_VPC_CONFIG_PATH: - return {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]} - elif key is TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH: - return {"KmsKeyId": "TestKms"} - elif key is TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH: - return True - elif key is TRAINING_JOB_RESOURCE_CONFIG_PATH: - return {"VolumeKmsKeyId": "volumekey"} - return default_value - - -def test_train_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_train, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override +def test_train_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB in_config = [ { @@ -1610,7 +1559,7 @@ def test_train_with_configs(sagemaker_session): }, } - sagemaker_session.train( + sagemaker_config_session.train( image_uri=IMAGE, input_mode="File", input_config=in_config, @@ -1619,9 +1568,7 @@ def test_train_with_configs(sagemaker_session): resource_config=resource_config, hyperparameters=hyperparameters, stop_condition=stop_cond, - tags=EXAMPLE_TAGS, metric_definitions=METRIC_DEFINITONS, - encrypt_inter_container_traffic=True, use_spot_instances=True, checkpoint_s3_uri="s3://mybucket/checkpoints/", checkpoint_local_path="/tmp/checkpoints", @@ -1631,14 +1578,14 @@ def test_train_with_configs(sagemaker_session): training_image_config=TRAINING_IMAGE_CONFIG, ) - _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] + _, _, actual_train_args = sagemaker_config_session.sagemaker_client.method_calls[0] assert actual_train_args["VpcConfig"] == { "Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"], } assert actual_train_args["HyperParameters"] == hyperparameters - assert actual_train_args["Tags"] == EXAMPLE_TAGS + assert actual_train_args["Tags"] == TAGS assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS assert actual_train_args["AlgorithmSpecification"]["EnableSageMakerMetricsTimeSeries"] is True assert actual_train_args["EnableInterContainerTrafficEncryption"] is True @@ -1707,7 +1654,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): vpc_config=VPC_CONFIG, hyperparameters=hyperparameters, stop_condition=stop_cond, - tags=EXAMPLE_TAGS, + tags=TAGS, metric_definitions=METRIC_DEFINITONS, encrypt_inter_container_traffic=True, use_spot_instances=True, @@ -1723,7 +1670,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): assert actual_train_args["VpcConfig"] == VPC_CONFIG assert actual_train_args["HyperParameters"] == hyperparameters - assert actual_train_args["Tags"] == EXAMPLE_TAGS + assert actual_train_args["Tags"] == TAGS assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS assert actual_train_args["AlgorithmSpecification"]["EnableSageMakerMetricsTimeSeries"] is True assert actual_train_args["EnableInterContainerTrafficEncryption"] is True @@ -1738,18 +1685,8 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_session): - # Config to test injection for - sagemaker_config_session.sagemaker_config.config = { - SAGEMAKER: { - TRANSFORM_JOB: { - DATA_CAPTURE_CONFIG: {KMS_KEY_ID: "jobKmsKeyId"}, - TRANSFORM_OUTPUT: {KMS_KEY_ID: "outputKmsKeyId"}, - TRANSFORM_RESOURCES: {VOLUME_KMS_KEY_ID: "volumeKmsKeyId"}, - TAGS: EXAMPLE_TAGS, - } - } - } + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRANSFORM_JOB model_name = "my-model" in_config = { @@ -1773,7 +1710,7 @@ def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_s "TransformResources": resource_config, "DataProcessing": data_processing, "DataCaptureConfig": data_capture_config._to_request_dict(), - "Tags": EXAMPLE_TAGS, + "Tags": TAGS, } ) expected_args["DataCaptureConfig"]["KmsKeyId"] = "jobKmsKeyId" @@ -1798,7 +1735,7 @@ def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_s resource_config=resource_config, experiment_config=None, model_client_config=None, - tags=EXAMPLE_TAGS, + tags=TAGS, data_processing=data_processing, batch_data_capture_config=data_capture_config, ) @@ -1877,7 +1814,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): resource_config={}, experiment_config=EXPERIMENT_CONFIG, model_client_config=MODEL_CLIENT_CONFIG, - tags=EXAMPLE_TAGS, + tags=TAGS, data_processing=None, batch_data_capture_config=batch_data_capture_config, ) @@ -1887,7 +1824,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): assert actual_args["MaxConcurrentTransforms"] == max_concurrent_transforms assert actual_args["MaxPayloadInMB"] == max_payload assert actual_args["Environment"] == env - assert actual_args["Tags"] == EXAMPLE_TAGS + assert actual_args["Tags"] == TAGS assert actual_args["ExperimentConfig"] == EXPERIMENT_CONFIG assert actual_args["ModelClientConfig"] == MODEL_CLIENT_CONFIG assert actual_args["DataCaptureConfig"] == batch_data_capture_config._to_request_dict() @@ -2139,44 +2076,25 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_ } -def _sagemaker_config_override_mock_for_model(key, default_value=None): - from sagemaker.session import ( - MODEL_EXECUTION_ROLE_ARN_PATH, - MODEL_VPC_CONFIG_PATH, - MODEL_ENABLE_NETWORK_ISOLATION_PATH, - ) - - if key is MODEL_EXECUTION_ROLE_ARN_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is MODEL_VPC_CONFIG_PATH: - return {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]} - elif key is MODEL_ENABLE_NETWORK_ISOLATION_PATH: - return True - return default_value - +def test_create_model_with_sagemaker_config_injection(sagemaker_config_session): -@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) -def test_create_model_with_configs(expand_container_def, sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_model, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MODEL - sagemaker_session.expand_role = Mock( + sagemaker_config_session.expand_role = Mock( name="expand_role", side_effect=lambda role_name: role_name ) - model = sagemaker_session.create_model( + model = sagemaker_config_session.create_model( MODEL_NAME, container_defs=PRIMARY_CONTAINER, ) assert model == MODEL_NAME - sagemaker_session.sagemaker_client.create_model.assert_called_with( + sagemaker_config_session.sagemaker_client.create_model.assert_called_with( ExecutionRoleArn="arn:aws:iam::111111111111:role/ConfigRole", ModelName=MODEL_NAME, PrimaryContainer=PRIMARY_CONTAINER, VpcConfig={"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, EnableNetworkIsolation=True, + Tags=TAGS, ) @@ -2334,7 +2252,7 @@ def test_create_model_from_job(sagemaker_session): def test_create_model_from_job_with_tags(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT - ims.create_model_from_job(JOB_NAME, tags=EXAMPLE_TAGS) + ims.create_model_from_job(JOB_NAME, tags=TAGS) assert ( call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list @@ -2344,36 +2262,19 @@ def test_create_model_from_job_with_tags(sagemaker_session): ModelName=JOB_NAME, PrimaryContainer=PRIMARY_CONTAINER, VpcConfig=VPC_CONFIG, - Tags=EXAMPLE_TAGS, + Tags=TAGS, ) -def _sagemaker_config_override_mock_for_edge_packaging(key, default_value=None): - from sagemaker.session import ( - EDGE_PACKAGING_ROLE_ARN_PATH, - EDGE_PACKAGING_OUTPUT_CONFIG_PATH, - ) - - if key is EDGE_PACKAGING_ROLE_ARN_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is EDGE_PACKAGING_OUTPUT_CONFIG_PATH: - return {"KmsKeyId": "configKmsKeyId"} - return default_value - - -def test_create_edge_packaging_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_edge_packaging, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override +def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB output_config = {"S3OutputLocation": S3_OUTPUT} - sagemaker_session.package_model_for_edge( + sagemaker_config_session.package_model_for_edge( output_config, ) - sagemaker_session.sagemaker_client.create_edge_packaging_job.assert_called_with( + sagemaker_config_session.sagemaker_client.create_edge_packaging_job.assert_called_with( RoleArn="arn:aws:iam::111111111111:role/ConfigRole", # provided from config OutputConfig={ "S3OutputLocation": S3_OUTPUT, # provided as param @@ -2383,41 +2284,16 @@ def test_create_edge_packaging_with_configs(sagemaker_session): ModelVersion=None, EdgePackagingJobName=None, CompilationJobName=None, + Tags=TAGS, ) -def _sagemaker_config_override_mock_for_monitoring_schedule(key, default_value=None): - from sagemaker.session import ( - MONITORING_JOB_ROLE_ARN_PATH, - MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, - MONITORING_JOB_NETWORK_CONFIG_PATH, - MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, - ) - - if key is MONITORING_JOB_ROLE_ARN_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is MONITORING_JOB_NETWORK_CONFIG_PATH: - return { - "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, - "EnableNetworkIsolation": True, - } - elif key is MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH: - return "configKmsKeyId" - elif key is MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH: - return "configVolumeKmsKeyId" - return default_value - - -def test_create_monitoring_schedule_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_monitoring_schedule, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override +def test_create_monitoring_schedule_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE monitoring_output_config = {"MonitoringOutputs": [{"S3Output": {"S3Uri": S3_OUTPUT}}]} - sagemaker_session.create_monitoring_schedule( + sagemaker_config_session.create_monitoring_schedule( JOB_NAME, schedule_expression=None, statistics_s3_uri=None, @@ -2430,7 +2306,7 @@ def test_create_monitoring_schedule_with_configs(sagemaker_session): image_uri="someimageuri", network_config={"VpcConfig": {"SecurityGroupIds": ["sg-asparam"]}}, ) - sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + sagemaker_config_session.sagemaker_client.create_monitoring_schedule.assert_called_with( MonitoringScheduleName=JOB_NAME, MonitoringScheduleConfig={ "MonitoringJobDefinition": { @@ -2457,47 +2333,30 @@ def test_create_monitoring_schedule_with_configs(sagemaker_session): "SecurityGroupIds": ["sg-asparam"], # provided as param }, "EnableNetworkIsolation": True, # fetched from config + "EnableInterContainerTrafficEncryption": False, # fetched from config }, - } + }, }, + Tags=TAGS, # fetched from config ) -def _sagemaker_config_override_mock_for_compile(key, default_value=None): - from sagemaker.session import ( - COMPILATION_JOB_ROLE_ARN_PATH, - COMPILATION_JOB_OUTPUT_CONFIG_PATH, - COMPILATION_JOB_VPC_CONFIG_PATH, - ) - - if key is COMPILATION_JOB_ROLE_ARN_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is COMPILATION_JOB_VPC_CONFIG_PATH: - return {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]} - elif key is COMPILATION_JOB_OUTPUT_CONFIG_PATH: - return {"KmsKeyId": "TestKms"} - return default_value - - -def test_compile_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_compile, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override +def test_compile_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_COMPILATION_JOB - sagemaker_session.compile_model( + sagemaker_config_session.compile_model( input_model_config={}, output_model_config={"S3OutputLocation": "s3://test"}, job_name="TestJob", ) - sagemaker_session.sagemaker_client.create_compilation_job.assert_called_with( + sagemaker_config_session.sagemaker_client.create_compilation_job.assert_called_with( InputConfig={}, OutputConfig={"S3OutputLocation": "s3://test", "KmsKeyId": "TestKms"}, RoleArn="arn:aws:iam::111111111111:role/ConfigRole", StoppingCondition=None, CompilationJobName="TestJob", VpcConfig={"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, + Tags=TAGS, ) @@ -2563,43 +2422,21 @@ def test_endpoint_from_production_variants(sagemaker_session): ) -def _sagemaker_config_override_mock_for_endpoint_config(key, default_value=None): - from sagemaker.session import ( - ENDPOINT_CONFIG_DATA_CAPTURE_PATH, - ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, - ENDPOINT_CONFIG_KMS_KEY_ID_PATH, - ) - - if key is ENDPOINT_CONFIG_KMS_KEY_ID_PATH: - return "testKmsKeyId" - elif key is ENDPOINT_CONFIG_DATA_CAPTURE_PATH: - return {"KmsKeyId": "testDataCaptureKmsKeyId"} - elif key is ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH: - return [{"CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}}] - return default_value - - -def test_create_enpoint_config_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_endpoint_config, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override +def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG data_capture_config_dict = {"DestinationS3Uri": "s3://test"} - tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] - - sagemaker_session.create_endpoint_config( + # This method does not support ASYNC_INFERENCE_CONFIG or multiple PRODUCTION_VARIANTS + sagemaker_config_session.create_endpoint_config( "endpoint-test", "simple-model", 1, "local", - tags=tags, data_capture_config_dict=data_capture_config_dict, ) - sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName="endpoint-test", ProductionVariants=[ { @@ -2612,8 +2449,88 @@ def test_create_enpoint_config_with_configs(sagemaker_session): } ], DataCaptureConfig={"DestinationS3Uri": "s3://test", "KmsKeyId": "testDataCaptureKmsKeyId"}, - KmsKeyId="testKmsKeyId", - Tags=tags, + KmsKeyId="ConfigKmsKeyId", + Tags=TAGS, + ) + + +def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( + sagemaker_config_session, +): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG + + pvs = [ + sagemaker.production_variant("A", "ml.p2.xlarge"), + sagemaker.production_variant("B", "ml.p2.xlarge"), + sagemaker.production_variant("C", "ml.p2.xlarge"), + ] + existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo" + existing_endpoint_name = "foo" + new_endpoint_name = "new-foo" + sagemaker_config_session.sagemaker_client.describe_endpoint_config.return_value = { + "ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")], + "EndpointConfigArn": existing_endpoint_arn, + "AsyncInferenceConfig": {}, + } + sagemaker_config_session.sagemaker_client.list_tags.return_value = {"Tags": []} + + sagemaker_config_session.create_endpoint_config_from_existing( + existing_endpoint_name, new_endpoint_name, new_production_variants=pvs + ) + + sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName=new_endpoint_name, + ProductionVariants=[ + { + "CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}, + **sagemaker.production_variant("A", "ml.p2.xlarge"), + }, + { + "CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId2"}, + **sagemaker.production_variant("B", "ml.p2.xlarge"), + }, + sagemaker.production_variant("C", "ml.p2.xlarge"), + ], + KmsKeyId="ConfigKmsKeyId", # from config + Tags=TAGS, # from config + AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": "testOutputKmsKeyId"}}, # from config + ) + + +def test_endpoint_from_production_variants_with_sagemaker_config_injection( + sagemaker_config_session, +): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG + + sagemaker_config_session.sagemaker_client.describe_endpoint = Mock( + return_value={"EndpointStatus": "InService"} + ) + pvs = [ + sagemaker.production_variant("A", "ml.p2.xlarge"), + sagemaker.production_variant("B", "ml.p2.xlarge"), + sagemaker.production_variant("C", "ml.p2.xlarge"), + ] + sagemaker_config_session.endpoint_from_production_variants( + "some-endpoint", + pvs, + data_capture_config_dict={}, + async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(), + ) + + expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict() + expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = "testOutputKmsKeyId" + sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName="some-endpoint", + ProductionVariants=pvs, + Tags=TAGS, # from config + KmsKeyId="ConfigKmsKeyId", # from config + AsyncInferenceConfig=expected_async_inference_config_dict, + DataCaptureConfig={"KmsKeyId": "testDataCaptureKmsKeyId"}, + ) + sagemaker_config_session.sagemaker_client.create_endpoint.assert_called_with( + EndpointConfigName="some-endpoint", + EndpointName="some-endpoint", + Tags=TAGS, # from config ) @@ -3078,27 +2995,6 @@ def test_wait_until_fail_access_denied_after_5_mins(patched_sleep): } -def _sagemaker_config_override_mock_for_auto_ml(key, default_value=None): - from sagemaker.session import ( - AUTO_ML_OUTPUT_CONFIG_PATH, - AUTO_ML_ROLE_ARN_PATH, - AUTO_ML_JOB_CONFIG_PATH, - ) - - if key is AUTO_ML_ROLE_ARN_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is AUTO_ML_JOB_CONFIG_PATH: - return { - "SecurityConfig": { - "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, - "VolumeKmsKeyId": "TestKmsKeyId", - } - } - elif key is AUTO_ML_OUTPUT_CONFIG_PATH: - return {"KmsKeyId": "configKmsKeyId"} - return default_value - - def test_auto_ml_pack_to_request(sagemaker_session): input_config = [ { @@ -3131,12 +3027,8 @@ def test_auto_ml_pack_to_request(sagemaker_session): ) -def test_create_auto_ml_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_auto_ml, - ) - sagemaker_session.get_sagemaker_config_value = sagemaker_session.get_sagemaker_config_override +def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_AUTO_ML input_config = [ { @@ -3156,23 +3048,28 @@ def test_create_auto_ml_with_configs(sagemaker_session): } job_name = JOB_NAME - sagemaker_session.auto_ml(input_config, output_config, auto_ml_job_config, job_name=job_name) - expected_call_args = DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS.copy() + sagemaker_config_session.auto_ml( + input_config, output_config, auto_ml_job_config, job_name=job_name + ) + expected_call_args = copy.deepcopy(DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS) expected_call_args["OutputDataConfig"]["KmsKeyId"] = "configKmsKeyId" expected_call_args["RoleArn"] = "arn:aws:iam::111111111111:role/ConfigRole" - expected_call_args["AutoMLJobConfig"]["SecurityConfig"] = {} + expected_call_args["AutoMLJobConfig"]["SecurityConfig"] = { + "EnableInterContainerTrafficEncryption": True + } expected_call_args["AutoMLJobConfig"]["SecurityConfig"]["VpcConfig"] = { "Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"], } expected_call_args["AutoMLJobConfig"]["SecurityConfig"]["VolumeKmsKeyId"] = "TestKmsKeyId" - sagemaker_session.sagemaker_client.create_auto_ml_job.assert_called_with( + sagemaker_config_session.sagemaker_client.create_auto_ml_job.assert_called_with( AutoMLJobName=expected_call_args["AutoMLJobName"], InputDataConfig=expected_call_args["InputDataConfig"], OutputDataConfig=expected_call_args["OutputDataConfig"], AutoMLJobConfig=expected_call_args["AutoMLJobConfig"], RoleArn=expected_call_args["RoleArn"], GenerateCandidateDefinitionsOnly=False, + Tags=TAGS, ) @@ -3373,31 +3270,9 @@ def test_create_model_package_from_containers_without_model_package_group_name( ) -def _sagemaker_config_override_mock_for_model_package(key, default_value=None): - from sagemaker.session import ( - MODEL_PACKAGE_VALIDATION_ROLE_PATH, - MODEL_PACKAGE_VALIDATION_PROFILES_PATH, - ) +def test_create_model_package_with_sagemaker_config_injection(sagemaker_config_session): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MODEL_PACKAGE - if key is MODEL_PACKAGE_VALIDATION_ROLE_PATH: - return "arn:aws:iam::111111111111:role/ConfigRole" - elif key is MODEL_PACKAGE_VALIDATION_PROFILES_PATH: - return [ - { - "TransformJobDefinition": { - "TransformOutput": {"KmsKeyId": "testKmsKeyId"}, - "TransformResources": {"VolumeKmsKeyId": "testVolumeKmsKeyId"}, - } - } - ] - return default_value - - -def test_create_model_package_with_configs(sagemaker_session): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_model_package, - ) model_package_name = "sagemaker-model-package" containers = ["dummy-container"] content_types = ["application/json"] @@ -3436,7 +3311,7 @@ def test_create_model_package_with_configs(sagemaker_session): domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" - sagemaker_session.create_model_package_from_containers( + sagemaker_config_session.create_model_package_from_containers( containers=containers, content_types=content_types, response_types=response_types, @@ -3455,27 +3330,29 @@ def test_create_model_package_with_configs(sagemaker_session): task=task, validation_specification=validation_specification, ) - expected_args = { - "ModelPackageName": model_package_name, - "InferenceSpecification": { - "Containers": containers, - "SupportedContentTypes": content_types, - "SupportedResponseMIMETypes": response_types, - "SupportedRealtimeInferenceInstanceTypes": inference_instances, - "SupportedTransformInstanceTypes": transform_instances, - }, - "ModelPackageDescription": description, - "ModelMetrics": model_metrics, - "MetadataProperties": metadata_properties, - "CertifyForMarketplace": marketplace_cert, - "ModelApprovalStatus": approval_status, - "DriftCheckBaselines": drift_check_baselines, - "CustomerMetadataProperties": customer_metadata_properties, - "Domain": domain, - "SamplePayloadUrl": sample_payload_url, - "Task": task, - "ValidationSpecification": validation_specification, - } + expected_args = copy.deepcopy( + { + "ModelPackageName": model_package_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + "SupportedTransformInstanceTypes": transform_instances, + }, + "ModelPackageDescription": description, + "ModelMetrics": model_metrics, + "MetadataProperties": metadata_properties, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, + "Domain": domain, + "SamplePayloadUrl": sample_payload_url, + "Task": task, + "ValidationSpecification": validation_specification, + } + ) expected_args["ValidationSpecification"][ "ValidationRole" ] = "arn:aws:iam::111111111111:role/ConfigRole" @@ -3485,7 +3362,10 @@ def test_create_model_package_with_configs(sagemaker_session): expected_args["ValidationSpecification"]["ValidationProfiles"][0]["TransformJobDefinition"][ "TransformOutput" ]["KmsKeyId"] = "testKmsKeyId" - sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) + + sagemaker_config_session.sagemaker_client.create_model_package.assert_called_with( + **expected_args + ) def test_create_model_package_from_containers_all_args(sagemaker_session): @@ -3696,37 +3576,20 @@ def feature_group_dummy_definitions(): return [{"FeatureName": "feature1", "FeatureType": "String"}] -def _sagemaker_config_override_mock_for_feature_store(key, default_value=None): - from sagemaker.session import ( - FEATURE_GROUP_ROLE_ARN_PATH, - FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, - FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, - ) - - if key is FEATURE_GROUP_ROLE_ARN_PATH: - return "config_role" - elif key is FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH: - return {"S3StorageConfig": {"KmsKeyId": "testKmsId"}} - elif key is FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH: - return {"SecurityConfig": {"KmsKeyId": "testKmsId2"}} - return default_value +def test_feature_group_create_with_sagemaker_config_injection( + sagemaker_config_session, feature_group_dummy_definitions +): + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_FEATURE_GROUP -def test_feature_group_create_with_config_injections( - sagemaker_session, feature_group_dummy_definitions -): - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=_sagemaker_config_override_mock_for_feature_store, - ) - sagemaker_session.create_feature_group( + sagemaker_config_session.create_feature_group( feature_group_name="MyFeatureGroup", record_identifier_name="feature1", event_time_feature_name="feature2", feature_definitions=feature_group_dummy_definitions, offline_store_config={"S3StorageConfig": {"S3Uri": "s3://test"}}, ) - assert sagemaker_session.sagemaker_client.create_feature_group.called_with( + assert sagemaker_config_session.sagemaker_client.create_feature_group.called_with( FeatureGroupName="MyFeatureGroup", RecordIdentifierFeatureName="feature1", EventTimeFeatureName="feature2", @@ -3734,6 +3597,7 @@ def test_feature_group_create_with_config_injections( RoleArn="config_role", OnlineStoreConfig={"SecurityConfig": {"KmsKeyId": "testKmsId2"}, "EnableOnlineStore": True}, OfflineStoreConfig={"S3StorageConfig": {"KmsKeyId": "testKmsId", "S3Uri": "s3://test"}}, + Tags=TAGS, ) diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index cf7d519823..da24cbdd14 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -15,25 +15,12 @@ import pytest from mock import MagicMock, Mock, patch, PropertyMock -import sagemaker -from sagemaker.config import ( - SAGEMAKER, - TRANSFORM_JOB, - DATA_CAPTURE_CONFIG, - TRANSFORM_OUTPUT, - TRANSFORM_RESOURCES, - VOLUME_KMS_KEY_ID, - TAGS, -) from sagemaker.transformer import _TransformJob, Transformer from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.inputs import BatchDataCaptureConfig -from sagemaker.session import ( - TRANSFORM_JOB_KMS_KEY_ID_PATH, - TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, - TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, -) + from tests.integ import test_local_mode +from tests.unit import SAGEMAKER_CONFIG_TRANSFORM_JOB ROLE = "DummyRole" REGION = "us-west-2" @@ -123,28 +110,9 @@ def transformer(sagemaker_session): ) -def _config_override_mock(key, default_value=None): - if key == TRANSFORM_OUTPUT_KMS_KEY_ID_PATH: - return "ConfigKmsKeyId" - elif key == TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH: - return "ConfigVolumeKmsKeyId" - elif key == TRANSFORM_JOB_KMS_KEY_ID_PATH: - return "DataCaptureConfigKmsKeyId" - return default_value - - @patch("sagemaker.transformer._TransformJob.start_new") def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = { - SAGEMAKER: { - TRANSFORM_JOB: { - DATA_CAPTURE_CONFIG: {sagemaker.config.KMS_KEY_ID: "DataCaptureConfigKmsKeyId"}, - TRANSFORM_OUTPUT: {sagemaker.config.KMS_KEY_ID: "ConfigKmsKeyId"}, - TRANSFORM_RESOURCES: {VOLUME_KMS_KEY_ID: "ConfigVolumeKmsKeyId"}, - TAGS: [], - } - } - } + sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRANSFORM_JOB transformer = Transformer( MODEL_NAME, @@ -153,8 +121,8 @@ def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_conf output_path=OUTPUT_PATH, sagemaker_session=sagemaker_config_session, ) - assert transformer.volume_kms_key == "ConfigVolumeKmsKeyId" - assert transformer.output_kms_key == "ConfigKmsKeyId" + assert transformer.volume_kms_key == "volumeKmsKeyId" + assert transformer.output_kms_key == "outputKmsKeyId" content_type = "text/csv" compression = "Gzip" @@ -203,7 +171,7 @@ def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_conf batch_data_capture_config, ) # KmsKeyId in BatchDataCapture will be inserted from the config - assert batch_data_capture_config.kms_key_id == "DataCaptureConfigKmsKeyId" + assert batch_data_capture_config.kms_key_id == "jobKmsKeyId" def test_delete_model(sagemaker_session): From 3dac2d3863dce7403399c51b6a78e02a966e103a Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 13 Mar 2023 16:41:57 -0700 Subject: [PATCH 15/40] fix: Sagemaker Config - Removed hard coded config values in the unit tests --- tests/unit/sagemaker/automl/test_auto_ml.py | 46 ++- .../feature_store/test_feature_group.py | 13 +- .../monitor/test_model_monitoring.py | 5 +- tests/unit/test_estimator.py | 72 +++- tests/unit/test_pipeline_model.py | 19 +- tests/unit/test_processing.py | 34 +- tests/unit/test_session.py | 322 +++++++++++++----- tests/unit/test_transformer.py | 22 +- 8 files changed, 403 insertions(+), 130 deletions(-) diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index 6c85773709..26d08784ac 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -313,11 +313,28 @@ def test_framework_initialization_with_sagemaker_config_injection(sagemaker_conf target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_config_session, ) - assert auto_ml.role == "arn:aws:iam::111111111111:role/ConfigRole" - assert auto_ml.output_kms_key == "configKmsKeyId" - assert auto_ml.volume_kms_key == "TestKmsKeyId" - assert auto_ml.vpc_config == {"SecurityGroupIds": ["sg-123"], "Subnets": ["subnets-123"]} - assert auto_ml.encrypt_inter_container_traffic is True + + expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ + "SecurityConfig" + ]["VolumeKmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["OutputDataConfig"][ + "KmsKeyId" + ] + expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ + "SecurityConfig" + ]["VpcConfig"] + expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"][ + "AutoML" + ]["AutoMLJobConfig"]["SecurityConfig"]["EnableInterContainerTrafficEncryption"] + assert auto_ml.role == expected_role_arn + assert auto_ml.output_kms_key == expected_kms_key_id + assert auto_ml.volume_kms_key == expected_volume_kms_key_id + assert auto_ml.vpc_config == expected_vpc_config + assert ( + auto_ml.encrypt_inter_container_traffic + == expected_enable_inter_container_traffic_encryption + ) def test_auto_ml_default_channel_name(sagemaker_session): @@ -849,15 +866,22 @@ def test_candidate_estimator_fit_initialization_with_sagemaker_config_injection( ) inputs = DEFAULT_S3_INPUT_DATA candidate_estimator.fit(inputs) + expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "ResourceConfig" + ]["VolumeKmsKeyId"] + expected_vpc_config = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["VpcConfig"] + expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][ + "TrainingJob" + ]["EnableInterContainerTrafficEncryption"] for train_call in sagemaker_config_session.train.call_args_list: train_args = train_call.kwargs - assert train_args["vpc_config"] == { - "SecurityGroupIds": ["sg-123"], - "Subnets": ["subnets-123"], - } - assert train_args["resource_config"]["VolumeKmsKeyId"] == "volumekey" - assert train_args["encrypt_inter_container_traffic"] is True + assert train_args["vpc_config"] == expected_vpc_config + assert train_args["resource_config"]["VolumeKmsKeyId"] == expected_volume_kms_key_id + assert ( + train_args["encrypt_inter_container_traffic"] + == expected_enable_inter_container_traffic_encryption + ) def test_candidate_estimator_get_steps(sagemaker_session): diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py index 0803e96e08..b13e236ee4 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -123,21 +123,28 @@ def test_feature_store_create_with_config_injection( event_time_feature_name="feature2", enable_online_store=True, ) + expected_offline_store_kms_key_id = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"][ + "OfflineStoreConfig" + ]["S3StorageConfig"]["KmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"]["RoleArn"] + expected_online_store_kms_key_id = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"][ + "OnlineStoreConfig" + ]["SecurityConfig"]["KmsKeyId"] sagemaker_config_session.create_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", record_identifier_name="feature1", event_time_feature_name="feature2", feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn="arn:aws:iam::111111111111:role/ConfigRole", + role_arn=expected_role_arn, description=None, tags=None, online_store_config={ "EnableOnlineStore": True, - "SecurityConfig": {"KmsKeyId": "OnlineConfigKmsKeyId"}, + "SecurityConfig": {"KmsKeyId": expected_online_store_kms_key_id}, }, offline_store_config={ "DisableGlueTableCreation": False, - "S3StorageConfig": {"S3Uri": s3_uri, "KmsKeyId": "OfflineConfigKmsKeyId"}, + "S3StorageConfig": {"S3Uri": s3_uri, "KmsKeyId": expected_offline_store_kms_key_id}, }, ) diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index ceb2b5e997..d837112937 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -915,6 +915,9 @@ def test_data_quality_batch_transform_monitor_create_schedule_with_sagemaker_con monitor_schedule_name=SCHEDULE_NAME, schedule_cron_expression=CRON_HOURLY, ) + expected_tags_from_config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE["SageMaker"][ + "MonitoringSchedule" + ]["Tags"][0] sagemaker_config_session.sagemaker_client.create_monitoring_schedule.assert_called_with( MonitoringScheduleName=SCHEDULE_NAME, @@ -926,7 +929,7 @@ def test_data_quality_batch_transform_monitor_create_schedule_with_sagemaker_con # new tags appended from config Tags=[ {"Key": "tag_key_1", "Value": "tag_value_1"}, - {"Key": "some-tag", "Value": "value-for-tag"}, + expected_tags_from_config, ], ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 3e4b301393..44aeea0d64 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -388,13 +388,35 @@ def test_framework_initialization_with_sagemaker_config_injection(sagemaker_conf InstanceGroup("group2", "ml.m4.xlarge", 2), ], ) - assert framework.role == "arn:aws:iam::111111111111:role/ConfigRole" - assert framework.enable_network_isolation() is True - assert framework.encrypt_inter_container_traffic is True - assert framework.output_kms_key == "TestKms" - assert framework.volume_kms_key == "volumekey" - assert framework.security_group_ids == ["sg-123"] - assert framework.subnets == ["subnets-123"] + expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "ResourceConfig" + ]["VolumeKmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "OutputDataConfig" + ]["KmsKeyId"] + expected_subnets = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["VpcConfig"][ + "Subnets" + ] + expected_security_groups = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "VpcConfig" + ]["SecurityGroupIds"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "EnableNetworkIsolation" + ] + expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][ + "TrainingJob" + ]["EnableInterContainerTrafficEncryption"] + assert framework.role == expected_role_arn + assert framework.enable_network_isolation() == expected_enable_network_isolation + assert ( + framework.encrypt_inter_container_traffic + == expected_enable_inter_container_traffic_encryption + ) + assert framework.output_kms_key == expected_kms_key_id + assert framework.volume_kms_key == expected_volume_kms_key_id + assert framework.security_group_ids == expected_security_groups + assert framework.subnets == expected_subnets def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_config_session): @@ -410,13 +432,35 @@ def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_conf sagemaker_session=sagemaker_config_session, base_job_name="base_job_name", ) - assert estimator.role == "arn:aws:iam::111111111111:role/ConfigRole" - assert estimator.enable_network_isolation() is True - assert estimator.encrypt_inter_container_traffic is True - assert estimator.output_kms_key == "TestKms" - assert estimator.volume_kms_key == "volumekey" - assert estimator.security_group_ids == ["sg-123"] - assert estimator.subnets == ["subnets-123"] + expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "ResourceConfig" + ]["VolumeKmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "OutputDataConfig" + ]["KmsKeyId"] + expected_subnets = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["VpcConfig"][ + "Subnets" + ] + expected_security_groups = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "VpcConfig" + ]["SecurityGroupIds"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "EnableNetworkIsolation" + ] + expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][ + "TrainingJob" + ]["EnableInterContainerTrafficEncryption"] + assert estimator.role == expected_role_arn + assert estimator.enable_network_isolation() == expected_enable_network_isolation + assert ( + estimator.encrypt_inter_container_traffic + == expected_enable_inter_container_traffic_encryption + ) + assert estimator.output_kms_key == expected_kms_key_id + assert estimator.volume_kms_key == expected_volume_kms_key_id + assert estimator.security_group_ids == expected_security_groups + assert estimator.subnets == expected_subnets def test_framework_with_heterogeneous_cluster(sagemaker_session): diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index b11926df44..af2b816d61 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -326,12 +326,17 @@ def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_sessio pipeline_model = PipelineModel( [framework_model, sparkml_model], sagemaker_session=sagemaker_config_session ) - assert pipeline_model.role == "arn:aws:iam::111111111111:role/ConfigRole" - assert pipeline_model.vpc_config == { - "SecurityGroupIds": ["sg-123"], - "Subnets": ["subnets-123"], - } - assert pipeline_model.enable_network_isolation is True + expected_role_arn = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["ExecutionRoleArn"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"][ + "EnableNetworkIsolation" + ] + expected_vpc_config = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["VpcConfig"] + expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ + "KmsKeyId" + ] + assert pipeline_model.role == expected_role_arn + assert pipeline_model.vpc_config == expected_vpc_config + assert pipeline_model.enable_network_isolation == expected_enable_network_isolation pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) sagemaker_config_session.endpoint_from_production_variants.assert_called_with( name="mi-1-2017-10-10-14-14-15", @@ -345,7 +350,7 @@ def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_sessio } ], tags=None, - kms_key="ConfigKmsKeyId", + kms_key=expected_kms_key_id, wait=True, data_capture_config_dict=None, ) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 43dab585db..325edac4b9 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -702,15 +702,31 @@ def test_script_processor_with_sagemaker_config_injection( experiment_config={"ExperimentName": "AnExperiment"}, ) expected_args = copy.deepcopy(_get_expected_args_all_parameters(processor._current_job_name)) - expected_args["resources"]["ClusterConfig"]["VolumeKmsKeyId"] = "testVolumeKmsKeyId" - expected_args["output_config"]["KmsKeyId"] = "testKmsKeyId" - expected_args["role_arn"] = "arn:aws:iam::111111111111:role/ConfigRole" - expected_args["network_config"]["VpcConfig"] = { - "SecurityGroupIds": ["sg-123"], - "Subnets": ["subnets-123"], - } - expected_args["network_config"]["EnableNetworkIsolation"] = True - expected_args["network_config"]["EnableInterContainerTrafficEncryption"] = False + expected_volume_kms_key_id = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"][ + "ProcessingResources" + ]["ClusterConfig"]["VolumeKmsKeyId"] + expected_output_kms_key_id = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"][ + "ProcessingOutputConfig" + ]["KmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"]["RoleArn"] + expected_vpc_config = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"][ + "NetworkConfig" + ]["VpcConfig"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"][ + "ProcessingJob" + ]["NetworkConfig"]["EnableNetworkIsolation"] + expected_enable_inter_containter_traffic_encryption = SAGEMAKER_CONFIG_PROCESSING_JOB[ + "SageMaker" + ]["ProcessingJob"]["NetworkConfig"]["EnableInterContainerTrafficEncryption"] + + expected_args["resources"]["ClusterConfig"]["VolumeKmsKeyId"] = expected_volume_kms_key_id + expected_args["output_config"]["KmsKeyId"] = expected_output_kms_key_id + expected_args["role_arn"] = expected_role_arn + expected_args["network_config"]["VpcConfig"] = expected_vpc_config + expected_args["network_config"]["EnableNetworkIsolation"] = expected_enable_network_isolation + expected_args["network_config"][ + "EnableInterContainerTrafficEncryption" + ] = expected_enable_inter_containter_traffic_encryption sagemaker_config_session.process.assert_called_with(**expected_args) assert "my_job_name" in processor._current_job_name diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index cf6f2c0026..91750907c3 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -331,37 +331,50 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_config_session "experiment_config": {"ExperimentName": "AnExperiment"}, } sagemaker_config_session.process(**process_request_args) + expected_volume_kms_key_id = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"][ + "ProcessingResources" + ]["ClusterConfig"]["VolumeKmsKeyId"] + expected_output_kms_key_id = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"][ + "ProcessingOutputConfig" + ]["KmsKeyId"] + expected_tags = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"]["Tags"] + expected_role_arn = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"]["RoleArn"] + expected_vpc_config = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"][ + "NetworkConfig" + ]["VpcConfig"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"][ + "ProcessingJob" + ]["NetworkConfig"]["EnableNetworkIsolation"] + expected_enable_inter_containter_traffic_encryption = SAGEMAKER_CONFIG_PROCESSING_JOB[ + "SageMaker" + ]["ProcessingJob"]["NetworkConfig"]["EnableInterContainerTrafficEncryption"] expected_request = copy.deepcopy( { "ProcessingJobName": job_name, "ProcessingResources": resource_config, "AppSpecification": app_specification, - "RoleArn": "arn:aws:iam::111111111111:role/ConfigRole", + "RoleArn": expected_role_arn, "ProcessingInputs": processing_inputs, "ProcessingOutputConfig": output_config, "Environment": {"my_env_variable": 20}, "NetworkConfig": { - "VpcConfig": {"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, - "EnableNetworkIsolation": True, - "EnableInterContainerTrafficEncryption": False, + "VpcConfig": expected_vpc_config, + "EnableNetworkIsolation": expected_enable_network_isolation, + "EnableInterContainerTrafficEncryption": expected_enable_inter_containter_traffic_encryption, }, "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, - "Tags": TAGS, + "Tags": expected_tags, "ExperimentConfig": {"ExperimentName": "AnExperiment"}, } ) - expected_request["ProcessingInputs"][0]["DatasetDefinition"] = { - "AthenaDatasetDefinition": {"KmsKeyId": "AthenaKmsKeyId"}, - "RedshiftDatasetDefinition": { - "KmsKeyId": "RedshiftKmsKeyId", - "ClusterRoleArn": "arn:aws:iam::111111111111:role/ClusterRole", - }, - } - expected_request["ProcessingOutputConfig"]["KmsKeyId"] = "testKmsKeyId" + expected_request["ProcessingInputs"][0]["DatasetDefinition"] = SAGEMAKER_CONFIG_PROCESSING_JOB[ + "SageMaker" + ]["ProcessingJob"]["ProcessingInputs"][0]["DatasetDefinition"] + expected_request["ProcessingOutputConfig"]["KmsKeyId"] = expected_output_kms_key_id expected_request["ProcessingResources"]["ClusterConfig"][ "VolumeKmsKeyId" - ] = "testVolumeKmsKeyId" + ] = expected_volume_kms_key_id sagemaker_config_session.sagemaker_client.create_processing_job.assert_called_with( **expected_request @@ -1580,16 +1593,32 @@ def test_train_with_sagemaker_config_injection(sagemaker_config_session): _, _, actual_train_args = sagemaker_config_session.sagemaker_client.method_calls[0] - assert actual_train_args["VpcConfig"] == { - "Subnets": ["subnets-123"], - "SecurityGroupIds": ["sg-123"], - } + expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "ResourceConfig" + ]["VolumeKmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "OutputDataConfig" + ]["KmsKeyId"] + expected_vpc_config = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["VpcConfig"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ + "EnableNetworkIsolation" + ] + expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][ + "TrainingJob" + ]["EnableInterContainerTrafficEncryption"] + expected_tags = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["Tags"] + + assert actual_train_args["VpcConfig"] == expected_vpc_config assert actual_train_args["HyperParameters"] == hyperparameters - assert actual_train_args["Tags"] == TAGS + assert actual_train_args["Tags"] == expected_tags assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS assert actual_train_args["AlgorithmSpecification"]["EnableSageMakerMetricsTimeSeries"] is True - assert actual_train_args["EnableInterContainerTrafficEncryption"] is True - assert actual_train_args["EnableNetworkIsolation"] is True + assert ( + actual_train_args["EnableInterContainerTrafficEncryption"] + == expected_enable_inter_container_traffic_encryption + ) + assert actual_train_args["EnableNetworkIsolation"] == expected_enable_network_isolation assert actual_train_args["EnableManagedSpotTraining"] is True assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/" assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints" @@ -1598,16 +1627,16 @@ def test_train_with_sagemaker_config_injection(sagemaker_config_session): assert ( actual_train_args["AlgorithmSpecification"]["TrainingImageConfig"] == TRAINING_IMAGE_CONFIG ) - assert actual_train_args["RoleArn"] == "arn:aws:iam::111111111111:role/ConfigRole" + assert actual_train_args["RoleArn"] == expected_role_arn assert actual_train_args["ResourceConfig"] == { "InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE, "VolumeSizeInGB": MAX_SIZE, - "VolumeKmsKeyId": "volumekey", + "VolumeKmsKeyId": expected_volume_kms_key_id, } assert actual_train_args["OutputDataConfig"] == { "S3OutputPath": S3_OUTPUT, - "KmsKeyId": "TestKms", + "KmsKeyId": expected_kms_key_id, } @@ -1713,9 +1742,17 @@ def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_s "Tags": TAGS, } ) - expected_args["DataCaptureConfig"]["KmsKeyId"] = "jobKmsKeyId" - expected_args["TransformOutput"]["KmsKeyId"] = "outputKmsKeyId" - expected_args["TransformResources"]["VolumeKmsKeyId"] = "volumeKmsKeyId" + # The following parameters should be fetched from config + expected_tags = SAGEMAKER_CONFIG_TRANSFORM_JOB["SageMaker"]["TransformJob"]["Tags"] + expected_args["DataCaptureConfig"]["KmsKeyId"] = SAGEMAKER_CONFIG_TRANSFORM_JOB["SageMaker"][ + "TransformJob" + ]["DataCaptureConfig"]["KmsKeyId"] + expected_args["TransformOutput"]["KmsKeyId"] = SAGEMAKER_CONFIG_TRANSFORM_JOB["SageMaker"][ + "TransformJob" + ]["TransformOutput"]["KmsKeyId"] + expected_args["TransformResources"]["VolumeKmsKeyId"] = SAGEMAKER_CONFIG_TRANSFORM_JOB[ + "SageMaker" + ]["TransformJob"]["TransformResources"]["VolumeKmsKeyId"] # make sure the original dicts were not modified before config injection assert "KmsKeyId" not in in_config @@ -1735,7 +1772,7 @@ def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_s resource_config=resource_config, experiment_config=None, model_client_config=None, - tags=TAGS, + tags=expected_tags, data_processing=data_processing, batch_data_capture_config=data_capture_config, ) @@ -2087,14 +2124,20 @@ def test_create_model_with_sagemaker_config_injection(sagemaker_config_session): MODEL_NAME, container_defs=PRIMARY_CONTAINER, ) + expected_role_arn = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["ExecutionRoleArn"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"][ + "EnableNetworkIsolation" + ] + expected_vpc_config = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["VpcConfig"] + expected_tags = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["Tags"] assert model == MODEL_NAME sagemaker_config_session.sagemaker_client.create_model.assert_called_with( - ExecutionRoleArn="arn:aws:iam::111111111111:role/ConfigRole", + ExecutionRoleArn=expected_role_arn, ModelName=MODEL_NAME, PrimaryContainer=PRIMARY_CONTAINER, - VpcConfig={"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, - EnableNetworkIsolation=True, - Tags=TAGS, + VpcConfig=expected_vpc_config, + EnableNetworkIsolation=expected_enable_network_isolation, + Tags=expected_tags, ) @@ -2274,17 +2317,24 @@ def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_config_ sagemaker_config_session.package_model_for_edge( output_config, ) + expected_role_arn = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"][ + "RoleArn" + ] + expected_kms_key_id = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"][ + "OutputConfig" + ]["KmsKeyId"] + expected_tags = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["Tags"] sagemaker_config_session.sagemaker_client.create_edge_packaging_job.assert_called_with( - RoleArn="arn:aws:iam::111111111111:role/ConfigRole", # provided from config + RoleArn=expected_role_arn, # provided from config OutputConfig={ "S3OutputLocation": S3_OUTPUT, # provided as param - "KmsKeyId": "configKmsKeyId", # fetched from config + "KmsKeyId": expected_kms_key_id, # fetched from config }, ModelName=None, ModelVersion=None, EdgePackagingJobName=None, CompilationJobName=None, - Tags=TAGS, + Tags=expected_tags, ) @@ -2306,6 +2356,32 @@ def test_create_monitoring_schedule_with_sagemaker_config_injection(sagemaker_co image_uri="someimageuri", network_config={"VpcConfig": {"SecurityGroupIds": ["sg-asparam"]}}, ) + expected_volume_kms_key_id = SAGEMAKER_CONFIG_MONITORING_SCHEDULE["SageMaker"][ + "MonitoringSchedule" + ]["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["MonitoringResources"][ + "ClusterConfig" + ][ + "VolumeKmsKeyId" + ] + expected_role_arn = SAGEMAKER_CONFIG_MONITORING_SCHEDULE["SageMaker"]["MonitoringSchedule"][ + "MonitoringScheduleConfig" + ]["MonitoringJobDefinition"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_MONITORING_SCHEDULE["SageMaker"]["MonitoringSchedule"][ + "MonitoringScheduleConfig" + ]["MonitoringJobDefinition"]["MonitoringOutputConfig"]["KmsKeyId"] + expected_subnets = SAGEMAKER_CONFIG_MONITORING_SCHEDULE["SageMaker"]["MonitoringSchedule"][ + "MonitoringScheduleConfig" + ]["MonitoringJobDefinition"]["NetworkConfig"]["VpcConfig"]["Subnets"] + expected_enable_network_isolation = SAGEMAKER_CONFIG_MONITORING_SCHEDULE["SageMaker"][ + "MonitoringSchedule" + ]["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["NetworkConfig"][ + "EnableNetworkIsolation" + ] + expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_MONITORING_SCHEDULE[ + "SageMaker" + ]["MonitoringSchedule"]["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["NetworkConfig"][ + "EnableInterContainerTrafficEncryption" + ] sagemaker_config_session.sagemaker_client.create_monitoring_schedule.assert_called_with( MonitoringScheduleName=JOB_NAME, MonitoringScheduleConfig={ @@ -2316,24 +2392,25 @@ def test_create_monitoring_schedule_with_sagemaker_config_injection(sagemaker_co "InstanceCount": 1, # provided as param "InstanceType": "ml.m4.xlarge", # provided as param "VolumeSizeInGB": 4, # provided as param - "VolumeKmsKeyId": "configVolumeKmsKeyId", # Fetched from config + "VolumeKmsKeyId": expected_volume_kms_key_id, # Fetched from config } }, "MonitoringAppSpecification": {"ImageUri": "someimageuri"}, # provided as param - "RoleArn": "arn:aws:iam::111111111111:role/ConfigRole", # Fetched from config + "RoleArn": expected_role_arn, # Fetched from config "MonitoringOutputConfig": { "MonitoringOutputs": [ # provided as param {"S3Output": {"S3Uri": "s3://sagemaker-123/output/jobname"}} ], - "KmsKeyId": "configKmsKeyId", # fetched from config + "KmsKeyId": expected_kms_key_id, # fetched from config }, "NetworkConfig": { "VpcConfig": { - "Subnets": ["subnets-123"], # fetched from config + "Subnets": expected_subnets, # fetched from config "SecurityGroupIds": ["sg-asparam"], # provided as param }, - "EnableNetworkIsolation": True, # fetched from config - "EnableInterContainerTrafficEncryption": False, # fetched from config + # The following are fetched from config + "EnableNetworkIsolation": expected_enable_network_isolation, + "EnableInterContainerTrafficEncryption": expected_enable_inter_container_traffic_encryption, }, }, }, @@ -2349,14 +2426,22 @@ def test_compile_with_sagemaker_config_injection(sagemaker_config_session): output_model_config={"S3OutputLocation": "s3://test"}, job_name="TestJob", ) + expected_role_arn = SAGEMAKER_CONFIG_COMPILATION_JOB["SageMaker"]["CompilationJob"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_COMPILATION_JOB["SageMaker"]["CompilationJob"][ + "OutputConfig" + ]["KmsKeyId"] + expected_vpc_config = SAGEMAKER_CONFIG_COMPILATION_JOB["SageMaker"]["CompilationJob"][ + "VpcConfig" + ] + expected_tags = SAGEMAKER_CONFIG_COMPILATION_JOB["SageMaker"]["CompilationJob"]["Tags"] sagemaker_config_session.sagemaker_client.create_compilation_job.assert_called_with( InputConfig={}, - OutputConfig={"S3OutputLocation": "s3://test", "KmsKeyId": "TestKms"}, - RoleArn="arn:aws:iam::111111111111:role/ConfigRole", + OutputConfig={"S3OutputLocation": "s3://test", "KmsKeyId": expected_kms_key_id}, + RoleArn=expected_role_arn, StoppingCondition=None, CompilationJobName="TestJob", - VpcConfig={"Subnets": ["subnets-123"], "SecurityGroupIds": ["sg-123"]}, - Tags=TAGS, + VpcConfig=expected_vpc_config, + Tags=expected_tags, ) @@ -2435,12 +2520,22 @@ def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config "local", data_capture_config_dict=data_capture_config_dict, ) + expected_production_variant_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ + "EndpointConfig" + ]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] + expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ + "EndpointConfig" + ]["DataCaptureConfig"]["KmsKeyId"] + expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ + "KmsKeyId" + ] + expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName="endpoint-test", ProductionVariants=[ { - "CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}, + "CoreDumpConfig": {"KmsKeyId": expected_production_variant_kms_key_id}, "ModelName": "simple-model", "VariantName": "AllTraffic", "InitialVariantWeight": 1, @@ -2448,9 +2543,12 @@ def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config "InstanceType": "local", } ], - DataCaptureConfig={"DestinationS3Uri": "s3://test", "KmsKeyId": "testDataCaptureKmsKeyId"}, - KmsKeyId="ConfigKmsKeyId", - Tags=TAGS, + DataCaptureConfig={ + "DestinationS3Uri": "s3://test", + "KmsKeyId": expected_data_capture_kms_key_id, + }, + KmsKeyId=expected_kms_key_id, + Tags=expected_tags, ) @@ -2478,22 +2576,36 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( existing_endpoint_name, new_endpoint_name, new_production_variants=pvs ) + expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ + "EndpointConfig" + ]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] + expected_production_variant_1_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ + "EndpointConfig" + ]["ProductionVariants"][1]["CoreDumpConfig"]["KmsKeyId"] + expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ + "AsyncInferenceConfig" + ]["OutputConfig"]["KmsKeyId"] + expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ + "KmsKeyId" + ] + expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] + sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName=new_endpoint_name, ProductionVariants=[ { - "CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}, + "CoreDumpConfig": {"KmsKeyId": expected_production_variant_0_kms_key_id}, **sagemaker.production_variant("A", "ml.p2.xlarge"), }, { - "CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId2"}, + "CoreDumpConfig": {"KmsKeyId": expected_production_variant_1_kms_key_id}, **sagemaker.production_variant("B", "ml.p2.xlarge"), }, sagemaker.production_variant("C", "ml.p2.xlarge"), ], - KmsKeyId="ConfigKmsKeyId", # from config - Tags=TAGS, # from config - AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": "testOutputKmsKeyId"}}, # from config + KmsKeyId=expected_kms_key_id, # from config + Tags=expected_tags, # from config + AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": expected_inference_kms_key_id}}, ) @@ -2516,21 +2628,31 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection( data_capture_config_dict={}, async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(), ) + expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ + "EndpointConfig" + ]["DataCaptureConfig"]["KmsKeyId"] + expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ + "AsyncInferenceConfig" + ]["OutputConfig"]["KmsKeyId"] + expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ + "KmsKeyId" + ] + expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict() - expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = "testOutputKmsKeyId" + expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName="some-endpoint", ProductionVariants=pvs, - Tags=TAGS, # from config - KmsKeyId="ConfigKmsKeyId", # from config + Tags=expected_tags, # from config + KmsKeyId=expected_kms_key_id, # from config AsyncInferenceConfig=expected_async_inference_config_dict, - DataCaptureConfig={"KmsKeyId": "testDataCaptureKmsKeyId"}, + DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id}, ) sagemaker_config_session.sagemaker_client.create_endpoint.assert_called_with( EndpointConfigName="some-endpoint", EndpointName="some-endpoint", - Tags=TAGS, # from config + Tags=expected_tags, # from config ) @@ -3052,16 +3174,29 @@ def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_config_session input_config, output_config, auto_ml_job_config, job_name=job_name ) expected_call_args = copy.deepcopy(DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS) - expected_call_args["OutputDataConfig"]["KmsKeyId"] = "configKmsKeyId" - expected_call_args["RoleArn"] = "arn:aws:iam::111111111111:role/ConfigRole" + expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ + "SecurityConfig" + ]["VolumeKmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["OutputDataConfig"][ + "KmsKeyId" + ] + expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ + "SecurityConfig" + ]["VpcConfig"] + expected_tags = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["Tags"] + expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"][ + "AutoML" + ]["AutoMLJobConfig"]["SecurityConfig"]["EnableInterContainerTrafficEncryption"] + expected_call_args["OutputDataConfig"]["KmsKeyId"] = expected_kms_key_id + expected_call_args["RoleArn"] = expected_role_arn expected_call_args["AutoMLJobConfig"]["SecurityConfig"] = { - "EnableInterContainerTrafficEncryption": True - } - expected_call_args["AutoMLJobConfig"]["SecurityConfig"]["VpcConfig"] = { - "Subnets": ["subnets-123"], - "SecurityGroupIds": ["sg-123"], + "EnableInterContainerTrafficEncryption": expected_enable_inter_container_traffic_encryption } - expected_call_args["AutoMLJobConfig"]["SecurityConfig"]["VolumeKmsKeyId"] = "TestKmsKeyId" + expected_call_args["AutoMLJobConfig"]["SecurityConfig"]["VpcConfig"] = expected_vpc_config + expected_call_args["AutoMLJobConfig"]["SecurityConfig"][ + "VolumeKmsKeyId" + ] = expected_volume_kms_key_id sagemaker_config_session.sagemaker_client.create_auto_ml_job.assert_called_with( AutoMLJobName=expected_call_args["AutoMLJobName"], InputDataConfig=expected_call_args["InputDataConfig"], @@ -3069,7 +3204,7 @@ def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_config_session AutoMLJobConfig=expected_call_args["AutoMLJobConfig"], RoleArn=expected_call_args["RoleArn"], GenerateCandidateDefinitionsOnly=False, - Tags=TAGS, + Tags=expected_tags, ) @@ -3330,6 +3465,15 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_config_s task=task, validation_specification=validation_specification, ) + expected_kms_key_id = SAGEMAKER_CONFIG_MODEL_PACKAGE["SageMaker"]["ModelPackage"][ + "ValidationSpecification" + ]["ValidationProfiles"][0]["TransformJobDefinition"]["TransformOutput"]["KmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_MODEL_PACKAGE["SageMaker"]["ModelPackage"][ + "ValidationSpecification" + ]["ValidationRole"] + expected_volume_kms_key_id = SAGEMAKER_CONFIG_MODEL_PACKAGE["SageMaker"]["ModelPackage"][ + "ValidationSpecification" + ]["ValidationProfiles"][0]["TransformJobDefinition"]["TransformResources"]["VolumeKmsKeyId"] expected_args = copy.deepcopy( { "ModelPackageName": model_package_name, @@ -3353,15 +3497,13 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_config_s "ValidationSpecification": validation_specification, } ) - expected_args["ValidationSpecification"][ - "ValidationRole" - ] = "arn:aws:iam::111111111111:role/ConfigRole" + expected_args["ValidationSpecification"]["ValidationRole"] = expected_role_arn expected_args["ValidationSpecification"]["ValidationProfiles"][0]["TransformJobDefinition"][ "TransformResources" - ] = {"VolumeKmsKeyId": "testVolumeKmsKeyId"} + ]["VolumeKmsKeyId"] = expected_volume_kms_key_id expected_args["ValidationSpecification"]["ValidationProfiles"][0]["TransformJobDefinition"][ "TransformOutput" - ]["KmsKeyId"] = "testKmsKeyId" + ]["KmsKeyId"] = expected_kms_key_id sagemaker_config_session.sagemaker_client.create_model_package.assert_called_with( **expected_args @@ -3589,15 +3731,31 @@ def test_feature_group_create_with_sagemaker_config_injection( feature_definitions=feature_group_dummy_definitions, offline_store_config={"S3StorageConfig": {"S3Uri": "s3://test"}}, ) - assert sagemaker_config_session.sagemaker_client.create_feature_group.called_with( - FeatureGroupName="MyFeatureGroup", - RecordIdentifierFeatureName="feature1", - EventTimeFeatureName="feature2", - FeatureDefinitions=feature_group_dummy_definitions, - RoleArn="config_role", - OnlineStoreConfig={"SecurityConfig": {"KmsKeyId": "testKmsId2"}, "EnableOnlineStore": True}, - OfflineStoreConfig={"S3StorageConfig": {"KmsKeyId": "testKmsId", "S3Uri": "s3://test"}}, - Tags=TAGS, + expected_offline_store_kms_key_id = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"][ + "OfflineStoreConfig" + ]["S3StorageConfig"]["KmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"]["RoleArn"] + expected_online_store_kms_key_id = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"][ + "OnlineStoreConfig" + ]["SecurityConfig"]["KmsKeyId"] + expected_tags = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"]["Tags"] + expected_request = { + "FeatureGroupName": "MyFeatureGroup", + "RecordIdentifierFeatureName": "feature1", + "EventTimeFeatureName": "feature2", + "FeatureDefinitions": feature_group_dummy_definitions, + "RoleArn": expected_role_arn, + "OnlineStoreConfig": { + "SecurityConfig": {"KmsKeyId": expected_online_store_kms_key_id}, + "EnableOnlineStore": True, + }, + "OfflineStoreConfig": { + "S3StorageConfig": {"KmsKeyId": expected_offline_store_kms_key_id, "S3Uri": "s3://test"} + }, + "Tags": expected_tags, + } + sagemaker_config_session.sagemaker_client.create_feature_group.assert_called_with( + **expected_request ) diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index da24cbdd14..b6bd867948 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -121,8 +121,19 @@ def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_conf output_path=OUTPUT_PATH, sagemaker_session=sagemaker_config_session, ) - assert transformer.volume_kms_key == "volumeKmsKeyId" - assert transformer.output_kms_key == "outputKmsKeyId" + # volume kms key and output kms key are inserted from the config + assert ( + transformer.volume_kms_key + == SAGEMAKER_CONFIG_TRANSFORM_JOB["SageMaker"]["TransformJob"]["TransformResources"][ + "VolumeKmsKeyId" + ] + ) + assert ( + transformer.output_kms_key + == SAGEMAKER_CONFIG_TRANSFORM_JOB["SageMaker"]["TransformJob"]["TransformOutput"][ + "KmsKeyId" + ] + ) content_type = "text/csv" compression = "Gzip" @@ -171,7 +182,12 @@ def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_conf batch_data_capture_config, ) # KmsKeyId in BatchDataCapture will be inserted from the config - assert batch_data_capture_config.kms_key_id == "jobKmsKeyId" + assert ( + batch_data_capture_config.kms_key_id + == SAGEMAKER_CONFIG_TRANSFORM_JOB["SageMaker"]["TransformJob"]["DataCaptureConfig"][ + "KmsKeyId" + ] + ) def test_delete_model(sagemaker_session): From e62a37889739f80eb894a04596d71a0e653d791b Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Tue, 14 Mar 2023 16:20:29 -0700 Subject: [PATCH 16/40] fix: inject from config into existing ProductionVariants inside create_endpoint_config_from_existing --- src/sagemaker/session.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 03e1f88825..739ed6fe6e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4194,31 +4194,30 @@ def create_endpoint_config_from_existing( "EndpointConfigName": new_config_name, } - # TODO: should this merge from config even if new_production_variants is None? - if new_production_variants: + production_variants = ( + new_production_variants or existing_endpoint_config_desc["ProductionVariants"] + ) + if production_variants: inferred_production_variants_from_config = ( self.get_sagemaker_config_override(ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH) or [] ) for i in range( - min(len(new_production_variants), len(inferred_production_variants_from_config)) + min(len(production_variants), len(inferred_production_variants_from_config)) ): original_config_dict_value = inferred_production_variants_from_config[i].copy() - merge_dicts(inferred_production_variants_from_config[i], new_production_variants[i]) - if new_production_variants[i] != inferred_production_variants_from_config[i]: + merge_dicts(inferred_production_variants_from_config[i], production_variants[i]) + if production_variants[i] != inferred_production_variants_from_config[i]: print( "Config value {} at config path {} was fetched first for " "index: 0.".format( original_config_dict_value, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH ), "It was then merged with the existing value {} to give {}".format( - new_production_variants[i], inferred_production_variants_from_config[i] + production_variants[i], inferred_production_variants_from_config[i] ), ) - new_production_variants[i].update(inferred_production_variants_from_config[i]) - - request["ProductionVariants"] = ( - new_production_variants or existing_endpoint_config_desc["ProductionVariants"] - ) + production_variants[i].update(inferred_production_variants_from_config[i]) + request["ProductionVariants"] = production_variants request_tags = new_tags or self.list_tags( existing_endpoint_config_desc["EndpointConfigArn"] From 2baeab0092fabc299ea5d51b0f6ba58bd03ee04e Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Wed, 15 Mar 2023 10:37:42 -0700 Subject: [PATCH 17/40] change: added unit test for verifying yaml safe_load method --- tests/data/config/config_file_with_code.yaml | 3 +++ tests/unit/sagemaker/config/test_config.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 tests/data/config/config_file_with_code.yaml diff --git a/tests/data/config/config_file_with_code.yaml b/tests/data/config/config_file_with_code.yaml new file mode 100644 index 0000000000..a20971ad38 --- /dev/null +++ b/tests/data/config/config_file_with_code.yaml @@ -0,0 +1,3 @@ +SchemaVersion: '1.0' +CustomParameters: + TestCode: !!python/object/apply:eval ["[x ** 2 for x in [1, 2, 3]]"] diff --git a/tests/unit/sagemaker/config/test_config.py b/tests/unit/sagemaker/config/test_config.py index ac3719479b..453d008a19 100644 --- a/tests/unit/sagemaker/config/test_config.py +++ b/tests/unit/sagemaker/config/test_config.py @@ -19,6 +19,7 @@ from sagemaker.config.config import SageMakerConfig from jsonschema import exceptions +from yaml.constructor import ConstructorError @pytest.fixture() @@ -47,6 +48,17 @@ def test_config_when_overriden_default_config_file_is_not_found(get_data_dir): del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] +def test_invalid_config_file_which_has_python_code(get_data_dir): + invalid_config_file_path = os.path.join(get_data_dir, "config_file_with_code.yaml") + # no exceptions will be thrown with yaml.unsafe_load + yaml.unsafe_load(open(invalid_config_file_path, "r")) + # PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using + # yaml.safe_load internally + with pytest.raises(ConstructorError) as exception_info: + SageMakerConfig(additional_config_paths=[invalid_config_file_path]) + assert "python/object/apply:eval" in str(exception_info.value) + + def test_config_when_additional_config_file_path_is_not_found(get_data_dir): fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") with pytest.raises(ValueError): From 43492a65136a729bdf80246e6d28ca27347c5d2f Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Wed, 15 Mar 2023 13:28:26 -0700 Subject: [PATCH 18/40] change: addressed PR comments for SageMaker Config --- src/sagemaker/config/config.py | 110 ++++++++++++------ .../expected_output_config_after_merge.yaml | 2 +- .../sample_additional_config_for_merge.yaml | 2 +- .../data/config/sample_config_for_merge.yaml | 2 +- tests/unit/sagemaker/config/test_config.py | 28 ++--- 5 files changed, 90 insertions(+), 54 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 7451de158d..81559a67d4 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -33,15 +33,22 @@ logger = logging.getLogger("sagemaker") _APP_NAME = "sagemaker" +# The default config file location of the Administrator provided Config file. This path can be +# overridden with `SAGEMAKER_ADMIN_CONFIG_OVERRIDE` environment variable. _DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml") +# The default config file location of the user provided Config file. This path can be +# overridden with `SAGEMAKER_USER_CONFIG_OVERRIDE` environment variable. _DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml") -ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE = "SAGEMAKER_DEFAULT_CONFIG_OVERRIDE" +ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE" ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE" -_config_paths = [_DEFAULT_ADMIN_CONFIG_FILE_PATH, _DEFAULT_USER_CONFIG_FILE_PATH] _BOTO_SESSION = boto3.DEFAULT_SESSION or boto3.Session() +# The default Boto3 S3 Resource. This is constructed from the default Boto3 session. This will be +# used to fetch SageMakerConfig from S3. Users can override this by passing their own S3 Resource +# as the constructor parameter for SageMakerConfig. _DEFAULT_S3_RESOURCE = _BOTO_SESSION.resource("s3") +S3_PREFIX = "s3://" class SageMakerConfig(object): @@ -68,7 +75,7 @@ def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAU _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH. Users can override the _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH - by using environment variables - SAGEMAKER_DEFAULT_CONFIG_OVERRIDE and + by using environment variables - SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE Additional Configuration file paths can also be provided as a constructor parameter. @@ -105,7 +112,7 @@ def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAU """ default_config_path = os.getenv( - ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH + ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH ) user_config_path = os.getenv( ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH @@ -113,17 +120,8 @@ def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAU self._config_paths = [default_config_path, user_config_path] if additional_config_paths: self._config_paths += additional_config_paths - self._s3_resource = s3_resource - config = {} - for file_path in self._config_paths: - if file_path.startswith("s3://"): - config_from_file = _load_config_from_s3(file_path, self._s3_resource) - else: - config_from_file = _load_config_from_file(file_path) - if config_from_file: - validate(config_from_file, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) - merge_dicts(config, config_from_file) - self._config = config + self._config_paths = list(filter(lambda item: item is not None, self._config_paths)) + self._config = _load_config_files(self._config_paths, s3_resource) @property def config_paths(self) -> List[str]: @@ -145,6 +143,53 @@ def config(self) -> dict: return self._config +def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: + """This method loads all the config files from the paths that were provided as Inputs. + + Note: Supported Config file locations are Local File System and S3. + + This method will throw exceptions for the following cases: + * Schema validation fails for one/more config files. + * When the config file is not a proper YAML file. + * Any S3 related issues that arises while fetching config file from S3. This includes + permission issues, S3 Object is not found in the specified S3 URI. + * File doesn't exist in a path that was specified by the user as part of environment + variable/ additional_config_paths. This doesn't include + _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH + + Args: + file_paths(List[str]): The list of paths corresponding to the config file. Note: This + path can either be a Local File System path or it can be a S3 URI. + s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config + files from S3. See :py:meth:`boto3.session.Session.resource`. + + Returns: + dict: A dictionary representing the configurations that were loaded from the config files. + + """ + merged_config = {} + for file_path in file_paths: + config_from_file = {} + if file_path.startswith(S3_PREFIX): + config_from_file = _load_config_from_s3(file_path, s3_resource_for_config) + else: + try: + config_from_file = _load_config_from_file(file_path) + except ValueError: + if file_path not in ( + _DEFAULT_ADMIN_CONFIG_FILE_PATH, + _DEFAULT_USER_CONFIG_FILE_PATH, + ): + # Throw exception only when User provided file path is invalid. + # If there are no files in the Default config file locations, don't throw + # Exceptions. + raise + if config_from_file: + validate(config_from_file, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + merge_dicts(merged_config, config_from_file) + return merged_config + + def _load_config_from_file(file_path: str) -> dict: """This method loads the config file from the path that was specified as parameter. @@ -159,28 +204,19 @@ def _load_config_from_file(file_path: str) -> dict: This method will throw Exceptions for the following cases: * When the config file is not a proper YAML file. - * File doesn't exist in a path that was specified by the consumer. This doesn't include - _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH + * File doesn't exist in a path that was specified by the consumer. """ - config = {} - if file_path: - inferred_file_path = file_path - if os.path.isdir(file_path): - inferred_file_path = os.path.join(file_path, "config.yaml") - if not os.path.exists(inferred_file_path): - if inferred_file_path not in ( - _DEFAULT_ADMIN_CONFIG_FILE_PATH, - _DEFAULT_USER_CONFIG_FILE_PATH, - ): - # Customer provided file path is invalid. - raise ValueError( - f"Unable to load config file from the location: {file_path} Please" - f" provide a valid file path" - ) - else: - logger.debug("Fetching configuration file from the path: %s", file_path) - config = yaml.safe_load(open(inferred_file_path, "r")) - return config + inferred_file_path = file_path + if os.path.isdir(file_path): + inferred_file_path = os.path.join(file_path, "config.yaml") + if not os.path.exists(inferred_file_path): + raise ValueError( + f"Unable to load config file from the location: {file_path} Please" + f" provide a valid file path" + ) + else: + logger.debug("Fetching configuration file from the path: %s", file_path) + return yaml.safe_load(open(inferred_file_path, "r")) def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: @@ -247,7 +283,7 @@ def _get_inferred_s3_uri(s3_uri, s3_resource_for_config): s3_bucket = s3_resource_for_config.Bucket(name=bucket) s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() s3_files_with_same_prefix = [ - "s3://{}/{}".format(bucket, s3_object.key) for s3_object in s3_objects + "{}{}/{}".format(S3_PREFIX, bucket, s3_object.key) for s3_object in s3_objects ] except Exception as e: # pylint: disable=W0703 # if customers didn't provide us with a valid S3 File/insufficient read permission, diff --git a/tests/data/config/expected_output_config_after_merge.yaml b/tests/data/config/expected_output_config_after_merge.yaml index 1cd7d815c6..b3556b4bcf 100644 --- a/tests/data/config/expected_output_config_after_merge.yaml +++ b/tests/data/config/expected_output_config_after_merge.yaml @@ -11,4 +11,4 @@ SageMaker: # Present only in the default config KmsKeyId: 'somekmskeyid' # Present only in the additional config - RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' \ No newline at end of file + RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' diff --git a/tests/data/config/sample_additional_config_for_merge.yaml b/tests/data/config/sample_additional_config_for_merge.yaml index 94aba11176..06ed08aca0 100644 --- a/tests/data/config/sample_additional_config_for_merge.yaml +++ b/tests/data/config/sample_additional_config_for_merge.yaml @@ -4,4 +4,4 @@ SageMaker: OnlineStoreConfig: SecurityConfig: KmsKeyId: 'additionalConfigKmsKeyId' - RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' \ No newline at end of file + RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' diff --git a/tests/data/config/sample_config_for_merge.yaml b/tests/data/config/sample_config_for_merge.yaml index f933f46afe..49d4a5b3ee 100644 --- a/tests/data/config/sample_config_for_merge.yaml +++ b/tests/data/config/sample_config_for_merge.yaml @@ -6,4 +6,4 @@ SageMaker: KmsKeyId: 'someotherkmskeyid' OfflineStoreConfig: S3StorageConfig: - KmsKeyId: 'somekmskeyid' \ No newline at end of file + KmsKeyId: 'somekmskeyid' diff --git a/tests/unit/sagemaker/config/test_config.py b/tests/unit/sagemaker/config/test_config.py index 453d008a19..aaf7ef3b09 100644 --- a/tests/unit/sagemaker/config/test_config.py +++ b/tests/unit/sagemaker/config/test_config.py @@ -42,10 +42,10 @@ def test_config_when_default_config_file_and_user_config_file_is_not_found(): def test_config_when_overriden_default_config_file_is_not_found(get_data_dir): fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") - os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = fake_config_file_path + os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = fake_config_file_path with pytest.raises(ValueError): SageMakerConfig() - del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] def test_invalid_config_file_which_has_python_code(get_data_dir): @@ -77,10 +77,10 @@ def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir def test_default_config_file_with_invalid_schema(get_data_dir): config_file_path = os.path.join(get_data_dir, "invalid_config_file.yaml") - os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = config_file_path + os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = config_file_path with pytest.raises(exceptions.ValidationError): SageMakerConfig() - del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] def test_default_config_file_when_directory_is_provided_as_the_path( @@ -89,9 +89,9 @@ def test_default_config_file_when_directory_is_provided_as_the_path( # This will try to load config.yaml file from that directory if present. expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = get_data_dir + os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = get_data_dir assert expected_config == SageMakerConfig().config - del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] def test_additional_config_paths_when_directory_is_provided( @@ -106,12 +106,12 @@ def test_additional_config_paths_when_directory_is_provided( def test_default_config_file_when_path_is_provided_as_environment_variable( get_data_dir, valid_config_with_all_the_scopes, base_config_with_schema ): - os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = get_data_dir + os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = get_data_dir # This will try to load config.yaml file from that directory if present. expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes assert expected_config == SageMakerConfig().config - del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] def test_merge_behavior_when_additional_config_file_path_is_not_found( @@ -121,10 +121,10 @@ def test_merge_behavior_when_additional_config_file_path_is_not_found( fake_additional_override_config_file_path = os.path.join( get_data_dir, "additional-config-not-found.yaml" ) - os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = valid_config_file_path + os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path with pytest.raises(ValueError): SageMakerConfig(additional_config_paths=[fake_additional_override_config_file_path]) - del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] def test_merge_behavior(get_data_dir, expected_merged_config): @@ -132,14 +132,14 @@ def test_merge_behavior(get_data_dir, expected_merged_config): additional_override_config_file_path = os.path.join( get_data_dir, "sample_additional_config_for_merge.yaml" ) - os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = valid_config_file_path + os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path assert ( expected_merged_config == SageMakerConfig(additional_config_paths=[additional_override_config_file_path]).config ) os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = additional_override_config_file_path assert expected_merged_config == SageMakerConfig().config - del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] @@ -235,7 +235,7 @@ def test_merge_of_s3_default_config_file_and_regular_config_file( additional_override_config_file_path = os.path.join( get_data_dir, "sample_additional_config_for_merge.yaml" ) - os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = config_file_s3_uri + os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = config_file_s3_uri assert ( expected_merged_config == SageMakerConfig( @@ -243,4 +243,4 @@ def test_merge_of_s3_default_config_file_and_regular_config_file( s3_resource=s3_resource_mock, ).config ) - del os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] + del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] From 53ca7fc04170d6c8be98627e1011f401c6a77c63 Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Wed, 15 Mar 2023 14:12:21 -0700 Subject: [PATCH 19/40] change: Sagemaker Config - minor clarification --- src/sagemaker/config/config_schema.py | 2 +- src/sagemaker/session.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 7942f05b52..da6d66e396 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -203,7 +203,7 @@ TYPE: "string", # Currently we support only one schema version (1.0). # In the future this might change if we introduce any breaking changes. - # So adding an enum as a validator. + # So added an enum as a validator. "enum": ["1.0"], "description": "The schema version of the document.", }, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 739ed6fe6e..fb5687e02a 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1432,6 +1432,10 @@ def update_training_job( update the keep-alive period if the warm pool status is `Available`. No other fields can be updated. (default: ``None``). """ + # No injections from sagemaker_config because the UpdateTrainingJob API's resource_config + # object accepts fewer parameters than the CreateTrainingJob API, and none that the + # sagemaker_config currently supports + update_training_job_request = self._get_update_training_job_request( job_name=job_name, profiler_rule_configs=profiler_rule_configs, From 70056aac67e91860393fa3af835b6bd5610e7093 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Wed, 15 Mar 2023 15:45:07 -0700 Subject: [PATCH 20/40] change: ModelMonitoring and Processing now use helper methods for updating NetworkConfig --- .../model_monitor/model_monitoring.py | 48 ++++------- src/sagemaker/network.py | 80 ++----------------- src/sagemaker/processing.py | 60 +++++--------- 3 files changed, 44 insertions(+), 144 deletions(-) diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 6dc474b1b8..1cd757cfbd 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -192,42 +192,24 @@ def __init__( self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, default_value=output_kms_key ) - _enable_network_isolation_from_config = ( - self.sagemaker_session.get_sagemaker_config_override( - MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH - ) + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + network_config, + "subnets", + MONITORING_JOB_SUBNETS_PATH, ) - - _subnets_from_config = self.sagemaker_session.get_sagemaker_config_override( - MONITORING_JOB_SUBNETS_PATH + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + self.network_config, + "security_group_ids", + MONITORING_JOB_SECURITY_GROUP_IDS_PATH, ) - _security_group_ids_from_config = self.sagemaker_session.get_sagemaker_config_override( - MONITORING_JOB_SECURITY_GROUP_IDS_PATH + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + self.network_config, + "enable_network_isolation", + MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH, ) - if network_config: - if not network_config.subnets: - network_config.subnets = _subnets_from_config - if network_config.enable_network_isolation is None: - network_config.enable_network_isolation = ( - _enable_network_isolation_from_config or False - ) - if not network_config.security_group_ids: - network_config.security_group_ids = _security_group_ids_from_config - self.network_config = network_config - else: - if ( - _enable_network_isolation_from_config is not None - or _subnets_from_config - or _security_group_ids_from_config - ): - self.network_config = NetworkConfig( - enable_network_isolation=_enable_network_isolation_from_config or False, - security_group_ids=_security_group_ids_from_config, - subnets=_subnets_from_config, - ) - else: - self.network_config = None - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( NetworkConfig, self.network_config, diff --git a/src/sagemaker/network.py b/src/sagemaker/network.py index 2f278414c3..2942e71062 100644 --- a/src/sagemaker/network.py +++ b/src/sagemaker/network.py @@ -48,82 +48,18 @@ def __init__( encrypt_inter_container_traffic (bool or PipelineVariable): Boolean that determines whether to encrypt inter-container traffic. Default value is None. """ - self._enable_network_isolation = enable_network_isolation - self._security_group_ids = security_group_ids - self._subnets = subnets + self.enable_network_isolation = enable_network_isolation + self.security_group_ids = security_group_ids + self.subnets = subnets self.encrypt_inter_container_traffic = encrypt_inter_container_traffic - @property - def security_group_ids(self): - """Getter for Security Groups - - Returns: - list[str]: List of Security Groups - - """ - return self._security_group_ids - - @property - def subnets(self): - """Getter for Subnets - - Returns: - list[str]: List of Subnets - - """ - return self._subnets - - @property - def enable_network_isolation(self): - """Getter for Enable Network Isolation - - Returns: - bool: Value of Enable Network Isolation - - """ - return self._enable_network_isolation - - @security_group_ids.setter - def security_group_ids( - self, - security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, - ): - """Setter for security groups. - - Args: - security_group_ids: List of Security Group Ids. - - """ - self._security_group_ids = security_group_ids - - @subnets.setter - def subnets( - self, - subnets: Optional[List[Union[str, PipelineVariable]]] = None, - ): - """Setter for subnets. - - Args: - subnets: List of Subnets. - - """ - self._subnets = subnets - - @enable_network_isolation.setter - def enable_network_isolation( - self, enable_network_isolation: Union[bool, PipelineVariable] = False - ): - """Setter for enable network isolation. - - Args: - enable_network_isolation: Value for enable network isolation - - """ - self._enable_network_isolation = enable_network_isolation - def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" - network_config_request = {"EnableNetworkIsolation": self.enable_network_isolation} + # Enable Network Isolation should default to False if it is not provided. + enable_network_isolation = ( + False if self.enable_network_isolation is None else self.enable_network_isolation + ) + network_config_request = {"EnableNetworkIsolation": enable_network_isolation} if self.encrypt_inter_container_traffic is not None: network_config_request[ diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 0dbff0ca59..c189e27698 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -154,41 +154,30 @@ def __init__( self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key ) - _enable_network_isolation_from_config = ( - self.sagemaker_session.get_sagemaker_config_override( - PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH - ) + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + network_config, + "subnets", + PROCESSING_JOB_SUBNETS_PATH, ) - - _subnets_from_config = self.sagemaker_session.get_sagemaker_config_override( - PROCESSING_JOB_SUBNETS_PATH + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + self.network_config, + "security_group_ids", + PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, ) - _security_group_ids_from_config = self.sagemaker_session.get_sagemaker_config_override( - PROCESSING_JOB_SECURITY_GROUP_IDS_PATH + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + self.network_config, + "enable_network_isolation", + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + ) + self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + NetworkConfig, + self.network_config, + "encrypt_inter_container_traffic", + PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, ) - if network_config: - if not network_config.subnets: - network_config.subnets = _subnets_from_config - if network_config.enable_network_isolation is None: - network_config.enable_network_isolation = ( - _enable_network_isolation_from_config or False - ) - if not network_config.security_group_ids: - network_config.security_group_ids = _security_group_ids_from_config - self.network_config = network_config - else: - if ( - _enable_network_isolation_from_config is not None - or _subnets_from_config - or _security_group_ids_from_config - ): - self.network_config = NetworkConfig( - enable_network_isolation=_enable_network_isolation_from_config or False, - security_group_ids=_security_group_ids_from_config, - subnets=_subnets_from_config, - ) - else: - self.network_config = None self.role = self.sagemaker_session.get_sagemaker_config_override( PROCESSING_JOB_ROLE_ARN_PATH, default_value=role ) @@ -199,13 +188,6 @@ def __init__( # after fetching the config. raise ValueError("IAM role should be provided for creating Processing jobs.") - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( - NetworkConfig, - self.network_config, - "encrypt_inter_container_traffic", - PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, - ) - @runnable_by_pipeline def run( self, From 51190cc0fcd2639179e38e9143c175d148e2a794 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Wed, 15 Mar 2023 16:59:37 -0700 Subject: [PATCH 21/40] change: Refactoring session.py and added additional schema validation for ValidationProfiles --- src/sagemaker/config/config.py | 5 +- src/sagemaker/config/config_schema.py | 5 ++ src/sagemaker/session.py | 110 +++++++++----------------- 3 files changed, 46 insertions(+), 74 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 81559a67d4..dd01c7b332 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -214,9 +214,8 @@ def _load_config_from_file(file_path: str) -> dict: f"Unable to load config file from the location: {file_path} Please" f" provide a valid file path" ) - else: - logger.debug("Fetching configuration file from the path: %s", file_path) - return yaml.safe_load(open(inferred_file_path, "r")) + logger.debug("Fetching configuration file from the path: %s", file_path) + return yaml.safe_load(open(inferred_file_path, "r")) def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index da6d66e396..e23bbcafd9 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -456,6 +456,11 @@ VALIDATION_PROFILES: { TYPE: "array", "items": {"$ref": "#/definitions/validationProfile"}, + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ModelPackageValidationSpecification.html + # According to the API docs, This array should have exactly 1 + # item. + "minItems": 1, + "maxItems": 1, }, VALIDATION_ROLE: {"$ref": "#/definitions/roleArn"}, }, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index fb5687e02a..04d67d868e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1483,7 +1483,7 @@ def _get_update_training_job_request( return update_training_job_request # TODO: unit tests or make a more generic version - def update_processing_input_from_config(self, inputs): + def _update_processing_input_from_config(self, inputs): """Updates Processor Inputs to fetch values from SageMakerConfig wherever applicable. Args: @@ -1504,28 +1504,9 @@ def update_processing_input_from_config(self, inputs): # config for the ones already present in dict_from_inputs. # If BOTH are present, we will still add to both and let the API call fail as it would # have even without injection from sagemaker_config. - athena_path = [DATASET_DEFINITION, ATHENA_DATASET_DEFINITION] - redshift_path = [DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION] - - athena_value_from_inputs = get_nested_value(dict_from_inputs, athena_path) - athena_value_from_config = get_nested_value(dict_from_config, athena_path) - - redshift_value_from_inputs = get_nested_value(dict_from_inputs, redshift_path) - redshift_value_from_config = get_nested_value(dict_from_config, redshift_path) - - if athena_value_from_inputs is not None: - merge_dicts(athena_value_from_config, athena_value_from_inputs) - inputs[i] = set_nested_value( - dict_from_inputs, athena_path, athena_value_from_config - ) - - if redshift_value_from_inputs is not None: - merge_dicts(redshift_value_from_config, redshift_value_from_inputs) - inputs[i] = set_nested_value( - dict_from_inputs, redshift_path, redshift_value_from_config - ) - - if processing_inputs_from_config != []: + merge_dicts(dict_from_config, dict_from_inputs) + inputs[i] = dict_from_config + if processing_inputs_from_config: print( "[Sagemaker Config - applied value]\n", "config key = {}\n".format(PROCESSING_JOB_INPUTS_PATH), @@ -1598,7 +1579,7 @@ def process( PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, ) - self.update_processing_input_from_config(inputs) + self._update_processing_input_from_config(inputs) role_arn = self.get_sagemaker_config_override( PROCESSING_JOB_ROLE_ARN_PATH, default_value=role_arn ) @@ -3902,53 +3883,40 @@ def create_model_package_from_containers( "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ - validation_role_from_config = self.get_sagemaker_config_override( - MODEL_PACKAGE_VALIDATION_ROLE_PATH - ) - validation_profiles_from_config = self.get_sagemaker_config_override( - MODEL_PACKAGE_VALIDATION_PROFILES_PATH - ) - if validation_role_from_config or validation_profiles_from_config: - if not validation_specification: - validation_specification = { - VALIDATION_ROLE: validation_role_from_config, - VALIDATION_PROFILES: validation_profiles_from_config, - } - else: - # ValidationSpecification is provided as method parameter - # Now we need to carefully merge - if VALIDATION_ROLE not in validation_specification: - # if Validation role is not provided as part of the dict, merge it - validation_specification[VALIDATION_ROLE] = validation_role_from_config - if VALIDATION_PROFILES not in validation_specification: - # if Validation profile is not provided as part of the dict, merge it - validation_specification[VALIDATION_PROFILES] = validation_profiles_from_config - elif validation_profiles_from_config: - # Validation profiles are provided in the config as well as parameter. - validation_profiles = validation_specification[VALIDATION_PROFILES] - for i in range( - min(len(validation_profiles), len(validation_profiles_from_config)) - ): - # Now we need to merge corresponding entries which are not provided in the - # dict , but are present in the config - validation_profile = validation_profiles[i] - validation_profile_from_config = validation_profiles_from_config[i] - original_config_dict_value = validation_profile_from_config.copy() - # Apply the default configurations on top of the config entries - merge_dicts(validation_profile_from_config, validation_profile) - if validation_profile != validation_profile_from_config: - print( - "Config value {} at config path {} was fetched first for " - "index {}.".format( - original_config_dict_value, - MODEL_PACKAGE_VALIDATION_PROFILES_PATH, - i, - ), - "It was then merged with the existing value {} to give {}".format( - validation_profile, validation_profile_from_config - ), - ) - validation_profile.update(validation_profile_from_config) + + if validation_specification: + # ValidationSpecification is provided. Now we can merge missing entries from config. + # If ValidationSpecification is not provided, it is safe to ignore. This is because, + # if this object is provided to the API, then both ValidationProfiles and ValidationRole + # are required and for ValidationProfile, ProfileName is a required parameter. That is + # not supported by the config now. So if we merge values from config, then API will + # throw an exception. In the future, when SageMaker Config starts supporting other + # parameters we can add that. + validation_role = self.get_sagemaker_config_override( + MODEL_PACKAGE_VALIDATION_ROLE_PATH, + default_value=validation_specification.get(VALIDATION_ROLE, None), + ) + validation_specification[VALIDATION_ROLE] = validation_role + validation_profiles_from_config = ( + self.get_sagemaker_config_override(MODEL_PACKAGE_VALIDATION_PROFILES_PATH) or [] + ) + validation_profiles = validation_specification.get(VALIDATION_PROFILES, []) + for i in range(min(len(validation_profiles), len(validation_profiles_from_config))): + original_config_dict_value = copy.deepcopy(validation_profiles_from_config[i]) + merge_dicts(validation_profiles_from_config[i], validation_profiles[i]) + if validation_profiles[i] != validation_profiles_from_config[i]: + print( + "Config value {} at config path {} was fetched first for " + "index {}.".format( + original_config_dict_value, + MODEL_PACKAGE_VALIDATION_PROFILES_PATH, + i, + ), + "It was then merged with the existing value {} to give {}".format( + validation_profiles[i], validation_profiles_from_config[i] + ), + ) + validation_profiles[i].update(validation_profiles_from_config[i]) model_pkg_request = get_create_model_package_request( model_package_name, model_package_group_name, From dba07faba75dd060385c32690111811babea638d Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Sat, 18 Mar 2023 12:04:51 -0700 Subject: [PATCH 22/40] update: expand one unit test --- tests/unit/test_pipeline_model.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index af2b816d61..c710298d9d 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -17,6 +17,7 @@ import pytest from botocore.utils import merge_dicts from mock import Mock, patch +from mock.mock import ANY from sagemaker.model import FrameworkModel from sagemaker.pipeline import PipelineModel @@ -317,15 +318,10 @@ def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_sessio endpoint_config = copy.deepcopy(SAGEMAKER_CONFIG_ENDPOINT_CONFIG) merge_dicts(combined_config, endpoint_config) sagemaker_config_session.sagemaker_config.config = combined_config + + sagemaker_config_session.create_model = Mock() sagemaker_config_session.endpoint_from_production_variants = Mock() - framework_model = DummyFrameworkModel(sagemaker_config_session) - sparkml_model = SparkMLModel( - model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_config_session - ) - pipeline_model = PipelineModel( - [framework_model, sparkml_model], sagemaker_session=sagemaker_config_session - ) expected_role_arn = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["ExecutionRoleArn"] expected_enable_network_isolation = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"][ "EnableNetworkIsolation" @@ -334,10 +330,27 @@ def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_sessio expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ "KmsKeyId" ] + + framework_model = DummyFrameworkModel(sagemaker_config_session) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_config_session + ) + pipeline_model = PipelineModel( + [framework_model, sparkml_model], sagemaker_session=sagemaker_config_session + ) assert pipeline_model.role == expected_role_arn assert pipeline_model.vpc_config == expected_vpc_config assert pipeline_model.enable_network_isolation == expected_enable_network_isolation + pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) + + sagemaker_config_session.create_model.assert_called_with( + ANY, + expected_role_arn, + ANY, + vpc_config=expected_vpc_config, + enable_network_isolation=expected_enable_network_isolation, + ) sagemaker_config_session.endpoint_from_production_variants.assert_called_with( name="mi-1-2017-10-10-14-14-15", production_variants=[ From 95bc7de187423ee0557f47351ccbf19e4463bee8 Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Sat, 18 Mar 2023 12:05:34 -0700 Subject: [PATCH 23/40] update: new integ test for cross context injection --- tests/integ/test_sagemaker_config.py | 416 +++++++++++++++++++++++++-- 1 file changed, 388 insertions(+), 28 deletions(-) diff --git a/tests/integ/test_sagemaker_config.py b/tests/integ/test_sagemaker_config.py index 28e2229ca3..6148199c70 100644 --- a/tests/integ/test_sagemaker_config.py +++ b/tests/integ/test_sagemaker_config.py @@ -13,73 +13,433 @@ from __future__ import absolute_import import os +import pathlib +import tempfile import pytest import yaml +from botocore.config import Config +from sagemaker import ( + PipelineModel, + image_uris, + Model, + Predictor, + Session, +) from sagemaker.config import SageMakerConfig +from sagemaker.model_monitor import DataCaptureConfig from sagemaker.s3 import S3Uploader +from sagemaker.sparkml import SparkMLModel +from sagemaker.utils import sagemaker_timestamp + +from tests.integ import DATA_DIR from tests.integ.kms_utils import get_or_create_kms_key +from tests.integ.test_inference_pipeline import SCHEMA +from tests.integ.timeout import timeout_and_delete_endpoint_by_name +S3_KEY_PREFIX = "integ-test-sagemaker_config" +ENDPOINT_CONFIG_TAGS = [ + {"Key": "SagemakerConfigUsed", "Value": "Yes"}, + {"Key": "ConfigOperation", "Value": "EndpointConfig"}, +] +MODEL_TAGS = [ + {"Key": "SagemakerConfigUsed", "Value": "Yes"}, + {"Key": "ConfigOperation", "Value": "Model"}, +] +CONFIG_DATA_DIR = os.path.join(DATA_DIR, "config") -@pytest.fixture() -def get_data_dir(): - return os.path.join(os.path.dirname(__file__), "..", "data", "config") +@pytest.fixture(scope="session") +def role_arn(sagemaker_session): + iam_client = sagemaker_session.boto_session.client("iam") + return iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] -@pytest.fixture(scope="module") -def s3_files_kms_key(sagemaker_session): + +@pytest.fixture(scope="session") +def kms_key_arn(sagemaker_session): return get_or_create_kms_key(sagemaker_session=sagemaker_session) @pytest.fixture() -def expected_merged_config(get_data_dir): +def expected_merged_config(): expected_merged_config_file_path = os.path.join( - get_data_dir, "expected_output_config_after_merge.yaml" + CONFIG_DATA_DIR, "expected_output_config_after_merge.yaml" ) return yaml.safe_load(open(expected_merged_config_file_path, "r").read()) -def test_config_download_from_s3_and_merge( - sagemaker_session, - s3_files_kms_key, - get_data_dir, - expected_merged_config, -): - +@pytest.fixture(scope="module") +def s3_uri_prefix(sagemaker_session): # Note: not using unique_name_from_base() here because the config contents are expected to # change very rarely (if ever), so rather than writing new files and deleting them every time # we can just use the same S3 paths s3_uri_prefix = os.path.join( "s3://", sagemaker_session.default_bucket(), - "integ-test-sagemaker_config", + S3_KEY_PREFIX, ) + return s3_uri_prefix - config_file_1_local_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml") - config_file_2_local_path = os.path.join(get_data_dir, "sample_additional_config_for_merge.yaml") - config_file_1_as_yaml = open(config_file_1_local_path, "r").read() - config_file_2_as_yaml = open(config_file_2_local_path, "r").read() +@pytest.fixture(scope="session") +def sagemaker_session_with_dynamically_generated_sagemaker_config( + role_arn, + kms_key_arn, + sagemaker_client_config, + sagemaker_runtime_config, + boto_session, + sagemaker_metrics_config, +): + # This config needs to be dynamically generated so it can include the specific infra parameters + # created/reused for the Integ tests + config_as_dict = { + "SchemaVersion": "1.0", + "SageMaker": { + "EndpointConfig": { + "AsyncInferenceConfig": {"OutputConfig": {"KmsKeyId": kms_key_arn}}, + "DataCaptureConfig": {"KmsKeyId": kms_key_arn}, + "KmsKeyId": kms_key_arn, + # TODO: re-enable after ProductionVariants injection is complete + # "ProductionVariants": [{ + # "CoreDumpConfig": { + # "KmsKeyId": kms_key_arn + # } + # }], + "Tags": ENDPOINT_CONFIG_TAGS, + }, + "Model": { + "EnableNetworkIsolation": True, + "ExecutionRoleArn": role_arn, + "Tags": MODEL_TAGS, + # VpcConfig is omitted for now, more info inside test + # test_sagemaker_config_cross_context_injection + }, + }, + } + + dynamic_sagemaker_config_yaml_path = os.path.join( + tempfile.gettempdir(), "dynamic_sagemaker_config.yaml" + ) + + # write to yaml file, and avoid references and anchors + yaml.Dumper.ignore_aliases = lambda *args: True + with open(pathlib.Path(dynamic_sagemaker_config_yaml_path), "w") as f: + yaml.dump(config_as_dict, f, sort_keys=False, default_flow_style=False) + # other Session inputs (same as sagemaker_session fixture) + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + metrics_client = ( + boto_session.client("sagemaker-metrics", **sagemaker_metrics_config) + if sagemaker_metrics_config + else None + ) + + session = Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + sagemaker_metrics_client=metrics_client, + sagemaker_config=SageMakerConfig( + additional_config_paths=[dynamic_sagemaker_config_yaml_path] + ), + ) + + return session + + +def test_config_download_from_s3_and_merge( + sagemaker_session, + kms_key_arn, + s3_uri_prefix, + expected_merged_config, +): + config_file_1_local_path = os.path.join(CONFIG_DATA_DIR, "sample_config_for_merge.yaml") + config_file_2_local_path = os.path.join( + CONFIG_DATA_DIR, "sample_additional_config_for_merge.yaml" + ) + + config_file_1_as_yaml = open(config_file_1_local_path, "r").read() s3_uri_config_1 = os.path.join(s3_uri_prefix, "config_1.yaml") - s3_uri_config_2 = os.path.join(s3_uri_prefix, "config_2.yaml") # Upload S3 files in case they dont already exist S3Uploader.upload_string_as_file_body( body=config_file_1_as_yaml, desired_s3_uri=s3_uri_config_1, - kms_key=s3_files_kms_key, - sagemaker_session=sagemaker_session, - ) - S3Uploader.upload_string_as_file_body( - body=config_file_2_as_yaml, - desired_s3_uri=s3_uri_config_2, - kms_key=s3_files_kms_key, + kms_key=kms_key_arn, sagemaker_session=sagemaker_session, ) # The thing being tested. - sagemaker_config = SageMakerConfig(additional_config_paths=[s3_uri_config_1, s3_uri_config_2]) + sagemaker_config = SageMakerConfig( + additional_config_paths=[s3_uri_config_1, config_file_2_local_path] + ) assert sagemaker_config.config == expected_merged_config + + +@pytest.mark.slow_test +def test_sagemaker_config_cross_context_injection( + sagemaker_session_with_dynamically_generated_sagemaker_config, + role_arn, + kms_key_arn, + s3_uri_prefix, + cpu_instance_type, + alternative_cpu_instance_type, +): + # This tests injection from the sagemaker_config, specifically for one scenario where a method + # call (deploy of PipelineModel) leads to injections from separate + # Model, EndpointConfig, and Endpoint configs. + + sagemaker_session = sagemaker_session_with_dynamically_generated_sagemaker_config + name = "test-sm-config-pipeline-deploy-{}".format(sagemaker_timestamp()) + test_tags = [ + { + "Key": "Test", + "Value": "test_sagemaker_config_cross_context_injection", + }, + ] + data_capture_s3_uri = os.path.join(s3_uri_prefix, "model-monitor", "data-capture") + + sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") + xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model") + sparkml_model_data = sagemaker_session.upload_data( + path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), + key_prefix=S3_KEY_PREFIX + "/sparkml/model", + ) + xgb_model_data = sagemaker_session.upload_data( + path=os.path.join(xgboost_data_path, "xgb_model.tar.gz"), + key_prefix=S3_KEY_PREFIX + "/xgboost/model", + ) + + with timeout_and_delete_endpoint_by_name(name, sagemaker_session): + + # Create classes + sparkml_model = SparkMLModel( + model_data=sparkml_model_data, + env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, + sagemaker_session=sagemaker_session, + ) + xgb_image = image_uris.retrieve( + "xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference" + ) + xgb_model = Model( + model_data=xgb_model_data, + image_uri=xgb_image, + sagemaker_session=sagemaker_session, + ) + pipeline_model = PipelineModel( + models=[sparkml_model, xgb_model], + predictor_cls=Predictor, + sagemaker_session=sagemaker_session, + name=name, + ) + + # Basic check before any API calls that config parameters were injected. Not included: + # - VpcConfig: The VPC created by the test suite today (via get_or_create_vpc_resources) + # creates two subnets in the same AZ. However, CreateEndpoint fails if it + # does not have at least two AZs. TODO: Can explore either creating a new + # VPC or modifying the existing one, so that it can be included in the + # config too + # - Tags: By design. These are injected before the API call, not inside the Model classes + assert [ + sparkml_model.role, + xgb_model.role, + pipeline_model.role, + sparkml_model.enable_network_isolation(), + xgb_model.enable_network_isolation(), + pipeline_model.enable_network_isolation, # This is not a function in PipelineModel + ] == [role_arn, role_arn, role_arn, True, True, True] + + # First mutating API call where sagemaker_config values should be injected in + predictor = pipeline_model.deploy( + 1, + alternative_cpu_instance_type, + endpoint_name=name, + data_capture_config=DataCaptureConfig( + True, + destination_s3_uri=data_capture_s3_uri, + sagemaker_session=sagemaker_session, + ), + tags=test_tags, + ) + endpoint_1 = sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=name) + + # Second mutating API call where sagemaker_config values should be injected in + predictor.update_endpoint(initial_instance_count=1, instance_type=cpu_instance_type) + endpoint_2 = sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=name) + + # Call remaining describe APIs to fetch info that we will validate against + model = sagemaker_session.sagemaker_client.describe_model(ModelName=name) + endpoint_config_1_name = endpoint_1["EndpointConfigName"] + endpoint_config_2_name = endpoint_2["EndpointConfigName"] + endpoint_config_1 = sagemaker_session.sagemaker_client.describe_endpoint_config( + EndpointConfigName=endpoint_config_1_name + ) + endpoint_config_2 = sagemaker_session.sagemaker_client.describe_endpoint_config( + EndpointConfigName=endpoint_config_2_name + ) + model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model["ModelArn"]) + endpoint_1_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=endpoint_1["EndpointArn"] + ) + endpoint_2_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=endpoint_2["EndpointArn"] + ) + endpoint_config_1_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=endpoint_config_1["EndpointConfigArn"] + ) + endpoint_config_2_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=endpoint_config_2["EndpointConfigArn"] + ) + + # Remove select key-values from the Describe API outputs that we do not need to compare + # (things that may keep changing over time, and ARNs.) + # Still leaving in more than just the sagemaker_config injected fields so we can verify that + # the injection has not overwritten anything it shouldn't. + for key in ["Containers", "CreationTime", "ModelArn", "ResponseMetadata"]: + model.pop(key) + + for key in ["EndpointArn", "CreationTime", "LastModifiedTime", "ResponseMetadata"]: + endpoint_1.pop(key) + endpoint_2.pop(key) + del endpoint_1["ProductionVariants"][0]["DeployedImages"] + del endpoint_2["ProductionVariants"][0]["DeployedImages"] + + for key in ["EndpointConfigArn", "CreationTime", "ResponseMetadata"]: + endpoint_config_1.pop(key) + endpoint_config_2.pop(key) + + for key in ["ResponseMetadata"]: + model_tags.pop(key) + endpoint_1_tags.pop(key) + endpoint_2_tags.pop(key) + endpoint_config_1_tags.pop(key) + endpoint_config_2_tags.pop(key) + + # Expected parameters for these objects + expected_model = { + "ModelName": name, + "InferenceExecutionConfig": {"Mode": "Serial"}, + "ExecutionRoleArn": role_arn, # from sagemaker_config + "EnableNetworkIsolation": True, # from sagemaker_config + } + + expected_endpoint_1 = { + "EndpointName": name, + "EndpointConfigName": endpoint_config_1_name, + "ProductionVariants": [ + { + "VariantName": "AllTraffic", + "CurrentWeight": 1.0, + "DesiredWeight": 1.0, + "CurrentInstanceCount": 1, + "DesiredInstanceCount": 1, + } + ], + "DataCaptureConfig": { + "EnableCapture": True, + "CaptureStatus": "Started", + "CurrentSamplingPercentage": 20, + "DestinationS3Uri": data_capture_s3_uri, + "KmsKeyId": kms_key_arn, # from sagemaker_config + }, + "EndpointStatus": "InService", + } + + expected_endpoint_2 = { + **expected_endpoint_1, + "EndpointConfigName": endpoint_config_2_name, + } + + expected_endpoint_config_1 = { + "EndpointConfigName": endpoint_config_1_name, + "ProductionVariants": [ + { + "VariantName": "AllTraffic", + "ModelName": name, + "InitialInstanceCount": 1, + "InstanceType": alternative_cpu_instance_type, + "InitialVariantWeight": 1.0, + "VolumeSizeInGB": 4, + } + ], + "DataCaptureConfig": { + "EnableCapture": True, + "InitialSamplingPercentage": 20, + "DestinationS3Uri": data_capture_s3_uri, + "KmsKeyId": kms_key_arn, # from sagemaker_config + "CaptureOptions": [{"CaptureMode": "Input"}, {"CaptureMode": "Output"}], + "CaptureContentTypeHeader": { + "CsvContentTypes": ["text/csv"], + "JsonContentTypes": ["application/json"], + }, + }, + "KmsKeyId": kms_key_arn, # from sagemaker_config + } + + expected_endpoint_config_2 = { + **expected_endpoint_config_1, + "EndpointConfigName": endpoint_config_2_name, + "ProductionVariants": [ + { + "VariantName": "AllTraffic", + "ModelName": name, + "InitialInstanceCount": 1, + "InstanceType": cpu_instance_type, + "InitialVariantWeight": 1.0, + "VolumeSizeInGB": 16, + } + ], + } + + # TODO: Update expected tags for endpoints if injection behavior is changed + expected_model_tags = {"Tags": MODEL_TAGS} + expected_endpoint_1_tags = {"Tags": test_tags + ENDPOINT_CONFIG_TAGS} + expected_endpoint_2_tags = {"Tags": test_tags + ENDPOINT_CONFIG_TAGS} + expected_endpoint_config_1_tags = {"Tags": test_tags + ENDPOINT_CONFIG_TAGS} + expected_endpoint_config_2_tags = {"Tags": test_tags + ENDPOINT_CONFIG_TAGS} + + # Doing the comparison in this way simplifies debugging failures for this test, + # because all the values can be compared and checked together at once, rather than having + # to run the test repeatedly to get through 10 separate comparisons one at a time + assert [ + model, + endpoint_1, + endpoint_2, + endpoint_config_1, + endpoint_config_2, + set(model_tags), + set(endpoint_1_tags), + set(endpoint_2_tags), + set(endpoint_config_1_tags), + set(endpoint_config_2_tags), + ] == [ + expected_model, + expected_endpoint_1, + expected_endpoint_2, + expected_endpoint_config_1, + expected_endpoint_config_2, + set(expected_model_tags), + set(expected_endpoint_1_tags), + set(expected_endpoint_2_tags), + set(expected_endpoint_config_1_tags), + set(expected_endpoint_config_2_tags), + ] + + # Finally delete the model. (Endpoints should be deleted by the + # timeout_and_delete_endpoint_by_name above ) + pipeline_model.delete_model() + with pytest.raises(Exception) as exception: + sagemaker_session.sagemaker_client.describe_model(ModelName=pipeline_model.name) + assert "Could not find model" in str(exception.value) From 19e185aa8aec7a8f415b8ae38095396bb674f225 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 20 Mar 2023 10:20:27 -0700 Subject: [PATCH 24/40] change: remove unwanted method and replace it with a different method for config injection --- src/sagemaker/automl/automl.py | 16 ++- src/sagemaker/automl/candidate_estimator.py | 8 +- src/sagemaker/estimator.py | 20 ++-- src/sagemaker/feature_store/feature_group.py | 12 +-- src/sagemaker/model.py | 56 +++++------ .../model_monitor/data_capture_config.py | 4 +- .../model_monitor/model_monitoring.py | 12 +-- src/sagemaker/pipeline.py | 12 +-- src/sagemaker/processing.py | 12 +-- src/sagemaker/session.py | 98 ++++++------------- src/sagemaker/transformer.py | 12 +-- src/sagemaker/workflow/pipeline.py | 12 +-- tests/unit/conftest.py | 8 +- tests/unit/sagemaker/automl/test_auto_ml.py | 4 - .../feature_store/test_feature_group.py | 8 +- .../sagemaker/huggingface/test_estimator.py | 5 - .../sagemaker/huggingface/test_processing.py | 4 - .../sagemaker/local/test_local_pipeline.py | 16 +-- tests/unit/sagemaker/model/test_deploy.py | 26 +++-- tests/unit/sagemaker/model/test_edge.py | 8 +- .../sagemaker/model/test_framework_model.py | 8 +- tests/unit/sagemaker/model/test_model.py | 8 +- .../sagemaker/model/test_model_package.py | 8 +- tests/unit/sagemaker/model/test_neo.py | 8 +- .../monitor/test_clarify_model_monitor.py | 4 - .../monitor/test_data_capture_config.py | 8 +- .../monitor/test_model_monitoring.py | 4 - .../sagemaker/tensorflow/test_estimator.py | 4 - .../tensorflow/test_estimator_attach.py | 8 +- .../tensorflow/test_estimator_init.py | 7 +- tests/unit/sagemaker/tensorflow/test_tfs.py | 8 +- .../test_huggingface_pytorch_compiler.py | 4 - .../test_huggingface_tensorflow_compiler.py | 4 - .../test_pytorch_compiler.py | 4 - .../test_tensorflow_compiler.py | 4 - tests/unit/sagemaker/workflow/conftest.py | 8 +- tests/unit/sagemaker/workflow/test_airflow.py | 4 - .../unit/sagemaker/workflow/test_pipeline.py | 13 ++- .../sagemaker/wrangler/test_processing.py | 8 +- tests/unit/test_algorithm.py | 8 +- tests/unit/test_amazon_estimator.py | 4 - tests/unit/test_chainer.py | 5 - tests/unit/test_estimator.py | 13 +-- tests/unit/test_fm.py | 8 +- tests/unit/test_ipinsights.py | 8 +- tests/unit/test_job.py | 8 +- tests/unit/test_kmeans.py | 8 +- tests/unit/test_knn.py | 8 +- tests/unit/test_lda.py | 8 +- tests/unit/test_linear_learner.py | 8 +- tests/unit/test_multidatamodel.py | 8 +- tests/unit/test_mxnet.py | 5 - tests/unit/test_ntm.py | 8 +- tests/unit/test_object2vec.py | 8 +- tests/unit/test_pca.py | 8 +- tests/unit/test_pipeline_model.py | 5 - tests/unit/test_predictor.py | 4 - tests/unit/test_predictor_async.py | 4 - tests/unit/test_processing.py | 13 +-- tests/unit/test_pytorch.py | 5 - tests/unit/test_randomcutforest.py | 9 +- tests/unit/test_rl.py | 5 - tests/unit/test_sklearn.py | 5 - tests/unit/test_sparkml_serving.py | 4 - tests/unit/test_timeout.py | 4 - tests/unit/test_transformer.py | 8 +- tests/unit/test_tuner.py | 5 - tests/unit/test_xgboost.py | 4 - 68 files changed, 286 insertions(+), 389 deletions(-) diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 67379558b5..9b59a6579a 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -207,18 +207,16 @@ def __init__( self._auto_ml_job_desc = None self._best_candidate = None self.sagemaker_session = sagemaker_session or Session() - self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( - AUTO_ML_VPC_CONFIG_PATH, default_value=vpc_config + self.vpc_config = self.sagemaker_session.resolve_value_from_config( + vpc_config, AUTO_ML_VPC_CONFIG_PATH ) - self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( - AUTO_ML_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( + volume_kms_key, AUTO_ML_VOLUME_KMS_KEY_ID_PATH ) - self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( - AUTO_ML_KMS_KEY_ID_PATH, default_value=output_kms_key - ) - self.role = self.sagemaker_session.get_sagemaker_config_override( - AUTO_ML_ROLE_ARN_PATH, default_value=role + self.output_kms_key = self.sagemaker_session.resolve_value_from_config( + output_kms_key, AUTO_ML_KMS_KEY_ID_PATH ) + self.role = self.sagemaker_session.resolve_value_from_config(role, AUTO_ML_ROLE_ARN_PATH) if not self.role: # Originally IAM role was a required parameter. # Now we marked that as Optional because we can fetch it from SageMakerConfig diff --git a/src/sagemaker/automl/candidate_estimator.py b/src/sagemaker/automl/candidate_estimator.py index eccf6c499a..80d7be1a2d 100644 --- a/src/sagemaker/automl/candidate_estimator.py +++ b/src/sagemaker/automl/candidate_estimator.py @@ -106,11 +106,11 @@ def fit( """Logs can only be shown if wait is set to True. Please either set wait to True or set logs to False.""" ) - vpc_config = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_VPC_CONFIG_PATH, default_value=vpc_config + vpc_config = self.sagemaker_session.resolve_value_from_config( + vpc_config, TRAINING_JOB_VPC_CONFIG_PATH ) - volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + volume_kms_key = self.sagemaker_session.resolve_value_from_config( + volume_kms_key, TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH ) self.name = candidate_name or self.name running_jobs = {} diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 8d23fe036c..ba2fa6a91d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -586,8 +586,8 @@ def __init__( self.deploy_instance_type = None self._compiled_models = {} - self.role = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_ROLE_ARN_PATH, role + self.role = self.sagemaker_session.resolve_value_from_config( + role, TRAINING_JOB_ROLE_ARN_PATH ) if not self.role: # Originally IAM role was a required parameter. @@ -595,19 +595,19 @@ def __init__( # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. raise ValueError("IAM role should be provided for creating estimators.") - self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_KMS_KEY_ID_PATH, default_value=output_kms_key + self.output_kms_key = self.sagemaker_session.resolve_value_from_config( + output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH ) - self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( + volume_kms_key, TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH ) # VPC configurations - self.subnets = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_SUBNETS_PATH, default_value=subnets + self.subnets = self.sagemaker_session.resolve_value_from_config( + subnets, TRAINING_JOB_SUBNETS_PATH ) - self.security_group_ids = self.sagemaker_session.get_sagemaker_config_override( - TRAINING_JOB_SECURITY_GROUP_IDS_PATH, default_value=security_group_ids + self.security_group_ids = self.sagemaker_session.resolve_value_from_config( + security_group_ids, TRAINING_JOB_SECURITY_GROUP_IDS_PATH ) # training image configs diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index ca89b91385..27ed96b603 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -557,14 +557,14 @@ def create( Returns: Response dict from service. """ - role_arn = self.sagemaker_session.get_sagemaker_config_override( - FEATURE_GROUP_ROLE_ARN_PATH, default_value=role_arn + role_arn = self.sagemaker_session.resolve_value_from_config( + role_arn, FEATURE_GROUP_ROLE_ARN_PATH ) - offline_store_kms_key_id = self.sagemaker_session.get_sagemaker_config_override( - FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, default_value=offline_store_kms_key_id + offline_store_kms_key_id = self.sagemaker_session.resolve_value_from_config( + offline_store_kms_key_id, FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH ) - online_store_kms_key_id = self.sagemaker_session.get_sagemaker_config_override( - FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, default_value=online_store_kms_key_id + online_store_kms_key_id = self.sagemaker_session.resolve_value_from_config( + online_store_kms_key_id, FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH ) if not role_arn: # Originally IAM role was a required parameter. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ac3266da3a..5124c1a414 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -284,16 +284,12 @@ def __init__( self._base_name = None self.sagemaker_session = sagemaker_session self.role = ( - self.sagemaker_session.get_sagemaker_config_override( - MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role - ) + self.sagemaker_session.resolve_value_from_config(role, MODEL_EXECUTION_ROLE_ARN_PATH) if sagemaker_session else role ) self.vpc_config = ( - self.sagemaker_session.get_sagemaker_config_override( - MODEL_VPC_CONFIG_PATH, default_value=vpc_config - ) + self.sagemaker_session.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) if sagemaker_session else vpc_config ) @@ -304,8 +300,8 @@ def __init__( self.inference_recommender_job_results = None self.inference_recommendations = None self._enable_network_isolation = ( - self.sagemaker_session.get_sagemaker_config_override( - MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=enable_network_isolation + self.sagemaker_session.resolve_value_from_config( + enable_network_isolation, MODEL_ENABLE_NETWORK_ISOLATION_PATH ) if sagemaker_session else enable_network_isolation @@ -705,14 +701,14 @@ def _create_sagemaker_model( self._init_sagemaker_session_if_does_not_exist(instance_type) # Depending on the instance type, a local session (or) a session is initialized. - self.role = self.sagemaker_session.get_sagemaker_config_override( - MODEL_EXECUTION_ROLE_ARN_PATH, default_value=self.role + self.role = self.sagemaker_session.resolve_value_from_config( + self.role, MODEL_EXECUTION_ROLE_ARN_PATH ) - self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( - MODEL_VPC_CONFIG_PATH, default_value=self.vpc_config + self.vpc_config = self.sagemaker_session.resolve_value_from_config( + self.vpc_config, MODEL_VPC_CONFIG_PATH ) - self._enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( - MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=self._enable_network_isolation + self._enable_network_isolation = self.sagemaker_session.resolve_value_from_config( + self._enable_network_isolation, MODEL_ENABLE_NETWORK_ISOLATION_PATH ) create_model_args = dict( name=self.name, @@ -910,12 +906,10 @@ def package_for_edge( if job_name is None: job_name = f"packaging{self._compilation_job_name[11:]}" self._init_sagemaker_session_if_does_not_exist(None) - s3_kms_key = self.sagemaker_session.get_sagemaker_config_override( - EDGE_PACKAGING_KMS_KEY_ID_PATH, default_value=s3_kms_key - ) - role = self.sagemaker_session.get_sagemaker_config_override( - EDGE_PACKAGING_ROLE_ARN_PATH, default_value=role + s3_kms_key = self.sagemaker_session.resolve_value_from_config( + s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH ) + role = self.sagemaker_session.resolve_value_from_config(role, EDGE_PACKAGING_ROLE_ARN_PATH) if role is not None: role = self.sagemaker_session.expand_role(role) config = self._edge_packaging_job_config( @@ -1020,9 +1014,7 @@ def compile( framework_version = framework_version or self._get_framework_version() self._init_sagemaker_session_if_does_not_exist(target_instance_family) - role = self.sagemaker_session.get_sagemaker_config_override( - COMPILATION_JOB_ROLE_ARN_PATH, default_value=role - ) + role = self.sagemaker_session.resolve_value_from_config(role, COMPILATION_JOB_ROLE_ARN_PATH) if not role: # Originally IAM role was a required parameter. # Now we marked that as Optional because we can fetch it from SageMakerConfig @@ -1182,14 +1174,14 @@ def deploy( self._init_sagemaker_session_if_does_not_exist(instance_type) # Depending on the instance type, a local session (or) a session is initialized. - self.role = self.sagemaker_session.get_sagemaker_config_override( - MODEL_EXECUTION_ROLE_ARN_PATH, default_value=self.role + self.role = self.sagemaker_session.resolve_value_from_config( + self.role, MODEL_EXECUTION_ROLE_ARN_PATH ) - self.vpc_config = self.sagemaker_session.get_sagemaker_config_override( - MODEL_VPC_CONFIG_PATH, default_value=self.vpc_config + self.vpc_config = self.sagemaker_session.resolve_value_from_config( + self.vpc_config, MODEL_VPC_CONFIG_PATH ) - self._enable_network_isolation = self.sagemaker_session.get_sagemaker_config_override( - MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=self._enable_network_isolation + self._enable_network_isolation = self.sagemaker_session.resolve_value_from_config( + self._enable_network_isolation, MODEL_ENABLE_NETWORK_ISOLATION_PATH ) tags = add_jumpstart_tags( @@ -1277,11 +1269,9 @@ def deploy( async_inference_config = self._build_default_async_inference_config( async_inference_config ) - async_inference_config.kms_key_id = ( - self.sagemaker_session.get_sagemaker_config_override( - ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, - default_value=async_inference_config.kms_key_id, - ) + async_inference_config.kms_key_id = self.sagemaker_session.resolve_value_from_config( + async_inference_config.kms_key_id, + ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, ) async_inference_config_dict = async_inference_config._to_request_dict() diff --git a/src/sagemaker/model_monitor/data_capture_config.py b/src/sagemaker/model_monitor/data_capture_config.py index b9febada4b..d8ee8e21cb 100644 --- a/src/sagemaker/model_monitor/data_capture_config.py +++ b/src/sagemaker/model_monitor/data_capture_config.py @@ -75,8 +75,8 @@ def __init__( _DATA_CAPTURE_S3_PATH, ) - self.kms_key_id = sagemaker_session.get_sagemaker_config_override( - ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH, kms_key_id + self.kms_key_id = sagemaker_session.resolve_value_from_config( + kms_key_id, ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH ) self.capture_options = capture_options or ["REQUEST", "RESPONSE"] self.csv_content_types = csv_content_types or ["text/csv"] diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 1cd757cfbd..8988285eeb 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -177,8 +177,8 @@ def __init__( self.latest_baselining_job_name = None self.monitoring_schedule_name = None self.job_definition_name = None - self.role = self.sagemaker_session.get_sagemaker_config_override( - MONITORING_JOB_ROLE_ARN_PATH, default_value=role + self.role = self.sagemaker_session.resolve_value_from_config( + role, MONITORING_JOB_ROLE_ARN_PATH ) if not self.role: # Originally IAM role was a required parameter. @@ -186,11 +186,11 @@ def __init__( # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. raise ValueError("IAM role should be provided for creating Monitoring Schedule.") - self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( - MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( + volume_kms_key, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH ) - self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( - MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, default_value=output_kms_key + self.output_kms_key = self.sagemaker_session.resolve_value_from_config( + output_kms_key, MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH ) self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( NetworkConfig, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index ccb87aa966..81ef07b080 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -92,16 +92,12 @@ def __init__( self.sagemaker_session = sagemaker_session self.endpoint_name = None self.role = ( - self.sagemaker_session.get_sagemaker_config_override( - MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role - ) + self.sagemaker_session.resolve_value_from_config(role, MODEL_EXECUTION_ROLE_ARN_PATH) if self.sagemaker_session else role ) self.vpc_config = ( - self.sagemaker_session.get_sagemaker_config_override( - MODEL_VPC_CONFIG_PATH, default_value=vpc_config - ) + self.sagemaker_session.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) if self.sagemaker_session else vpc_config ) @@ -248,8 +244,8 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, ) self.endpoint_name = endpoint_name or self.name - kms_key = self.sagemaker_session.get_sagemaker_config_override( - ENDPOINT_CONFIG_KMS_KEY_ID_PATH, default_value=kms_key + kms_key = self.sagemaker_session.resolve_value_from_config( + kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH ) data_capture_config_dict = None diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index c189e27698..4ad59d3fbd 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -148,11 +148,11 @@ def __init__( sagemaker_session = LocalSession(disable_local_code=True) self.sagemaker_session = sagemaker_session or Session() - self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( - PROCESSING_JOB_KMS_KEY_ID_PATH, default_value=output_kms_key + self.output_kms_key = self.sagemaker_session.resolve_value_from_config( + output_kms_key, PROCESSING_JOB_KMS_KEY_ID_PATH ) - self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( - PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( + volume_kms_key, PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH ) self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( NetworkConfig, @@ -178,8 +178,8 @@ def __init__( "encrypt_inter_container_traffic", PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, ) - self.role = self.sagemaker_session.get_sagemaker_config_override( - PROCESSING_JOB_ROLE_ARN_PATH, default_value=role + self.role = self.sagemaker_session.resolve_value_from_config( + role, PROCESSING_JOB_ROLE_ARN_PATH ) if not self.role: # Originally IAM role was a required parameter. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 04d67d868e..8518593b95 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -660,28 +660,6 @@ def get_sagemaker_config_value(self, key): # Copy the value so any modifications to the output will not modify the source config return copy.deepcopy(config_value) - def get_sagemaker_config_override(self, key, default_value=None): - """Util method that fetches a particular key path in the SageMakerConfig and returns it. - - If a default value is provided, then this method will return the default value. - - Args: - key: Key Path of the config file entry. - default_value: The existing value that was passed as method parameter. If this is not - None, then the method will return this value - - Returns: - object: The corresponding value in the Config file/ the default value. - - """ - config_value = self.get_sagemaker_config_value(key) - self._print_message_on_sagemaker_config_usage(default_value, config_value, key) - - if default_value is not None: - return default_value - - return config_value - def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): """Creates an S3 Bucket if it does not exist. @@ -1130,7 +1108,7 @@ def train( # noqa: C901 config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, default_value=False, ) - role = self.get_sagemaker_config_override(TRAINING_JOB_ROLE_ARN_PATH, role) + role = self.resolve_value_from_config(role, TRAINING_JOB_ROLE_ARN_PATH) enable_network_isolation = self.resolve_value_from_config( direct_input=enable_network_isolation, config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, @@ -1491,8 +1469,8 @@ def _update_processing_input_from_config(self, inputs): """ inputs_copy = copy.deepcopy(inputs) - processing_inputs_from_config = ( - self.get_sagemaker_config_override(PROCESSING_JOB_INPUTS_PATH) or [] + processing_inputs_from_config = self.resolve_value_from_config( + config_path=PROCESSING_JOB_INPUTS_PATH, default_value=[] ) for i in range(min(len(inputs), len(processing_inputs_from_config))): dict_from_inputs = inputs[i] @@ -1580,9 +1558,7 @@ def process( ) self._update_processing_input_from_config(inputs) - role_arn = self.get_sagemaker_config_override( - PROCESSING_JOB_ROLE_ARN_PATH, default_value=role_arn - ) + role_arn = self.resolve_value_from_config(role_arn, PROCESSING_JOB_ROLE_ARN_PATH) inferred_network_config_from_config = ( self._update_nested_dictionary_with_values_from_config( network_config, PROCESSING_JOB_NETWORK_CONFIG_PATH @@ -1762,11 +1738,9 @@ def create_monitoring_schedule( tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. """ - role_arn = self.get_sagemaker_config_override( - MONITORING_JOB_ROLE_ARN_PATH, default_value=role_arn - ) - volume_kms_key = self.get_sagemaker_config_override( - MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + role_arn = self.resolve_value_from_config(role_arn, MONITORING_JOB_ROLE_ARN_PATH) + volume_kms_key = self.resolve_value_from_config( + volume_kms_key, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH ) inferred_network_config_from_config = ( self._update_nested_dictionary_with_values_from_config( @@ -1797,8 +1771,8 @@ def create_monitoring_schedule( } if monitoring_output_config is not None: - kms_key_from_config = self.get_sagemaker_config_override( - MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH + kms_key_from_config = self.resolve_value_from_config( + config_path=MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH ) if KMS_KEY_ID not in monitoring_output_config and kms_key_from_config: monitoring_output_config[KMS_KEY_ID] = kms_key_from_config @@ -2467,7 +2441,7 @@ def auto_ml( Contains "AutoGenerateEndpointName" and "EndpointName" """ - role = self.get_sagemaker_config_override(AUTO_ML_ROLE_ARN_PATH, default_value=role) + role = self.resolve_value_from_config(role, AUTO_ML_ROLE_ARN_PATH) inferred_output_config = self._update_nested_dictionary_with_values_from_config( output_config, AUTO_ML_OUTPUT_CONFIG_PATH ) @@ -2746,11 +2720,11 @@ def compile_model( Returns: str: ARN of the compile model job, if it is created. """ - role = self.get_sagemaker_config_override(COMPILATION_JOB_ROLE_ARN_PATH, default_value=role) + role = self.resolve_value_from_config(role, COMPILATION_JOB_ROLE_ARN_PATH) inferred_output_model_config = self._update_nested_dictionary_with_values_from_config( output_model_config, COMPILATION_JOB_OUTPUT_CONFIG_PATH ) - vpc_config = self.get_sagemaker_config_override(COMPILATION_JOB_VPC_CONFIG_PATH) + vpc_config = self.resolve_value_from_config(config_path=COMPILATION_JOB_VPC_CONFIG_PATH) compilation_job_request = { "InputConfig": input_model_config, "OutputConfig": inferred_output_model_config, @@ -2796,7 +2770,7 @@ def package_model_for_edge( tags (list[dict]): List of tags for labeling a compile model job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ - role = self.get_sagemaker_config_override(EDGE_PACKAGING_ROLE_ARN_PATH, role) + role = self.resolve_value_from_config(role, EDGE_PACKAGING_ROLE_ARN_PATH) inferred_output_model_config = self._update_nested_dictionary_with_values_from_config( output_model_config, EDGE_PACKAGING_OUTPUT_CONFIG_PATH ) @@ -3692,10 +3666,8 @@ def create_model( """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS)) - role = self.get_sagemaker_config_override(MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role) - vpc_config = self.get_sagemaker_config_override( - MODEL_VPC_CONFIG_PATH, default_value=vpc_config - ) + role = self.resolve_value_from_config(role, MODEL_EXECUTION_ROLE_ARN_PATH) + vpc_config = self.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) enable_network_isolation = self.resolve_value_from_config( direct_input=enable_network_isolation, config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, @@ -3773,8 +3745,8 @@ def create_model_from_job( ) name = name or training_job_name role = role or training_job["RoleArn"] - role = self.get_sagemaker_config_override( - MODEL_EXECUTION_ROLE_ARN_PATH, default_value=role or training_job["RoleArn"] + role = self.resolve_value_from_config( + role, MODEL_EXECUTION_ROLE_ARN_PATH, training_job["RoleArn"] ) enable_network_isolation = self.resolve_value_from_config( direct_input=enable_network_isolation, @@ -3788,9 +3760,7 @@ def create_model_from_job( env=env, ) vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override) - vpc_config = self.get_sagemaker_config_override( - MODEL_VPC_CONFIG_PATH, default_value=vpc_config - ) + vpc_config = self.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) return self.create_model( name, role, @@ -3892,13 +3862,13 @@ def create_model_package_from_containers( # not supported by the config now. So if we merge values from config, then API will # throw an exception. In the future, when SageMaker Config starts supporting other # parameters we can add that. - validation_role = self.get_sagemaker_config_override( + validation_role = self.resolve_value_from_config( + validation_specification.get(VALIDATION_ROLE, None), MODEL_PACKAGE_VALIDATION_ROLE_PATH, - default_value=validation_specification.get(VALIDATION_ROLE, None), ) validation_specification[VALIDATION_ROLE] = validation_role - validation_profiles_from_config = ( - self.get_sagemaker_config_override(MODEL_PACKAGE_VALIDATION_PROFILES_PATH) or [] + validation_profiles_from_config = self.resolve_value_from_config( + config_path=MODEL_PACKAGE_VALIDATION_PROFILES_PATH, default_value=[] ) validation_profiles = validation_specification.get(VALIDATION_PROFILES, []) for i in range(min(len(validation_profiles), len(validation_profiles_from_config))): @@ -4068,8 +4038,8 @@ def create_endpoint_config( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, ) - inferred_production_variants_from_config = ( - self.get_sagemaker_config_override(ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH) or [] + inferred_production_variants_from_config = self.resolve_value_from_config( + config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, default_value=[] ) if inferred_production_variants_from_config: inferred_production_variant_from_config = ( @@ -4100,9 +4070,7 @@ def create_endpoint_config( ) if tags is not None: request["Tags"] = tags - kms_key = self.get_sagemaker_config_override( - ENDPOINT_CONFIG_KMS_KEY_ID_PATH, default_value=kms_key - ) + kms_key = self.resolve_value_from_config(kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH) if kms_key is not None: request["KmsKeyId"] = kms_key @@ -4170,8 +4138,8 @@ def create_endpoint_config_from_existing( new_production_variants or existing_endpoint_config_desc["ProductionVariants"] ) if production_variants: - inferred_production_variants_from_config = ( - self.get_sagemaker_config_override(ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH) or [] + inferred_production_variants_from_config = self.resolve_value_from_config( + config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, default_value=[] ) for i in range( min(len(production_variants), len(inferred_production_variants_from_config)) @@ -4204,8 +4172,8 @@ def create_endpoint_config_from_existing( if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None: request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId") if KMS_KEY_ID not in request: - kms_key_from_config = self.get_sagemaker_config_override( - ENDPOINT_CONFIG_KMS_KEY_ID_PATH + kms_key_from_config = self.resolve_value_from_config( + config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH ) if kms_key_from_config: request[KMS_KEY_ID] = kms_key_from_config @@ -4770,9 +4738,7 @@ def endpoint_from_production_variants( str: The name of the created ``Endpoint``. """ config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} - kms_key = self.get_sagemaker_config_override( - ENDPOINT_CONFIG_KMS_KEY_ID_PATH, default_value=kms_key - ) + kms_key = self.resolve_value_from_config(kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH) tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) @@ -5232,9 +5198,7 @@ def create_feature_group( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) ) - role_arn = self.get_sagemaker_config_override( - FEATURE_GROUP_ROLE_ARN_PATH, default_value=role_arn - ) + role_arn = self.resolve_value_from_config(role_arn, FEATURE_GROUP_ROLE_ARN_PATH) inferred_online_store_from_config = self._update_nested_dictionary_with_values_from_config( online_store_config, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH ) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 804947475a..f9bd16369b 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -126,11 +126,11 @@ def __init__( self._reset_output_path = False self.sagemaker_session = sagemaker_session or Session() - self.volume_kms_key = self.sagemaker_session.get_sagemaker_config_override( - TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, default_value=volume_kms_key + self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( + volume_kms_key, TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH ) - self.output_kms_key = self.sagemaker_session.get_sagemaker_config_override( - TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, default_value=output_kms_key + self.output_kms_key = self.sagemaker_session.resolve_value_from_config( + output_kms_key, TRANSFORM_OUTPUT_KMS_KEY_ID_PATH ) @runnable_by_pipeline @@ -439,8 +439,8 @@ def transform_with_monitoring( pipeline_role_arn = ( role if role - else transformer.sagemaker_session.get_sagemaker_config_override( - PIPELINE_ROLE_ARN_PATH, default_value=get_execution_role() + else transformer.sagemaker_session.resolve_value_from_config( + get_execution_role(), PIPELINE_ROLE_ARN_PATH ) ) pipeline.upsert(pipeline_role_arn) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 6cef6dbff2..41f234e1c0 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -127,8 +127,8 @@ def create( Returns: A response dict from the service. """ - role_arn = self.sagemaker_session.get_sagemaker_config_override( - PIPELINE_ROLE_ARN_PATH, role_arn + role_arn = self.sagemaker_session.resolve_value_from_config( + role_arn, PIPELINE_ROLE_ARN_PATH ) if not role_arn: # Originally IAM role was a required parameter. @@ -221,8 +221,8 @@ def update( Returns: A response dict from the service. """ - role_arn = self.sagemaker_session.get_sagemaker_config_override( - PIPELINE_ROLE_ARN_PATH, role_arn + role_arn = self.sagemaker_session.resolve_value_from_config( + role_arn, PIPELINE_ROLE_ARN_PATH ) if not role_arn: # Originally IAM role was a required parameter. @@ -261,8 +261,8 @@ def upsert( Returns: response dict from service """ - role_arn = self.sagemaker_session.get_sagemaker_config_override( - PIPELINE_ROLE_ARN_PATH, role_arn + role_arn = self.sagemaker_session.resolve_value_from_config( + role_arn, PIPELINE_ROLE_ARN_PATH ) if not role_arn: # Originally IAM role was a required parameter. diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 75d7730523..39d055091d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -71,9 +71,11 @@ def sagemaker_session(boto_session, client): default_bucket=_DEFAULT_BUCKET, sagemaker_metrics_client=client, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index 26d08784ac..ddcedc24e1 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -280,10 +280,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return sms diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py index b13e236ee4..9cd5c0f6c4 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -54,9 +54,11 @@ def s3_uri(): @pytest.fixture def sagemaker_session_mock(): session_mock = Mock() - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session_mock diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 867f79cd22..2fac0ca735 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -82,11 +82,6 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) - - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index def116ac53..46f8fcb775 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -55,10 +55,6 @@ def sagemaker_session(): name="resolve_class_attribute_from_config", side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, ) - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) session_mock.resolve_value_from_config = Mock( name="resolve_value_from_config", side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py index f119e52b41..43a69e71f1 100644 --- a/tests/unit/sagemaker/local/test_local_pipeline.py +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -140,9 +140,11 @@ def pipeline_session(boto_session, client): sagemaker_client=client, default_bucket=BUCKET, ) - pipeline_session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + pipeline_session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return pipeline_session_mock @@ -150,9 +152,11 @@ def pipeline_session(boto_session, client): @pytest.fixture() def local_sagemaker_session(boto_session): local_session_mock = LocalSession(boto_session=boto_session, default_bucket="my-bucket") - local_session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + local_session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return local_session_mock diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index a31a431526..16894dfbc2 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -68,9 +68,11 @@ @pytest.fixture def sagemaker_session(): session = Mock() - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session @@ -441,13 +443,17 @@ def test_deploy_wrong_serverless_config(sagemaker_session): @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_deploy_creates_correct_session(local_session, session): - local_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + local_session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, + ) + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) # We expect a LocalSession when deploying to instance_type = 'local' model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) diff --git a/tests/unit/sagemaker/model/test_edge.py b/tests/unit/sagemaker/model/test_edge.py index 627cd28086..21fb6d710a 100644 --- a/tests/unit/sagemaker/model/test_edge.py +++ b/tests/unit/sagemaker/model/test_edge.py @@ -31,9 +31,11 @@ @pytest.fixture def sagemaker_session(): session = Mock(boto_region_name=REGION) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index 9d01d9fead..36171477b5 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -94,9 +94,11 @@ def sagemaker_session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index ff7408d6ca..692bab1c4f 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -109,9 +109,11 @@ def sagemaker_session(): s3_resource=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index c0cdff6483..068bc164fc 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -59,9 +59,11 @@ def sagemaker_session(): session.sagemaker_client.describe_model_package = Mock( return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index b21957a205..1e135e221e 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -35,9 +35,11 @@ @pytest.fixture def sagemaker_session(): session = Mock(boto_region_name=REGION) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 75a414b35c..fc47144df2 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -432,10 +432,6 @@ def sagemaker_session(sagemaker_client): if direct_input is not None else default_value, ) - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session_mock diff --git a/tests/unit/sagemaker/monitor/test_data_capture_config.py b/tests/unit/sagemaker/monitor/test_data_capture_config.py index 5717aa4ad8..8587851266 100644 --- a/tests/unit/sagemaker/monitor/test_data_capture_config.py +++ b/tests/unit/sagemaker/monitor/test_data_capture_config.py @@ -56,9 +56,11 @@ def test_init_when_non_defaults_provided(): def test_init_when_optionals_not_provided(): sagemaker_session = Mock() sagemaker_session.default_bucket.return_value = DEFAULT_BUCKET_NAME - sagemaker_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sagemaker_session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) data_capture_config = DataCaptureConfig( diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index d837112937..f258c8ab54 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -481,10 +481,6 @@ def sagemaker_session(): session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session_mock diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 43b978ea37..1de1d304d5 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -98,10 +98,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py index ceab141c0c..caa08f8e3c 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py @@ -42,9 +42,11 @@ def sagemaker_session(): describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_init.py b/tests/unit/sagemaker/tensorflow/test_estimator_init.py index dd7d974e0a..71764b0e93 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_init.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_init.py @@ -25,12 +25,7 @@ @pytest.fixture() def sagemaker_session(): - session_mock = Mock(name="sagemaker_session", boto_region_name=REGION) - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) - return session_mock + return Mock(name="sagemaker_session", boto_region_name=REGION) def _build_tf(sagemaker_session, **kwargs): diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index 2ce5e71681..ecbdf87a2e 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -67,9 +67,11 @@ def sagemaker_session(): session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 002489bb27..40e3241333 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -89,10 +89,6 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index a2a8708f2c..a0abeb2f29 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -86,10 +86,6 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 663917596c..643dc6337c 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -88,10 +88,6 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 1719223652..5f67a4c46b 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -93,10 +93,6 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py index 972c6e5d89..c8d9c44320 100644 --- a/tests/unit/sagemaker/workflow/conftest.py +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -64,8 +64,10 @@ def pipeline_session(mock_boto_session, mock_client): sagemaker_client=mock_client, default_bucket=BUCKET, ) - pipeline_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + pipeline_session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return pipeline_session diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index b8010ef0ad..e53efbb30b 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -53,10 +53,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 50cda15e40..9b4e0d006a 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -46,9 +46,11 @@ def sagemaker_session_mock(): session_mock = Mock() session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket") session_mock.local_mode = False - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags @@ -72,8 +74,9 @@ def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock): def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock): - sagemaker_session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", side_effect=lambda a, b: "ConfigRoleArn" + sagemaker_session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: "ConfigRoleArn", ) sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = { "PipelineArn": "pipeline-arn" diff --git a/tests/unit/sagemaker/wrangler/test_processing.py b/tests/unit/sagemaker/wrangler/test_processing.py index 48b6e17f1d..2f0d9a4875 100644 --- a/tests/unit/sagemaker/wrangler/test_processing.py +++ b/tests/unit/sagemaker/wrangler/test_processing.py @@ -46,9 +46,11 @@ def sagemaker_session(): name="resolve_class_attribute_from_config", side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, ) - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session_mock diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index 3bc247a307..cc23cabfa6 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -954,9 +954,11 @@ def test_algorithm_no_required_hyperparameters(session): def test_algorithm_attach_from_hyperparameter_tuning(): session = Mock() - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) job_name = "training-job-that-is-part-of-a-tuning-job" algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees" diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 15e7c2ecc7..0d51c7261f 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -74,10 +74,6 @@ def sagemaker_session(): sms.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=returned_job_description ) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) sms.resolve_value_from_config = Mock( name="resolve_value_from_config", side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 98824ed5c3..16e7958821 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -81,11 +81,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 44aeea0d64..cdec02e194 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -245,11 +245,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return sms @@ -4004,9 +3999,11 @@ def test_estimator_local_mode_error(sagemaker_session): def test_estimator_local_mode_ok(sagemaker_local_session): - sagemaker_local_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sagemaker_local_session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) # When using instance local with a session which is not LocalSession we should error out Estimator( diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index 5e5e08bbe8..db0f5c7075 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -68,9 +68,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index e44cf68546..0d9d4a1e8f 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -65,9 +65,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index f3a451f22e..f9dc9f1c40 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -80,9 +80,11 @@ def sagemaker_session(): name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None ) mock_session.expand_role = Mock(name="expand_role", return_value=ROLE) - mock_session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + mock_session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return mock_session diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index de0e517d63..13ef923fc2 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -62,9 +62,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 7880716edd..2d630dc5c5 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -68,9 +68,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index 1dc0001ec7..e1e9110923 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -57,9 +57,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index 83126166ae..d6a823674c 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -63,9 +63,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_multidatamodel.py b/tests/unit/test_multidatamodel.py index 58164be338..9777d1505d 100644 --- a/tests/unit/test_multidatamodel.py +++ b/tests/unit/test_multidatamodel.py @@ -80,9 +80,11 @@ def sagemaker_session(): name="upload_data", return_value=os.path.join(VALID_MULTI_MODEL_DATA_PREFIX, "mleap_model.tar.gz"), ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) s3_mock = Mock() diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 6d4e8e1430..0e3bad1dea 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -108,11 +108,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 570ef16b56..20aaa53590 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -62,9 +62,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index 98ec6f29e1..f72bce977e 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -70,9 +70,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index dfba755f18..555aa508dc 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -62,9 +62,11 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return sms diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index c710298d9d..7d4dd8389c 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -74,11 +74,6 @@ def sagemaker_session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) sms.resolve_value_from_config = Mock( name="resolve_value_from_config", side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index b86a5f0e67..ee53628ef4 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -51,10 +51,6 @@ def empty_sagemaker_session(): ims.sagemaker_runtime_client.invoke_endpoint = Mock( name="invoke_endpoint", return_value={"Body": response_body} ) - ims.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return ims diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index 45c45ef000..cc55cd32ed 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -52,10 +52,6 @@ def empty_sagemaker_session(): "OutputLocation": ASYNC_OUTPUT_LOCATION, }, ) - ims.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) response_body = Mock("body") response_body.read = Mock("read", return_value=RETURN_VALUE) response_body.close = Mock("close", return_value=None) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 325edac4b9..f3a7652bba 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -89,11 +89,6 @@ def sagemaker_session(): name="resolve_class_attribute_from_config", side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, ) - - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) session_mock.resolve_value_from_config = Mock( name="resolve_value_from_config", side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input @@ -122,9 +117,11 @@ def pipeline_session(): session_mock.describe_processing_job = MagicMock( name="describe_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) - session_mock.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session_mock.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) session_mock.__class__ = PipelineSession diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 4fa9bc146a..a1c9ffde3b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -90,11 +90,6 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) - - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index 65ee77399a..31ab045ffe 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -62,11 +62,12 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + sms.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) - return sms diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 1b73714ed1..f4369f09e6 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -83,11 +83,6 @@ def fixture_sagemaker_session(): if direct_input is not None else default_value, ) - - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index e9c79bf152..7095095a2d 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -85,11 +85,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index e6415aada8..3fb21d62d2 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -46,10 +46,6 @@ def sagemaker_session(): sms.boto_region_name = REGION sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return sms diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py index ee8bb5b138..bded6ce2cc 100644 --- a/tests/unit/test_timeout.py +++ b/tests/unit/test_timeout.py @@ -60,10 +60,6 @@ def session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name=DEFAULT_BUCKET_NAME, return_value=BUCKET_NAME) - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return sms diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index b6bd867948..8fadcab52f 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -72,9 +72,11 @@ def sagemaker_session(): name="resolve_class_attribute_from_config", side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, + session.resolve_value_from_config = Mock( + name="resolve_value_from_config", + side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input + if direct_input is not None + else default_value, ) return session diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 8079015e4b..1d6599c0ff 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -73,11 +73,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - - sms.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return sms diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index ebec275d8c..2c35ad8584 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -88,10 +88,6 @@ def sagemaker_session(): if direct_input is not None else default_value, ) - session.get_sagemaker_config_override = Mock( - name="get_sagemaker_config_override", - side_effect=lambda key, default_value=None: default_value, - ) return session From 3bd3a94ecf092f8d7c502275d7656cf33e7262f3 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 20 Mar 2023 11:35:03 -0700 Subject: [PATCH 25/40] fix: Address documentation errors and removed unnecessary properties and setters --- .../async_inference/async_inference_config.py | 23 +--- src/sagemaker/config/config.py | 120 +++--------------- 2 files changed, 20 insertions(+), 123 deletions(-) diff --git a/src/sagemaker/async_inference/async_inference_config.py b/src/sagemaker/async_inference/async_inference_config.py index 6f1bd56f6a..f5e2cb8f57 100644 --- a/src/sagemaker/async_inference/async_inference_config.py +++ b/src/sagemaker/async_inference/async_inference_config.py @@ -57,30 +57,9 @@ def __init__( """ self.output_path = output_path self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance - self._kms_key_id = kms_key_id + self.kms_key_id = kms_key_id self.notification_config = notification_config - @property - def kms_key_id(self): - """Getter for kms_key_id - - Returns: - str: The KMS Key ID. - """ - return self._kms_key_id - - @kms_key_id.setter - def kms_key_id(self, kms_key_id: str): - """Setter for kms_key_id - - Args: - kms_key_id: The new kms_key_id to replace the existing one. - - Returns: - - """ - self._kms_key_id = kms_key_id - def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" request_dict = { diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index dd01c7b332..b93e192c08 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -69,31 +69,32 @@ class SageMakerConfig(object): """ def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE): - """Constructor for SageMakerConfig. + """Initializes the SageMakerConfig object. - By default, it will first look for Config files in paths that are dictated by - _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH. + By default, it will first look for Config files in the default locations as dictated by + the SDK. - Users can override the _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH + Users can override the default Admin Config file path and the default User Config file path by using environment variables - SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE Additional Configuration file paths can also be provided as a constructor parameter. - This constructor will then - * Load each config file. + This __init__ method will then + * Load each config file (Can be in S3/Local File system). * It will validate the schema of the config files. * It will perform the merge operation in the same order. - This constructor will throw exceptions for the following cases: - * Schema validation fails for one/more config files. - * When the config file is not a proper YAML file. - * Any S3 related issues that arises while fetching config file from S3. This includes - permission issues, S3 Object is not found in the specified S3 URI. - * File doesn't exist in a path that was specified by the user as part of environment - variable/ additional_config_paths. This doesn't include - _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH - + This __init__ method will throw exceptions for the following cases: + * jsonschema.exceptions.ValidationError: Schema validation fails for one/more config files. + * RuntimeError: If the method is unable to retrieve the list of all the S3 files with the + same prefix/Unable to retrieve the file. + * ValueError: If an S3 URI is provided and there are no S3 files with that prefix. + * ValueError: If a folder in S3 bucket is provided as s3_uri, and if it doesn't have + config.yaml. + * ValueError: A file doesn't exist in a path that was specified by the user as part of + environment variable/ additional_config_paths. This doesn't include the default config + file locations. Args: additional_config_paths: List of Config file paths. @@ -105,8 +106,8 @@ def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAU * S3 URI of the directory containing the config file (in this case, we will look for config.yaml in that directory) Note: S3 URI follows the format s3:/// - s3_resource: Corresponds to boto3 S3 resource. This will be used to fetch Config - files from S3. If it is not provided, we will create a default s3 resource + s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This will be used to fetch + Config files from S3. If it is not provided, we will create a default s3 resource See :py:meth:`boto3.session.Session.resource`. This argument is not needed if the config files are present in the local file system @@ -128,7 +129,7 @@ def config_paths(self) -> List[str]: """Getter for Config paths. Returns: - List[str]: This corresponds to the list of config file paths. + List[str]: This method returns the list of config file paths. """ return self._config_paths @@ -144,29 +145,6 @@ def config(self) -> dict: def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: - """This method loads all the config files from the paths that were provided as Inputs. - - Note: Supported Config file locations are Local File System and S3. - - This method will throw exceptions for the following cases: - * Schema validation fails for one/more config files. - * When the config file is not a proper YAML file. - * Any S3 related issues that arises while fetching config file from S3. This includes - permission issues, S3 Object is not found in the specified S3 URI. - * File doesn't exist in a path that was specified by the user as part of environment - variable/ additional_config_paths. This doesn't include - _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH - - Args: - file_paths(List[str]): The list of paths corresponding to the config file. Note: This - path can either be a Local File System path or it can be a S3 URI. - s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config - files from S3. See :py:meth:`boto3.session.Session.resource`. - - Returns: - dict: A dictionary representing the configurations that were loaded from the config files. - - """ merged_config = {} for file_path in file_paths: config_from_file = {} @@ -191,21 +169,6 @@ def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: def _load_config_from_file(file_path: str) -> dict: - """This method loads the config file from the path that was specified as parameter. - - If the path that was provided, corresponds to a directory then this method will try to search - for 'config.yaml' in that directory. Note: We will not be doing any recursive search. - - Args: - file_path(str): The file path from which the Config file needs to be loaded. - - Returns: - dict: A dictionary representing the configurations that were loaded from the config file. - - This method will throw Exceptions for the following cases: - * When the config file is not a proper YAML file. - * File doesn't exist in a path that was specified by the consumer. - """ inferred_file_path = file_path if os.path.isdir(file_path): inferred_file_path = os.path.join(file_path, "config.yaml") @@ -219,28 +182,6 @@ def _load_config_from_file(file_path: str) -> dict: def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: - """This method loads the config file from the S3 URI that was specified as parameter. - - If the S3 URI that was provided, corresponds to a directory then this method will try to - search for 'config.yaml' in that directory. Note: We will not be doing any recursive search. - - Args: - s3_uri(str): The S3 URI of the config file. - Note: S3 URI follows the format s3:/// - s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config - files from S3. See :py:meth:`boto3.session.Session.resource`. - - Returns: - dict: A dictionary representing the configurations that were loaded from the config file. - - This method will throw Exceptions for the following cases: - * If Boto3 S3 resource is not provided. - * When the config file is not a proper YAML file. - * If the method is unable to retrieve the list of all the S3 files with the same prefix - * If there are no S3 files with that prefix. - * If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml, - then we will throw an Exception. - """ if not s3_resource_for_config: raise RuntimeError("Please provide a S3 client for loading the config") logger.debug("Fetching configuration file from the S3 URI: %s", s3_uri) @@ -253,29 +194,6 @@ def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: def _get_inferred_s3_uri(s3_uri, s3_resource_for_config): - """Verifies whether the given S3 URI exists and returns the URI. - - If there are multiple S3 objects with the same key prefix, - then this method will verify whether S3 URI + /config.yaml exists. - s3://example-bucket/somekeyprefix/config.yaml - - Args: - s3_uri (str) : An S3 uri that refers to a location in which config file is present. - s3_uri must start with 's3://'. - An example s3_uri: 's3://example-bucket/config.yaml'. - s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config - files from S3. - See :py:meth:`boto3.session.Session.resource` - - Returns: - str: Valid S3 URI of the Config file. None if it doesn't exist. - - This method will throw Exceptions for the following cases: - * If the method is unable to retrieve the list of all the S3 files with the same prefix - * If there are no S3 files with that prefix. - * If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml, - then we will throw an Exception. - """ parsed_url = urlparse(s3_uri) bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") try: From d4905c9285173f2709ee4554da38ee8e551e5257 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 20 Mar 2023 17:45:12 -0700 Subject: [PATCH 26/40] fix: moving certain config file helper methods to utils.py --- src/sagemaker/automl/automl.py | 33 +- src/sagemaker/automl/candidate_estimator.py | 24 +- src/sagemaker/config/__init__.py | 154 ++-- src/sagemaker/config/config.py | 4 + src/sagemaker/config/config_schema.py | 168 ++++- src/sagemaker/estimator.py | 53 +- src/sagemaker/feature_store/feature_group.py | 21 +- src/sagemaker/model.py | 70 +- .../model_monitor/data_capture_config.py | 10 +- .../model_monitor/model_monitoring.py | 57 +- src/sagemaker/pipeline.py | 43 +- src/sagemaker/processing.py | 50 +- src/sagemaker/session.py | 711 +++++------------- src/sagemaker/transformer.py | 43 +- src/sagemaker/utils.py | 271 +++++++ src/sagemaker/workflow/pipeline.py | 17 +- tests/unit/conftest.py | 6 - tests/unit/sagemaker/automl/test_auto_ml.py | 9 +- .../feature_store/test_feature_group.py | 11 +- .../sagemaker/huggingface/test_estimator.py | 9 +- .../sagemaker/huggingface/test_processing.py | 13 +- .../image_uris/jumpstart/conftest.py | 2 + .../image_uris/jumpstart/test_catboost.py | 2 + .../sagemaker/local/test_local_pipeline.py | 18 +- tests/unit/sagemaker/model/test_deploy.py | 22 +- tests/unit/sagemaker/model/test_edge.py | 8 +- .../sagemaker/model/test_framework_model.py | 8 +- tests/unit/sagemaker/model/test_model.py | 8 +- .../sagemaker/model/test_model_package.py | 8 +- tests/unit/sagemaker/model/test_neo.py | 8 +- .../monitor/test_clarify_model_monitor.py | 18 +- .../monitor/test_data_capture_config.py | 7 +- .../monitor/test_model_monitoring.py | 17 +- .../sagemaker/tensorflow/test_estimator.py | 18 +- .../tensorflow/test_estimator_attach.py | 8 +- .../tensorflow/test_estimator_init.py | 4 +- tests/unit/sagemaker/tensorflow/test_tfs.py | 8 +- .../test_huggingface_pytorch_compiler.py | 9 +- .../test_huggingface_tensorflow_compiler.py | 9 +- .../test_pytorch_compiler.py | 9 +- .../test_tensorflow_compiler.py | 9 +- tests/unit/sagemaker/workflow/conftest.py | 9 +- tests/unit/sagemaker/workflow/test_airflow.py | 13 +- .../unit/sagemaker/workflow/test_pipeline.py | 16 +- .../sagemaker/wrangler/test_processing.py | 13 +- tests/unit/test_algorithm.py | 13 +- tests/unit/test_amazon_estimator.py | 8 +- tests/unit/test_chainer.py | 9 +- tests/unit/test_estimator.py | 25 +- tests/unit/test_fm.py | 8 +- tests/unit/test_ipinsights.py | 8 +- tests/unit/test_job.py | 8 +- tests/unit/test_kmeans.py | 8 +- tests/unit/test_knn.py | 8 +- tests/unit/test_lda.py | 8 +- tests/unit/test_linear_learner.py | 8 +- tests/unit/test_multidatamodel.py | 8 +- tests/unit/test_mxnet.py | 9 +- tests/unit/test_ntm.py | 8 +- tests/unit/test_object2vec.py | 8 +- tests/unit/test_pca.py | 8 +- tests/unit/test_pipeline_model.py | 8 +- tests/unit/test_processing.py | 26 +- tests/unit/test_pytorch.py | 9 +- tests/unit/test_randomcutforest.py | 8 +- tests/unit/test_rl.py | 9 +- tests/unit/test_session.py | 293 +------- tests/unit/test_sklearn.py | 9 +- tests/unit/test_sparkml_serving.py | 1 + tests/unit/test_timeout.py | 2 + tests/unit/test_transformer.py | 14 +- tests/unit/test_tuner.py | 9 +- tests/unit/test_utils.py | 312 +++++++- tests/unit/test_xgboost.py | 10 +- tests/unit/tuner_test_utils.py | 2 + 75 files changed, 1426 insertions(+), 1466 deletions(-) diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 9b59a6579a..9a9db0a307 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -19,16 +19,16 @@ from sagemaker import Model, PipelineModel from sagemaker.automl.candidate_estimator import CandidateEstimator -from sagemaker.job import _Job -from sagemaker.session import ( - Session, - AUTO_ML_KMS_KEY_ID_PATH, +from sagemaker.config import ( AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_KMS_KEY_ID_PATH, AUTO_ML_VPC_CONFIG_PATH, AUTO_ML_VOLUME_KMS_KEY_ID_PATH, - PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, + AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, ) -from sagemaker.utils import name_from_base +from sagemaker.job import _Job +from sagemaker.session import Session +from sagemaker.utils import name_from_base, resolve_value_from_config from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -207,16 +207,18 @@ def __init__( self._auto_ml_job_desc = None self._best_candidate = None self.sagemaker_session = sagemaker_session or Session() - self.vpc_config = self.sagemaker_session.resolve_value_from_config( - vpc_config, AUTO_ML_VPC_CONFIG_PATH + self.vpc_config = resolve_value_from_config( + vpc_config, AUTO_ML_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session + ) + self.volume_kms_key = resolve_value_from_config( + volume_kms_key, AUTO_ML_VOLUME_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session ) - self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( - volume_kms_key, AUTO_ML_VOLUME_KMS_KEY_ID_PATH + self.output_kms_key = resolve_value_from_config( + output_kms_key, AUTO_ML_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session ) - self.output_kms_key = self.sagemaker_session.resolve_value_from_config( - output_kms_key, AUTO_ML_KMS_KEY_ID_PATH + self.role = resolve_value_from_config( + role, AUTO_ML_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) - self.role = self.sagemaker_session.resolve_value_from_config(role, AUTO_ML_ROLE_ARN_PATH) if not self.role: # Originally IAM role was a required parameter. # Now we marked that as Optional because we can fetch it from SageMakerConfig @@ -224,10 +226,11 @@ def __init__( # after fetching the config. raise ValueError("IAM role should be provided for creating AutoML jobs.") - self.encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config( + self.encrypt_inter_container_traffic = resolve_value_from_config( direct_input=encrypt_inter_container_traffic, - config_path=PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION, + config_path=AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, default_value=False, + sagemaker_session=self.sagemaker_session, ) self._check_problem_type_and_job_objective(self.problem_type, self.job_objective) diff --git a/src/sagemaker/automl/candidate_estimator.py b/src/sagemaker/automl/candidate_estimator.py index 80d7be1a2d..3ec5f6995b 100644 --- a/src/sagemaker/automl/candidate_estimator.py +++ b/src/sagemaker/automl/candidate_estimator.py @@ -14,15 +14,14 @@ from __future__ import absolute_import from six import string_types - -from sagemaker.session import ( - Session, +from sagemaker.config import ( TRAINING_JOB_VPC_CONFIG_PATH, TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, - PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, ) +from sagemaker.session import Session from sagemaker.job import _Job -from sagemaker.utils import name_from_base +from sagemaker.utils import name_from_base, resolve_value_from_config class CandidateEstimator(object): @@ -106,11 +105,13 @@ def fit( """Logs can only be shown if wait is set to True. Please either set wait to True or set logs to False.""" ) - vpc_config = self.sagemaker_session.resolve_value_from_config( - vpc_config, TRAINING_JOB_VPC_CONFIG_PATH + vpc_config = resolve_value_from_config( + vpc_config, TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session ) - volume_kms_key = self.sagemaker_session.resolve_value_from_config( - volume_kms_key, TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH + volume_kms_key = resolve_value_from_config( + volume_kms_key, + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) self.name = candidate_name or self.name running_jobs = {} @@ -146,10 +147,11 @@ def fit( # Check training_job config not auto_ml_job config because this function calls # training job API - _encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config( + _encrypt_inter_container_traffic = resolve_value_from_config( direct_input=encrypt_inter_container_traffic, - config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, default_value=False, + sagemaker_session=self.sagemaker_session, ) train_args = self._get_train_args( diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index f2938aaf8c..c4e92545b6 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -15,63 +15,119 @@ from __future__ import absolute_import from sagemaker.config.config import SageMakerConfig # noqa: F401 from sagemaker.config.config_schema import ( # noqa: F401 - SECURITY_GROUP_IDS, - SUBNETS, - ENABLE_NETWORK_ISOLATION, + KEY, + TRAINING_JOB, + TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_VPC_CONFIG_PATH, + TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, + TRAINING_JOB_RESOURCE_CONFIG_PATH, + PROCESSING_JOB_INPUTS_PATH, + PROCESSING_JOB, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, + PROCESSING_JOB_NETWORK_CONFIG_PATH, + PROCESSING_OUTPUT_CONFIG_PATH, + PROCESSING_JOB_PROCESSING_RESOURCES_PATH, + MONITORING_JOB_ROLE_ARN_PATH, + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, + MONITORING_JOB_NETWORK_CONFIG_PATH, + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, + MONITORING_SCHEDULE, + MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, + AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_OUTPUT_CONFIG_PATH, + AUTO_ML_JOB_CONFIG_PATH, + AUTO_ML, + COMPILATION_JOB_ROLE_ARN_PATH, + COMPILATION_JOB_OUTPUT_CONFIG_PATH, + COMPILATION_JOB_VPC_CONFIG_PATH, + COMPILATION_JOB, + EDGE_PACKAGING_ROLE_ARN_PATH, + EDGE_PACKAGING_OUTPUT_CONFIG_PATH, + EDGE_PACKAGING_JOB, + TRANSFORM_JOB, + TRANSFORM_JOB_KMS_KEY_ID_PATH, + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, VOLUME_KMS_KEY_ID, + TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH, + MODEL, + MODEL_EXECUTION_ROLE_ARN_PATH, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + MODEL_VPC_CONFIG_PATH, + MODEL_PACKAGE_VALIDATION_ROLE_PATH, + VALIDATION_ROLE, + VALIDATION_PROFILES, + MODEL_PACKAGE_VALIDATION_PROFILES_PATH, + ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, KMS_KEY_ID, - ROLE_ARN, - EXECUTION_ROLE_ARN, - CLUSTER_ROLE_ARN, - VPC_CONFIG, - OUTPUT_DATA_CONFIG, - AUTO_ML_JOB_CONFIG, - ASYNC_INFERENCE_CONFIG, - OUTPUT_CONFIG, - PROCESSING_OUTPUT_CONFIG, + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, + ENDPOINT_CONFIG, + ENDPOINT_CONFIG_DATA_CAPTURE_PATH, + ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, + SAGEMAKER, + FEATURE_GROUP, + TAGS, + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, + PIPELINE_ROLE_ARN_PATH, + TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, + EDGE_PACKAGING_KMS_KEY_ID_PATH, + ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, + PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, + PROCESSING_JOB_SUBNETS_PATH, + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, + PIPELINE_TAGS_PATH, + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + TRAINING_JOB_SUBNETS_PATH, + TRAINING_JOB_KMS_KEY_ID_PATH, + FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, + FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, + AUTO_ML_KMS_KEY_ID_PATH, + AUTO_ML_VPC_CONFIG_PATH, + AUTO_ML_VOLUME_KMS_KEY_ID_PATH, + AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, + ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH, + MONITORING_SCHEDULE_CONFIG, + MONITORING_JOB_DEFINITION, + MONITORING_OUTPUT_CONFIG, + MONITORING_RESOURCES, CLUSTER_CONFIG, NETWORK_CONFIG, - CORE_DUMP_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, + ENABLE_NETWORK_ISOLATION, + VPC_CONFIG, + SUBNETS, + SECURITY_GROUP_IDS, + ROLE_ARN, + VALUE, + OUTPUT_CONFIG, DATA_CAPTURE_CONFIG, - MONITORING_OUTPUT_CONFIG, - RESOURCE_CONFIG, - SCHEMA_VERSION, - DATASET_DEFINITION, - ATHENA_DATASET_DEFINITION, - REDSHIFT_DATASET_DEFINITION, - MONITORING_JOB_DEFINITION, - SAGEMAKER, - PYTHON_SDK, - MODULES, - OFFLINE_STORE_CONFIG, - ONLINE_STORE_CONFIG, - S3_STORAGE_CONFIG, + PRODUCTION_VARIANTS, + AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, + OUTPUT_DATA_CONFIG, + MODEL_PACKAGE, + VALIDATION_SPECIFICATION, TRANSFORM_JOB_DEFINITION, - MONITORING_SCHEDULE_CONFIG, - MONITORING_RESOURCES, - PROCESSING_RESOURCES, - PRODUCTION_VARIANTS, TRANSFORM_OUTPUT, TRANSFORM_RESOURCES, - VALIDATION_ROLE, - VALIDATION_SPECIFICATION, - VALIDATION_PROFILES, + OFFLINE_STORE_CONFIG, + S3_STORAGE_CONFIG, + ONLINE_STORE_CONFIG, PROCESSING_INPUTS, - FEATURE_GROUP, - EDGE_PACKAGING_JOB, - TRAINING_JOB, - PROCESSING_JOB, - MODEL_PACKAGE, - MODEL, - MONITORING_SCHEDULE, - ENDPOINT_CONFIG, - AUTO_ML, - COMPILATION_JOB, - PIPELINE, - TRANSFORM_JOB, - ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, - TAGS, - KEY, - VALUE, + DATASET_DEFINITION, + ATHENA_DATASET_DEFINITION, + REDSHIFT_DATASET_DEFINITION, + CLUSTER_ROLE_ARN, + PROCESSING_OUTPUT_CONFIG, + PROCESSING_RESOURCES, + RESOURCE_CONFIG, + EXECUTION_ROLE_ARN, + ASYNC_INFERENCE_CONFIG, ) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index b93e192c08..66640c7f1a 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -145,6 +145,7 @@ def config(self) -> dict: def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: + """Placeholder docstring""" merged_config = {} for file_path in file_paths: config_from_file = {} @@ -169,6 +170,7 @@ def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: def _load_config_from_file(file_path: str) -> dict: + """Placeholder docstring""" inferred_file_path = file_path if os.path.isdir(file_path): inferred_file_path = os.path.join(file_path, "config.yaml") @@ -182,6 +184,7 @@ def _load_config_from_file(file_path: str) -> dict: def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: + """Placeholder docstring""" if not s3_resource_for_config: raise RuntimeError("Please provide a S3 client for loading the config") logger.debug("Fetching configuration file from the S3 URI: %s", s3_uri) @@ -194,6 +197,7 @@ def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: def _get_inferred_s3_uri(s3_uri, s3_resource_for_config): + """Placeholder docstring""" parsed_url = urlparse(s3_uri) bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") try: diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index e23bbcafd9..ae4104020f 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module contains/maintains the schema of the Config f`i`le.""" +"""This module contains/maintains the schema of the Config file.""" from __future__ import absolute_import, print_function SECURITY_GROUP_IDS = "SecurityGroupIds" @@ -80,6 +80,172 @@ ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption" +def _simple_path(*args: str): + """Appends an arbitrary number of strings to use as path constants""" + return ".".join(args) + + +COMPILATION_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, VPC_CONFIG) +COMPILATION_JOB_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG, KMS_KEY_ID +) +COMPILATION_JOB_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG) +COMPILATION_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, ROLE_ARN) +TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( + SAGEMAKER, TRAINING_JOB, ENABLE_NETWORK_ISOLATION +) +TRAINING_JOB_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, OUTPUT_DATA_CONFIG, KMS_KEY_ID) +TRAINING_JOB_RESOURCE_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG) +TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, OUTPUT_DATA_CONFIG) +TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG, VOLUME_KMS_KEY_ID +) +TRAINING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, ROLE_ARN) +TRAINING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, VPC_CONFIG) +TRAINING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( + TRAINING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS +) +TRAINING_JOB_SUBNETS_PATH = _simple_path(TRAINING_JOB_VPC_CONFIG_PATH, SUBNETS) +EDGE_PACKAGING_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG, KMS_KEY_ID +) +EDGE_PACKAGING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG) +EDGE_PACKAGING_ROLE_ARN_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, ROLE_ARN) +ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG, KMS_KEY_ID +) +ENDPOINT_CONFIG_DATA_CAPTURE_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG) +ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG +) +ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, PRODUCTION_VARIANTS +) +ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG, OUTPUT_CONFIG, KMS_KEY_ID +) +ENDPOINT_CONFIG_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, KMS_KEY_ID) +FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ONLINE_STORE_CONFIG) +FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH = _simple_path( + SAGEMAKER, FEATURE_GROUP, OFFLINE_STORE_CONFIG +) +FEATURE_GROUP_ROLE_ARN_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ROLE_ARN) +FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH = _simple_path( + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, S3_STORAGE_CONFIG, KMS_KEY_ID +) +FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH = _simple_path( + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, SECURITY_CONFIG, KMS_KEY_ID +) +AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG) +AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG, KMS_KEY_ID) +AUTO_ML_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID +) +AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML, ROLE_ARN) +AUTO_ML_VPC_CONFIG_PATH = _simple_path( + SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG +) +AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG) +MONITORING_JOB_DEFINITION_PREFIX = _simple_path( + SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION +) +MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, MONITORING_OUTPUT_CONFIG, KMS_KEY_ID +) +MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, MONITORING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID +) +MONITORING_JOB_NETWORK_CONFIG_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG) +MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG, ENABLE_NETWORK_ISOLATION +) +MONITORING_JOB_VPC_CONFIG_PATH = _simple_path( + MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG, VPC_CONFIG +) +MONITORING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( + MONITORING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS +) +MONITORING_JOB_SUBNETS_PATH = _simple_path(MONITORING_JOB_VPC_CONFIG_PATH, SUBNETS) +MONITORING_JOB_ROLE_ARN_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, ROLE_ARN) +PIPELINE_ROLE_ARN_PATH = _simple_path(SAGEMAKER, PIPELINE, ROLE_ARN) +PIPELINE_TAGS_PATH = _simple_path(SAGEMAKER, PIPELINE, TAGS) +TRANSFORM_OUTPUT_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, TRANSFORM_OUTPUT, KMS_KEY_ID +) +TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, TRANSFORM_RESOURCES, VOLUME_KMS_KEY_ID +) +TRANSFORM_JOB_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, DATA_CAPTURE_CONFIG, KMS_KEY_ID +) +TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, TRANSFORM_JOB, TRANSFORM_RESOURCES, VOLUME_KMS_KEY_ID +) +MODEL_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, MODEL, VPC_CONFIG) +MODEL_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(SAGEMAKER, MODEL, ENABLE_NETWORK_ISOLATION) +MODEL_EXECUTION_ROLE_ARN_PATH = _simple_path(SAGEMAKER, MODEL, EXECUTION_ROLE_ARN) +PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_NETWORK_ISOLATION +) +PROCESSING_JOB_INPUTS_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, PROCESSING_INPUTS) +REDSHIFT_DATASET_DEFINITION_KMS_KEY_ID_PATH = _simple_path( + DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION, KMS_KEY_ID +) +ATHENA_DATASET_DEFINITION_KMS_KEY_ID_PATH = _simple_path( + DATASET_DEFINITION, ATHENA_DATASET_DEFINITION, KMS_KEY_ID +) +REDSHIFT_DATASET_DEFINITION_CLUSTER_ROLE_ARN_PATH = _simple_path( + DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION, CLUSTER_ROLE_ARN +) +PROCESSING_JOB_NETWORK_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG) +PROCESSING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, VPC_CONFIG) +PROCESSING_JOB_SUBNETS_PATH = _simple_path(PROCESSING_JOB_VPC_CONFIG_PATH, SUBNETS) +PROCESSING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( + PROCESSING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS +) +PROCESSING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG) +PROCESSING_JOB_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG, KMS_KEY_ID +) +PROCESSING_JOB_PROCESSING_RESOURCES_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, PROCESSING_RESOURCES +) +PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, PROCESSING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID +) +PROCESSING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, ROLE_ARN) +MODEL_PACKAGE_VALIDATION_ROLE_PATH = _simple_path( + SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_ROLE +) +MODEL_PACKAGE_VALIDATION_PROFILES_PATH = _simple_path( + SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES +) + +# Paths for reference elsewhere in the SDK. +# Names include the schema version since the paths could change with other schema versions +MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( + SAGEMAKER, + MONITORING_SCHEDULE, + MONITORING_SCHEDULE_CONFIG, + MONITORING_JOB_DEFINITION, + NETWORK_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, +) +AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( + SAGEMAKER, + AUTO_ML, + AUTO_ML_JOB_CONFIG, + SECURITY_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, +) +PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( + SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) +TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( + SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) + + SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = { "$schema": "https://json-schema.org/draft/2020-12/schema", TYPE: OBJECT, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index ba2fa6a91d..51fb3c99b5 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -29,6 +29,15 @@ import sagemaker from sagemaker import git_utils, image_uris, vpc_utils from sagemaker.analytics import TrainingJobAnalytics +from sagemaker.config import ( + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + TRAINING_JOB_SUBNETS_PATH, + TRAINING_JOB_KMS_KEY_ID_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, +) from sagemaker.debugger import ( # noqa: F401 # pylint: disable=unused-import DEBUGGER_FLAG, DebuggerHookConfig, @@ -72,16 +81,7 @@ ) from sagemaker.predictor import Predictor from sagemaker.s3 import S3Uploader, parse_s3_url -from sagemaker.session import ( - Session, - TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, - TRAINING_JOB_ROLE_ARN_PATH, - TRAINING_JOB_SECURITY_GROUP_IDS_PATH, - TRAINING_JOB_SUBNETS_PATH, - TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - TRAINING_JOB_KMS_KEY_ID_PATH, - PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, -) +from sagemaker.session import Session from sagemaker.transformer import Transformer from sagemaker.utils import ( base_from_name, @@ -91,6 +91,7 @@ name_from_base, to_string, check_and_get_run_experiment_config, + resolve_value_from_config, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -586,8 +587,8 @@ def __init__( self.deploy_instance_type = None self._compiled_models = {} - self.role = self.sagemaker_session.resolve_value_from_config( - role, TRAINING_JOB_ROLE_ARN_PATH + self.role = resolve_value_from_config( + role, TRAINING_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) if not self.role: # Originally IAM role was a required parameter. @@ -595,19 +596,23 @@ def __init__( # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. raise ValueError("IAM role should be provided for creating estimators.") - self.output_kms_key = self.sagemaker_session.resolve_value_from_config( - output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH + self.output_kms_key = resolve_value_from_config( + output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session ) - self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( - volume_kms_key, TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH + self.volume_kms_key = resolve_value_from_config( + volume_kms_key, + TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) # VPC configurations - self.subnets = self.sagemaker_session.resolve_value_from_config( - subnets, TRAINING_JOB_SUBNETS_PATH + self.subnets = resolve_value_from_config( + subnets, TRAINING_JOB_SUBNETS_PATH, sagemaker_session=self.sagemaker_session ) - self.security_group_ids = self.sagemaker_session.resolve_value_from_config( - security_group_ids, TRAINING_JOB_SECURITY_GROUP_IDS_PATH + self.security_group_ids = resolve_value_from_config( + security_group_ids, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + sagemaker_session=self.sagemaker_session, ) # training image configs @@ -616,10 +621,11 @@ def __init__( training_repository_credentials_provider_arn ) - self.encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config( + self.encrypt_inter_container_traffic = resolve_value_from_config( direct_input=encrypt_inter_container_traffic, - config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, default_value=False, + sagemaker_session=self.sagemaker_session, ) self.use_spot_instances = use_spot_instances @@ -636,10 +642,11 @@ def __init__( self.enable_sagemaker_metrics = enable_sagemaker_metrics - self._enable_network_isolation = self.sagemaker_session.resolve_value_from_config( + self._enable_network_isolation = resolve_value_from_config( direct_input=enable_network_isolation, config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, default_value=False, + sagemaker_session=self.sagemaker_session, ) self.profiler_config = profiler_config diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 27ed96b603..68e6b9dfef 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -41,12 +41,12 @@ from botocore.config import Config from pathos.multiprocessing import ProcessingPool -from sagemaker.session import ( - Session, +from sagemaker.config import ( FEATURE_GROUP_ROLE_ARN_PATH, FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, ) +from sagemaker.session import Session from sagemaker.feature_store.feature_definition import ( FeatureDefinition, FeatureTypeEnum, @@ -61,6 +61,7 @@ FeatureParameter, TableFormatEnum, ) +from sagemaker.utils import resolve_value_from_config logger = logging.getLogger(__name__) @@ -557,14 +558,18 @@ def create( Returns: Response dict from service. """ - role_arn = self.sagemaker_session.resolve_value_from_config( - role_arn, FEATURE_GROUP_ROLE_ARN_PATH + role_arn = resolve_value_from_config( + role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) - offline_store_kms_key_id = self.sagemaker_session.resolve_value_from_config( - offline_store_kms_key_id, FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH + offline_store_kms_key_id = resolve_value_from_config( + offline_store_kms_key_id, + FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) - online_store_kms_key_id = self.sagemaker_session.resolve_value_from_config( - online_store_kms_key_id, FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH + online_store_kms_key_id = resolve_value_from_config( + online_store_kms_key_id, + FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) if not role_arn: # Originally IAM role was a required parameter. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5124c1a414..0357c76f64 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -29,8 +29,7 @@ utils, git_utils, ) -from sagemaker.session import ( - Session, +from sagemaker.config import ( COMPILATION_JOB_ROLE_ARN_PATH, EDGE_PACKAGING_KMS_KEY_ID_PATH, EDGE_PACKAGING_ROLE_ARN_PATH, @@ -39,6 +38,7 @@ MODEL_EXECUTION_ROLE_ARN_PATH, ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, ) +from sagemaker.session import Session from sagemaker.model_metrics import ModelMetrics from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -51,6 +51,7 @@ unique_name_from_base, update_container_with_inference_params, to_string, + resolve_value_from_config, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -283,15 +284,11 @@ def __init__( self.name = name self._base_name = None self.sagemaker_session = sagemaker_session - self.role = ( - self.sagemaker_session.resolve_value_from_config(role, MODEL_EXECUTION_ROLE_ARN_PATH) - if sagemaker_session - else role + self.role = resolve_value_from_config( + role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) - self.vpc_config = ( - self.sagemaker_session.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) - if sagemaker_session - else vpc_config + self.vpc_config = resolve_value_from_config( + vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session ) self.endpoint_name = None self._is_compiled_model = False @@ -299,12 +296,10 @@ def __init__( self._is_edge_packaged_model = False self.inference_recommender_job_results = None self.inference_recommendations = None - self._enable_network_isolation = ( - self.sagemaker_session.resolve_value_from_config( - enable_network_isolation, MODEL_ENABLE_NETWORK_ISOLATION_PATH - ) - if sagemaker_session - else enable_network_isolation + self._enable_network_isolation = resolve_value_from_config( + enable_network_isolation, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, ) if self._enable_network_isolation is None: self._enable_network_isolation = False @@ -701,14 +696,16 @@ def _create_sagemaker_model( self._init_sagemaker_session_if_does_not_exist(instance_type) # Depending on the instance type, a local session (or) a session is initialized. - self.role = self.sagemaker_session.resolve_value_from_config( - self.role, MODEL_EXECUTION_ROLE_ARN_PATH + self.role = resolve_value_from_config( + self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) - self.vpc_config = self.sagemaker_session.resolve_value_from_config( - self.vpc_config, MODEL_VPC_CONFIG_PATH + self.vpc_config = resolve_value_from_config( + self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session ) - self._enable_network_isolation = self.sagemaker_session.resolve_value_from_config( - self._enable_network_isolation, MODEL_ENABLE_NETWORK_ISOLATION_PATH + self._enable_network_isolation = resolve_value_from_config( + self._enable_network_isolation, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, ) create_model_args = dict( name=self.name, @@ -906,10 +903,12 @@ def package_for_edge( if job_name is None: job_name = f"packaging{self._compilation_job_name[11:]}" self._init_sagemaker_session_if_does_not_exist(None) - s3_kms_key = self.sagemaker_session.resolve_value_from_config( - s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH + s3_kms_key = resolve_value_from_config( + s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session + ) + role = resolve_value_from_config( + role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) - role = self.sagemaker_session.resolve_value_from_config(role, EDGE_PACKAGING_ROLE_ARN_PATH) if role is not None: role = self.sagemaker_session.expand_role(role) config = self._edge_packaging_job_config( @@ -1014,7 +1013,9 @@ def compile( framework_version = framework_version or self._get_framework_version() self._init_sagemaker_session_if_does_not_exist(target_instance_family) - role = self.sagemaker_session.resolve_value_from_config(role, COMPILATION_JOB_ROLE_ARN_PATH) + role = resolve_value_from_config( + role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session + ) if not role: # Originally IAM role was a required parameter. # Now we marked that as Optional because we can fetch it from SageMakerConfig @@ -1174,14 +1175,16 @@ def deploy( self._init_sagemaker_session_if_does_not_exist(instance_type) # Depending on the instance type, a local session (or) a session is initialized. - self.role = self.sagemaker_session.resolve_value_from_config( - self.role, MODEL_EXECUTION_ROLE_ARN_PATH + self.role = resolve_value_from_config( + self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) - self.vpc_config = self.sagemaker_session.resolve_value_from_config( - self.vpc_config, MODEL_VPC_CONFIG_PATH + self.vpc_config = resolve_value_from_config( + self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session ) - self._enable_network_isolation = self.sagemaker_session.resolve_value_from_config( - self._enable_network_isolation, MODEL_ENABLE_NETWORK_ISOLATION_PATH + self._enable_network_isolation = resolve_value_from_config( + self._enable_network_isolation, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, ) tags = add_jumpstart_tags( @@ -1269,9 +1272,10 @@ def deploy( async_inference_config = self._build_default_async_inference_config( async_inference_config ) - async_inference_config.kms_key_id = self.sagemaker_session.resolve_value_from_config( + async_inference_config.kms_key_id = resolve_value_from_config( async_inference_config.kms_key_id, ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) async_inference_config_dict = async_inference_config._to_request_dict() diff --git a/src/sagemaker/model_monitor/data_capture_config.py b/src/sagemaker/model_monitor/data_capture_config.py index d8ee8e21cb..aa11d41aad 100644 --- a/src/sagemaker/model_monitor/data_capture_config.py +++ b/src/sagemaker/model_monitor/data_capture_config.py @@ -18,7 +18,9 @@ from __future__ import print_function, absolute_import from sagemaker import s3 -from sagemaker.session import Session, ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH +from sagemaker.config import ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH +from sagemaker.session import Session +from sagemaker.utils import resolve_value_from_config _MODEL_MONITOR_S3_PATH = "model-monitor" _DATA_CAPTURE_S3_PATH = "data-capture" @@ -75,8 +77,10 @@ def __init__( _DATA_CAPTURE_S3_PATH, ) - self.kms_key_id = sagemaker_session.resolve_value_from_config( - kms_key_id, ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH + self.kms_key_id = resolve_value_from_config( + kms_key_id, + ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH, + sagemaker_session=sagemaker_session, ) self.capture_options = capture_options or ["REQUEST", "RESPONSE"] self.csv_content_types = csv_content_types or ["text/csv"] diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 8988285eeb..ca1f98714e 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -35,6 +35,13 @@ SAGEMAKER, MONITORING_SCHEDULE, TAGS, + MONITORING_JOB_SUBNETS_PATH, + MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, + MONITORING_JOB_SECURITY_GROUP_IDS_PATH, + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, + MONITORING_JOB_ROLE_ARN_PATH, ) from sagemaker.exceptions import UnexpectedStatusException from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations, Statistics @@ -47,17 +54,13 @@ from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat from sagemaker.network import NetworkConfig from sagemaker.processing import Processor, ProcessingInput, ProcessingJob, ProcessingOutput -from sagemaker.session import ( - Session, - MONITORING_JOB_SUBNETS_PATH, - MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, - MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, - MONITORING_JOB_SECURITY_GROUP_IDS_PATH, - MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, - MONITORING_JOB_ROLE_ARN_PATH, +from sagemaker.session import Session +from sagemaker.utils import ( + name_from_base, + retries, + resolve_value_from_config, + resolve_class_attribute_from_config, ) -from sagemaker.utils import name_from_base, retries DEFAULT_REPOSITORY_NAME = "sagemaker-model-monitor-analyzer" @@ -177,8 +180,8 @@ def __init__( self.latest_baselining_job_name = None self.monitoring_schedule_name = None self.job_definition_name = None - self.role = self.sagemaker_session.resolve_value_from_config( - role, MONITORING_JOB_ROLE_ARN_PATH + self.role = resolve_value_from_config( + role, MONITORING_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) if not self.role: # Originally IAM role was a required parameter. @@ -186,35 +189,43 @@ def __init__( # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. raise ValueError("IAM role should be provided for creating Monitoring Schedule.") - self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( - volume_kms_key, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH + self.volume_kms_key = resolve_value_from_config( + volume_kms_key, + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) - self.output_kms_key = self.sagemaker_session.resolve_value_from_config( - output_kms_key, MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH + self.output_kms_key = resolve_value_from_config( + output_kms_key, + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, network_config, "subnets", MONITORING_JOB_SUBNETS_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, self.network_config, "security_group_ids", MONITORING_JOB_SECURITY_GROUP_IDS_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, self.network_config, "enable_network_isolation", MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, self.network_config, "encrypt_inter_container_traffic", - PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, + MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, + sagemaker_session=self.sagemaker_session, ) def run_baseline( @@ -1437,7 +1448,7 @@ def _create_monitoring_schedule_from_job_definition( ) # Not using value from sagemaker - # config key PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION here + # config key MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH here # because no MonitoringJobDefinition is set for this call self.sagemaker_session.sagemaker_client.create_monitoring_schedule( @@ -1507,7 +1518,7 @@ def _update_monitoring_schedule(self, job_definition_name, schedule_cron_express } # Not using value from sagemaker - # config key PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION here + # config key MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH here # because no MonitoringJobDefinition is set for this call self.sagemaker_session.sagemaker_client.update_monitoring_schedule( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 81ef07b080..5008558c63 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -17,19 +17,19 @@ import sagemaker from sagemaker import ModelMetrics, Model -from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.metadata_properties import MetadataProperties -from sagemaker.session import ( - Session, +from sagemaker.config import ( ENDPOINT_CONFIG_KMS_KEY_ID_PATH, MODEL_VPC_CONFIG_PATH, MODEL_ENABLE_NETWORK_ISOLATION_PATH, MODEL_EXECUTION_ROLE_ARN_PATH, ) - +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties +from sagemaker.session import Session from sagemaker.utils import ( name_from_image, update_container_with_inference_params, + resolve_value_from_config, ) from sagemaker.transformer import Transformer from sagemaker.workflow.entities import PipelineVariable @@ -91,27 +91,18 @@ def __init__( self.name = name self.sagemaker_session = sagemaker_session self.endpoint_name = None - self.role = ( - self.sagemaker_session.resolve_value_from_config(role, MODEL_EXECUTION_ROLE_ARN_PATH) - if self.sagemaker_session - else role + self.role = resolve_value_from_config( + role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) - self.vpc_config = ( - self.sagemaker_session.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) - if self.sagemaker_session - else vpc_config + self.vpc_config = resolve_value_from_config( + vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session + ) + self.enable_network_isolation = resolve_value_from_config( + direct_input=enable_network_isolation, + config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, + default_value=False, + sagemaker_session=self.sagemaker_session, ) - - if self.sagemaker_session is not None: - self.enable_network_isolation = self.sagemaker_session.resolve_value_from_config( - direct_input=enable_network_isolation, - config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, - default_value=False, - ) - else: - self.enable_network_isolation = ( - False if enable_network_isolation is None else enable_network_isolation - ) if not self.role: # Originally IAM role was a required parameter. @@ -244,8 +235,8 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, ) self.endpoint_name = endpoint_name or self.name - kms_key = self.sagemaker_session.resolve_value_from_config( - kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH + kms_key = resolve_value_from_config( + kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session ) data_capture_config_dict = None diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 4ad59d3fbd..6dba6ef6fa 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -30,6 +30,15 @@ from six.moves.urllib.parse import urlparse from six.moves.urllib.request import url2pathname from sagemaker import s3 +from sagemaker.config import ( + PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, + PROCESSING_JOB_SUBNETS_PATH, + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, +) from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig @@ -38,17 +47,10 @@ get_config_value, name_from_base, check_and_get_run_experiment_config, + resolve_value_from_config, + resolve_class_attribute_from_config, ) -from sagemaker.session import ( - Session, - PROCESSING_JOB_KMS_KEY_ID_PATH, - PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, - PROCESSING_JOB_SUBNETS_PATH, - PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, - PROCESSING_JOB_ROLE_ARN_PATH, - PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, -) +from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.functions import Join from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -148,38 +150,44 @@ def __init__( sagemaker_session = LocalSession(disable_local_code=True) self.sagemaker_session = sagemaker_session or Session() - self.output_kms_key = self.sagemaker_session.resolve_value_from_config( - output_kms_key, PROCESSING_JOB_KMS_KEY_ID_PATH + self.output_kms_key = resolve_value_from_config( + output_kms_key, PROCESSING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session ) - self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( - volume_kms_key, PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH + self.volume_kms_key = resolve_value_from_config( + volume_kms_key, + PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, network_config, "subnets", PROCESSING_JOB_SUBNETS_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, self.network_config, "security_group_ids", PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, self.network_config, "enable_network_isolation", PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, ) - self.network_config = self.sagemaker_session.resolve_class_attribute_from_config( + self.network_config = resolve_class_attribute_from_config( NetworkConfig, self.network_config, "encrypt_inter_container_traffic", - PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, + sagemaker_session=self.sagemaker_session, ) - self.role = self.sagemaker_session.resolve_value_from_config( - role, PROCESSING_JOB_ROLE_ARN_PATH + self.role = resolve_value_from_config( + role, PROCESSING_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) if not self.role: # Originally IAM role was a required parameter. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 8518593b95..401f5c15de 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -14,7 +14,6 @@ from __future__ import absolute_import, print_function import copy -import inspect import json import logging import os @@ -36,61 +35,65 @@ import sagemaker.logs from sagemaker import vpc_utils from sagemaker._studio import _append_project_tags -from sagemaker.config import ( # noqa: F401 - SageMakerConfig, - SAGEMAKER, +from sagemaker.config import SageMakerConfig # noqa: F401 +from sagemaker.config import ( + KEY, TRAINING_JOB, - ENABLE_NETWORK_ISOLATION, - KMS_KEY_ID, - RESOURCE_CONFIG, - VOLUME_KMS_KEY_ID, - ROLE_ARN, - VPC_CONFIG, - SECURITY_GROUP_IDS, - SUBNETS, - EDGE_PACKAGING_JOB, - OUTPUT_CONFIG, - FEATURE_GROUP, - OFFLINE_STORE_CONFIG, - ONLINE_STORE_CONFIG, - AUTO_ML, - AUTO_ML_JOB_CONFIG, - SECURITY_CONFIG, - OUTPUT_DATA_CONFIG, + TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_VPC_CONFIG_PATH, + TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, + TRAINING_JOB_RESOURCE_CONFIG_PATH, + PROCESSING_JOB_INPUTS_PATH, + PROCESSING_JOB, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, + PROCESSING_JOB_NETWORK_CONFIG_PATH, + PROCESSING_OUTPUT_CONFIG_PATH, + PROCESSING_JOB_PROCESSING_RESOURCES_PATH, + MONITORING_JOB_ROLE_ARN_PATH, + MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, + MONITORING_JOB_NETWORK_CONFIG_PATH, + MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, MONITORING_SCHEDULE, - MONITORING_SCHEDULE_CONFIG, - MONITORING_JOB_DEFINITION, - MONITORING_OUTPUT_CONFIG, - MONITORING_RESOURCES, - CLUSTER_CONFIG, - NETWORK_CONFIG, + MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, + AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_OUTPUT_CONFIG_PATH, + AUTO_ML_JOB_CONFIG_PATH, + AUTO_ML, + COMPILATION_JOB_ROLE_ARN_PATH, + COMPILATION_JOB_OUTPUT_CONFIG_PATH, + COMPILATION_JOB_VPC_CONFIG_PATH, + COMPILATION_JOB, + EDGE_PACKAGING_ROLE_ARN_PATH, + EDGE_PACKAGING_OUTPUT_CONFIG_PATH, + EDGE_PACKAGING_JOB, TRANSFORM_JOB, - TRANSFORM_OUTPUT, - TRANSFORM_RESOURCES, - DATA_CAPTURE_CONFIG, + TRANSFORM_JOB_KMS_KEY_ID_PATH, + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, + VOLUME_KMS_KEY_ID, + TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH, MODEL, - EXECUTION_ROLE_ARN, - S3_STORAGE_CONFIG, - ENDPOINT_CONFIG, - PIPELINE, - COMPILATION_JOB, - PROCESSING_JOB, - PROCESSING_INPUTS, - DATASET_DEFINITION, - REDSHIFT_DATASET_DEFINITION, - ATHENA_DATASET_DEFINITION, - CLUSTER_ROLE_ARN, - PROCESSING_OUTPUT_CONFIG, - PROCESSING_RESOURCES, - ASYNC_INFERENCE_CONFIG, - ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, - TAGS, - KEY, - PRODUCTION_VARIANTS, + MODEL_EXECUTION_ROLE_ARN_PATH, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + MODEL_VPC_CONFIG_PATH, + MODEL_PACKAGE_VALIDATION_ROLE_PATH, VALIDATION_ROLE, VALIDATION_PROFILES, - MODEL_PACKAGE, - VALIDATION_SPECIFICATION, + MODEL_PACKAGE_VALIDATION_PROFILES_PATH, + ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, + KMS_KEY_ID, + ENDPOINT_CONFIG_KMS_KEY_ID_PATH, + ENDPOINT_CONFIG, + ENDPOINT_CONFIG_DATA_CAPTURE_PATH, + ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, + SAGEMAKER, + FEATURE_GROUP, + TAGS, + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, ) from sagemaker.deprecations import deprecated_class from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig @@ -101,9 +104,11 @@ secondary_training_status_message, sts_regional_endpoint, retries, - get_config_value, - get_nested_value, - set_nested_value, + resolve_value_from_config, + get_sagemaker_config_value, + resolve_class_attribute_from_config, + resolve_nested_dict_value_from_config, + update_nested_dictionary_with_values_from_config, ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings @@ -123,173 +128,6 @@ } -def _simple_path(*args: str): - """Appends an arbitrary number of strings to use as path constants""" - return ".".join(args) - - -COMPILATION_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, VPC_CONFIG) -COMPILATION_JOB_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG, KMS_KEY_ID -) -COMPILATION_JOB_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG) -COMPILATION_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, ROLE_ARN) -TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( - SAGEMAKER, TRAINING_JOB, ENABLE_NETWORK_ISOLATION -) -TRAINING_JOB_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, OUTPUT_DATA_CONFIG, KMS_KEY_ID) -TRAINING_JOB_RESOURCE_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG) -TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, OUTPUT_DATA_CONFIG) -TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, TRAINING_JOB, RESOURCE_CONFIG, VOLUME_KMS_KEY_ID -) -TRAINING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, ROLE_ARN) -TRAINING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, VPC_CONFIG) -TRAINING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( - TRAINING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS -) -TRAINING_JOB_SUBNETS_PATH = _simple_path(TRAINING_JOB_VPC_CONFIG_PATH, SUBNETS) -EDGE_PACKAGING_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG, KMS_KEY_ID -) -EDGE_PACKAGING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG) -EDGE_PACKAGING_ROLE_ARN_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, ROLE_ARN) -ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG, KMS_KEY_ID -) -ENDPOINT_CONFIG_DATA_CAPTURE_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG) -ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH = _simple_path( - SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG -) -ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH = _simple_path( - SAGEMAKER, ENDPOINT_CONFIG, PRODUCTION_VARIANTS -) -ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG, OUTPUT_CONFIG, KMS_KEY_ID -) -ENDPOINT_CONFIG_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, KMS_KEY_ID) -FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ONLINE_STORE_CONFIG) -FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH = _simple_path( - SAGEMAKER, FEATURE_GROUP, OFFLINE_STORE_CONFIG -) -FEATURE_GROUP_ROLE_ARN_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ROLE_ARN) -FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH = _simple_path( - FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, S3_STORAGE_CONFIG, KMS_KEY_ID -) -FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH = _simple_path( - FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, SECURITY_CONFIG, KMS_KEY_ID -) -AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG) -AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG, KMS_KEY_ID) -AUTO_ML_VOLUME_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID -) -AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML, ROLE_ARN) -AUTO_ML_VPC_CONFIG_PATH = _simple_path( - SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG -) -AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG) -MONITORING_JOB_DEFINITION_PREFIX = _simple_path( - SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION -) -MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH = _simple_path( - MONITORING_JOB_DEFINITION_PREFIX, MONITORING_OUTPUT_CONFIG, KMS_KEY_ID -) -MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( - MONITORING_JOB_DEFINITION_PREFIX, MONITORING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID -) -MONITORING_JOB_NETWORK_CONFIG_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG) -MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( - MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG, ENABLE_NETWORK_ISOLATION -) -MONITORING_JOB_VPC_CONFIG_PATH = _simple_path( - MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG, VPC_CONFIG -) -MONITORING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( - MONITORING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS -) -MONITORING_JOB_SUBNETS_PATH = _simple_path(MONITORING_JOB_VPC_CONFIG_PATH, SUBNETS) -MONITORING_JOB_ROLE_ARN_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, ROLE_ARN) -PIPELINE_ROLE_ARN_PATH = _simple_path(SAGEMAKER, PIPELINE, ROLE_ARN) -PIPELINE_TAGS_PATH = _simple_path(SAGEMAKER, PIPELINE, TAGS) -TRANSFORM_OUTPUT_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, TRANSFORM_JOB, TRANSFORM_OUTPUT, KMS_KEY_ID -) -TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, TRANSFORM_JOB, TRANSFORM_RESOURCES, VOLUME_KMS_KEY_ID -) -TRANSFORM_JOB_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, TRANSFORM_JOB, DATA_CAPTURE_CONFIG, KMS_KEY_ID -) -TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, TRANSFORM_JOB, TRANSFORM_RESOURCES, VOLUME_KMS_KEY_ID -) - -MODEL_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, MODEL, VPC_CONFIG) -MODEL_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(SAGEMAKER, MODEL, ENABLE_NETWORK_ISOLATION) -MODEL_EXECUTION_ROLE_ARN_PATH = _simple_path(SAGEMAKER, MODEL, EXECUTION_ROLE_ARN) -PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path( - SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_NETWORK_ISOLATION -) -PROCESSING_JOB_INPUTS_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, PROCESSING_INPUTS) -REDSHIFT_DATASET_DEFINITION_KMS_KEY_ID_PATH = _simple_path( - DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION, KMS_KEY_ID -) -ATHENA_DATASET_DEFINITION_KMS_KEY_ID_PATH = _simple_path( - DATASET_DEFINITION, ATHENA_DATASET_DEFINITION, KMS_KEY_ID -) -REDSHIFT_DATASET_DEFINITION_CLUSTER_ROLE_ARN_PATH = _simple_path( - DATASET_DEFINITION, REDSHIFT_DATASET_DEFINITION, CLUSTER_ROLE_ARN -) -PROCESSING_JOB_NETWORK_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG) -PROCESSING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, VPC_CONFIG) -PROCESSING_JOB_SUBNETS_PATH = _simple_path(PROCESSING_JOB_VPC_CONFIG_PATH, SUBNETS) -PROCESSING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( - PROCESSING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS -) -PROCESSING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG) -PROCESSING_JOB_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, PROCESSING_JOB, PROCESSING_OUTPUT_CONFIG, KMS_KEY_ID -) -PROCESSING_JOB_PROCESSING_RESOURCES_PATH = _simple_path( - SAGEMAKER, PROCESSING_JOB, PROCESSING_RESOURCES -) -PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, PROCESSING_JOB, PROCESSING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID -) -PROCESSING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, ROLE_ARN) -MODEL_PACKAGE_VALIDATION_ROLE_PATH = _simple_path( - SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_ROLE -) -MODEL_PACKAGE_VALIDATION_PROFILES_PATH = _simple_path( - SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES -) - -# Paths for reference elsewhere in the SDK. -# Names include the schema version since the paths could change with other schema versions -PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, - MONITORING_SCHEDULE, - MONITORING_SCHEDULE_CONFIG, - MONITORING_JOB_DEFINITION, - NETWORK_CONFIG, - ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, -) -PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, - AUTO_ML, - AUTO_ML_JOB_CONFIG, - SECURITY_CONFIG, - ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, -) -PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION -) -PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path( - SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION -) - - class LogState(object): """Placeholder docstring""" @@ -646,20 +484,6 @@ def default_bucket(self): return self._default_bucket - def get_sagemaker_config_value(self, key): - """Util method that fetches a particular key path in the SageMakerConfig and returns it. - - Args: - key: Key Path of the config file entry. - - Returns: - object: The corresponding value in the Config file/ the default value. - """ - config_value = get_config_value(key, self.sagemaker_config.config) - - # Copy the value so any modifications to the output will not modify the source config - return copy.deepcopy(config_value) - def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): """Creates an S3 Bucket if it does not exist. @@ -769,154 +593,6 @@ def _print_message_on_sagemaker_config_usage( # There is no print statement needed if nothing was specified in the config and nothing is # being automatically applied - def resolve_value_from_config( - self, direct_input=None, config_path: str = None, default_value=None - ): - """Makes a decision of which value is the right value for the caller to use. - - Note: This method also incorporates info from the sagemaker config. - - Uses this order of prioritization: - (1) direct_input, (2) config value, (3) default_value, (4) None - - Args: - direct_input: the value that the caller of this method started with. Usually this is an - input to the caller's class or method - config_path (str): a string denoting the path to use to lookup the config value in the - sagemaker config - default_value: the value to use if not present elsewhere - - Returns: - The value that should be used by the caller - """ - config_value = self.get_sagemaker_config_value(config_path) - self._print_message_on_sagemaker_config_usage(direct_input, config_value, config_path) - - if direct_input is not None: - return direct_input - - if config_value is not None: - return config_value - - return default_value - - def resolve_class_attribute_from_config( - self, - clazz: Optional[type], - instance: Optional[object], - attribute: str, - config_path: str, - default_value=None, - ): - """Utility method that merges config values to data classes. - - Takes an instance of a class and, if not already set, sets the instance's attribute to a - value fetched from the sagemaker_config or the default_value. - - Uses this order of prioritization to determine what the value of the attribute should be: - (1) current value of attribute, (2) config value, (3) default_value, (4) does not set it - - Args: - clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the - instance is None. If None is provided here, no new object will be created - if 'instance' doesnt exist. Note: if provided, the constructor should set default - values to None; Otherwise, the constructor's non-None default will be left - as-is even if a config value was defined. - instance (Optional[object]): instance of the Class 'clazz' that has an attribute - of 'attribute' to set - attribute (str): attribute of the instance to set if not already set - config_path (str): a string denoting the path to use to lookup the config value in the - sagemaker config - default_value: the value to use if not present elsewhere - - Returns: - The updated class instance that should be used by the caller instead of the - 'instance' parameter that was passed in. - """ - config_value = self.get_sagemaker_config_value(config_path) - - if config_value is None and default_value is None: - # return instance unmodified. Could be None or populated - return instance - - if instance is None: - if clazz is None or not inspect.isclass(clazz): - return instance - # construct a new instance if the instance does not exist - instance = clazz() - - if not hasattr(instance, attribute): - raise TypeError( - "Unexpected structure of object.", - "Expected attribute {} to be present inside instance {} of class {}".format( - attribute, instance, clazz - ), - ) - - current_value = getattr(instance, attribute) - if current_value is None: - # only set value if object does not already have a value set - if config_value is not None: - setattr(instance, attribute, config_value) - elif default_value is not None: - setattr(instance, attribute, default_value) - - self._print_message_on_sagemaker_config_usage(current_value, config_value, config_path) - - return instance - - def resolve_nested_dict_value_from_config( - self, - dictionary: dict, - nested_keys: List[str], - config_path: str, - default_value: object = None, - ): - """Utility method that sets the value of a key path in a nested dictionary . - - This method takes a dictionary and, if not already set, sets the value for the provided - list of nested keys to the value fetched from the sagemaker_config or the default_value. - - Uses this order of prioritization to determine what the value of the attribute should be: - (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it - - Args: - dictionary: dict to update - nested_keys: path of keys at which the value should be checked (and set if needed) - config_path (str): a string denoting the path to use to lookup the config value in the - sagemaker config - default_value: the value to use if not present elsewhere - - Returns: - The updated dictionary that should be used by the caller instead of the - 'dictionary' parameter that was passed in. - """ - config_value = self.get_sagemaker_config_value(config_path) - - if config_value is None and default_value is None: - # if there is nothing to set, return early. And there is no need to traverse through - # the dictionary or add nested dicts to it - return dictionary - - try: - current_nested_value = get_nested_value(dictionary, nested_keys) - except ValueError as e: - logging.error("Failed to check dictionary for applying sagemaker config: %s", e) - return dictionary - - if current_nested_value is None: - # only set value if not already set - if config_value is not None: - dictionary = set_nested_value(dictionary, nested_keys, config_value) - elif default_value is not None: - dictionary = set_nested_value(dictionary, nested_keys, default_value) - - self._print_message_on_sagemaker_config_usage( - current_nested_value, config_value, config_path - ) - - return dictionary - def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): """Appends tags specified in the sagemaker_config to the given list of tags. @@ -931,7 +607,7 @@ def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): Returns: A potentially extended list of tags. """ - config_tags = self.get_sagemaker_config_value(config_path_to_tags) + config_tags = get_sagemaker_config_value(self, config_path_to_tags) if config_tags is None or len(config_tags) == 0: return tags @@ -1103,25 +779,27 @@ def train( # noqa: C901 tags, "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, TAGS) ) - _encrypt_inter_container_traffic = self.resolve_value_from_config( + _encrypt_inter_container_traffic = resolve_value_from_config( direct_input=encrypt_inter_container_traffic, - config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION, + config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, default_value=False, + sagemaker_session=self, ) - role = self.resolve_value_from_config(role, TRAINING_JOB_ROLE_ARN_PATH) - enable_network_isolation = self.resolve_value_from_config( + role = resolve_value_from_config(role, TRAINING_JOB_ROLE_ARN_PATH, sagemaker_session=self) + enable_network_isolation = resolve_value_from_config( direct_input=enable_network_isolation, config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, default_value=False, + sagemaker_session=self, ) - inferred_vpc_config = self._update_nested_dictionary_with_values_from_config( - vpc_config, TRAINING_JOB_VPC_CONFIG_PATH + inferred_vpc_config = update_nested_dictionary_with_values_from_config( + vpc_config, TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self ) - inferred_output_config = self._update_nested_dictionary_with_values_from_config( - output_config, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH + inferred_output_config = update_nested_dictionary_with_values_from_config( + output_config, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, sagemaker_session=self ) - inferred_resource_config = self._update_nested_dictionary_with_values_from_config( - resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH + inferred_resource_config = update_nested_dictionary_with_values_from_config( + resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self ) train_request = self._get_train_request( input_mode=input_mode, @@ -1469,8 +1147,8 @@ def _update_processing_input_from_config(self, inputs): """ inputs_copy = copy.deepcopy(inputs) - processing_inputs_from_config = self.resolve_value_from_config( - config_path=PROCESSING_JOB_INPUTS_PATH, default_value=[] + processing_inputs_from_config = resolve_value_from_config( + config_path=PROCESSING_JOB_INPUTS_PATH, default_value=[], sagemaker_session=self ) for i in range(min(len(inputs), len(processing_inputs_from_config))): dict_from_inputs = inputs[i] @@ -1551,24 +1229,25 @@ def process( tags, "{}.{}.{}".format(SAGEMAKER, PROCESSING_JOB, TAGS) ) - network_config = self.resolve_nested_dict_value_from_config( + network_config = resolve_nested_dict_value_from_config( network_config, ["EnableInterContainerTrafficEncryption"], - PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, + sagemaker_session=self, ) self._update_processing_input_from_config(inputs) - role_arn = self.resolve_value_from_config(role_arn, PROCESSING_JOB_ROLE_ARN_PATH) - inferred_network_config_from_config = ( - self._update_nested_dictionary_with_values_from_config( - network_config, PROCESSING_JOB_NETWORK_CONFIG_PATH - ) + role_arn = resolve_value_from_config( + role_arn, PROCESSING_JOB_ROLE_ARN_PATH, sagemaker_session=self ) - inferred_output_config = self._update_nested_dictionary_with_values_from_config( - output_config, PROCESSING_OUTPUT_CONFIG_PATH + inferred_network_config_from_config = update_nested_dictionary_with_values_from_config( + network_config, PROCESSING_JOB_NETWORK_CONFIG_PATH, sagemaker_session=self ) - inferred_resources_config = self._update_nested_dictionary_with_values_from_config( - resources, PROCESSING_JOB_PROCESSING_RESOURCES_PATH + inferred_output_config = update_nested_dictionary_with_values_from_config( + output_config, PROCESSING_OUTPUT_CONFIG_PATH, sagemaker_session=self + ) + inferred_resources_config = update_nested_dictionary_with_values_from_config( + resources, PROCESSING_JOB_PROCESSING_RESOURCES_PATH, sagemaker_session=self ) process_request = self._get_process_request( inputs=inputs, @@ -1738,14 +1417,14 @@ def create_monitoring_schedule( tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. """ - role_arn = self.resolve_value_from_config(role_arn, MONITORING_JOB_ROLE_ARN_PATH) - volume_kms_key = self.resolve_value_from_config( - volume_kms_key, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH + role_arn = resolve_value_from_config( + role_arn, MONITORING_JOB_ROLE_ARN_PATH, sagemaker_session=self ) - inferred_network_config_from_config = ( - self._update_nested_dictionary_with_values_from_config( - network_config, MONITORING_JOB_NETWORK_CONFIG_PATH - ) + volume_kms_key = resolve_value_from_config( + volume_kms_key, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, sagemaker_session=self + ) + inferred_network_config_from_config = update_nested_dictionary_with_values_from_config( + network_config, MONITORING_JOB_NETWORK_CONFIG_PATH, sagemaker_session=self ) monitoring_schedule_request = { "MonitoringScheduleName": monitoring_schedule_name, @@ -1771,8 +1450,8 @@ def create_monitoring_schedule( } if monitoring_output_config is not None: - kms_key_from_config = self.resolve_value_from_config( - config_path=MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH + kms_key_from_config = resolve_value_from_config( + config_path=MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, sagemaker_session=self ) if KMS_KEY_ID not in monitoring_output_config and kms_key_from_config: monitoring_output_config[KMS_KEY_ID] = kms_key_from_config @@ -2115,10 +1794,11 @@ def update_monitoring_schedule( ].get("NetworkConfig") _network_config = network_config or existing_network_config - _network_config = self.resolve_nested_dict_value_from_config( + _network_config = resolve_nested_dict_value_from_config( _network_config, ["EnableInterContainerTrafficEncryption"], - PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION, + MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, + sagemaker_session=self, ) if _network_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ @@ -2441,12 +2121,12 @@ def auto_ml( Contains "AutoGenerateEndpointName" and "EndpointName" """ - role = self.resolve_value_from_config(role, AUTO_ML_ROLE_ARN_PATH) - inferred_output_config = self._update_nested_dictionary_with_values_from_config( - output_config, AUTO_ML_OUTPUT_CONFIG_PATH + role = resolve_value_from_config(role, AUTO_ML_ROLE_ARN_PATH, sagemaker_session=self) + inferred_output_config = update_nested_dictionary_with_values_from_config( + output_config, AUTO_ML_OUTPUT_CONFIG_PATH, sagemaker_session=self ) - inferred_automl_job_config = self._update_nested_dictionary_with_values_from_config( - auto_ml_job_config, AUTO_ML_JOB_CONFIG_PATH + inferred_automl_job_config = update_nested_dictionary_with_values_from_config( + auto_ml_job_config, AUTO_ML_JOB_CONFIG_PATH, sagemaker_session=self ) auto_ml_job_request = self._get_auto_ml_request( input_config=input_config, @@ -2720,11 +2400,15 @@ def compile_model( Returns: str: ARN of the compile model job, if it is created. """ - role = self.resolve_value_from_config(role, COMPILATION_JOB_ROLE_ARN_PATH) - inferred_output_model_config = self._update_nested_dictionary_with_values_from_config( - output_model_config, COMPILATION_JOB_OUTPUT_CONFIG_PATH + role = resolve_value_from_config( + role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self + ) + inferred_output_model_config = update_nested_dictionary_with_values_from_config( + output_model_config, COMPILATION_JOB_OUTPUT_CONFIG_PATH, sagemaker_session=self + ) + vpc_config = resolve_value_from_config( + config_path=COMPILATION_JOB_VPC_CONFIG_PATH, sagemaker_session=self ) - vpc_config = self.resolve_value_from_config(config_path=COMPILATION_JOB_VPC_CONFIG_PATH) compilation_job_request = { "InputConfig": input_model_config, "OutputConfig": inferred_output_model_config, @@ -2770,9 +2454,9 @@ def package_model_for_edge( tags (list[dict]): List of tags for labeling a compile model job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ - role = self.resolve_value_from_config(role, EDGE_PACKAGING_ROLE_ARN_PATH) - inferred_output_model_config = self._update_nested_dictionary_with_values_from_config( - output_model_config, EDGE_PACKAGING_OUTPUT_CONFIG_PATH + role = resolve_value_from_config(role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self) + inferred_output_model_config = update_nested_dictionary_with_values_from_config( + output_model_config, EDGE_PACKAGING_OUTPUT_CONFIG_PATH, sagemaker_session=self ) edge_packaging_job_request = { "OutputConfig": inferred_output_model_config, @@ -3530,14 +3214,21 @@ def transform( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, TRANSFORM_JOB, TAGS) ) - batch_data_capture_config = self.resolve_class_attribute_from_config( - None, batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH + batch_data_capture_config = resolve_class_attribute_from_config( + None, + batch_data_capture_config, + "kms_key_id", + TRANSFORM_JOB_KMS_KEY_ID_PATH, + sagemaker_session=self, ) - output_config = self.resolve_nested_dict_value_from_config( - output_config, [KMS_KEY_ID], TRANSFORM_OUTPUT_KMS_KEY_ID_PATH + output_config = resolve_nested_dict_value_from_config( + output_config, [KMS_KEY_ID], TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, sagemaker_session=self ) - resource_config = self.resolve_nested_dict_value_from_config( - resource_config, [VOLUME_KMS_KEY_ID], TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH + resource_config = resolve_nested_dict_value_from_config( + resource_config, + [VOLUME_KMS_KEY_ID], + TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH, + sagemaker_session=self, ) transform_request = self._get_transform_request( @@ -3666,12 +3357,17 @@ def create_model( """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS)) - role = self.resolve_value_from_config(role, MODEL_EXECUTION_ROLE_ARN_PATH) - vpc_config = self.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) - enable_network_isolation = self.resolve_value_from_config( + role = resolve_value_from_config( + role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self + ) + vpc_config = resolve_value_from_config( + vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self + ) + enable_network_isolation = resolve_value_from_config( direct_input=enable_network_isolation, config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=False, + sagemaker_session=self, ) create_model_request = self._create_model_request( name=name, @@ -3745,13 +3441,14 @@ def create_model_from_job( ) name = name or training_job_name role = role or training_job["RoleArn"] - role = self.resolve_value_from_config( - role, MODEL_EXECUTION_ROLE_ARN_PATH, training_job["RoleArn"] + role = resolve_value_from_config( + role, MODEL_EXECUTION_ROLE_ARN_PATH, training_job["RoleArn"], self ) - enable_network_isolation = self.resolve_value_from_config( + enable_network_isolation = resolve_value_from_config( direct_input=enable_network_isolation, config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=False, + sagemaker_session=self, ) env = env or {} primary_container = container_def( @@ -3760,7 +3457,9 @@ def create_model_from_job( env=env, ) vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override) - vpc_config = self.resolve_value_from_config(vpc_config, MODEL_VPC_CONFIG_PATH) + vpc_config = resolve_value_from_config( + vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self + ) return self.create_model( name, role, @@ -3862,13 +3561,16 @@ def create_model_package_from_containers( # not supported by the config now. So if we merge values from config, then API will # throw an exception. In the future, when SageMaker Config starts supporting other # parameters we can add that. - validation_role = self.resolve_value_from_config( + validation_role = resolve_value_from_config( validation_specification.get(VALIDATION_ROLE, None), MODEL_PACKAGE_VALIDATION_ROLE_PATH, + sagemaker_session=self, ) validation_specification[VALIDATION_ROLE] = validation_role - validation_profiles_from_config = self.resolve_value_from_config( - config_path=MODEL_PACKAGE_VALIDATION_PROFILES_PATH, default_value=[] + validation_profiles_from_config = resolve_value_from_config( + config_path=MODEL_PACKAGE_VALIDATION_PROFILES_PATH, + default_value=[], + sagemaker_session=self, ) validation_profiles = validation_specification.get(VALIDATION_PROFILES, []) for i in range(min(len(validation_profiles), len(validation_profiles_from_config))): @@ -4038,8 +3740,10 @@ def create_endpoint_config( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, ) - inferred_production_variants_from_config = self.resolve_value_from_config( - config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, default_value=[] + inferred_production_variants_from_config = resolve_value_from_config( + config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, + default_value=[], + sagemaker_session=self, ) if inferred_production_variants_from_config: inferred_production_variant_from_config = ( @@ -4070,15 +3774,15 @@ def create_endpoint_config( ) if tags is not None: request["Tags"] = tags - kms_key = self.resolve_value_from_config(kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH) + kms_key = resolve_value_from_config( + kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self + ) if kms_key is not None: request["KmsKeyId"] = kms_key if data_capture_config_dict is not None: - inferred_data_capture_config_dict = ( - self._update_nested_dictionary_with_values_from_config( - data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH - ) + inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config( + data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self ) request["DataCaptureConfig"] = inferred_data_capture_config_dict @@ -4138,8 +3842,10 @@ def create_endpoint_config_from_existing( new_production_variants or existing_endpoint_config_desc["ProductionVariants"] ) if production_variants: - inferred_production_variants_from_config = self.resolve_value_from_config( - config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, default_value=[] + inferred_production_variants_from_config = resolve_value_from_config( + config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, + default_value=[], + sagemaker_session=self, ) for i in range( min(len(production_variants), len(inferred_production_variants_from_config)) @@ -4172,8 +3878,8 @@ def create_endpoint_config_from_existing( if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None: request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId") if KMS_KEY_ID not in request: - kms_key_from_config = self.resolve_value_from_config( - config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH + kms_key_from_config = resolve_value_from_config( + config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self ) if kms_key_from_config: request[KMS_KEY_ID] = kms_key_from_config @@ -4183,10 +3889,10 @@ def create_endpoint_config_from_existing( ) if request_data_capture_config_dict is not None: - inferred_data_capture_config_dict = ( - self._update_nested_dictionary_with_values_from_config( - request_data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH - ) + inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config( + request_data_capture_config_dict, + ENDPOINT_CONFIG_DATA_CAPTURE_PATH, + sagemaker_session=self, ) request["DataCaptureConfig"] = inferred_data_capture_config_dict @@ -4194,10 +3900,10 @@ def create_endpoint_config_from_existing( "AsyncInferenceConfig", None ) if async_inference_config_dict is not None: - inferred_async_inference_config_dict = ( - self._update_nested_dictionary_with_values_from_config( - async_inference_config_dict, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH - ) + inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config( + async_inference_config_dict, + ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, + sagemaker_session=self, ) request["AsyncInferenceConfig"] = inferred_async_inference_config_dict @@ -4738,7 +4444,9 @@ def endpoint_from_production_variants( str: The name of the created ``Endpoint``. """ config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} - kms_key = self.resolve_value_from_config(kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH) + kms_key = resolve_value_from_config( + kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self + ) tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) @@ -4748,17 +4456,15 @@ def endpoint_from_production_variants( if kms_key: config_options["KmsKeyId"] = kms_key if data_capture_config_dict is not None: - inferred_data_capture_config_dict = ( - self._update_nested_dictionary_with_values_from_config( - data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH - ) + inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config( + data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self ) config_options["DataCaptureConfig"] = inferred_data_capture_config_dict if async_inference_config_dict is not None: - inferred_async_inference_config_dict = ( - self._update_nested_dictionary_with_values_from_config( - async_inference_config_dict, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH - ) + inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config( + async_inference_config_dict, + ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, + sagemaker_session=self, ) config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict @@ -5198,16 +4904,18 @@ def create_feature_group( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) ) - role_arn = self.resolve_value_from_config(role_arn, FEATURE_GROUP_ROLE_ARN_PATH) - inferred_online_store_from_config = self._update_nested_dictionary_with_values_from_config( - online_store_config, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH + role_arn = resolve_value_from_config( + role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self + ) + inferred_online_store_from_config = update_nested_dictionary_with_values_from_config( + online_store_config, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, sagemaker_session=self ) if inferred_online_store_from_config is not None: # OnlineStore should be handled differently because if you set KmsKeyId, then you # need to set EnableOnlineStore key as well inferred_online_store_from_config["EnableOnlineStore"] = True - inferred_offline_store_from_config = self._update_nested_dictionary_with_values_from_config( - offline_store_config, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH + inferred_offline_store_from_config = update_nested_dictionary_with_values_from_config( + offline_store_config, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, sagemaker_session=self ) kwargs = dict( FeatureGroupName=feature_group_name, @@ -5549,53 +5257,6 @@ def wait_for_athena_query(self, query_execution_id: str, poll: int = 5): else: LOGGER.error("Failed to execute query %s.", query_execution_id) - def _update_nested_dictionary_with_values_from_config( - self, source_dict, config_key_path - ) -> dict: - """Updates a given nested dictionary with missing values which are present in Config. - - Args: - source_dict: The input nested dictionary that was provided as method parameter. - config_key_path: The Key Path in the Config file which corresponds to this - source_dict parameter. - - Returns: - dict: The merged nested dictionary which includes missings values that are present - in the Config file. - """ - inferred_config_dict = self.get_sagemaker_config_value(config_key_path) or {} - original_config_dict_value = copy.deepcopy(inferred_config_dict) - merge_dicts(inferred_config_dict, source_dict or {}) - - if original_config_dict_value == {}: - # The config value is empty. That means either - # (1) inferred_config_dict equals source_dict, or - # (2) if source_dict was None, inferred_config_dict equals {} - # We should return whatever source_dict was to be safe. Because if for example, - # a VpcConfig is set to {} instead of None, some boto calls will fail due to - # ParamValidationError (because a VpcConfig was specified but required parameters for - # the VpcConfig were missing.) - - # Dont need to print because no config value was used or defined - return source_dict - - if source_dict == inferred_config_dict: - # We didnt use any values from the config, but we should print if any of the config - # values were defined - self._print_message_on_sagemaker_config_usage( - source_dict, original_config_dict_value, config_key_path - ) - else: - # Something from the config was merged in - print( - "[Sagemaker Config - applied value]\n", - "config key = {}\n".format(config_key_path), - "config value = {}\n".format(original_config_dict_value), - "source value = {}\n".format(source_dict), - "combined value that will be used = {}\n".format(inferred_config_dict), - ) - return inferred_config_dict - def download_athena_query_result( self, bucket: str, diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index f9bd16369b..1dc891ea82 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -19,15 +19,14 @@ import time from botocore import exceptions -from sagemaker.job import _Job -from sagemaker.session import ( - Session, - get_execution_role, +from sagemaker.config import ( TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, TRANSFORM_JOB_KMS_KEY_ID_PATH, TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, PIPELINE_ROLE_ARN_PATH, ) +from sagemaker.job import _Job +from sagemaker.session import Session, get_execution_role from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.functions import Join @@ -38,6 +37,8 @@ base_name_from_image, name_from_base, check_and_get_run_experiment_config, + resolve_value_from_config, + resolve_class_attribute_from_config, ) @@ -126,11 +127,15 @@ def __init__( self._reset_output_path = False self.sagemaker_session = sagemaker_session or Session() - self.volume_kms_key = self.sagemaker_session.resolve_value_from_config( - volume_kms_key, TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH + self.volume_kms_key = resolve_value_from_config( + volume_kms_key, + TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, + sagemaker_session=sagemaker_session, ) - self.output_kms_key = self.sagemaker_session.resolve_value_from_config( - output_kms_key, TRANSFORM_OUTPUT_KMS_KEY_ID_PATH + self.output_kms_key = resolve_value_from_config( + output_kms_key, + TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) @runnable_by_pipeline @@ -268,8 +273,12 @@ def transform( experiment_config = check_and_get_run_experiment_config(experiment_config) - batch_data_capture_config = self.sagemaker_session.resolve_class_attribute_from_config( - None, batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH + batch_data_capture_config = resolve_class_attribute_from_config( + None, + batch_data_capture_config, + "kms_key_id", + TRANSFORM_JOB_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, ) self.latest_transform_job = _TransformJob.start_new( @@ -394,8 +403,12 @@ def transform_with_monitoring( transformer.sagemaker_session = PipelineSession() self.sagemaker_session = sagemaker_session - batch_data_capture_config = self.sagemaker_session.resolve_class_attribute_from_config( - None, batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH + batch_data_capture_config = resolve_class_attribute_from_config( + None, + batch_data_capture_config, + "kms_key_id", + TRANSFORM_JOB_KMS_KEY_ID_PATH, + sagemaker_session=sagemaker_session, ) transform_step_args = transformer.transform( @@ -439,8 +452,10 @@ def transform_with_monitoring( pipeline_role_arn = ( role if role - else transformer.sagemaker_session.resolve_value_from_config( - get_execution_role(), PIPELINE_ROLE_ARN_PATH + else resolve_value_from_config( + get_execution_role(), + PIPELINE_ROLE_ARN_PATH, + sagemaker_session=transformer.sagemaker_session, ) ) pipeline.upsert(pipeline_role_arn) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 762fbf0d3e..1073d72b54 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -14,7 +14,9 @@ from __future__ import absolute_import import contextlib +import copy import errno +import inspect import logging import os import random @@ -31,6 +33,7 @@ from importlib import import_module import botocore +from botocore.utils import merge_dicts from six.moves.urllib import parse from sagemaker import deprecations @@ -1030,3 +1033,271 @@ def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None return experiment_config return run_obj.experiment_config if run_obj else None + + +def resolve_value_from_config( + direct_input=None, + config_path: str = None, + default_value=None, + sagemaker_session=None, +): + """Makes a decision of which value is the right value for the caller to use. + + Note: This method also incorporates info from the sagemaker config. + + Uses this order of prioritization: + (1) direct_input, (2) config value, (3) default_value, (4) None + + Args: + direct_input: the value that the caller of this method started with. Usually this is an + input to the caller's class or method + config_path (str): a string denoting the path to use to lookup the config value in the + sagemaker config + default_value: the value to use if not present elsewhere + sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for + SageMaker interactions (default: None). + + Returns: + The value that should be used by the caller + """ + config_value = ( + get_sagemaker_config_value(sagemaker_session, config_path) if config_path else None + ) + _print_message_on_sagemaker_config_usage(direct_input, config_value, config_path) + + if direct_input is not None: + return direct_input + + if config_value is not None: + return config_value + + return default_value + + +def get_sagemaker_config_value(sagemaker_session, key): + """Util method that fetches a particular key path in the SageMakerConfig and returns it. + + Args: + key: Key Path of the config file entry. + sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for + SageMaker interactions. + + Returns: + object: The corresponding value in the Config file/ the default value. + """ + if not sagemaker_session: + return None + config_value = get_config_value(key, sagemaker_session.sagemaker_config.config) + # Copy the value so any modifications to the output will not modify the source config + return copy.deepcopy(config_value) + + +def _print_message_on_sagemaker_config_usage(direct_input, config_value, config_path: str): + """Informs the SDK user whether a config value was present and automatically substituted + + Args: + direct_input: the value that would be used if no sagemaker_config or default values + existed. Usually this will be user-provided input to a Class or to a + session.py method, or None if no input was provided. + config_value: the value fetched from sagemaker_config. This is usually the value that + will be used if direct_input is None. + config_path: a string denoting the path of keys that point to the config value in the + sagemaker_config + + Returns: + No output (just prints information) + """ + + if config_value is not None: + + if direct_input is not None and config_value != direct_input: + # Sagemaker Config had a value defined that is NOT going to be used + # and the config value has not already been applied earlier + print( + "[Sagemaker Config - skipped value]\n", + "config key = {}\n".format(config_path), + "config value = {}\n".format(config_value), + "specified value that will be used = {}\n".format(direct_input), + ) + + elif direct_input is None: + # Sagemaker Config value is going to be used + print( + "[Sagemaker Config - applied value]\n", + "config key = {}\n".format(config_path), + "config value that will be used = {}\n".format(config_value), + ) + + # There is no print statement needed if nothing was specified in the config and nothing is + # being automatically applied + + +def resolve_class_attribute_from_config( + clazz: Optional[type], + instance: Optional[object], + attribute: str, + config_path: str, + default_value=None, + sagemaker_session=None, +): + """Utility method that merges config values to data classes. + + Takes an instance of a class and, if not already set, sets the instance's attribute to a + value fetched from the sagemaker_config or the default_value. + + Uses this order of prioritization to determine what the value of the attribute should be: + (1) current value of attribute, (2) config value, (3) default_value, (4) does not set it + + Args: + clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the + instance is None. If None is provided here, no new object will be created + if 'instance' doesnt exist. Note: if provided, the constructor should set default + values to None; Otherwise, the constructor's non-None default will be left + as-is even if a config value was defined. + instance (Optional[object]): instance of the Class 'clazz' that has an attribute + of 'attribute' to set + attribute (str): attribute of the instance to set if not already set + config_path (str): a string denoting the path to use to lookup the config value in the + sagemaker config + default_value: the value to use if not present elsewhere + sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for + SageMaker interactions (default: None). + + Returns: + The updated class instance that should be used by the caller instead of the + 'instance' parameter that was passed in. + """ + config_value = get_sagemaker_config_value(sagemaker_session, config_path) + + if config_value is None and default_value is None: + # return instance unmodified. Could be None or populated + return instance + + if instance is None: + if clazz is None or not inspect.isclass(clazz): + return instance + # construct a new instance if the instance does not exist + instance = clazz() + + if not hasattr(instance, attribute): + raise TypeError( + "Unexpected structure of object.", + "Expected attribute {} to be present inside instance {} of class {}".format( + attribute, instance, clazz + ), + ) + + current_value = getattr(instance, attribute) + if current_value is None: + # only set value if object does not already have a value set + if config_value is not None: + setattr(instance, attribute, config_value) + elif default_value is not None: + setattr(instance, attribute, default_value) + + _print_message_on_sagemaker_config_usage(current_value, config_value, config_path) + + return instance + + +def resolve_nested_dict_value_from_config( + dictionary: dict, + nested_keys: List[str], + config_path: str, + default_value: object = None, + sagemaker_session=None, +): + """Utility method that sets the value of a key path in a nested dictionary . + + This method takes a dictionary and, if not already set, sets the value for the provided + list of nested keys to the value fetched from the sagemaker_config or the default_value. + + Uses this order of prioritization to determine what the value of the attribute should be: + (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it + + Args: + dictionary: dict to update + nested_keys: path of keys at which the value should be checked (and set if needed) + config_path (str): a string denoting the path to use to lookup the config value in the + sagemaker config + default_value: the value to use if not present elsewhere + sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for + SageMaker interactions (default: None). + + Returns: + The updated dictionary that should be used by the caller instead of the + 'dictionary' parameter that was passed in. + """ + config_value = get_sagemaker_config_value(sagemaker_session, config_path) + + if config_value is None and default_value is None: + # if there is nothing to set, return early. And there is no need to traverse through + # the dictionary or add nested dicts to it + return dictionary + + try: + current_nested_value = get_nested_value(dictionary, nested_keys) + except ValueError as e: + logging.error("Failed to check dictionary for applying sagemaker config: %s", e) + return dictionary + + if current_nested_value is None: + # only set value if not already set + if config_value is not None: + dictionary = set_nested_value(dictionary, nested_keys, config_value) + elif default_value is not None: + dictionary = set_nested_value(dictionary, nested_keys, default_value) + + _print_message_on_sagemaker_config_usage(current_nested_value, config_value, config_path) + + return dictionary + + +def update_nested_dictionary_with_values_from_config( + source_dict, config_key_path, sagemaker_session=None +) -> dict: + """Updates a given nested dictionary with missing values which are present in Config. + + Args: + source_dict: The input nested dictionary that was provided as method parameter. + config_key_path: The Key Path in the Config file which corresponds to this + source_dict parameter. + sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for + SageMaker interactions (default: None). + + Returns: + dict: The merged nested dictionary which includes missings values that are present + in the Config file. + """ + inferred_config_dict = get_sagemaker_config_value(sagemaker_session, config_key_path) or {} + original_config_dict_value = copy.deepcopy(inferred_config_dict) + merge_dicts(inferred_config_dict, source_dict or {}) + + if original_config_dict_value == {}: + # The config value is empty. That means either + # (1) inferred_config_dict equals source_dict, or + # (2) if source_dict was None, inferred_config_dict equals {} + # We should return whatever source_dict was to be safe. Because if for example, + # a VpcConfig is set to {} instead of None, some boto calls will fail due to + # ParamValidationError (because a VpcConfig was specified but required parameters for + # the VpcConfig were missing.) + + # Don't need to print because no config value was used or defined + return source_dict + + if source_dict == inferred_config_dict: + # We didn't use any values from the config, but we should print if any of the config + # values were defined + _print_message_on_sagemaker_config_usage( + source_dict, original_config_dict_value, config_key_path + ) + else: + # Something from the config was merged in + print( + "[Sagemaker Config - applied value]\n", + "config key = {}\n".format(config_key_path), + "config value = {}\n".format(original_config_dict_value), + "source value = {}\n".format(source_dict), + "combined value that will be used = {}\n".format(inferred_config_dict), + ) + return inferred_config_dict diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 41f234e1c0..6e3449e1a4 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -25,8 +25,9 @@ from sagemaker import s3 from sagemaker._studio import _append_project_tags -from sagemaker.session import Session, PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH -from sagemaker.utils import retry_with_backoff +from sagemaker.config import PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH +from sagemaker.session import Session +from sagemaker.utils import resolve_value_from_config, retry_with_backoff from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep from sagemaker.workflow.entities import ( @@ -127,8 +128,8 @@ def create( Returns: A response dict from the service. """ - role_arn = self.sagemaker_session.resolve_value_from_config( - role_arn, PIPELINE_ROLE_ARN_PATH + role_arn = resolve_value_from_config( + role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) if not role_arn: # Originally IAM role was a required parameter. @@ -221,8 +222,8 @@ def update( Returns: A response dict from the service. """ - role_arn = self.sagemaker_session.resolve_value_from_config( - role_arn, PIPELINE_ROLE_ARN_PATH + role_arn = resolve_value_from_config( + role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) if not role_arn: # Originally IAM role was a required parameter. @@ -261,8 +262,8 @@ def upsert( Returns: response dict from service """ - role_arn = self.sagemaker_session.resolve_value_from_config( - role_arn, PIPELINE_ROLE_ARN_PATH + role_arn = resolve_value_from_config( + role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session ) if not role_arn: # Originally IAM role was a required parameter. diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 39d055091d..f12d543b9c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -71,12 +71,6 @@ def sagemaker_session(boto_session, client): default_bucket=_DEFAULT_BUCKET, sagemaker_metrics_client=client, ) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) return session diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index ddcedc24e1..cf5e669fa6 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -273,13 +273,8 @@ def sagemaker_session(): ) sms.list_candidates = Mock(name="list_candidates", return_value={"Candidates": []}) sms.sagemaker_client.list_tags = Mock(name="list_tags", return_value=LIST_TAGS_RESULT) - # For the purposes of unit tests, no values should be fetched from sagemaker config - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py index 9cd5c0f6c4..a53f1da680 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -53,14 +53,9 @@ def s3_uri(): @pytest.fixture def sagemaker_session_mock(): - session_mock = Mock() - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) - return session_mock + sagemaker_session_mock = Mock() + sagemaker_session_mock.sagemaker_config.config = {} + return sagemaker_session_mock @pytest.fixture diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 2fac0ca735..4ed014f577 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -75,13 +75,8 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index 46f8fcb775..e20228aa4f 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -50,17 +50,8 @@ def sagemaker_session(): session_mock.upload_data = Mock(name="upload_data", return_value=MOCKED_S3_URI) session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = ROLE - # For the purposes of unit tests, no values should be fetched from sagemaker config - session_mock.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session_mock.sagemaker_config.config = {} return session_mock diff --git a/tests/unit/sagemaker/image_uris/jumpstart/conftest.py b/tests/unit/sagemaker/image_uris/jumpstart/conftest.py index a67dca6ab4..386d8df94c 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/conftest.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/conftest.py @@ -31,4 +31,6 @@ def session(): settings=SessionSettings(), ) sms.default_bucket = Mock(return_value=BUCKET_NAME) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py index fe22ad7d8e..f8b0a7c581 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py @@ -24,6 +24,8 @@ @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_catboost_image_uri(patched_get_model_specs, session): + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py index 43a69e71f1..0bd0d381ea 100644 --- a/tests/unit/sagemaker/local/test_local_pipeline.py +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -140,24 +140,18 @@ def pipeline_session(boto_session, client): sagemaker_client=client, default_bucket=BUCKET, ) - pipeline_session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + pipeline_session.sagemaker_config = Mock() + pipeline_session.sagemaker_config.config = {} return pipeline_session_mock @pytest.fixture() def local_sagemaker_session(boto_session): local_session_mock = LocalSession(boto_session=boto_session, default_bucket="my-bucket") - local_session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + local_session_mock.sagemaker_config = Mock() + local_session_mock.sagemaker_config.config = {} return local_session_mock diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 16894dfbc2..6d39a88a35 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -68,12 +68,8 @@ @pytest.fixture def sagemaker_session(): session = Mock() - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session @@ -443,18 +439,8 @@ def test_deploy_wrong_serverless_config(sagemaker_session): @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_deploy_creates_correct_session(local_session, session): - local_session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + local_session.sagemaker_config.config = {} + session.sagemaker_config.config = {} # We expect a LocalSession when deploying to instance_type = 'local' model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) diff --git a/tests/unit/sagemaker/model/test_edge.py b/tests/unit/sagemaker/model/test_edge.py index 21fb6d710a..1e8a110ce8 100644 --- a/tests/unit/sagemaker/model/test_edge.py +++ b/tests/unit/sagemaker/model/test_edge.py @@ -31,12 +31,8 @@ @pytest.fixture def sagemaker_session(): session = Mock(boto_region_name=REGION) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index 36171477b5..d2c364c087 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -94,12 +94,8 @@ def sagemaker_session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 692bab1c4f..70506338c2 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -109,12 +109,8 @@ def sagemaker_session(): s3_resource=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 068bc164fc..4633b10eb5 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -59,12 +59,8 @@ def sagemaker_session(): session.sagemaker_client.describe_model_package = Mock( return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE ) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index 1e135e221e..baa0aef4b0 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -35,12 +35,8 @@ @pytest.fixture def sagemaker_session(): session = Mock(boto_region_name=REGION) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index fc47144df2..221b26a513 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -416,22 +416,8 @@ def sagemaker_session(sagemaker_client): session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) - - # For the purposes of unit tests, no values should be fetched from sagemaker config - session_mock.resolve_nested_dict_value_from_config = Mock( - name="resolve_nested_dict_value_from_config", - side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, - ) - session_mock.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session_mock.sagemaker_config.config = {} return session_mock diff --git a/tests/unit/sagemaker/monitor/test_data_capture_config.py b/tests/unit/sagemaker/monitor/test_data_capture_config.py index 8587851266..24c7a1ddd3 100644 --- a/tests/unit/sagemaker/monitor/test_data_capture_config.py +++ b/tests/unit/sagemaker/monitor/test_data_capture_config.py @@ -56,12 +56,7 @@ def test_init_when_non_defaults_provided(): def test_init_when_optionals_not_provided(): sagemaker_session = Mock() sagemaker_session.default_bucket.return_value = DEFAULT_BUCKET_NAME - sagemaker_session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + sagemaker_session.sagemaker_config.config = {} data_capture_config = DataCaptureConfig( enable_capture=DEFAULT_ENABLE_CAPTURE, sagemaker_session=sagemaker_session diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index f258c8ab54..b1cf601c01 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -463,21 +463,8 @@ def sagemaker_session(): ) session_mock.expand_role.return_value = ROLE - # For the purposes of unit tests, no values should be fetched from sagemaker config - session_mock.resolve_nested_dict_value_from_config = Mock( - name="resolve_nested_dict_value_from_config", - side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, - ) - session_mock.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session_mock.sagemaker_config.config = {} session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 1de1d304d5..548e0643f6 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -82,22 +82,8 @@ def sagemaker_session(): session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) - - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_nested_dict_value_from_config = Mock( - name="resolve_nested_dict_value_from_config", - side_effect=lambda dictionary, nested_keys, config_path, default_value=None: dictionary, - ) - session.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py index caa08f8e3c..1ba953b76b 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py @@ -42,12 +42,8 @@ def sagemaker_session(): describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_init.py b/tests/unit/sagemaker/tensorflow/test_estimator_init.py index 71764b0e93..238f405765 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_init.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_init.py @@ -25,7 +25,9 @@ @pytest.fixture() def sagemaker_session(): - return Mock(name="sagemaker_session", boto_region_name=REGION) + session_mock = Mock(name="sagemaker_session", boto_region_name=REGION) + session_mock.sagemaker_config.config = {} + return session_mock def _build_tf(sagemaker_session, **kwargs): diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index ecbdf87a2e..b403a7f07b 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -67,12 +67,8 @@ def sagemaker_session(): session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 40e3241333..59adb31f71 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -82,13 +82,8 @@ def fixture_sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index a0abeb2f29..110066bd49 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -79,13 +79,8 @@ def fixture_sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 643dc6337c..87d614fd20 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -81,13 +81,8 @@ def fixture_sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 5f67a4c46b..1c3b4c252d 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -86,13 +86,8 @@ def fixture_sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py index c8d9c44320..91318b610f 100644 --- a/tests/unit/sagemaker/workflow/conftest.py +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -64,10 +64,7 @@ def pipeline_session(mock_boto_session, mock_client): sagemaker_client=mock_client, default_bucket=BUCKET, ) - pipeline_session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + pipeline_session.sagemaker_config = Mock() + pipeline_session.sagemaker_config.config = {} return pipeline_session diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index e53efbb30b..9ee09054f5 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -42,17 +42,8 @@ def sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session._default_bucket = BUCKET_NAME - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 9b4e0d006a..97e7f17549 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -46,12 +46,8 @@ def sagemaker_session_mock(): session_mock = Mock() session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket") session_mock.local_mode = False - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session_mock.sagemaker_config.config = {} session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) @@ -74,10 +70,10 @@ def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock): def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock): - sagemaker_session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: "ConfigRoleArn", - ) + # For tests which doesn't verify config file injection, operate with empty config + sagemaker_session_mock.sagemaker_config.config = { + "SageMaker": {"Pipeline": {"RoleArn": "ConfigRoleArn"}} + } sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = { "PipelineArn": "pipeline-arn" } diff --git a/tests/unit/sagemaker/wrangler/test_processing.py b/tests/unit/sagemaker/wrangler/test_processing.py index 2f0d9a4875..d13292ad9e 100644 --- a/tests/unit/sagemaker/wrangler/test_processing.py +++ b/tests/unit/sagemaker/wrangler/test_processing.py @@ -41,17 +41,8 @@ def sagemaker_session(): ) session_mock.expand_role.return_value = ROLE - # For the purposes of unit tests, no values should be fetched from sagemaker config - session_mock.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session_mock.sagemaker_config.config = {} return session_mock diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index cc23cabfa6..fb443c6432 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -909,12 +909,6 @@ def test_algorithm_enable_network_isolation_with_product_id(session): @patch("sagemaker.Session") def test_algorithm_encrypt_inter_container_traffic(session): - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["encrypt_inter_container_traffic"] = True @@ -954,12 +948,7 @@ def test_algorithm_no_required_hyperparameters(session): def test_algorithm_attach_from_hyperparameter_tuning(): session = Mock() - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + session.sagemaker_config.config = {} job_name = "training-job-that-is-part-of-a-tuning-job" algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees" role_arn = "arn:aws:iam::123412341234:role/SageMakerRole" diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 0d51c7261f..e8e71f97ee 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -74,12 +74,8 @@ def sagemaker_session(): sms.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=returned_job_description ) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 16e7958821..7d9c24cf3d 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -74,13 +74,8 @@ def sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index cdec02e194..cfed30f076 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -238,13 +238,8 @@ def sagemaker_session(): sms.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) sms.upload_data = Mock(return_value=OUTPUT_PATH) - # For the purposes of unit tests, no values should be fetched from sagemaker config - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms @@ -1783,6 +1778,8 @@ def test_local_code_location(): local_mode=True, spec=sagemaker.local.LocalSession, ) + sms.sagemaker_config = Mock() + sms.sagemaker_config.config = {} t = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -3745,9 +3742,13 @@ def test_register_under_pipeline_session(pipeline_session): def test_local_mode(session_class, local_session_class): local_session = Mock(spec=sagemaker.local.LocalSession) local_session.local_mode = True + local_session.sagemaker_config = Mock() + local_session.sagemaker_config.config = {} session = Mock() session.local_mode = False + session.sagemaker_config = Mock() + session.sagemaker_config.config = {} local_session_class.return_value = local_session session_class.return_value = session @@ -3773,6 +3774,8 @@ def test_local_mode_file_output_path(local_session_class): local_session = Mock(spec=sagemaker.local.LocalSession) local_session.local_mode = True local_session_class.return_value = local_session + local_session.sagemaker_config = Mock() + local_session.sagemaker_config.config = {} e = Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, "local", output_path="file:///tmp/model/") assert e.output_path == "file:///tmp/model/" @@ -3999,12 +4002,8 @@ def test_estimator_local_mode_error(sagemaker_session): def test_estimator_local_mode_ok(sagemaker_local_session): - sagemaker_local_session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + sagemaker_local_session.sagemaker_config = Mock() + sagemaker_local_session.sagemaker_config.config = {} # When using instance local with a session which is not LocalSession we should error out Estimator( image_uri="some-image", diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index db0f5c7075..0f1cca2d87 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -68,12 +68,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index 0d9d4a1e8f..4197ca77da 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -65,12 +65,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index f9dc9f1c40..5e0d7748bb 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -80,12 +80,8 @@ def sagemaker_session(): name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None ) mock_session.expand_role = Mock(name="expand_role", return_value=ROLE) - mock_session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + mock_session.sagemaker_config.config = {} return mock_session diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index 13ef923fc2..0ac457a6a7 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -62,12 +62,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 2d630dc5c5..828bd73cd0 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -68,12 +68,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index e1e9110923..8ef73047e5 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -57,12 +57,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index d6a823674c..d055d20066 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -63,12 +63,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_multidatamodel.py b/tests/unit/test_multidatamodel.py index 9777d1505d..c0a6bb7ed3 100644 --- a/tests/unit/test_multidatamodel.py +++ b/tests/unit/test_multidatamodel.py @@ -80,12 +80,8 @@ def sagemaker_session(): name="upload_data", return_value=os.path.join(VALID_MULTI_MODEL_DATA_PREFIX, "mleap_model.tar.gz"), ) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} s3_mock = Mock() boto_mock.client("s3").return_value = s3_mock diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 0e3bad1dea..d6dcb339fb 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -101,13 +101,8 @@ def sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 20aaa53590..93f9a13336 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -62,12 +62,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index f72bce977e..623a616c3f 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -70,12 +70,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 555aa508dc..f12a527f37 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -62,12 +62,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 7d4dd8389c..7f7ca3b00e 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -74,12 +74,8 @@ def sagemaker_session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index f3a7652bba..0149e17fcd 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -84,17 +84,8 @@ def sagemaker_session(): name="describe_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session_mock.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session_mock.sagemaker_config.config = {} return session_mock @@ -117,19 +108,10 @@ def pipeline_session(): session_mock.describe_processing_job = MagicMock( name="describe_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) - session_mock.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) session_mock.__class__ = PipelineSession - # For the purposes of unit tests, no values should be fetched from sagemaker config - session_mock.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) + # For tests which doesn't verify config file injection, operate with empty config + session_mock.sagemaker_config.config = {} return session_mock diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index a1c9ffde3b..d31ddcd587 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -83,13 +83,8 @@ def fixture_sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index 31ab045ffe..e7662ce88a 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -62,12 +62,8 @@ def sagemaker_session(): ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index f4369f09e6..c4b66d5490 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -76,13 +76,8 @@ def fixture_sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 91750907c3..92147b7f17 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4513,13 +4513,14 @@ def test_append_sagemaker_config_tags(sagemaker_session): def sort(tags): return tags.sort(key=lambda tag: tag["Key"]) - sagemaker_session.get_sagemaker_config_value = MagicMock( - return_value=[ - {"Key": "tagkey1", "Value": "tagvalue1"}, - {"Key": "tagkey2", "Value": "tagvalue2"}, - {"Key": "tagkey3", "Value": "tagvalue3"}, - ] - ) + config_tag_value = [ + {"Key": "tagkey1", "Value": "tagvalue1"}, + {"Key": "tagkey2", "Value": "tagvalue2"}, + {"Key": "tagkey3", "Value": "tagvalue3"}, + ] + + sagemaker_session.sagemaker_config = Mock() + sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": config_tag_value}}} base_case = sagemaker_session._append_sagemaker_config_tags(tags_base, "DUMMY.CONFIG.PATH") assert sort(base_case) == sort( @@ -4561,7 +4562,9 @@ def sort(tags): ] ) - sagemaker_session.get_sagemaker_config_value = MagicMock(return_value=tags_none) + sagemaker_session.sagemaker_config.config = { + "DUMMY": {"CONFIG": {"OTHER_PATH": config_tag_value}} + } config_tags_none = sagemaker_session._append_sagemaker_config_tags( tags_base, "DUMMY.CONFIG.PATH" ) @@ -4572,7 +4575,7 @@ def sort(tags): ] ) - sagemaker_session.get_sagemaker_config_value = MagicMock(return_value=tags_empty) + sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": config_tag_value}}} config_tags_empty = sagemaker_session._append_sagemaker_config_tags( tags_base, "DUMMY.CONFIG.PATH" ) @@ -4582,275 +4585,3 @@ def sort(tags): {"Key": "tagkey5", "Value": "000"}, ] ) - - -def test_resolve_value_from_config(sagemaker_session): - # using a shorter name for inside the test - ss = sagemaker_session - - # direct_input should be respected - ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") - assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "INPUT" - - ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") - assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" - - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ss.resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None) == "INPUT" - - # Config or default values should be returned if no direct_input - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ss.resolve_value_from_config(None, None, "DEFAULT_VALUE") == "DEFAULT_VALUE" - - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ( - ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "DEFAULT_VALUE" - ) - - ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") - assert ( - ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE") == "CONFIG_VALUE" - ) - - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ss.resolve_value_from_config(None, None, None) is None - - # Different falsy direct_inputs - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ss.resolve_value_from_config("", "DUMMY.CONFIG.PATH", None) == "" - - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ss.resolve_value_from_config([], "DUMMY.CONFIG.PATH", None) == [] - - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ss.resolve_value_from_config(False, "DUMMY.CONFIG.PATH", None) is False - - ss.get_sagemaker_config_value = MagicMock(return_value=None) - assert ss.resolve_value_from_config({}, "DUMMY.CONFIG.PATH", None) == {} - - # Different falsy config_values - ss.get_sagemaker_config_value = MagicMock(return_value="") - assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == "" - - ss.get_sagemaker_config_value = MagicMock(return_value=[]) - assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == [] - - ss.get_sagemaker_config_value = MagicMock(return_value=False) - assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) is False - - ss.get_sagemaker_config_value = MagicMock(return_value={}) - assert ss.resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None) == {} - - -@pytest.mark.parametrize( - "existing_value, config_value, default_value", - [ - ("EXISTING_VALUE", "CONFIG_VALUE", "DEFAULT_VALUE"), - (False, True, False), - (False, False, True), - (0, 1, 2), - ], -) -def test_resolve_class_attribute_from_config( - sagemaker_session, existing_value, config_value, default_value -): - # using a shorter name for inside the test - ss = sagemaker_session - - class TestClass(object): - def __init__(self, test_attribute=None, extra=None): - self.test_attribute = test_attribute - # the presence of an extra value that is set to None by default helps make sure a brand new - # TestClass object is being created only in the right scenarios - self.extra_attribute = extra - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - else: - return False - - dummy_config_path = ["DUMMY", "CONFIG", "PATH"] - - # with an existing config value - ss.get_sagemaker_config_value = MagicMock(return_value=config_value) - - # instance exists and has value; config has value - test_instance = TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") - assert ss.resolve_class_attribute_from_config( - TestClass, test_instance, "test_attribute", dummy_config_path - ) == TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") - - # instance exists but doesnt have value; config has value - test_instance = TestClass(extra="EXTRA_VALUE") - assert ss.resolve_class_attribute_from_config( - TestClass, test_instance, "test_attribute", dummy_config_path - ) == TestClass(test_attribute=config_value, extra="EXTRA_VALUE") - - # instance doesnt exist; config has value - test_instance = None - assert ss.resolve_class_attribute_from_config( - TestClass, test_instance, "test_attribute", dummy_config_path - ) == TestClass(test_attribute=config_value, extra=None) - - # wrong attribute used - test_instance = TestClass() - with pytest.raises(TypeError): - ss.resolve_class_attribute_from_config( - TestClass, test_instance, "other_attribute", dummy_config_path - ) - - # instance doesnt exist; clazz doesnt exist - test_instance = None - assert ( - ss.resolve_class_attribute_from_config( - None, test_instance, "test_attribute", dummy_config_path - ) - is None - ) - - # instance doesnt exist; clazz isnt a class - test_instance = None - assert ( - ss.resolve_class_attribute_from_config( - "CLASS", test_instance, "test_attribute", dummy_config_path - ) - is None - ) - - # without an existing config value - ss.get_sagemaker_config_value = MagicMock(return_value=None) - - # instance exists but doesnt have value; config doesnt have value - test_instance = TestClass(extra="EXTRA_VALUE") - assert ss.resolve_class_attribute_from_config( - TestClass, test_instance, "test_attribute", dummy_config_path - ) == TestClass(test_attribute=None, extra="EXTRA_VALUE") - - # instance exists but doesnt have value; config doesnt have value; default_value passed in - test_instance = TestClass(extra="EXTRA_VALUE") - assert ss.resolve_class_attribute_from_config( - TestClass, test_instance, "test_attribute", dummy_config_path, default_value=default_value - ) == TestClass(test_attribute=default_value, extra="EXTRA_VALUE") - - # instance doesnt exist; config doesnt have value - test_instance = None - assert ( - ss.resolve_class_attribute_from_config( - TestClass, test_instance, "test_attribute", dummy_config_path - ) - is None - ) - - # instance doesnt exist; config doesnt have value; default_value passed in - test_instance = None - assert ss.resolve_class_attribute_from_config( - TestClass, test_instance, "test_attribute", dummy_config_path, default_value=default_value - ) == TestClass(test_attribute=default_value, extra=None) - - -def test_resolve_nested_dict_value_from_config(sagemaker_session): - # using a shorter name for inside the test - ss = sagemaker_session - - dummy_config_path = ["DUMMY", "CONFIG", "PATH"] - - # with an existing config value - ss.get_sagemaker_config_value = MagicMock(return_value="CONFIG_VALUE") - - # happy cases: return existing dict with existing values - assert ss.resolve_nested_dict_value_from_config( - {"local": {"region_name": "us-west-2", "port": "123"}}, - ["local", "region_name"], - dummy_config_path, - default_value="DEFAULT_VALUE", - ) == {"local": {"region_name": "us-west-2", "port": "123"}} - assert ss.resolve_nested_dict_value_from_config( - {"local": {"region_name": "us-west-2", "port": "123"}}, - ["local", "region_name"], - dummy_config_path, - default_value=None, - ) == {"local": {"region_name": "us-west-2", "port": "123"}} - - # happy case: return dict with config_value when it wasnt set in dict or was None - assert ss.resolve_nested_dict_value_from_config( - {"local": {"port": "123"}}, - ["local", "region_name"], - dummy_config_path, - default_value="DEFAULT_VALUE", - ) == {"local": {"region_name": "CONFIG_VALUE", "port": "123"}} - assert ss.resolve_nested_dict_value_from_config( - {}, ["local", "region_name"], dummy_config_path, default_value=None - ) == {"local": {"region_name": "CONFIG_VALUE"}} - assert ss.resolve_nested_dict_value_from_config( - None, ["local", "region_name"], dummy_config_path, default_value=None - ) == {"local": {"region_name": "CONFIG_VALUE"}} - assert ss.resolve_nested_dict_value_from_config( - { - "local": {"region_name": "us-west-2", "port": "123"}, - "other": {"key": 1}, - "nest1": {"nest2": {"nest3": {"nest4a": "value", "nest4b": None}}}, - }, - ["nest1", "nest2", "nest3", "nest4b", "does_not", "exist"], - dummy_config_path, - default_value="DEFAULT_VALUE", - ) == { - "local": {"region_name": "us-west-2", "port": "123"}, - "other": {"key": 1}, - "nest1": { - "nest2": { - "nest3": {"nest4a": "value", "nest4b": {"does_not": {"exist": "CONFIG_VALUE"}}} - } - }, - } - - # edge case: doesnt overwrite non-None and non-dict values - dictionary = { - "local": {"region_name": "us-west-2", "port": "123"}, - "other": {"key": 1}, - "nest1": {"nest2": {"nest3": {"nest4a": "value", "nest4b": None}}}, - } - dictionary_copy = copy.deepcopy(dictionary) - assert ( - ss.resolve_nested_dict_value_from_config( - dictionary, - ["nest1", "nest2", "nest3", "nest4a", "does_not", "exist"], - dummy_config_path, - default_value="DEFAULT_VALUE", - ) - == dictionary_copy - ) - assert ( - ss.resolve_nested_dict_value_from_config( - dictionary, ["other", "key"], dummy_config_path, default_value="DEFAULT_VALUE" - ) - == dictionary_copy - ) - - # without an existing config value - ss.get_sagemaker_config_value = MagicMock(return_value=None) - - # happy case: return dict with default_value when it wasnt set in dict and in config - assert ss.resolve_nested_dict_value_from_config( - {"local": {"port": "123"}}, - ["local", "region_name"], - dummy_config_path, - default_value="DEFAULT_VALUE", - ) == {"local": {"region_name": "DEFAULT_VALUE", "port": "123"}} - - # happy case: return dict as-is when value wasnt set in dict, in config, and as default - assert ss.resolve_nested_dict_value_from_config( - {"local": {"port": "123"}}, ["local", "region_name"], dummy_config_path, default_value=None - ) == {"local": {"port": "123"}} - assert ( - ss.resolve_nested_dict_value_from_config( - {}, ["local", "region_name"], dummy_config_path, default_value=None - ) - == {} - ) - assert ( - ss.resolve_nested_dict_value_from_config( - None, ["local", "region_name"], dummy_config_path, default_value=None - ) - is None - ) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 7095095a2d..d7d9f45c48 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -78,13 +78,8 @@ def sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index 3fb21d62d2..f81ba3dde0 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -46,6 +46,7 @@ def sagemaker_session(): sms.boto_region_name = REGION sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py index bded6ce2cc..09d2669876 100644 --- a/tests/unit/test_timeout.py +++ b/tests/unit/test_timeout.py @@ -60,6 +60,8 @@ def session(): settings=SessionSettings(), ) sms.default_bucket = Mock(name=DEFAULT_BUCKET_NAME, return_value=BUCKET_NAME) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 8fadcab52f..240c894fa2 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -66,18 +66,8 @@ def mock_create_tar_file(): def sagemaker_session(): boto_mock = Mock(name="boto_session") session = Mock(name="sagemaker_session", boto_session=boto_mock, local_mode=False) - - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_class_attribute_from_config = Mock( - name="resolve_class_attribute_from_config", - side_effect=lambda clazz, instance, attribute, config_path, default_value=None: instance, - ) - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 1d6599c0ff..e6d50eeee9 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -66,13 +66,8 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - # For the purposes of unit tests, no values should be fetched from sagemaker config - sms.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + sms.sagemaker_config.config = {} return sms diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2e57328339..0a539e9fe5 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -31,7 +31,13 @@ import sagemaker from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings -from sagemaker.utils import retry_with_backoff, check_and_get_run_experiment_config +from sagemaker.utils import ( + retry_with_backoff, + check_and_get_run_experiment_config, + resolve_value_from_config, + resolve_class_attribute_from_config, + resolve_nested_dict_value_from_config, +) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -979,6 +985,310 @@ def test_retry_with_backoff(patched_sleep): assert retry_with_backoff(callable_func, 2) == func_return_val +def test_resolve_value_from_config(): + # using a shorter name for inside the test + sagemaker_session = MagicMock() + sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": "CONFIG_VALUE"}}} + + # direct_input should be respected + assert ( + resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", "DEFAULT_VALUE", sagemaker_session) + == "INPUT" + ) + + assert ( + resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None, sagemaker_session) == "INPUT" + ) + + assert ( + resolve_value_from_config("INPUT", "DUMMY.CONFIG.INVALID_PATH", None, sagemaker_session) + == "INPUT" + ) + + # Config or default values should be returned if no direct_input + assert ( + resolve_value_from_config(None, None, "DEFAULT_VALUE", sagemaker_session) == "DEFAULT_VALUE" + ) + + assert ( + resolve_value_from_config( + None, "DUMMY.CONFIG.INVALID_PATH", "DEFAULT_VALUE", sagemaker_session + ) + == "DEFAULT_VALUE" + ) + + assert ( + resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE", sagemaker_session) + == "CONFIG_VALUE" + ) + + assert resolve_value_from_config(None, None, None, sagemaker_session) is None + + # Different falsy direct_inputs + assert resolve_value_from_config("", "DUMMY.CONFIG.PATH", None, sagemaker_session) == "" + + assert resolve_value_from_config([], "DUMMY.CONFIG.PATH", None, sagemaker_session) == [] + + assert resolve_value_from_config(False, "DUMMY.CONFIG.PATH", None, sagemaker_session) is False + + assert resolve_value_from_config({}, "DUMMY.CONFIG.PATH", None, sagemaker_session) == {} + + # Different falsy config_values + sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": ""}}} + assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) == "" + + sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": []}}} + assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) == [] + + sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": False}}} + assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) is False + + sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": {}}}} + assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) == {} + + +@pytest.mark.parametrize( + "existing_value, config_value, default_value", + [ + ("EXISTING_VALUE", "CONFIG_VALUE", "DEFAULT_VALUE"), + (False, True, False), + (False, False, True), + (0, 1, 2), + ], +) +def test_resolve_class_attribute_from_config(existing_value, config_value, default_value): + # using a shorter name for inside the test + ss = MagicMock() + + class TestClass(object): + def __init__(self, test_attribute=None, extra=None): + self.test_attribute = test_attribute + # the presence of an extra value that is set to None by default helps make sure a brand new + # TestClass object is being created only in the right scenarios + self.extra_attribute = extra + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + else: + return False + + dummy_config_path = "DUMMY.CONFIG.PATH" + + # with an existing config value + ss.sagemaker_config = Mock() + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": config_value}}} + + # instance exists and has value; config has value + test_instance = TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") + assert resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path, sagemaker_session=ss + ) == TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") + + # instance exists but doesnt have value; config has value + test_instance = TestClass(extra="EXTRA_VALUE") + assert resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path, sagemaker_session=ss + ) == TestClass(test_attribute=config_value, extra="EXTRA_VALUE") + + # instance doesnt exist; config has value + test_instance = None + assert resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path, sagemaker_session=ss + ) == TestClass(test_attribute=config_value, extra=None) + + # wrong attribute used + test_instance = TestClass() + with pytest.raises(TypeError): + resolve_class_attribute_from_config( + TestClass, test_instance, "other_attribute", dummy_config_path, sagemaker_session=ss + ) + + # instance doesnt exist; clazz doesnt exist + test_instance = None + assert ( + resolve_class_attribute_from_config( + None, test_instance, "test_attribute", dummy_config_path, sagemaker_session=ss + ) + is None + ) + + # instance doesnt exist; clazz isnt a class + test_instance = None + assert ( + resolve_class_attribute_from_config( + "CLASS", test_instance, "test_attribute", dummy_config_path, sagemaker_session=ss + ) + is None + ) + + # without an existing config value + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"SOMEOTHERPATH": config_value}}} + # instance exists but doesnt have value; config doesnt have value + test_instance = TestClass(extra="EXTRA_VALUE") + assert resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path, sagemaker_session=ss + ) == TestClass(test_attribute=None, extra="EXTRA_VALUE") + + # instance exists but doesnt have value; config doesnt have value; default_value passed in + test_instance = TestClass(extra="EXTRA_VALUE") + assert resolve_class_attribute_from_config( + TestClass, + test_instance, + "test_attribute", + dummy_config_path, + default_value=default_value, + sagemaker_session=ss, + ) == TestClass(test_attribute=default_value, extra="EXTRA_VALUE") + + # instance doesnt exist; config doesnt have value + test_instance = None + assert ( + resolve_class_attribute_from_config( + TestClass, test_instance, "test_attribute", dummy_config_path, sagemaker_session=ss + ) + is None + ) + + # instance doesnt exist; config doesnt have value; default_value passed in + test_instance = None + assert resolve_class_attribute_from_config( + TestClass, + test_instance, + "test_attribute", + dummy_config_path, + default_value=default_value, + sagemaker_session=ss, + ) == TestClass(test_attribute=default_value, extra=None) + + +def test_resolve_nested_dict_value_from_config(): + # using a shorter name for inside the test + ss = MagicMock() + + dummy_config_path = "DUMMY.CONFIG.PATH" + # happy cases: return existing dict with existing values + assert resolve_nested_dict_value_from_config( + {"local": {"region_name": "us-west-2", "port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value="DEFAULT_VALUE", + sagemaker_session=ss, + ) == {"local": {"region_name": "us-west-2", "port": "123"}} + assert resolve_nested_dict_value_from_config( + {"local": {"region_name": "us-west-2", "port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value=None, + sagemaker_session=ss, + ) == {"local": {"region_name": "us-west-2", "port": "123"}} + + # happy case: return dict with config_value when it wasnt set in dict or was None + ss.sagemaker_config = Mock() + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": "CONFIG_VALUE"}}} + assert resolve_nested_dict_value_from_config( + {"local": {"port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value="DEFAULT_VALUE", + sagemaker_session=ss, + ) == {"local": {"region_name": "CONFIG_VALUE", "port": "123"}} + assert resolve_nested_dict_value_from_config( + {}, ["local", "region_name"], dummy_config_path, default_value=None, sagemaker_session=ss + ) == {"local": {"region_name": "CONFIG_VALUE"}} + assert resolve_nested_dict_value_from_config( + None, ["local", "region_name"], dummy_config_path, default_value=None, sagemaker_session=ss + ) == {"local": {"region_name": "CONFIG_VALUE"}} + assert resolve_nested_dict_value_from_config( + { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": {"nest2": {"nest3": {"nest4a": "value", "nest4b": None}}}, + }, + ["nest1", "nest2", "nest3", "nest4b", "does_not", "exist"], + dummy_config_path, + default_value="DEFAULT_VALUE", + sagemaker_session=ss, + ) == { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": { + "nest2": { + "nest3": {"nest4a": "value", "nest4b": {"does_not": {"exist": "CONFIG_VALUE"}}} + } + }, + } + + # edge case: doesnt overwrite non-None and non-dict values + dictionary = { + "local": {"region_name": "us-west-2", "port": "123"}, + "other": {"key": 1}, + "nest1": {"nest2": {"nest3": {"nest4a": "value", "nest4b": None}}}, + } + dictionary_copy = copy.deepcopy(dictionary) + assert ( + resolve_nested_dict_value_from_config( + dictionary, + ["nest1", "nest2", "nest3", "nest4a", "does_not", "exist"], + dummy_config_path, + default_value="DEFAULT_VALUE", + sagemaker_session=ss, + ) + == dictionary_copy + ) + assert ( + resolve_nested_dict_value_from_config( + dictionary, + ["other", "key"], + dummy_config_path, + default_value="DEFAULT_VALUE", + sagemaker_session=ss, + ) + == dictionary_copy + ) + + # without an existing config value + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"ANOTHER_PATH": "CONFIG_VALUE"}}} + + # happy case: return dict with default_value when it wasnt set in dict and in config + assert resolve_nested_dict_value_from_config( + {"local": {"port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value="DEFAULT_VALUE", + sagemaker_session=ss, + ) == {"local": {"region_name": "DEFAULT_VALUE", "port": "123"}} + + # happy case: return dict as-is when value wasnt set in dict, in config, and as default + assert resolve_nested_dict_value_from_config( + {"local": {"port": "123"}}, + ["local", "region_name"], + dummy_config_path, + default_value=None, + sagemaker_session=ss, + ) == {"local": {"port": "123"}} + assert ( + resolve_nested_dict_value_from_config( + {}, + ["local", "region_name"], + dummy_config_path, + default_value=None, + sagemaker_session=ss, + ) + == {} + ) + assert ( + resolve_nested_dict_value_from_config( + None, + ["local", "region_name"], + dummy_config_path, + default_value=None, + sagemaker_session=ss, + ) + is None + ) + + def test_check_and_get_run_experiment_config(): supplied_exp_cfg = {"ExperimentName": "my-supplied-exp-name", "RunName": "my-supplied-run-name"} run_exp_cfg = {"ExperimentName": "my-run-exp-name", "RunName": "my-run-run-name"} diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 2c35ad8584..76ad51896f 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -80,14 +80,8 @@ def sagemaker_session(): session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - - # For the purposes of unit tests, no values should be fetched from sagemaker config - session.resolve_value_from_config = Mock( - name="resolve_value_from_config", - side_effect=lambda direct_input=None, config_path=None, default_value=None: direct_input - if direct_input is not None - else default_value, - ) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config.config = {} return session diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index 8e12f11e8d..215cc80d26 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -72,6 +72,8 @@ ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"} SAGEMAKER_SESSION = Mock() +# For tests which doesn't verify config file injection, operate with empty config +SAGEMAKER_SESSION.sagemaker_config.config = {} ESTIMATOR = Estimator( IMAGE_NAME, From f02131b60b9038d178271f84cf328253e8de82ce Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Tue, 21 Mar 2023 12:52:16 -0700 Subject: [PATCH 27/40] change: Add a separate helper to merge list of objects --- src/sagemaker/session.py | 143 +++++++++++-------------------------- src/sagemaker/utils.py | 95 ++++++++++++++++++++++++ tests/unit/test_session.py | 53 +++++++++----- tests/unit/test_utils.py | 122 +++++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 119 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 401f5c15de..2d48eac179 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -13,7 +13,6 @@ """Placeholder docstring""" from __future__ import absolute_import, print_function -import copy import json import logging import os @@ -28,7 +27,6 @@ import boto3 import botocore import botocore.config -from botocore.utils import merge_dicts from botocore.exceptions import ClientError import six @@ -109,6 +107,7 @@ resolve_class_attribute_from_config, resolve_nested_dict_value_from_config, update_nested_dictionary_with_values_from_config, + update_list_of_dicts_with_values_from_config, ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings @@ -1138,39 +1137,6 @@ def _get_update_training_job_request( return update_training_job_request - # TODO: unit tests or make a more generic version - def _update_processing_input_from_config(self, inputs): - """Updates Processor Inputs to fetch values from SageMakerConfig wherever applicable. - - Args: - inputs (list[dict]): A list of Processing Input objects. - - """ - inputs_copy = copy.deepcopy(inputs) - processing_inputs_from_config = resolve_value_from_config( - config_path=PROCESSING_JOB_INPUTS_PATH, default_value=[], sagemaker_session=self - ) - for i in range(min(len(inputs), len(processing_inputs_from_config))): - dict_from_inputs = inputs[i] - dict_from_config = processing_inputs_from_config[i] - - # The Dataset Definition input must specify exactly one of either - # AthenaDatasetDefinition or RedshiftDatasetDefinition types (source: API reference). - # So to prevent API failure because of sagemaker_config, we will only populate from the - # config for the ones already present in dict_from_inputs. - # If BOTH are present, we will still add to both and let the API call fail as it would - # have even without injection from sagemaker_config. - merge_dicts(dict_from_config, dict_from_inputs) - inputs[i] = dict_from_config - if processing_inputs_from_config: - print( - "[Sagemaker Config - applied value]\n", - "config key = {}\n".format(PROCESSING_JOB_INPUTS_PATH), - "config value = {}\n".format(processing_inputs_from_config), - "source value = {}\n".format(inputs_copy), - "combined value that will be used = {}\n".format(inputs), - ) - def process( self, inputs, @@ -1235,8 +1201,20 @@ def process( PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, sagemaker_session=self, ) - - self._update_processing_input_from_config(inputs) + # Processing Input can either have AthenaDatasetDefinition or RedshiftDatasetDefinition + # or neither, but not both + union_key_paths_for_dataset_definition = [ + [ + "DatasetDefinition.AthenaDatasetDefinition", + "DatasetDefinition.RedshiftDatasetDefinition", + ] + ] + update_list_of_dicts_with_values_from_config( + inputs, + PROCESSING_JOB_INPUTS_PATH, + union_key_paths=union_key_paths_for_dataset_definition, + sagemaker_session=self, + ) role_arn = resolve_value_from_config( role_arn, PROCESSING_JOB_ROLE_ARN_PATH, sagemaker_session=self ) @@ -3567,28 +3545,13 @@ def create_model_package_from_containers( sagemaker_session=self, ) validation_specification[VALIDATION_ROLE] = validation_role - validation_profiles_from_config = resolve_value_from_config( - config_path=MODEL_PACKAGE_VALIDATION_PROFILES_PATH, - default_value=[], + validation_profiles = validation_specification.get(VALIDATION_PROFILES, []) + update_list_of_dicts_with_values_from_config( + validation_profiles, + MODEL_PACKAGE_VALIDATION_PROFILES_PATH, + required_key_paths=["ProfileName", "TransformJobDefinition"], sagemaker_session=self, ) - validation_profiles = validation_specification.get(VALIDATION_PROFILES, []) - for i in range(min(len(validation_profiles), len(validation_profiles_from_config))): - original_config_dict_value = copy.deepcopy(validation_profiles_from_config[i]) - merge_dicts(validation_profiles_from_config[i], validation_profiles[i]) - if validation_profiles[i] != validation_profiles_from_config[i]: - print( - "Config value {} at config path {} was fetched first for " - "index {}.".format( - original_config_dict_value, - MODEL_PACKAGE_VALIDATION_PROFILES_PATH, - i, - ), - "It was then merged with the existing value {} to give {}".format( - validation_profiles[i], validation_profiles_from_config[i] - ), - ) - validation_profiles[i].update(validation_profiles_from_config[i]) model_pkg_request = get_create_model_package_request( model_package_name, model_package_group_name, @@ -3740,32 +3703,20 @@ def create_endpoint_config( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, ) - inferred_production_variants_from_config = resolve_value_from_config( - config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, - default_value=[], + production_variants = [provided_production_variant] + # Currently we just inject CoreDumpConfig.KmsKeyId from the config for production variant. + # But if that parameter is injected, then CoreDumpConfig.DestinationS3Uri needs to be + # present. + # But SageMaker Python SDK doesn't support CoreDumpConfig.DestinationS3Uri. + update_list_of_dicts_with_values_from_config( + production_variants, + ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, + required_key_paths=["CoreDumpConfig.DestinationS3Uri"], sagemaker_session=self, ) - if inferred_production_variants_from_config: - inferred_production_variant_from_config = ( - inferred_production_variants_from_config[0] or {} - ) - original_config_dict_value = inferred_production_variant_from_config.copy() - merge_dicts(inferred_production_variant_from_config, provided_production_variant) - if provided_production_variant != inferred_production_variant_from_config: - print( - "Config value {} at config path {} was fetched first for " - "index: 0.".format( - original_config_dict_value, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH - ), - "It was then merged with the existing value {} to give {}".format( - provided_production_variant, inferred_production_variant_from_config - ), - ) - provided_production_variant.update(inferred_production_variant_from_config) - request = { "EndpointConfigName": name, - "ProductionVariants": [provided_production_variant], + "ProductionVariants": production_variants, } tags = _append_project_tags(tags) @@ -3841,28 +3792,12 @@ def create_endpoint_config_from_existing( production_variants = ( new_production_variants or existing_endpoint_config_desc["ProductionVariants"] ) - if production_variants: - inferred_production_variants_from_config = resolve_value_from_config( - config_path=ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, - default_value=[], - sagemaker_session=self, - ) - for i in range( - min(len(production_variants), len(inferred_production_variants_from_config)) - ): - original_config_dict_value = inferred_production_variants_from_config[i].copy() - merge_dicts(inferred_production_variants_from_config[i], production_variants[i]) - if production_variants[i] != inferred_production_variants_from_config[i]: - print( - "Config value {} at config path {} was fetched first for " - "index: 0.".format( - original_config_dict_value, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH - ), - "It was then merged with the existing value {} to give {}".format( - production_variants[i], inferred_production_variants_from_config[i] - ), - ) - production_variants[i].update(inferred_production_variants_from_config[i]) + update_list_of_dicts_with_values_from_config( + production_variants, + ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, + required_key_paths=["CoreDumpConfig.DestinationS3Uri"], + sagemaker_session=self, + ) request["ProductionVariants"] = production_variants request_tags = new_tags or self.list_tags( @@ -4443,6 +4378,12 @@ def endpoint_from_production_variants( Returns: str: The name of the created ``Endpoint``. """ + update_list_of_dicts_with_values_from_config( + production_variants, + ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, + required_key_paths=["CoreDumpConfig.DestinationS3Uri"], + sagemaker_session=self, + ) config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} kms_key = resolve_value_from_config( kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 1073d72b54..d349035858 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1253,6 +1253,101 @@ def resolve_nested_dict_value_from_config( return dictionary +def update_list_of_dicts_with_values_from_config( + input_list, + config_key_path, + required_key_paths: List[str] = None, + union_key_paths: List[List[str]] = None, + sagemaker_session=None, +): + """Helper method for updating Lists with corresponding values in the Config + + In some cases, config file might introduce new parameters which requires certain other + parameters to be provided as part of the input list. Without those parameters, the underlying + service will throw an exception. This method provides the capability to specify required key + paths. + + In some other cases, config file might introduce new parameters but the service API requires + either an existing parameter or the new parameter that was supplied by config but not both + + Args: + input_list: The input list that was provided as a method parameter. + config_key_path: The Key Path in the Config file that corresponds to the input_list + parameter. + required_key_paths (List[str]): List of required key paths that should be verified in the + merged output. If a required key path is missing, we will not perform the merge for that + item. + union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify + whether exactly zero/one of the parameters exist. + For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter but + not both, then pass [['X1', 'X2']] + sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for + SageMaker interactions (default: None). + + Returns: + No output. In place merge happens. + """ + if not input_list: + return + inputs_copy = copy.deepcopy(input_list) + inputs_from_config = resolve_value_from_config( + config_path=config_key_path, default_value=[], sagemaker_session=sagemaker_session + ) + for i in range(min(len(input_list), len(inputs_from_config))): + dict_from_inputs = input_list[i] + dict_from_config = inputs_from_config[i] + merge_dicts(dict_from_config, dict_from_inputs) + # Check if required key paths are present in merged dict (dict_from_config) + required_key_path_check_failed = _validate_required_paths_in_a_dict( + dict_from_config, required_key_paths + ) + if required_key_path_check_failed: + # Don't do the merge, config is introducing a new parameter which needs a + # corresponding required parameter. + continue + union_key_path_check_failed = _validate_union_key_paths_in_a_dict( + dict_from_config, union_key_paths + ) + if union_key_path_check_failed: + # Don't do the merge, Union parameters are not obeyed. + continue + input_list[i] = dict_from_config + if inputs_from_config: + print( + "[Sagemaker Config - applied value]\n", + "config key = {}\n".format(config_key_path), + "config value = {}\n".format(inputs_from_config), + "source value = {}\n".format(inputs_copy), + "combined value that will be used = {}\n".format(input_list), + ) + + +def _validate_required_paths_in_a_dict(source_dict, required_key_paths: List[str] = None) -> bool: + """Placeholder docstring""" + if not required_key_paths: + return False + for required_key_path in required_key_paths: + if get_config_value(required_key_path, source_dict) is None: + return True + return False + + +def _validate_union_key_paths_in_a_dict( + source_dict, union_key_paths: List[List[str]] = None +) -> bool: + """Placeholder docstring""" + if not union_key_paths: + return False + for union_key_path in union_key_paths: + union_parameter_present = False + for key_path in union_key_path: + if get_config_value(key_path, source_dict): + if union_parameter_present: + return True + union_parameter_present = True + return False + + def update_nested_dictionary_with_values_from_config( source_dict, config_key_path, sagemaker_session=None ) -> dict: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92147b7f17..1ac8d2782f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -263,7 +263,13 @@ def test_process(boto_session): def test_create_process_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_PROCESSING_JOB + processing_job_config = copy.deepcopy(SAGEMAKER_CONFIG_PROCESSING_JOB) + # deleting RedshiftDatasetDefinition. API can take either RedshiftDatasetDefinition or + # AthenaDatasetDefinition + del processing_job_config["SageMaker"]["ProcessingJob"]["ProcessingInputs"][0][ + "DatasetDefinition" + ]["RedshiftDatasetDefinition"] + sagemaker_config_session.sagemaker_config.config = processing_job_config processing_inputs = [ { @@ -276,13 +282,8 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_config_session "S3DataDistributionType": "FullyReplicated", "S3CompressionType": "None", }, - # DatasetDefinition and (AthenaDatasetDefinition or RedshiftDatasetDefinition) need - # to be present for config injection. Included both AthenaDatasetDefinition and - # RedshiftDatasetDefinition at the same time to test injection (even though this API - # call would fail normally) "DatasetDefinition": { "AthenaDatasetDefinition": {}, - "RedshiftDatasetDefinition": {}, }, } ] @@ -368,7 +369,7 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_config_session "ExperimentConfig": {"ExperimentName": "AnExperiment"}, } ) - expected_request["ProcessingInputs"][0]["DatasetDefinition"] = SAGEMAKER_CONFIG_PROCESSING_JOB[ + expected_request["ProcessingInputs"][0]["DatasetDefinition"] = processing_job_config[ "SageMaker" ]["ProcessingJob"]["ProcessingInputs"][0]["DatasetDefinition"] expected_request["ProcessingOutputConfig"]["KmsKeyId"] = expected_output_kms_key_id @@ -2520,9 +2521,6 @@ def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config "local", data_capture_config_dict=data_capture_config_dict, ) - expected_production_variant_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ - "EndpointConfig" - ]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ "EndpointConfig" ]["DataCaptureConfig"]["KmsKeyId"] @@ -2535,7 +2533,6 @@ def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config EndpointConfigName="endpoint-test", ProductionVariants=[ { - "CoreDumpConfig": {"KmsKeyId": expected_production_variant_kms_key_id}, "ModelName": "simple-model", "VariantName": "AllTraffic", "InitialVariantWeight": 1, @@ -2562,6 +2559,8 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( sagemaker.production_variant("B", "ml.p2.xlarge"), sagemaker.production_variant("C", "ml.p2.xlarge"), ] + # Add DestinationS3Uri to only one production variant + pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"} existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo" existing_endpoint_name = "foo" new_endpoint_name = "new-foo" @@ -2579,9 +2578,6 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ "EndpointConfig" ]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] - expected_production_variant_1_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ - "EndpointConfig" - ]["ProductionVariants"][1]["CoreDumpConfig"]["KmsKeyId"] expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ "AsyncInferenceConfig" ]["OutputConfig"]["KmsKeyId"] @@ -2594,11 +2590,14 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( EndpointConfigName=new_endpoint_name, ProductionVariants=[ { - "CoreDumpConfig": {"KmsKeyId": expected_production_variant_0_kms_key_id}, + "CoreDumpConfig": { + "KmsKeyId": expected_production_variant_0_kms_key_id, + "DestinationS3Uri": pvs[0]["CoreDumpConfig"]["DestinationS3Uri"], + }, **sagemaker.production_variant("A", "ml.p2.xlarge"), }, { - "CoreDumpConfig": {"KmsKeyId": expected_production_variant_1_kms_key_id}, + # Merge shouldn't happen because input for this index doesn't have DestinationS3Uri **sagemaker.production_variant("B", "ml.p2.xlarge"), }, sagemaker.production_variant("C", "ml.p2.xlarge"), @@ -2622,6 +2621,8 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection( sagemaker.production_variant("B", "ml.p2.xlarge"), sagemaker.production_variant("C", "ml.p2.xlarge"), ] + # Add DestinationS3Uri to only one production variant + pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"} sagemaker_config_session.endpoint_from_production_variants( "some-endpoint", pvs, @@ -2641,9 +2642,22 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection( expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict() expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id + expected_pvs = [ + sagemaker.production_variant("A", "ml.p2.xlarge"), + sagemaker.production_variant("B", "ml.p2.xlarge"), + sagemaker.production_variant("C", "ml.p2.xlarge"), + ] + # Add DestinationS3Uri, KmsKeyId to only one production variant + expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ + "EndpointConfig" + ]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] + expected_pvs[0]["CoreDumpConfig"] = { + "DestinationS3Uri": "s3://test", + "KmsKeyId": expected_production_variant_0_kms_key_id, + } sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName="some-endpoint", - ProductionVariants=pvs, + ProductionVariants=expected_pvs, Tags=expected_tags, # from config KmsKeyId=expected_kms_key_id, # from config AsyncInferenceConfig=expected_async_inference_config_dict, @@ -3429,7 +3443,10 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_config_s } } validation_profiles = [ - {"TransformJobDefinition": {"TransformOutput": {"S3OutputPath": "s3://test"}}} + { + "ProfileName": "profileName", + "TransformJobDefinition": {"TransformOutput": {"S3OutputPath": "s3://test"}}, + } ] validation_specification = {"ValidationProfiles": validation_profiles} diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 0a539e9fe5..80fb4af05c 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -37,6 +37,7 @@ resolve_value_from_config, resolve_class_attribute_from_config, resolve_nested_dict_value_from_config, + update_list_of_dicts_with_values_from_config, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -106,6 +107,127 @@ def test_get_nested_value(): assert sagemaker.utils.get_nested_value(dictionary, []) is None +def test_update_list_of_dicts_with_values_from_config(): + input_list = [{"a": 1, "b": 2}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + } + ] + # Using short form for sagemaker_session + ss = MagicMock() + ss.sagemaker_config = Mock() + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + config_path = "DUMMY.CONFIG.PATH" + # happy case - both inputs and config have same number of elements + update_list_of_dicts_with_values_from_config(input_list, config_path, sagemaker_session=ss) + assert input_list == [{"a": 1, "b": 2, "c": 3}] + # Case where Input has more entries compared to Config + input_list = [ + {"a": 1, "b": 2}, + {"a": 5, "b": 6}, + ] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + } + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config(input_list, config_path, sagemaker_session=ss) + assert input_list == [ + {"a": 1, "b": 2, "c": 3}, + {"a": 5, "b": 6}, + ] + # Case where Config has more entries when compared to the input + input_list = [{"a": 1, "b": 2}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + }, + {"a": 5, "b": 6}, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config(input_list, config_path, sagemaker_session=ss) + assert input_list == [{"a": 1, "b": 2, "c": 3}] + # Testing required parameters. If required parameters are not present, don't do the merge + input_list = [{"a": 1, "b": 2}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + }, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config( + input_list, config_path, required_key_paths=["d"], sagemaker_session=ss + ) + # since 'd' is not there , merge shouldn't have happened + assert input_list == [{"a": 1, "b": 2}] + # Testing required parameters. If required parameters are present, do the merge + input_list = [{"a": 1, "b": 2}, {"a": 5, "c": 6}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + }, + { + "a": 7, # This should not be used. Use values from Input. + "b": 8, + }, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config( + input_list, config_path, required_key_paths=["c"], sagemaker_session=ss + ) + assert input_list == [ + {"a": 1, "b": 2, "c": 3}, + {"a": 5, "b": 8, "c": 6}, + ] + # Testing union parameters: If both parameters are present don't do the merge + input_list = [{"a": 1, "b": 2}, {"a": 5, "c": 6}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + }, + { + "a": 7, # This should not be used. Use values from Input. + "d": 8, # c is present in the original list and d is present in this list. + }, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config( + input_list, config_path, union_key_paths=[["c", "d"]], sagemaker_session=ss + ) + assert input_list == [ + {"a": 1, "b": 2, "c": 3}, + {"a": 5, "c": 6}, # merge didn't happen + ] + # Testing union parameters: Happy case + input_list = [{"a": 1, "b": 2}, {"a": 5, "c": 6}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + }, + { + "a": 7, # This should not be used. Use values from Input. + "d": 8, + }, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config( + input_list, config_path, union_key_paths=[["c", "e"], ["d", "e"]], sagemaker_session=ss + ) + assert input_list == [ + {"a": 1, "b": 2, "c": 3}, + {"a": 5, "c": 6, "d": 8}, + ] + + def test_set_nested_value(): dictionary = { "local": {"region_name": "us-west-2", "port": "123"}, From ada7ddc532c2bd3efba34126b635f78d35692415 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Tue, 21 Mar 2023 15:02:18 -0700 Subject: [PATCH 28/40] fix: Documentation updates for SageMakerConfig --- src/sagemaker/automl/automl.py | 2 +- src/sagemaker/automl/candidate_estimator.py | 2 +- src/sagemaker/config/config.py | 113 ++++++++---------- src/sagemaker/dataset_definition/inputs.py | 80 ------------- src/sagemaker/estimator.py | 2 +- src/sagemaker/feature_store/feature_group.py | 2 +- src/sagemaker/inputs.py | 21 +--- src/sagemaker/model.py | 2 +- .../model_monitor/model_monitoring.py | 2 +- src/sagemaker/pipeline.py | 2 +- src/sagemaker/processing.py | 2 +- src/sagemaker/session.py | 14 +-- src/sagemaker/transformer.py | 4 +- src/sagemaker/utils.py | 88 +++++++------- src/sagemaker/workflow/pipeline.py | 7 +- tests/data/config/config.yaml | 68 +++++------ .../expected_output_config_after_merge.yaml | 6 +- .../sample_additional_config_for_merge.yaml | 4 +- .../data/config/sample_config_for_merge.yaml | 4 +- tests/integ/test_sagemaker_config.py | 6 - tests/unit/sagemaker/config/conftest.py | 48 ++++---- tests/unit/test_djl_inference.py | 3 + 22 files changed, 182 insertions(+), 300 deletions(-) diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 9a9db0a307..236f6e0a10 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -224,7 +224,7 @@ def __init__( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating AutoML jobs.") + raise ValueError("An AWS IAM role is required to create an AutoML job.") self.encrypt_inter_container_traffic = resolve_value_from_config( direct_input=encrypt_inter_container_traffic, diff --git a/src/sagemaker/automl/candidate_estimator.py b/src/sagemaker/automl/candidate_estimator.py index 3ec5f6995b..73d8be5ba7 100644 --- a/src/sagemaker/automl/candidate_estimator.py +++ b/src/sagemaker/automl/candidate_estimator.py @@ -93,7 +93,7 @@ def fit( the ML compute instance(s). encrypt_inter_container_traffic (bool): To encrypt all communications between ML compute instances in distributed training. If not passed, will be fetched from - sagemaker_config. Default: False. + sagemaker_config if a value is defined there. Default: False. vpc_config (dict): Specifies a VPC that jobs and hosted models have access to. Control access to and from training and model containers by configuring the VPC wait (bool): Whether the call should wait until all jobs completes (default: True). diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 66640c7f1a..7da7b8f274 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -10,10 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module configures the default values for SageMaker Python SDK. +"""This module configures the default values defined by the user for SageMaker Python SDK calls. -It supports loading Config files from local file system/S3. -The schema of the Config file is dictated in config_schema.py in the same module. +It supports loading config files from the local file system and Amazon S3. +The schema of the config file is dictated in config_schema.py in the same module. """ from __future__ import absolute_import @@ -33,10 +33,10 @@ logger = logging.getLogger("sagemaker") _APP_NAME = "sagemaker" -# The default config file location of the Administrator provided Config file. This path can be +# The default config file location of the Administrator provided config file. This path can be # overridden with `SAGEMAKER_ADMIN_CONFIG_OVERRIDE` environment variable. _DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml") -# The default config file location of the user provided Config file. This path can be +# The default config file location of the user provided config file. This path can be # overridden with `SAGEMAKER_USER_CONFIG_OVERRIDE` environment variable. _DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml") @@ -61,7 +61,7 @@ class SageMakerConfig(object): create its own SageMakerConfig object. Note: Once sagemaker.session.Session is initialized, it will operate with the configuration - values at that instant. If the users wish to alter configuration files/file paths after + values at that instant. If the users wish to alter config files/file paths after sagemaker.session.Session is initialized, then that will not be reflected in sagemaker.session.Session. They would have to re-initialize sagemaker.session.Session to pick the latest changes. @@ -71,45 +71,45 @@ class SageMakerConfig(object): def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE): """Initializes the SageMakerConfig object. - By default, it will first look for Config files in the default locations as dictated by - the SDK. + By default, this method first searches for config files in the default locations + defined by the SDK. - Users can override the default Admin Config file path and the default User Config file path - by using environment variables - SAGEMAKER_ADMIN_CONFIG_OVERRIDE and - SAGEMAKER_USER_CONFIG_OVERRIDE + Users can override the default admin and user config file paths using the + SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables, + respectively. - Additional Configuration file paths can also be provided as a constructor parameter. + Additional config file paths can also be provided as a constructor parameter. - This __init__ method will then - * Load each config file (Can be in S3/Local File system). - * It will validate the schema of the config files. - * It will perform the merge operation in the same order. + This method then: + * Loads each config file, whether it is Amazon S3 or the local file system. + * Validates the schema of the config files. + * Merges the files in the same order. - This __init__ method will throw exceptions for the following cases: - * jsonschema.exceptions.ValidationError: Schema validation fails for one/more config files. - * RuntimeError: If the method is unable to retrieve the list of all the S3 files with the - same prefix/Unable to retrieve the file. - * ValueError: If an S3 URI is provided and there are no S3 files with that prefix. - * ValueError: If a folder in S3 bucket is provided as s3_uri, and if it doesn't have - config.yaml. - * ValueError: A file doesn't exist in a path that was specified by the user as part of - environment variable/ additional_config_paths. This doesn't include the default config - file locations. + This method throws exceptions in the following cases: + * jsonschema.exceptions.ValidationError: Schema validation fails for one or more config + files. + * RuntimeError: The method is unable to retrieve the list of all S3 files with the + same prefix or is unable to retrieve the file. + * ValueError: There are no S3 files with the prefix when an S3 URI is provided. + * ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided. + * ValueError: A file doesn't exist in a path that was specified by the user as part of an + environment variable or additional configuration file path. This doesn't include the default + config file locations. Args: - additional_config_paths: List of Config file paths. - These paths can be one of the following: + additional_config_paths: List of config file paths. + These paths can be one of the following. In the case of a directory, this method + searches for a config.yaml file in that directory. This method does not perform a + recursive search of folders in that directory. * Local file path - * Local directory path (in this case, we will look for config.yaml in that - directory) + * Local directory path * S3 URI of the config file - * S3 URI of the directory containing the config file (in this case, we will look for - config.yaml in that directory) + * S3 URI of the directory containing the config file Note: S3 URI follows the format s3:/// - s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This will be used to fetch - Config files from S3. If it is not provided, we will create a default s3 resource - See :py:meth:`boto3.session.Session.resource`. This argument is not needed if the - config files are present in the local file system + s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch + config files from S3. If it is not provided, this method creates a default S3 resource + See :py:meth:boto3.session.Session.resource. This argument is not needed if the + config files are present in the local file system. """ default_config_path = os.getenv( @@ -118,30 +118,11 @@ def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAU user_config_path = os.getenv( ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH ) - self._config_paths = [default_config_path, user_config_path] + self.config_paths = [default_config_path, user_config_path] if additional_config_paths: - self._config_paths += additional_config_paths - self._config_paths = list(filter(lambda item: item is not None, self._config_paths)) - self._config = _load_config_files(self._config_paths, s3_resource) - - @property - def config_paths(self) -> List[str]: - """Getter for Config paths. - - Returns: - List[str]: This method returns the list of config file paths. - """ - return self._config_paths - - @property - def config(self) -> dict: - """Getter for the configuration object. - - Returns: - dict: A dictionary representing the configurations that were loaded from the config - file(s). - """ - return self._config + self.config_paths += additional_config_paths + self.config_paths = list(filter(lambda item: item is not None, self.config_paths)) + self.config = _load_config_files(self.config_paths, s3_resource) def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: @@ -176,18 +157,18 @@ def _load_config_from_file(file_path: str) -> dict: inferred_file_path = os.path.join(file_path, "config.yaml") if not os.path.exists(inferred_file_path): raise ValueError( - f"Unable to load config file from the location: {file_path} Please" - f" provide a valid file path" + f"Unable to load the config file from the location: {file_path}" + f"Provide a valid file path" ) - logger.debug("Fetching configuration file from the path: %s", file_path) + logger.debug("Fetching config file from the path: %s", file_path) return yaml.safe_load(open(inferred_file_path, "r")) def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: """Placeholder docstring""" if not s3_resource_for_config: - raise RuntimeError("Please provide a S3 client for loading the config") - logger.debug("Fetching configuration file from the S3 URI: %s", s3_uri) + raise RuntimeError("No S3 client found. Provide a S3 client to load the config file.") + logger.debug("Fetching config file from the S3 URI: %s", s3_uri) inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config) parsed_url = urlparse(inferred_s3_uri) bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") @@ -212,14 +193,14 @@ def _get_inferred_s3_uri(s3_uri, s3_resource_for_config): raise RuntimeError(f"Unable to read from S3 with URI: {s3_uri} due to {e}") if len(s3_files_with_same_prefix) == 0: # Customer provided us with an incorrect s3 path. - raise ValueError("Please provide a valid s3 path instead of {}".format(s3_uri)) + raise ValueError("Provide a valid S3 path instead of {}".format(s3_uri)) if len(s3_files_with_same_prefix) > 1: # Customer has provided us with a S3 URI which points to a directory # search for s3:///directory-key-prefix/config.yaml inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, "config.yaml")).replace("s3:/", "s3://") if inferred_s3_uri not in s3_files_with_same_prefix: # We don't know which file we should be operating with. - raise ValueError("Please provide a S3 URI which has config.yaml in the directory") + raise ValueError("Provide an S3 URI of a directory that has a config.yaml file.") # Customer has a config.yaml present in the directory that was provided as the S3 URI return inferred_s3_uri return s3_uri diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py index e037e6aee2..468be22ac3 100644 --- a/src/sagemaker/dataset_definition/inputs.py +++ b/src/sagemaker/dataset_definition/inputs.py @@ -71,46 +71,6 @@ def __init__( output_compression=output_compression, ) - @property - def kms_key_id(self): - """Getter for KMSKeyId. - - Returns: - str: The KMS Key ID. - - """ - return self.__dict__["kms_key_id"] - - @kms_key_id.setter - def kms_key_id(self, kms_key_id: str): - """Setter for KMSKeyId. - - Args: - kms_key_id: The KMSKeyId to be used. - - """ - self.__dict__["kms_key_id"] = kms_key_id - - @property - def cluster_role_arn(self): - """Getter for Cluster Role ARN. - - Returns: - str: The cluster Role ARN. - - """ - return self.__dict__["cluster_role_arn"] - - @cluster_role_arn.setter - def cluster_role_arn(self, cluster_role_arn: str): - """Setter for Cluster Role ARN. - - Args: - cluster_role_arn: The ClusterRoleArn to be used. - - """ - self.__dict__["cluster_role_arn"] = cluster_role_arn - class AthenaDatasetDefinition(ApiObject): """DatasetDefinition for Athena. @@ -159,26 +119,6 @@ def __init__( output_compression=output_compression, ) - @property - def kms_key_id(self): - """Getter for KMSKeyId. - - Returns: - str: The KMS Key ID. - - """ - return self.__dict__["kms_key_id"] - - @kms_key_id.setter - def kms_key_id(self, kms_key_id: str): - """Setter for KMSKeyId. - - Args: - kms_key_id: The KMSKeyId to be used. - - """ - self.__dict__["kms_key_id"] = kms_key_id - class DatasetDefinition(ApiObject): """DatasetDefinition input.""" @@ -230,26 +170,6 @@ def __init__( athena_dataset_definition=athena_dataset_definition, ) - @property - def redshift_dataset_definition(self): - """Getter for RedshiftDatasetDefinition - - Returns: - RedshiftDatasetDefinition: RedshiftDatasetDefinition object. - - """ - return self.__dict__["redshift_dataset_definition"] - - @property - def athena_dataset_definition(self): - """Getter for AthenaDatasetDefinition - - Returns: - AthenaDatasetDefinition: AthenaDatasetDefinition object. - - """ - return self.__dict__["athena_dataset_definition"] - class S3Input(ApiObject): """Metadata of data objects stored in S3. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 51fb3c99b5..922150b901 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -595,7 +595,7 @@ def __init__( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating estimators.") + raise ValueError("An AWS IAM role is required to create an estimator.") self.output_kms_key = resolve_value_from_config( output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session ) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 68e6b9dfef..67b643c475 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -576,7 +576,7 @@ def create( # Now we marked that as Optional because we can fetch it from SageMakerConfig, # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating Feature Groups.") + raise ValueError("An AWS IAM role is required to create a Feature Group.") create_feature_store_args = dict( feature_group_name=self.name, record_identifier_name=record_identifier_name, diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index bdedb5dea5..f0c678c623 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -260,28 +260,9 @@ def __init__( (default: None) """ self.destination_s3_uri = destination_s3_uri - self._kms_key_id = kms_key_id + self.kms_key_id = kms_key_id self.generate_inference_id = generate_inference_id - @property - def kms_key_id(self): - """Getter for KmsKeyId - - Returns: - str: The KMS Key ID. - """ - return self._kms_key_id - - @kms_key_id.setter - def kms_key_id(self, kms_key_id: str): - """Setter for KmsKeyId - - Args: - kms_key_id: The KMS Key ID to set. - - """ - self._kms_key_id = kms_key_id - def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" batch_data_capture_config = { diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 0357c76f64..0ca017e1f4 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1021,7 +1021,7 @@ def compile( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating compilation jobs.") + raise ValueError("An AWS IAM role is required to create a compilation job.") config = self._compilation_job_config( target_instance_family, input_shape, diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index ca1f98714e..46481e9fbe 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -188,7 +188,7 @@ def __init__( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating Monitoring Schedule.") + raise ValueError("An AWS IAM role is required to create a Monitoring Schedule.") self.volume_kms_key = resolve_value_from_config( volume_kms_key, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 5008558c63..2667166366 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -109,7 +109,7 @@ def __init__( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating Pipeline Model.") + raise ValueError("An AWS IAM role is required to create a Pipeline Model.") def pipeline_container_def(self, instance_type=None): """The pipeline definition for deploying this model. diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 6dba6ef6fa..2fae52f166 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -194,7 +194,7 @@ def __init__( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating Processing jobs.") + raise ValueError("An AWS IAM role is required to create a Processing job.") @runnable_by_pipeline def run( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2d48eac179..e5ce5bced0 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -191,7 +191,7 @@ def __init__( this instance's ``boto_session``. sagemaker_config (sagemaker.config.SageMakerConfig): The SageMakerConfig object which holds the default values for the SageMaker Python SDK. (default: None). If not - provided, This class will create its own SageMakerConfig object. + provided, one will be created. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -595,16 +595,16 @@ def _print_message_on_sagemaker_config_usage( def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): """Appends tags specified in the sagemaker_config to the given list of tags. - To minimize the chance of duplicate tags being applied, this is intended - to be used right before calls to sagemaker_client (rather - than during initialization of classes like EstimatorBase) + To minimize the chance of duplicate tags being applied, this is intended to be used + immediately before calls to sagemaker_client, rather than during initialization of + classes like EstimatorBase. Args: - tags: the list of tags to append to. - config_path_to_tags: the path to look up in the config + tags: The list of tags to append to. + config_path_to_tags: The path to look up tags in the config. Returns: - A potentially extended list of tags. + A list of tags. """ config_tags = get_sagemaker_config_value(self, config_path_to_tags) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 1dc891ea82..34685e786b 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -130,7 +130,7 @@ def __init__( self.volume_kms_key = resolve_value_from_config( volume_kms_key, TRANSFORM_RESOURCES_VOLUME_KMS_KEY_ID_PATH, - sagemaker_session=sagemaker_session, + sagemaker_session=self.sagemaker_session, ) self.output_kms_key = resolve_value_from_config( output_kms_key, @@ -408,7 +408,7 @@ def transform_with_monitoring( batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH, - sagemaker_session=sagemaker_session, + sagemaker_session=self.sagemaker_session, ) transform_step_args = transformer.transform( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d349035858..203fe24bde 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1041,19 +1041,22 @@ def resolve_value_from_config( default_value=None, sagemaker_session=None, ): - """Makes a decision of which value is the right value for the caller to use. + """Decides which value for the caller to use. - Note: This method also incorporates info from the sagemaker config. + Note: This method incorporates information from the sagemaker config. Uses this order of prioritization: - (1) direct_input, (2) config value, (3) default_value, (4) None + 1. direct_input + 2. config value + 3. default_value + 4. None Args: - direct_input: the value that the caller of this method started with. Usually this is an - input to the caller's class or method - config_path (str): a string denoting the path to use to lookup the config value in the - sagemaker config - default_value: the value to use if not present elsewhere + direct_input: The value that the caller of this method starts with. Usually this is an + input to the caller's class or method. + config_path (str): A string denoting the path used to lookup the value in the + sagemaker config. + default_value: The value used if not present elsewhere. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). @@ -1075,7 +1078,7 @@ def resolve_value_from_config( def get_sagemaker_config_value(sagemaker_session, key): - """Util method that fetches a particular key path in the SageMakerConfig and returns it. + """Returns the value that corresponds to the provided key from the configuration file. Args: key: Key Path of the config file entry. @@ -1083,7 +1086,7 @@ def get_sagemaker_config_value(sagemaker_session, key): SageMaker interactions. Returns: - object: The corresponding value in the Config file/ the default value. + object: The corresponding default value in the configuration file. """ if not sagemaker_session: return None @@ -1096,16 +1099,16 @@ def _print_message_on_sagemaker_config_usage(direct_input, config_value, config_ """Informs the SDK user whether a config value was present and automatically substituted Args: - direct_input: the value that would be used if no sagemaker_config or default values - existed. Usually this will be user-provided input to a Class or to a - session.py method, or None if no input was provided. - config_value: the value fetched from sagemaker_config. This is usually the value that - will be used if direct_input is None. - config_path: a string denoting the path of keys that point to the config value in the - sagemaker_config + direct_input: The value that is used if no default values exist. Usually, + this is user-provided input to a Class or to a session.py method, or None if no input + was provided. + config_value: The value fetched from sagemaker_config. This is usually the value that + will be used if direct_input is None. + config_path: A string denoting the path of keys that point to the config value in the + sagemaker_config. Returns: - No output (just prints information) + None. Prints information. """ if config_value is not None: @@ -1146,7 +1149,10 @@ def resolve_class_attribute_from_config( value fetched from the sagemaker_config or the default_value. Uses this order of prioritization to determine what the value of the attribute should be: - (1) current value of attribute, (2) config value, (3) default_value, (4) does not set it + 1. current value of attribute + 2. config value + 3. default_value + 4. does not set it Args: clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the @@ -1216,11 +1222,11 @@ def resolve_nested_dict_value_from_config( (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it Args: - dictionary: dict to update - nested_keys: path of keys at which the value should be checked (and set if needed) - config_path (str): a string denoting the path to use to lookup the config value in the - sagemaker config - default_value: the value to use if not present elsewhere + dictionary: The dict to update. + nested_keys: The paths of keys where the value should be checked and set if needed. + config_path (str): A string denoting the path used to find the config value in the + sagemaker config. + default_value: The value to use if not present elsewhere. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). @@ -1260,7 +1266,7 @@ def update_list_of_dicts_with_values_from_config( union_key_paths: List[List[str]] = None, sagemaker_session=None, ): - """Helper method for updating Lists with corresponding values in the Config + """Updates a list of dictionaries with missing values that are present in Config. In some cases, config file might introduce new parameters which requires certain other parameters to be provided as part of the input list. Without those parameters, the underlying @@ -1279,8 +1285,8 @@ def update_list_of_dicts_with_values_from_config( item. union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify whether exactly zero/one of the parameters exist. - For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter but - not both, then pass [['X1', 'X2']] + For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or + neither but not both, then pass [['X1', 'X2']] sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). @@ -1290,25 +1296,23 @@ def update_list_of_dicts_with_values_from_config( if not input_list: return inputs_copy = copy.deepcopy(input_list) - inputs_from_config = resolve_value_from_config( - config_path=config_key_path, default_value=[], sagemaker_session=sagemaker_session - ) + inputs_from_config = get_sagemaker_config_value(sagemaker_session, config_key_path) or [] for i in range(min(len(input_list), len(inputs_from_config))): dict_from_inputs = input_list[i] dict_from_config = inputs_from_config[i] merge_dicts(dict_from_config, dict_from_inputs) # Check if required key paths are present in merged dict (dict_from_config) - required_key_path_check_failed = _validate_required_paths_in_a_dict( + required_key_path_check_passed = _validate_required_paths_in_a_dict( dict_from_config, required_key_paths ) - if required_key_path_check_failed: + if not required_key_path_check_passed: # Don't do the merge, config is introducing a new parameter which needs a # corresponding required parameter. continue - union_key_path_check_failed = _validate_union_key_paths_in_a_dict( + union_key_path_check_passed = _validate_union_key_paths_in_a_dict( dict_from_config, union_key_paths ) - if union_key_path_check_failed: + if not union_key_path_check_passed: # Don't do the merge, Union parameters are not obeyed. continue input_list[i] = dict_from_config @@ -1325,11 +1329,11 @@ def update_list_of_dicts_with_values_from_config( def _validate_required_paths_in_a_dict(source_dict, required_key_paths: List[str] = None) -> bool: """Placeholder docstring""" if not required_key_paths: - return False + return True for required_key_path in required_key_paths: if get_config_value(required_key_path, source_dict) is None: - return True - return False + return False + return True def _validate_union_key_paths_in_a_dict( @@ -1337,21 +1341,21 @@ def _validate_union_key_paths_in_a_dict( ) -> bool: """Placeholder docstring""" if not union_key_paths: - return False + return True for union_key_path in union_key_paths: union_parameter_present = False for key_path in union_key_path: if get_config_value(key_path, source_dict): if union_parameter_present: - return True + return False union_parameter_present = True - return False + return True def update_nested_dictionary_with_values_from_config( source_dict, config_key_path, sagemaker_session=None ) -> dict: - """Updates a given nested dictionary with missing values which are present in Config. + """Updates a nested dictionary with missing values that are present in Config. Args: source_dict: The input nested dictionary that was provided as method parameter. @@ -1361,7 +1365,7 @@ def update_nested_dictionary_with_values_from_config( SageMaker interactions (default: None). Returns: - dict: The merged nested dictionary which includes missings values that are present + dict: The merged nested dictionary that is updated with missing values that are present in the Config file. """ inferred_config_dict = get_sagemaker_config_value(sagemaker_session, config_key_path) or {} diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 6e3449e1a4..73c115f84a 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -136,7 +136,7 @@ def create( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating Pipeline.") + raise ValueError("An AWS IAM role is required to create a Pipeline.") if self.sagemaker_session.local_mode: if parallelism_config: logger.warning("Pipeline parallelism config is not supported in the local mode.") @@ -230,7 +230,7 @@ def update( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for updating Pipeline.") + raise ValueError("An AWS IAM role is required to update a Pipeline.") if self.sagemaker_session.local_mode: if parallelism_config: logger.warning("Pipeline parallelism config is not supported in the local mode.") @@ -270,8 +270,7 @@ def upsert( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("IAM role should be provided for creating/updating Pipeline.") - exists = True + raise ValueError("An AWS IAM role is required to create or update a Pipeline.") try: response = self.create(role_arn, description, tags, parallelism_config) except ClientError as ce: diff --git a/tests/data/config/config.yaml b/tests/data/config/config.yaml index 8e0e40019a..46d55cf8f0 100644 --- a/tests/data/config/config.yaml +++ b/tests/data/config/config.yaml @@ -3,19 +3,19 @@ SageMaker: FeatureGroup: OnlineStoreConfig: SecurityConfig: - KmsKeyId: 'someotherkmskeyid' + KmsKeyId: 'kmskeyid1' OfflineStoreConfig: S3StorageConfig: - KmsKeyId: 'somekmskeyid' - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + KmsKeyId: 'kmskeyid2' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' MonitoringSchedule: MonitoringScheduleConfig: MonitoringJobDefinition: MonitoringOutputConfig: - KmsKeyId: 'somekmskey' + KmsKeyId: 'kmskeyid1' MonitoringResources: ClusterConfig: - VolumeKmsKeyId: 'somevolumekmskey' + VolumeKmsKeyId: 'volumekmskeyid1' NetworkConfig: EnableNetworkIsolation: true VpcConfig: @@ -23,50 +23,50 @@ SageMaker: - 'sg123' Subnets: - 'subnet-1234' - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' EndpointConfig: AsyncInferenceConfig: OutputConfig: - KmsKeyId: 'somekmskey' + KmsKeyId: 'kmskeyid1' DataCaptureConfig: - KmsKeyId: 'somekmskey2' - KmsKeyId: 'somekmskey3' + KmsKeyId: 'kmskeyid2' + KmsKeyId: 'kmskeyid3' ProductionVariants: - CoreDumpConfig: - KmsKeyId: 'somekmskey4' + KmsKeyId: 'kmskeyid4' AutoML: AutoMLJobConfig: SecurityConfig: - VolumeKmsKeyId: 'somevolumekmskey' + VolumeKmsKeyId: 'volumekmskeyid1' VpcConfig: SecurityGroupIds: - 'sg123' Subnets: - 'subnet-1234' OutputDataConfig: - KmsKeyId: 'somekmskey' - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + KmsKeyId: 'kmskeyid1' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' TransformJob: DataCaptureConfig: - KmsKeyId: 'somekmskey' + KmsKeyId: 'kmskeyid1' TransformOutput: - KmsKeyId: 'somekmskey2' + KmsKeyId: 'kmskeyid2' TransformResources: - VolumeKmsKeyId: 'somevolumekmskey' + VolumeKmsKeyId: 'volumekmskeyid1' CompilationJob: OutputConfig: - KmsKeyId: 'somekmskey' - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + KmsKeyId: 'kmskeyid1' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' VpcConfig: SecurityGroupIds: - 'sg123' Subnets: - 'subnet-1234' Pipeline: - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' Model: EnableNetworkIsolation: true - ExecutionRoleArn: 'arn:aws:iam::366666666666:role/IMRole' + ExecutionRoleArn: 'arn:aws:iam::555555555555:role/IMRole' VpcConfig: SecurityGroupIds: - 'sg123' @@ -77,10 +77,10 @@ SageMaker: ValidationProfiles: - TransformJobDefinition: TransformOutput: - KmsKeyId: 'somerandomkmskeyid' + KmsKeyId: 'kmskeyid1' TransformResources: - VolumeKmsKeyId: 'somerandomkmskeyid' - ValidationRole: 'arn:aws:iam::366666666666:role/IMRole' + VolumeKmsKeyId: 'volumekmskeyid1' + ValidationRole: 'arn:aws:iam::555555555555:role/IMRole' ProcessingJob: NetworkConfig: EnableNetworkIsolation: true @@ -92,23 +92,23 @@ SageMaker: ProcessingInputs: - DatasetDefinition: AthenaDatasetDefinition: - KmsKeyId: 'somekmskeyid' + KmsKeyId: 'kmskeyid1' RedshiftDatasetDefinition: - KmsKeyId: 'someotherkmskeyid' - ClusterRoleArn: 'arn:aws:iam::366666666666:role/IMRole' + KmsKeyId: 'kmskeyid2' + ClusterRoleArn: 'arn:aws:iam::555555555555:role/IMRole' ProcessingOutputConfig: - KmsKeyId: 'somerandomkmskeyid' + KmsKeyId: 'kmskeyid3' ProcessingResources: ClusterConfig: - VolumeKmsKeyId: 'somerandomkmskeyid' - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + VolumeKmsKeyId: 'volumekmskeyid1' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' TrainingJob: EnableNetworkIsolation: true OutputDataConfig: - KmsKeyId: 'somekmskey' + KmsKeyId: 'kmskeyid1' ResourceConfig: - VolumeKmsKeyId: 'somevolumekmskey' - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + VolumeKmsKeyId: 'volumekmskeyid1' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' VpcConfig: SecurityGroupIds: - 'sg123' @@ -116,5 +116,5 @@ SageMaker: - 'subnet-1234' EdgePackagingJob: OutputConfig: - KmsKeyId: 'somekeyid' - RoleArn: 'arn:aws:iam::366666666666:role/IMRole' + KmsKeyId: 'kmskeyid1' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' diff --git a/tests/data/config/expected_output_config_after_merge.yaml b/tests/data/config/expected_output_config_after_merge.yaml index b3556b4bcf..9cc5c88632 100644 --- a/tests/data/config/expected_output_config_after_merge.yaml +++ b/tests/data/config/expected_output_config_after_merge.yaml @@ -5,10 +5,10 @@ SageMaker: SecurityConfig: # Present in the additional override as well as default config. # Pick the additional config value. - KmsKeyId: 'additionalConfigKmsKeyId' + KmsKeyId: 'kmskeyid3' OfflineStoreConfig: S3StorageConfig: # Present only in the default config - KmsKeyId: 'somekmskeyid' + KmsKeyId: 'kmskeyid2' # Present only in the additional config - RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' + RoleArn: 'arn:aws:iam::555555555555:role/additionalConfigRole' diff --git a/tests/data/config/sample_additional_config_for_merge.yaml b/tests/data/config/sample_additional_config_for_merge.yaml index 06ed08aca0..40e6e6347d 100644 --- a/tests/data/config/sample_additional_config_for_merge.yaml +++ b/tests/data/config/sample_additional_config_for_merge.yaml @@ -3,5 +3,5 @@ SageMaker: FeatureGroup: OnlineStoreConfig: SecurityConfig: - KmsKeyId: 'additionalConfigKmsKeyId' - RoleArn: 'arn:aws:iam::377777777777:role/additionalConfigRole' + KmsKeyId: 'kmskeyid3' + RoleArn: 'arn:aws:iam::555555555555:role/additionalConfigRole' diff --git a/tests/data/config/sample_config_for_merge.yaml b/tests/data/config/sample_config_for_merge.yaml index 49d4a5b3ee..c8bebb463a 100644 --- a/tests/data/config/sample_config_for_merge.yaml +++ b/tests/data/config/sample_config_for_merge.yaml @@ -3,7 +3,7 @@ SageMaker: FeatureGroup: OnlineStoreConfig: SecurityConfig: - KmsKeyId: 'someotherkmskeyid' + KmsKeyId: 'kmskeyid1' OfflineStoreConfig: S3StorageConfig: - KmsKeyId: 'somekmskeyid' + KmsKeyId: 'kmskeyid2' diff --git a/tests/integ/test_sagemaker_config.py b/tests/integ/test_sagemaker_config.py index 6148199c70..da9bf685ab 100644 --- a/tests/integ/test_sagemaker_config.py +++ b/tests/integ/test_sagemaker_config.py @@ -100,12 +100,6 @@ def sagemaker_session_with_dynamically_generated_sagemaker_config( "AsyncInferenceConfig": {"OutputConfig": {"KmsKeyId": kms_key_arn}}, "DataCaptureConfig": {"KmsKeyId": kms_key_arn}, "KmsKeyId": kms_key_arn, - # TODO: re-enable after ProductionVariants injection is complete - # "ProductionVariants": [{ - # "CoreDumpConfig": { - # "KmsKeyId": kms_key_arn - # } - # }], "Tags": ENDPOINT_CONFIG_TAGS, }, "Model": { diff --git a/tests/unit/sagemaker/config/conftest.py b/tests/unit/sagemaker/config/conftest.py index fc517016c2..70a497e708 100644 --- a/tests/unit/sagemaker/config/conftest.py +++ b/tests/unit/sagemaker/config/conftest.py @@ -29,13 +29,13 @@ def valid_vpc_config(): @pytest.fixture() def valid_iam_role_arn(): - return "arn:aws:iam::366666666666:role/IMRole" + return "arn:aws:iam::555555555555:role/IMRole" @pytest.fixture() def valid_feature_group_config(valid_iam_role_arn): - s3_storage_config = {"KmsKeyId": "somekmskeyid"} - security_storage_config = {"KmsKeyId": "someotherkmskeyid"} + security_storage_config = {"KmsKeyId": "kmskeyid1"} + s3_storage_config = {"KmsKeyId": "kmskeyid2"} online_store_config = {"SecurityConfig": security_storage_config} offline_store_config = {"S3StorageConfig": s3_storage_config} return { @@ -48,7 +48,7 @@ def valid_feature_group_config(valid_iam_role_arn): @pytest.fixture() def valid_edge_packaging_config(valid_iam_role_arn): return { - "OutputConfig": {"KmsKeyId": "somekeyid"}, + "OutputConfig": {"KmsKeyId": "kmskeyid1"}, "RoleArn": valid_iam_role_arn, } @@ -65,8 +65,8 @@ def valid_model_config(valid_iam_role_arn, valid_vpc_config): @pytest.fixture() def valid_model_package_config(valid_iam_role_arn): transform_job_definition = { - "TransformOutput": {"KmsKeyId": "somerandomkmskeyid"}, - "TransformResources": {"VolumeKmsKeyId": "somerandomkmskeyid"}, + "TransformOutput": {"KmsKeyId": "kmskeyid1"}, + "TransformResources": {"VolumeKmsKeyId": "volumekmskeyid1"}, } validation_specification = { "ValidationProfiles": [{"TransformJobDefinition": transform_job_definition}], @@ -79,17 +79,17 @@ def valid_model_package_config(valid_iam_role_arn): def valid_processing_job_config(valid_iam_role_arn, valid_vpc_config): network_config = {"EnableNetworkIsolation": True, "VpcConfig": valid_vpc_config} dataset_definition = { - "AthenaDatasetDefinition": {"KmsKeyId": "somekmskeyid"}, + "AthenaDatasetDefinition": {"KmsKeyId": "kmskeyid1"}, "RedshiftDatasetDefinition": { - "KmsKeyId": "someotherkmskeyid", + "KmsKeyId": "kmskeyid2", "ClusterRoleArn": valid_iam_role_arn, }, } return { "NetworkConfig": network_config, "ProcessingInputs": [{"DatasetDefinition": dataset_definition}], - "ProcessingOutputConfig": {"KmsKeyId": "somerandomkmskeyid"}, - "ProcessingResources": {"ClusterConfig": {"VolumeKmsKeyId": "somerandomkmskeyid"}}, + "ProcessingOutputConfig": {"KmsKeyId": "kmskeyid3"}, + "ProcessingResources": {"ClusterConfig": {"VolumeKmsKeyId": "volumekmskeyid1"}}, "RoleArn": valid_iam_role_arn, } @@ -98,8 +98,8 @@ def valid_processing_job_config(valid_iam_role_arn, valid_vpc_config): def valid_training_job_config(valid_iam_role_arn, valid_vpc_config): return { "EnableNetworkIsolation": True, - "OutputDataConfig": {"KmsKeyId": "somekmskey"}, - "ResourceConfig": {"VolumeKmsKeyId": "somevolumekmskey"}, + "OutputDataConfig": {"KmsKeyId": "kmskeyid1"}, + "ResourceConfig": {"VolumeKmsKeyId": "volumekmskeyid1"}, "RoleArn": valid_iam_role_arn, "VpcConfig": valid_vpc_config, } @@ -113,7 +113,7 @@ def valid_pipeline_config(valid_iam_role_arn): @pytest.fixture() def valid_compilation_job_config(valid_iam_role_arn, valid_vpc_config): return { - "OutputConfig": {"KmsKeyId": "somekmskey"}, + "OutputConfig": {"KmsKeyId": "kmskeyid1"}, "RoleArn": valid_iam_role_arn, "VpcConfig": valid_vpc_config, } @@ -122,9 +122,9 @@ def valid_compilation_job_config(valid_iam_role_arn, valid_vpc_config): @pytest.fixture() def valid_transform_job_config(): return { - "DataCaptureConfig": {"KmsKeyId": "somekmskey"}, - "TransformOutput": {"KmsKeyId": "somekmskey2"}, - "TransformResources": {"VolumeKmsKeyId": "somevolumekmskey"}, + "DataCaptureConfig": {"KmsKeyId": "kmskeyid1"}, + "TransformOutput": {"KmsKeyId": "kmskeyid2"}, + "TransformResources": {"VolumeKmsKeyId": "volumekmskeyid1"}, } @@ -132,9 +132,9 @@ def valid_transform_job_config(): def valid_automl_config(valid_iam_role_arn, valid_vpc_config): return { "AutoMLJobConfig": { - "SecurityConfig": {"VolumeKmsKeyId": "somevolumekmskey", "VpcConfig": valid_vpc_config} + "SecurityConfig": {"VolumeKmsKeyId": "volumekmskeyid1", "VpcConfig": valid_vpc_config} }, - "OutputDataConfig": {"KmsKeyId": "somekmskey"}, + "OutputDataConfig": {"KmsKeyId": "kmskeyid1"}, "RoleArn": valid_iam_role_arn, } @@ -142,10 +142,10 @@ def valid_automl_config(valid_iam_role_arn, valid_vpc_config): @pytest.fixture() def valid_endpointconfig_config(): return { - "AsyncInferenceConfig": {"OutputConfig": {"KmsKeyId": "somekmskey"}}, - "DataCaptureConfig": {"KmsKeyId": "somekmskey2"}, - "KmsKeyId": "somekmskey3", - "ProductionVariants": [{"CoreDumpConfig": {"KmsKeyId": "somekmskey4"}}], + "AsyncInferenceConfig": {"OutputConfig": {"KmsKeyId": "kmskeyid1"}}, + "DataCaptureConfig": {"KmsKeyId": "kmskeyid2"}, + "KmsKeyId": "kmskeyid3", + "ProductionVariants": [{"CoreDumpConfig": {"KmsKeyId": "kmskeyid4"}}], } @@ -155,8 +155,8 @@ def valid_monitoring_schedule_config(valid_iam_role_arn, valid_vpc_config): return { "MonitoringScheduleConfig": { "MonitoringJobDefinition": { - "MonitoringOutputConfig": {"KmsKeyId": "somekmskey"}, - "MonitoringResources": {"ClusterConfig": {"VolumeKmsKeyId": "somevolumekmskey"}}, + "MonitoringOutputConfig": {"KmsKeyId": "kmskeyid1"}, + "MonitoringResources": {"ClusterConfig": {"VolumeKmsKeyId": "volumekmskeyid1"}}, "NetworkConfig": network_config, "RoleArn": valid_iam_role_arn, } diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index 65d031921c..66fc967540 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -57,6 +57,9 @@ def sagemaker_session(): endpoint_from_production_variants=Mock(name="endpoint_from_production_variants"), ) session.default_bucket = Mock(name="default_bucket", return_valie=BUCKET) + # For tests which doesn't verify config file injection, operate with empty config + session.sagemaker_config = Mock() + session.sagemaker_config.config = {} return session From 1ec0563dba251bc0f1bdebe9ae7ea0d47a60d085 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 27 Mar 2023 10:15:09 -0700 Subject: [PATCH 29/40] fix: bubble up exceptions from S3 while fetching the Config --- src/sagemaker/config/config.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 7da7b8f274..6b87f2e678 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -181,16 +181,11 @@ def _get_inferred_s3_uri(s3_uri, s3_resource_for_config): """Placeholder docstring""" parsed_url = urlparse(s3_uri) bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") - try: - s3_bucket = s3_resource_for_config.Bucket(name=bucket) - s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() - s3_files_with_same_prefix = [ - "{}{}/{}".format(S3_PREFIX, bucket, s3_object.key) for s3_object in s3_objects - ] - except Exception as e: # pylint: disable=W0703 - # if customers didn't provide us with a valid S3 File/insufficient read permission, - # We will fail hard. - raise RuntimeError(f"Unable to read from S3 with URI: {s3_uri} due to {e}") + s3_bucket = s3_resource_for_config.Bucket(name=bucket) + s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() + s3_files_with_same_prefix = [ + "{}{}/{}".format(S3_PREFIX, bucket, s3_object.key) for s3_object in s3_objects + ] if len(s3_files_with_same_prefix) == 0: # Customer provided us with an incorrect s3 path. raise ValueError("Provide a valid S3 path instead of {}".format(s3_uri)) From 263594a94b5c4bc0fb259863b3635f067d1622a9 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 27 Mar 2023 11:10:49 -0700 Subject: [PATCH 30/40] fix: Added additional test cases for config helper methods. Also made minor documentation updates. --- src/sagemaker/config/config.py | 22 +++++------ tests/unit/test_utils.py | 67 ++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 6b87f2e678..63f4d32205 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -52,20 +52,20 @@ class SageMakerConfig(object): - """SageMakerConfig class encapsulates the Config for SageMaker Python SDK. + """A class that encapsulates the configuration for the SageMaker Python SDK. - Usages: - This class will be integrated with sagemaker.session.Session. Users of SageMaker Python SDK - will have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If - SageMakerConfig object is not provided by the user, then sagemaker.session.Session will - create its own SageMakerConfig object. + This class is used to define default values provided by the user. - Note: Once sagemaker.session.Session is initialized, it will operate with the configuration - values at that instant. If the users wish to alter config files/file paths after - sagemaker.session.Session is initialized, then that will not be reflected in - sagemaker.session.Session. They would have to re-initialize sagemaker.session.Session to - pick the latest changes. + This class is integrated with sagemaker.session.Session. Users of the SageMaker Python SDK + have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If a + SageMakerConfig object is not provided by the user, then sagemaker.session.Session + creates its own SageMakerConfig object. + Note: After sagemaker.session.Session is initialized, it operates with the configuration + values defined at that instant. If you modify the configuration files or file paths after + sagemaker.session.Session is initialized, those changes are not reflected in + sagemaker.session.Session. To incorporate the changes in the configuration files, + initialize sagemaker.session.Session again. """ def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 80fb4af05c..888adc3799 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -226,6 +226,73 @@ def test_update_list_of_dicts_with_values_from_config(): {"a": 1, "b": 2, "c": 3}, {"a": 5, "c": 6, "d": 8}, ] + # Same happy case with different order of items in union_key_paths + input_list = [{"a": 1, "b": 2}, {"a": 5, "c": 6}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + }, + { + "a": 7, # This should not be used. Use values from Input. + "d": 8, + }, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config( + input_list, config_path, union_key_paths=[["d", "e"], ["c", "e"]], sagemaker_session=ss + ) + assert input_list == [ + {"a": 1, "b": 2, "c": 3}, + {"a": 5, "c": 6, "d": 8}, + ] + # Testing the combination of union parameter and required parameter. i.e. A parameter is both + # required and part of Union. + input_list = [{"a": 1, "b": 2}, {"a": 5, "c": 6}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "c": 3, + }, + { + "a": 7, # This should not be used. Use values from Input. + "d": 8, + }, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config( + input_list, + config_path, + required_key_paths=["e"], + union_key_paths=[["d", "e"], ["c", "e"]], + sagemaker_session=ss, + ) + # No merge should happen since 'e' is not present, even though union is obeyed. + assert input_list == [{"a": 1, "b": 2}, {"a": 5, "c": 6}] + # Same test but the required parameter is present. + input_list = [{"a": 1, "e": 2}, {"a": 5, "e": 6}] + input_config_list = [ + { + "a": 4, # This should not be used. Use values from Input. + "f": 3, + }, + { + "a": 7, # This should not be used. Use values from Input. + "g": 8, + }, + ] + ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + update_list_of_dicts_with_values_from_config( + input_list, + config_path, + required_key_paths=["e"], + union_key_paths=[["d", "e"], ["c", "e"]], + sagemaker_session=ss, + ) + assert input_list == [ + {"a": 1, "e": 2, "f": 3}, + {"a": 5, "e": 6, "g": 8}, + ] def test_set_nested_value(): From cd2181b7b9382deb64caad9b1947d76ec931da8a Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Mon, 27 Mar 2023 16:13:35 -0700 Subject: [PATCH 31/40] fix: small bug fix to print statements for update_list_of_dicts_with_values_from_config --- src/sagemaker/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 203fe24bde..5ce2581122 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1297,6 +1297,8 @@ def update_list_of_dicts_with_values_from_config( return inputs_copy = copy.deepcopy(input_list) inputs_from_config = get_sagemaker_config_value(sagemaker_session, config_key_path) or [] + unmodified_inputs_from_config = copy.deepcopy(inputs_from_config) + for i in range(min(len(input_list), len(inputs_from_config))): dict_from_inputs = input_list[i] dict_from_config = inputs_from_config[i] @@ -1316,11 +1318,12 @@ def update_list_of_dicts_with_values_from_config( # Don't do the merge, Union parameters are not obeyed. continue input_list[i] = dict_from_config - if inputs_from_config: + + if unmodified_inputs_from_config: print( "[Sagemaker Config - applied value]\n", "config key = {}\n".format(config_key_path), - "config value = {}\n".format(inputs_from_config), + "config value = {}\n".format(unmodified_inputs_from_config), "source value = {}\n".format(inputs_copy), "combined value that will be used = {}\n".format(input_list), ) From 6086451eb2f2bb8d16893cc19a8e942f526fd1f5 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 27 Mar 2023 16:19:07 -0700 Subject: [PATCH 32/40] fix: Replace SageMakerConfig class with just method invocations --- src/sagemaker/config/__init__.py | 3 +- src/sagemaker/config/config.py | 147 ++++++++--------- src/sagemaker/local/local_session.py | 37 +++-- src/sagemaker/processing.py | 12 +- src/sagemaker/session.py | 23 ++- src/sagemaker/utils.py | 5 +- src/sagemaker/workflow/pipeline_context.py | 15 +- tests/conftest.py | 1 + tests/integ/test_local_mode.py | 31 +--- tests/integ/test_sagemaker_config.py | 10 +- tests/unit/__init__.py | 30 ++-- tests/unit/conftest.py | 52 ------ tests/unit/sagemaker/automl/test_auto_ml.py | 24 ++- tests/unit/sagemaker/config/test_config.py | 56 +++---- .../feature_store/test_feature_group.py | 12 +- .../sagemaker/huggingface/test_estimator.py | 2 +- .../sagemaker/huggingface/test_processing.py | 2 +- .../image_uris/jumpstart/conftest.py | 2 +- .../image_uris/jumpstart/test_catboost.py | 2 +- .../test_inference_recommender_mixin.py | 1 + .../sagemaker/local/test_local_pipeline.py | 8 +- tests/unit/sagemaker/model/test_deploy.py | 6 +- tests/unit/sagemaker/model/test_edge.py | 2 +- .../sagemaker/model/test_framework_model.py | 2 +- tests/unit/sagemaker/model/test_model.py | 2 +- .../sagemaker/model/test_model_package.py | 2 +- tests/unit/sagemaker/model/test_neo.py | 2 +- .../monitor/test_clarify_model_monitor.py | 2 +- .../monitor/test_data_capture_config.py | 2 +- .../monitor/test_model_monitoring.py | 12 +- .../sagemaker/tensorflow/test_estimator.py | 2 +- .../tensorflow/test_estimator_attach.py | 2 +- .../tensorflow/test_estimator_init.py | 2 +- tests/unit/sagemaker/tensorflow/test_tfs.py | 2 +- .../test_huggingface_pytorch_compiler.py | 2 +- .../test_huggingface_tensorflow_compiler.py | 2 +- .../test_pytorch_compiler.py | 2 +- .../test_tensorflow_compiler.py | 2 +- tests/unit/sagemaker/workflow/conftest.py | 4 +- tests/unit/sagemaker/workflow/test_airflow.py | 2 +- .../unit/sagemaker/workflow/test_pipeline.py | 4 +- .../sagemaker/wrangler/test_processing.py | 2 +- tests/unit/test_algorithm.py | 2 +- tests/unit/test_amazon_estimator.py | 2 +- tests/unit/test_chainer.py | 2 +- tests/unit/test_djl_inference.py | 4 +- tests/unit/test_estimator.py | 34 ++-- tests/unit/test_fm.py | 2 +- tests/unit/test_ipinsights.py | 2 +- tests/unit/test_job.py | 2 +- tests/unit/test_kmeans.py | 2 +- tests/unit/test_knn.py | 2 +- tests/unit/test_lda.py | 2 +- tests/unit/test_linear_learner.py | 2 +- tests/unit/test_multidatamodel.py | 2 +- tests/unit/test_mxnet.py | 2 +- tests/unit/test_ntm.py | 2 +- tests/unit/test_object2vec.py | 2 +- tests/unit/test_pca.py | 2 +- tests/unit/test_pipeline_model.py | 20 +-- tests/unit/test_processing.py | 22 +-- tests/unit/test_pytorch.py | 2 +- tests/unit/test_randomcutforest.py | 2 +- tests/unit/test_rl.py | 2 +- tests/unit/test_session.py | 150 +++++++++--------- tests/unit/test_sklearn.py | 2 +- tests/unit/test_sparkml_serving.py | 2 +- tests/unit/test_timeout.py | 2 +- tests/unit/test_transformer.py | 8 +- tests/unit/test_tuner.py | 2 +- tests/unit/test_utils.py | 84 +++++----- tests/unit/test_xgboost.py | 2 +- tests/unit/tuner_test_utils.py | 2 +- 73 files changed, 414 insertions(+), 491 deletions(-) diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index c4e92545b6..002768911e 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -13,7 +13,7 @@ """This module configures the default values for SageMaker Python SDK.""" from __future__ import absolute_import -from sagemaker.config.config import SageMakerConfig # noqa: F401 +from sagemaker.config.config import fetch_sagemaker_config, validate_sagemaker_config # noqa: F401 from sagemaker.config.config_schema import ( # noqa: F401 KEY, TRAINING_JOB, @@ -130,4 +130,5 @@ RESOURCE_CONFIG, EXECUTION_ROLE_ARN, ASYNC_INFERENCE_CONFIG, + SCHEMA_VERSION, ) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 63f4d32205..5d6bac24ad 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -24,7 +24,7 @@ from typing import List import boto3 import yaml -from jsonschema import validate +import jsonschema from platformdirs import site_config_dir, user_config_dir from botocore.utils import merge_dicts from six.moves.urllib.parse import urlparse @@ -51,87 +51,65 @@ S3_PREFIX = "s3://" -class SageMakerConfig(object): - """A class that encapsulates the configuration for the SageMaker Python SDK. +def fetch_sagemaker_config( + additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE +) -> dict: + """Helper method that loads config files and merges them. + + By default, this method first searches for config files in the default locations + defined by the SDK. + + Users can override the default admin and user config file paths using the + SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables, + respectively. + + Additional config file paths can also be provided as a parameter. + + This method then: + * Loads each config file, whether it is Amazon S3 or the local file system. + * Validates the schema of the config files. + * Merges the files in the same order. + + This method throws exceptions in the following cases: + * jsonschema.exceptions.ValidationError: Schema validation fails for one or more config + files. + * RuntimeError: The method is unable to retrieve the list of all S3 files with the + same prefix or is unable to retrieve the file. + * ValueError: There are no S3 files with the prefix when an S3 URI is provided. + * ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided. + * ValueError: A file doesn't exist in a path that was specified by the user as part of an + environment variable or additional configuration file path. This doesn't include the default + config file locations. + + Args: + additional_config_paths: List of config file paths. + These paths can be one of the following. In the case of a directory, this method + searches for a config.yaml file in that directory. This method does not perform a + recursive search of folders in that directory. + * Local file path + * Local directory path + * S3 URI of the config file + * S3 URI of the directory containing the config file + Note: S3 URI follows the format s3:/// + s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch + config files from S3. If it is not provided, this method creates a default S3 resource + See :py:meth:boto3.session.Session.resource. This argument is not needed if the + config files are present in the local file system. - This class is used to define default values provided by the user. - - This class is integrated with sagemaker.session.Session. Users of the SageMaker Python SDK - have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If a - SageMakerConfig object is not provided by the user, then sagemaker.session.Session - creates its own SageMakerConfig object. - - Note: After sagemaker.session.Session is initialized, it operates with the configuration - values defined at that instant. If you modify the configuration files or file paths after - sagemaker.session.Session is initialized, those changes are not reflected in - sagemaker.session.Session. To incorporate the changes in the configuration files, - initialize sagemaker.session.Session again. """ - - def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE): - """Initializes the SageMakerConfig object. - - By default, this method first searches for config files in the default locations - defined by the SDK. - - Users can override the default admin and user config file paths using the - SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables, - respectively. - - Additional config file paths can also be provided as a constructor parameter. - - This method then: - * Loads each config file, whether it is Amazon S3 or the local file system. - * Validates the schema of the config files. - * Merges the files in the same order. - - This method throws exceptions in the following cases: - * jsonschema.exceptions.ValidationError: Schema validation fails for one or more config - files. - * RuntimeError: The method is unable to retrieve the list of all S3 files with the - same prefix or is unable to retrieve the file. - * ValueError: There are no S3 files with the prefix when an S3 URI is provided. - * ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided. - * ValueError: A file doesn't exist in a path that was specified by the user as part of an - environment variable or additional configuration file path. This doesn't include the default - config file locations. - - Args: - additional_config_paths: List of config file paths. - These paths can be one of the following. In the case of a directory, this method - searches for a config.yaml file in that directory. This method does not perform a - recursive search of folders in that directory. - * Local file path - * Local directory path - * S3 URI of the config file - * S3 URI of the directory containing the config file - Note: S3 URI follows the format s3:/// - s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch - config files from S3. If it is not provided, this method creates a default S3 resource - See :py:meth:boto3.session.Session.resource. This argument is not needed if the - config files are present in the local file system. - - """ - default_config_path = os.getenv( - ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH - ) - user_config_path = os.getenv( - ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH - ) - self.config_paths = [default_config_path, user_config_path] - if additional_config_paths: - self.config_paths += additional_config_paths - self.config_paths = list(filter(lambda item: item is not None, self.config_paths)) - self.config = _load_config_files(self.config_paths, s3_resource) - - -def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: - """Placeholder docstring""" + default_config_path = os.getenv( + ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH + ) + user_config_path = os.getenv(ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH) + config_paths = [default_config_path, user_config_path] + if additional_config_paths: + config_paths += additional_config_paths + config_paths = list(filter(lambda item: item is not None, config_paths)) merged_config = {} - for file_path in file_paths: + for file_path in config_paths: config_from_file = {} if file_path.startswith(S3_PREFIX): - config_from_file = _load_config_from_s3(file_path, s3_resource_for_config) + config_from_file = _load_config_from_s3(file_path, s3_resource) else: try: config_from_file = _load_config_from_file(file_path) @@ -145,11 +123,24 @@ def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict: # Exceptions. raise if config_from_file: - validate(config_from_file, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + validate_sagemaker_config(config_from_file) merge_dicts(merged_config, config_from_file) return merged_config +def validate_sagemaker_config(sagemaker_config: dict = None): + """Helper method that validates whether the schema of a given dictionary. + + This method will validate whether the dictionary adheres to the schema + defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA` + + Args: + sagemaker_config: A dictionary containing default values for the + SageMaker Python SDK. (default: None). + """ + jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + def _load_config_from_file(file_path: str) -> dict: """Placeholder docstring""" inferred_file_path = file_path diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index d088dcf0c8..cb41548d04 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -21,7 +21,7 @@ import boto3 from botocore.exceptions import ClientError -from sagemaker.config import SageMakerConfig +from sagemaker.config import fetch_sagemaker_config, validate_sagemaker_config from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import get_docker_host from sagemaker.local.entities import ( @@ -605,7 +605,7 @@ def __init__( default_bucket=None, s3_endpoint_url=None, disable_local_code=False, - sagemaker_config: SageMakerConfig = None, + sagemaker_config: dict = None, ): """Create a Local SageMaker Session. @@ -618,6 +618,16 @@ def __init__( disable_local_code (bool): Set ``True`` to override the default AWS configuration chain to disable the ``local.local_code`` setting, which may not be supported for some SDK features (default: False). + sagemaker_config: A dictionary containing default values for the + SageMaker Python SDK. (default: None). The dictionary must adhere to the schema + defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`. + If sagemaker_config is not provided and configuration files exist (at the default + paths for admins and users, or paths set through the environment variables + SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE), + a new dictionary will be generated from those configuration files. Alternatively, + this dictionary can be generated by calling + :func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the + Session. """ self.s3_endpoint_url = s3_endpoint_url # We use this local variable to avoid disrupting the __init__->_initialize API of the @@ -635,12 +645,7 @@ def __init__( logger.warning("Windows Support for Local Mode is Experimental") def _initialize( - self, - boto_session, - sagemaker_client, - sagemaker_runtime_client, - sagemaker_config: SageMakerConfig = None, - **kwargs + self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs ): # pylint: disable=unused-argument """Initialize this Local SageMaker Session. @@ -669,20 +674,20 @@ def _initialize( self.sagemaker_client = LocalSagemakerClient(self) self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True + sagemaker_config = kwargs.get("sagemaker_config", None) + if sagemaker_config: + validate_sagemaker_config(sagemaker_config) if self.s3_endpoint_url is not None: self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url) self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url) - self.sagemaker_config = sagemaker_config or ( - SageMakerConfig(s3_resource=self.s3_resource) - if "sagemaker_config" not in kwargs - else kwargs.get("sagemaker_config") + self.sagemaker_config = ( + sagemaker_config if sagemaker_config else fetch_sagemaker_config( + s3_resource=self.s3_resource) ) else: - self.sagemaker_config = sagemaker_config or ( - SageMakerConfig() - if "sagemaker_config" not in kwargs - else kwargs.get("sagemaker_config") + self.sagemaker_config = ( + sagemaker_config if sagemaker_config else fetch_sagemaker_config() ) sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 2fae52f166..228f5068a1 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1246,20 +1246,10 @@ def __init__( self.s3_data_distribution_type = s3_data_distribution_type self.s3_compression_type = s3_compression_type self.s3_input = s3_input - self._dataset_definition = dataset_definition + self.dataset_definition = dataset_definition self.app_managed = app_managed self._create_s3_input() - @property - def dataset_definition(self): - """Getter for DataSetDefinition - - Returns: - DatasetDefinition: The DatasetDefinition Object. - - """ - return self._dataset_definition - def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index e5ce5bced0..f580d67eae 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -33,7 +33,7 @@ import sagemaker.logs from sagemaker import vpc_utils from sagemaker._studio import _append_project_tags -from sagemaker.config import SageMakerConfig # noqa: F401 +from sagemaker.config import fetch_sagemaker_config, validate_sagemaker_config # noqa: F401 from sagemaker.config import ( KEY, TRAINING_JOB, @@ -157,7 +157,7 @@ def __init__( default_bucket=None, settings=SessionSettings(), sagemaker_metrics_client=None, - sagemaker_config: SageMakerConfig = None, + sagemaker_config: dict = None, ): """Initialize a SageMaker ``Session``. @@ -189,9 +189,16 @@ def __init__( Client which makes SageMaker Metrics related calls to Amazon SageMaker (default: None). If not provided, one will be created using this instance's ``boto_session``. - sagemaker_config (sagemaker.config.SageMakerConfig): The SageMakerConfig object which - holds the default values for the SageMaker Python SDK. (default: None). If not - provided, one will be created. + sagemaker_config (dict): A dictionary containing default values for the + SageMaker Python SDK. (default: None). The dictionary must adhere to the schema + defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`. + If sagemaker_config is not provided and configuration files exist (at the default + paths for admins and users, or paths set through the environment variables + SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE), + a new dictionary will be generated from those configuration files. Alternatively, + this dictionary can be generated by calling + :func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the + Session. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -217,7 +224,7 @@ def _initialize( sagemaker_runtime_client, sagemaker_featurestore_runtime_client, sagemaker_metrics_client, - sagemaker_config: SageMakerConfig = None, + sagemaker_config: dict = None, ): """Initialize this SageMaker Session. @@ -260,13 +267,13 @@ def _initialize( self.local_mode = False if sagemaker_config: - self.sagemaker_config = sagemaker_config + validate_sagemaker_config(sagemaker_config) else: if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) else: s3 = self.s3_resource - self.sagemaker_config = SageMakerConfig(s3_resource=s3) + self.sagemaker_config = fetch_sagemaker_config(s3_resource=s3) @property def boto_region_name(self): diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 5ce2581122..b416a48b5a 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -37,6 +37,7 @@ from six.moves.urllib import parse from sagemaker import deprecations +from sagemaker.config import validate_sagemaker_config from sagemaker.session_settings import SessionSettings from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string @@ -1090,7 +1091,9 @@ def get_sagemaker_config_value(sagemaker_session, key): """ if not sagemaker_session: return None - config_value = get_config_value(key, sagemaker_session.sagemaker_config.config) + if sagemaker_session.sagemaker_config: + validate_sagemaker_config(sagemaker_session.sagemaker_config) + config_value = get_config_value(key, sagemaker_session.sagemaker_config) # Copy the value so any modifications to the output will not modify the source config return copy.deepcopy(config_value) diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index 3dcb8905c4..cd1d07189d 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -18,7 +18,6 @@ from functools import wraps from typing import Dict, Optional, Callable -from sagemaker.config import SageMakerConfig from sagemaker.session import Session, SessionSettings from sagemaker.local import LocalSession @@ -113,7 +112,7 @@ def __init__( sagemaker_client=None, default_bucket=None, settings=SessionSettings(), - sagemaker_config: SageMakerConfig = None, + sagemaker_config: dict = None, ): """Initialize a ``PipelineSession``. @@ -133,8 +132,16 @@ def __init__( Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional parameters to apply to the session. - sagemaker_config (sagemaker.config.SageMakerConfig): The SageMakerConfig object which - holds the default values for the SageMaker Python SDK. (default: None). + sagemaker_config: A dictionary containing default values for the + SageMaker Python SDK. (default: None). The dictionary must adhere to the schema + defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`. + If sagemaker_config is not provided and configuration files exist (at the default + paths for admins and users, or paths set through the environment variables + SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE), + a new dictionary will be generated from those configuration files. Alternatively, + this dictionary can be generated by calling + :func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the + Session. """ super().__init__( boto_session=boto_session, diff --git a/tests/conftest.py b/tests/conftest.py index ebaf3db4ee..68761046bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -166,6 +166,7 @@ def sagemaker_session( sagemaker_client=sagemaker_client, sagemaker_runtime_client=runtime_client, sagemaker_metrics_client=metrics_client, + sagemaker_config={"SchemaVersion": "1.0"}, ) diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index f1a6488ab4..0d07ed68c4 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -23,7 +23,6 @@ import stopit import tests.integ.lock as lock -from sagemaker.config import SageMakerConfig from tests.integ import DATA_DIR from mock import Mock, ANY @@ -59,14 +58,7 @@ class LocalNoS3Session(LocalSession): def __init__(self): super(LocalSession, self).__init__() - def _initialize( - self, - boto_session, - sagemaker_client, - sagemaker_runtime_client, - sagemaker_config: SageMakerConfig = None, - **kwargs - ): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} @@ -76,11 +68,7 @@ def _initialize( self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True - self.sagemaker_config = sagemaker_config or ( - SageMakerConfig() - if "sagemaker_config" not in kwargs - else kwargs.get("sagemaker_config") - ) + self.sagemaker_config = kwargs.get("sagemaker_config", None) class LocalPipelineNoS3Session(LocalPipelineSession): @@ -91,14 +79,7 @@ class LocalPipelineNoS3Session(LocalPipelineSession): def __init__(self): super(LocalPipelineSession, self).__init__() - def _initialize( - self, - boto_session, - sagemaker_client, - sagemaker_runtime_client, - sagemaker_config: SageMakerConfig = None, - **kwargs - ): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} @@ -108,11 +89,7 @@ def _initialize( self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True - self.sagemaker_config = sagemaker_config or ( - SageMakerConfig() - if "sagemaker_config" not in kwargs - else kwargs.get("sagemaker_config") - ) + self.sagemaker_config = kwargs.get("sagemaker_config", None) @pytest.fixture(scope="module") diff --git a/tests/integ/test_sagemaker_config.py b/tests/integ/test_sagemaker_config.py index da9bf685ab..b0cf92841b 100644 --- a/tests/integ/test_sagemaker_config.py +++ b/tests/integ/test_sagemaker_config.py @@ -27,7 +27,7 @@ Predictor, Session, ) -from sagemaker.config import SageMakerConfig +from sagemaker.config import fetch_sagemaker_config from sagemaker.model_monitor import DataCaptureConfig from sagemaker.s3 import S3Uploader from sagemaker.sparkml import SparkMLModel @@ -144,9 +144,7 @@ def sagemaker_session_with_dynamically_generated_sagemaker_config( sagemaker_client=sagemaker_client, sagemaker_runtime_client=runtime_client, sagemaker_metrics_client=metrics_client, - sagemaker_config=SageMakerConfig( - additional_config_paths=[dynamic_sagemaker_config_yaml_path] - ), + sagemaker_config=fetch_sagemaker_config([dynamic_sagemaker_config_yaml_path]), ) return session @@ -175,11 +173,11 @@ def test_config_download_from_s3_and_merge( ) # The thing being tested. - sagemaker_config = SageMakerConfig( + sagemaker_config = fetch_sagemaker_config( additional_config_paths=[s3_uri_config_1, config_file_2_local_path] ) - assert sagemaker_config.config == expected_merged_config + assert sagemaker_config == expected_merged_config @pytest.mark.slow_test diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index e818f4743c..0f54dd9e0d 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -69,6 +69,7 @@ EXECUTION_ROLE_ARN, MODEL, ASYNC_INFERENCE_CONFIG, + SCHEMA_VERSION, ) DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") @@ -76,6 +77,7 @@ SAGEMAKER_CONFIG_MONITORING_SCHEDULE = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { MONITORING_SCHEDULE: { MONITORING_SCHEDULE_CONFIG: { @@ -94,10 +96,11 @@ }, TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], } - } + }, } SAGEMAKER_CONFIG_COMPILATION_JOB = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { COMPILATION_JOB: { OUTPUT_CONFIG: {KMS_KEY_ID: "TestKms"}, @@ -105,10 +108,11 @@ VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], }, - } + }, } SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { EDGE_PACKAGING_JOB: { OUTPUT_CONFIG: { @@ -121,6 +125,7 @@ } SAGEMAKER_CONFIG_ENDPOINT_CONFIG = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { ENDPOINT_CONFIG: { ASYNC_INFERENCE_CONFIG: { @@ -138,10 +143,11 @@ ], TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], }, - } + }, } SAGEMAKER_CONFIG_AUTO_ML = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { AUTO_ML: { AUTO_ML_JOB_CONFIG: { @@ -155,10 +161,11 @@ ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], }, - } + }, } SAGEMAKER_CONFIG_MODEL_PACKAGE = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { MODEL_PACKAGE: { VALIDATION_SPECIFICATION: { @@ -175,10 +182,11 @@ # TODO - does SDK not support tags for this API? # TAGS: EXAMPLE_TAGS, }, - } + }, } SAGEMAKER_CONFIG_FEATURE_GROUP = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { FEATURE_GROUP: { OFFLINE_STORE_CONFIG: { @@ -194,10 +202,11 @@ ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], }, - } + }, } SAGEMAKER_CONFIG_PROCESSING_JOB = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { PROCESSING_JOB: { NETWORK_CONFIG: { @@ -227,10 +236,11 @@ ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole", TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], }, - } + }, } SAGEMAKER_CONFIG_TRAINING_JOB = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { TRAINING_JOB: { ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: True, @@ -245,6 +255,7 @@ } SAGEMAKER_CONFIG_TRANSFORM_JOB = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { TRANSFORM_JOB: { DATA_CAPTURE_CONFIG: {KMS_KEY_ID: "jobKmsKeyId"}, @@ -252,10 +263,11 @@ TRANSFORM_RESOURCES: {VOLUME_KMS_KEY_ID: "volumeKmsKeyId"}, TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], } - } + }, } SAGEMAKER_CONFIG_MODEL = { + SCHEMA_VERSION: "1.0", SAGEMAKER: { MODEL: { ENABLE_NETWORK_ISOLATION: True, @@ -263,5 +275,5 @@ VPC_CONFIG: {SUBNETS: ["subnets-123"], SECURITY_GROUP_IDS: ["sg-123"]}, TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}], }, - } + }, } diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index f12d543b9c..00c0e0354a 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -14,16 +14,11 @@ import pytest -from jsonschema.validators import validate -from mock.mock import MagicMock import sagemaker from mock import Mock, PropertyMock -from sagemaker.config import SageMakerConfig -from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA, SCHEMA_VERSION - _ROLE = "DummyRole" _REGION = "us-west-2" _DEFAULT_BUCKET = "my-bucket" @@ -72,50 +67,3 @@ def sagemaker_session(boto_session, client): sagemaker_metrics_client=client, ) return session - - -@pytest.fixture() -def sagemaker_config_session(): - """ - Returns: a sagemaker.Session to use for tests of injection of default parameters from the - sagemaker_config. - - This session has a custom SageMakerConfig that allows us to set the sagemaker_config.config - dict manually. This allows us to test in unit tests without tight coupling to the exact - sagemaker_config related helpers/utils/methods used. (And those helpers/utils/methods should - have their own separate and specific unit tests.) - - An alternative would be to mock each call to a sagemaker_config-related method, but that would - be harder to maintain/update over time, and be less readable. - """ - - class SageMakerConfigWithSetter(SageMakerConfig): - """ - Version of SageMakerConfig that allows the config to be set - """ - - def __init__(self): - self._config = {} - # no need to call super - - @property - def config(self) -> dict: - return self._config - - @config.setter - def config(self, new_config): - """Validates and sets a new config.""" - # Add schema version if not already there since that is required - if SCHEMA_VERSION not in new_config: - new_config[SCHEMA_VERSION] = "1.0" - # Validate to make sure unit tests are not accidentally testing with a wrong config - validate(new_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) - self._config = new_config - - boto_mock = MagicMock(name="boto_session", region_name="us-west-2") - session_with_custom_sagemaker_config = sagemaker.Session( - boto_session=boto_mock, - sagemaker_client=MagicMock(), - sagemaker_config=SageMakerConfigWithSetter(), - ) - return session_with_custom_sagemaker_config diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index cf5e669fa6..8ccaa609d0 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -274,7 +274,7 @@ def sagemaker_session(): sms.list_candidates = Mock(name="list_candidates", return_value={"Candidates": []}) sms.sagemaker_client.list_tags = Mock(name="list_tags", return_value=LIST_TAGS_RESULT) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms @@ -297,12 +297,12 @@ def test_auto_ml_without_role_parameter(sagemaker_session): ) -def test_framework_initialization_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_AUTO_ML +def test_framework_initialization_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_AUTO_ML auto_ml = AutoML( target_attribute_name=TARGET_ATTRIBUTE_NAME, - sagemaker_session=sagemaker_config_session, + sagemaker_session=sagemaker_session, ) expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ @@ -835,23 +835,21 @@ def test_deploy_optional_args(candidate_estimator, sagemaker_session, candidate_ def test_candidate_estimator_fit_initialization_with_sagemaker_config_injection( - sagemaker_config_session, + sagemaker_session, ): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB - sagemaker_config_session.train = Mock() - sagemaker_config_session.transform = Mock() + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB + sagemaker_session.train = Mock() + sagemaker_session.transform = Mock() desc_training_job_response = copy.deepcopy(TRAINING_JOB) del desc_training_job_response["VpcConfig"] del desc_training_job_response["OutputDataConfig"]["KmsKeyId"] - sagemaker_config_session.sagemaker_client.describe_training_job = Mock( + sagemaker_session.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=desc_training_job_response ) - candidate_estimator = CandidateEstimator( - CANDIDATE_DICT, sagemaker_session=sagemaker_config_session - ) + candidate_estimator = CandidateEstimator(CANDIDATE_DICT, sagemaker_session=sagemaker_session) candidate_estimator._check_all_job_finished = Mock( name="_check_all_job_finished", return_value=True ) @@ -865,7 +863,7 @@ def test_candidate_estimator_fit_initialization_with_sagemaker_config_injection( "TrainingJob" ]["EnableInterContainerTrafficEncryption"] - for train_call in sagemaker_config_session.train.call_args_list: + for train_call in sagemaker_session.train.call_args_list: train_args = train_call.kwargs assert train_args["vpc_config"] == expected_vpc_config assert train_args["resource_config"]["VolumeKmsKeyId"] == expected_volume_kms_key_id diff --git a/tests/unit/sagemaker/config/test_config.py b/tests/unit/sagemaker/config/test_config.py index aaf7ef3b09..00732e5320 100644 --- a/tests/unit/sagemaker/config/test_config.py +++ b/tests/unit/sagemaker/config/test_config.py @@ -17,7 +17,7 @@ import yaml from mock import Mock, MagicMock -from sagemaker.config.config import SageMakerConfig +from sagemaker.config.config import fetch_sagemaker_config from jsonschema import exceptions from yaml.constructor import ConstructorError @@ -37,14 +37,14 @@ def expected_merged_config(get_data_dir): def test_config_when_default_config_file_and_user_config_file_is_not_found(): - assert SageMakerConfig().config == {} + assert fetch_sagemaker_config() == {} def test_config_when_overriden_default_config_file_is_not_found(get_data_dir): fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = fake_config_file_path with pytest.raises(ValueError): - SageMakerConfig() + fetch_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -55,14 +55,14 @@ def test_invalid_config_file_which_has_python_code(get_data_dir): # PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using # yaml.safe_load internally with pytest.raises(ConstructorError) as exception_info: - SageMakerConfig(additional_config_paths=[invalid_config_file_path]) + fetch_sagemaker_config(additional_config_paths=[invalid_config_file_path]) assert "python/object/apply:eval" in str(exception_info.value) def test_config_when_additional_config_file_path_is_not_found(get_data_dir): fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") with pytest.raises(ValueError): - SageMakerConfig(additional_config_paths=[fake_config_file_path]) + fetch_sagemaker_config(additional_config_paths=[fake_config_file_path]) def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir): @@ -71,7 +71,7 @@ def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir ) os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = fake_additional_override_config_file_path with pytest.raises(ValueError): - SageMakerConfig() + fetch_sagemaker_config() del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] @@ -79,7 +79,7 @@ def test_default_config_file_with_invalid_schema(get_data_dir): config_file_path = os.path.join(get_data_dir, "invalid_config_file.yaml") os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = config_file_path with pytest.raises(exceptions.ValidationError): - SageMakerConfig() + fetch_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -90,7 +90,7 @@ def test_default_config_file_when_directory_is_provided_as_the_path( expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = get_data_dir - assert expected_config == SageMakerConfig().config + assert expected_config == fetch_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -100,7 +100,7 @@ def test_additional_config_paths_when_directory_is_provided( # This will try to load config.yaml file from that directory if present. expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert expected_config == SageMakerConfig(additional_config_paths=[get_data_dir]).config + assert expected_config == fetch_sagemaker_config(additional_config_paths=[get_data_dir]) def test_default_config_file_when_path_is_provided_as_environment_variable( @@ -110,7 +110,7 @@ def test_default_config_file_when_path_is_provided_as_environment_variable( # This will try to load config.yaml file from that directory if present. expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert expected_config == SageMakerConfig().config + assert expected_config == fetch_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -123,7 +123,7 @@ def test_merge_behavior_when_additional_config_file_path_is_not_found( ) os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path with pytest.raises(ValueError): - SageMakerConfig(additional_config_paths=[fake_additional_override_config_file_path]) + fetch_sagemaker_config(additional_config_paths=[fake_additional_override_config_file_path]) del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -133,12 +133,11 @@ def test_merge_behavior(get_data_dir, expected_merged_config): get_data_dir, "sample_additional_config_for_merge.yaml" ) os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path - assert ( - expected_merged_config - == SageMakerConfig(additional_config_paths=[additional_override_config_file_path]).config + assert expected_merged_config == fetch_sagemaker_config( + additional_config_paths=[additional_override_config_file_path] ) os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = additional_override_config_file_path - assert expected_merged_config == SageMakerConfig().config + assert expected_merged_config == fetch_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] @@ -161,11 +160,8 @@ def test_s3_config_file( config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert ( - expected_config - == SageMakerConfig( - additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock - ).config + assert expected_config == fetch_sagemaker_config( + additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock ) @@ -178,7 +174,9 @@ def test_config_factory_when_default_s3_config_file_is_not_found(s3_resource_moc ).all.return_value = [] config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) with pytest.raises(ValueError): - SageMakerConfig(additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock) + fetch_sagemaker_config( + additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock + ) def test_s3_config_file_when_uri_provided_corresponds_to_a_path( @@ -206,11 +204,8 @@ def test_s3_config_file_when_uri_provided_corresponds_to_a_path( config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert ( - expected_config - == SageMakerConfig( - additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock - ).config + assert expected_config == fetch_sagemaker_config( + additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock ) @@ -236,11 +231,8 @@ def test_merge_of_s3_default_config_file_and_regular_config_file( get_data_dir, "sample_additional_config_for_merge.yaml" ) os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = config_file_s3_uri - assert ( - expected_merged_config - == SageMakerConfig( - additional_config_paths=[additional_override_config_file_path], - s3_resource=s3_resource_mock, - ).config + assert expected_merged_config == fetch_sagemaker_config( + additional_config_paths=[additional_override_config_file_path], + s3_resource=s3_resource_mock, ) del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py index a53f1da680..8ffd82ee1b 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -54,7 +54,7 @@ def s3_uri(): @pytest.fixture def sagemaker_session_mock(): sagemaker_session_mock = Mock() - sagemaker_session_mock.sagemaker_config.config = {} + sagemaker_session_mock.sagemaker_config = {} return sagemaker_session_mock @@ -106,13 +106,13 @@ def test_feature_group_create_without_role( def test_feature_store_create_with_config_injection( - sagemaker_config_session, role_arn, feature_group_dummy_definitions, s3_uri + sagemaker_session, role_arn, feature_group_dummy_definitions, s3_uri ): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_FEATURE_GROUP - sagemaker_config_session.create_feature_group = Mock() + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_FEATURE_GROUP + sagemaker_session.create_feature_group = Mock() - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_config_session) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -127,7 +127,7 @@ def test_feature_store_create_with_config_injection( expected_online_store_kms_key_id = SAGEMAKER_CONFIG_FEATURE_GROUP["SageMaker"]["FeatureGroup"][ "OnlineStoreConfig" ]["SecurityConfig"]["KmsKeyId"] - sagemaker_config_session.create_feature_group.assert_called_with( + sagemaker_session.create_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", record_identifier_name="feature1", event_time_feature_name="feature2", diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 4ed014f577..666a142543 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -76,7 +76,7 @@ def fixture_sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index e20228aa4f..f5ea031143 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -51,7 +51,7 @@ def sagemaker_session(): session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = ROLE # For tests which doesn't verify config file injection, operate with empty config - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} return session_mock diff --git a/tests/unit/sagemaker/image_uris/jumpstart/conftest.py b/tests/unit/sagemaker/image_uris/jumpstart/conftest.py index 386d8df94c..5b4227750d 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/conftest.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/conftest.py @@ -32,5 +32,5 @@ def session(): ) sms.default_bucket = Mock(return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py index f8b0a7c581..9261fd561e 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py @@ -25,7 +25,7 @@ @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_catboost_image_uri(patched_get_model_specs, session): # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py index 49417ae47e..c936655ae8 100644 --- a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py +++ b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py @@ -183,6 +183,7 @@ def sagemaker_session(): session.create_inference_recommendations_job.return_value = IR_JOB_NAME session.wait_for_inference_recommendations_job.return_value = IR_SAMPLE_INFERENCE_RESPONSE + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py index 0bd0d381ea..0fdf13509a 100644 --- a/tests/unit/sagemaker/local/test_local_pipeline.py +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -141,8 +141,8 @@ def pipeline_session(boto_session, client): default_bucket=BUCKET, ) # For tests which doesn't verify config file injection, operate with empty config - pipeline_session.sagemaker_config = Mock() - pipeline_session.sagemaker_config.config = {} + + pipeline_session.sagemaker_config = {} return pipeline_session_mock @@ -150,8 +150,8 @@ def pipeline_session(boto_session, client): def local_sagemaker_session(boto_session): local_session_mock = LocalSession(boto_session=boto_session, default_bucket="my-bucket") # For tests which doesn't verify config file injection, operate with empty config - local_session_mock.sagemaker_config = Mock() - local_session_mock.sagemaker_config.config = {} + + local_session_mock.sagemaker_config = {} return local_session_mock diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 6d39a88a35..5c045dcae4 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -69,7 +69,7 @@ def sagemaker_session(): session = Mock() # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session @@ -439,8 +439,8 @@ def test_deploy_wrong_serverless_config(sagemaker_session): @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_deploy_creates_correct_session(local_session, session): - local_session.sagemaker_config.config = {} - session.sagemaker_config.config = {} + local_session.sagemaker_config = {} + session.sagemaker_config = {} # We expect a LocalSession when deploying to instance_type = 'local' model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) diff --git a/tests/unit/sagemaker/model/test_edge.py b/tests/unit/sagemaker/model/test_edge.py index 1e8a110ce8..9e2c10c586 100644 --- a/tests/unit/sagemaker/model/test_edge.py +++ b/tests/unit/sagemaker/model/test_edge.py @@ -32,7 +32,7 @@ def sagemaker_session(): session = Mock(boto_region_name=REGION) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index d2c364c087..ca48e93187 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -95,7 +95,7 @@ def sagemaker_session(): ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 70506338c2..ee412e1399 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -110,7 +110,7 @@ def sagemaker_session(): ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 4633b10eb5..a87a2b74f4 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -60,7 +60,7 @@ def sagemaker_session(): return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE ) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index baa0aef4b0..e6912476d5 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -36,7 +36,7 @@ def sagemaker_session(): session = Mock(boto_region_name=REGION) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 221b26a513..33800c9a1d 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -417,7 +417,7 @@ def sagemaker_session(sagemaker_client): name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) # For tests which doesn't verify config file injection, operate with empty config - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} return session_mock diff --git a/tests/unit/sagemaker/monitor/test_data_capture_config.py b/tests/unit/sagemaker/monitor/test_data_capture_config.py index 24c7a1ddd3..474c63f09a 100644 --- a/tests/unit/sagemaker/monitor/test_data_capture_config.py +++ b/tests/unit/sagemaker/monitor/test_data_capture_config.py @@ -56,7 +56,7 @@ def test_init_when_non_defaults_provided(): def test_init_when_optionals_not_provided(): sagemaker_session = Mock() sagemaker_session.default_bucket.return_value = DEFAULT_BUCKET_NAME - sagemaker_session.sagemaker_config.config = {} + sagemaker_session.sagemaker_config = {} data_capture_config = DataCaptureConfig( enable_capture=DEFAULT_ENABLE_CAPTURE, sagemaker_session=sagemaker_session diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index b1cf601c01..04c67082fd 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -464,7 +464,7 @@ def sagemaker_session(): session_mock.expand_role.return_value = ROLE # For tests which doesn't verify config file injection, operate with empty config - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) @@ -875,13 +875,13 @@ def _test_data_quality_batch_transform_monitor_create_schedule( def test_data_quality_batch_transform_monitor_create_schedule_with_sagemaker_config_injection( data_quality_monitor, - sagemaker_config_session, + sagemaker_session, ): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE - sagemaker_config_session.sagemaker_client.create_monitoring_schedule = Mock() - data_quality_monitor.sagemaker_session = sagemaker_config_session + sagemaker_session.sagemaker_client.create_monitoring_schedule = Mock() + data_quality_monitor.sagemaker_session = sagemaker_session # for batch transform input data_quality_monitor.create_monitoring_schedule( @@ -902,7 +902,7 @@ def test_data_quality_batch_transform_monitor_create_schedule_with_sagemaker_con "MonitoringSchedule" ]["Tags"][0] - sagemaker_config_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( MonitoringScheduleName=SCHEDULE_NAME, MonitoringScheduleConfig={ "MonitoringJobDefinitionName": data_quality_monitor.job_definition_name, diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 548e0643f6..e384f21f92 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -83,7 +83,7 @@ def sagemaker_session(): session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py index 1ba953b76b..1a8762ea5d 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py @@ -43,7 +43,7 @@ def sagemaker_session(): session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_init.py b/tests/unit/sagemaker/tensorflow/test_estimator_init.py index 238f405765..9f4ee47034 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_init.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_init.py @@ -26,7 +26,7 @@ @pytest.fixture() def sagemaker_session(): session_mock = Mock(name="sagemaker_session", boto_region_name=REGION) - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} return session_mock diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index b403a7f07b..2c8b9f3ff4 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -68,7 +68,7 @@ def sagemaker_session(): session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 59adb31f71..cc8e2af0d2 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -83,7 +83,7 @@ def fixture_sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 110066bd49..852d7ee372 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -80,7 +80,7 @@ def fixture_sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 87d614fd20..4d46fba62c 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -82,7 +82,7 @@ def fixture_sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 1c3b4c252d..8f48cf1fb7 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -87,7 +87,7 @@ def fixture_sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py index 91318b610f..140e8508ac 100644 --- a/tests/unit/sagemaker/workflow/conftest.py +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -65,6 +65,6 @@ def pipeline_session(mock_boto_session, mock_client): default_bucket=BUCKET, ) # For tests which doesn't verify config file injection, operate with empty config - pipeline_session.sagemaker_config = Mock() - pipeline_session.sagemaker_config.config = {} + + pipeline_session.sagemaker_config = {} return pipeline_session diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index 9ee09054f5..32afc6d7b5 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -43,7 +43,7 @@ def sagemaker_session(): session._default_bucket = BUCKET_NAME # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 97e7f17549..c92db39791 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -47,7 +47,7 @@ def sagemaker_session_mock(): session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket") session_mock.local_mode = False # For tests which doesn't verify config file injection, operate with empty config - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} session_mock._append_sagemaker_config_tags = Mock( name="_append_sagemaker_config_tags", side_effect=lambda tags, config_path_to_tags: tags ) @@ -71,7 +71,7 @@ def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock): def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock): # For tests which doesn't verify config file injection, operate with empty config - sagemaker_session_mock.sagemaker_config.config = { + sagemaker_session_mock.sagemaker_config = { "SageMaker": {"Pipeline": {"RoleArn": "ConfigRoleArn"}} } sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = { diff --git a/tests/unit/sagemaker/wrangler/test_processing.py b/tests/unit/sagemaker/wrangler/test_processing.py index d13292ad9e..e61dee4e37 100644 --- a/tests/unit/sagemaker/wrangler/test_processing.py +++ b/tests/unit/sagemaker/wrangler/test_processing.py @@ -42,7 +42,7 @@ def sagemaker_session(): session_mock.expand_role.return_value = ROLE # For tests which doesn't verify config file injection, operate with empty config - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} return session_mock diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index fb443c6432..d4502a0890 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -948,7 +948,7 @@ def test_algorithm_no_required_hyperparameters(session): def test_algorithm_attach_from_hyperparameter_tuning(): session = Mock() - session.sagemaker_config.config = {} + session.sagemaker_config = {} job_name = "training-job-that-is-part-of-a-tuning-job" algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees" role_arn = "arn:aws:iam::123412341234:role/SageMakerRole" diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index e8e71f97ee..f571f4cbc2 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -75,7 +75,7 @@ def sagemaker_session(): name="describe_training_job", return_value=returned_job_description ) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 7d9c24cf3d..fdf29fef2e 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -75,7 +75,7 @@ def sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index 66fc967540..9f19625239 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -58,8 +58,8 @@ def sagemaker_session(): ) session.default_bucket = Mock(name="default_bucket", return_valie=BUCKET) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config = Mock() - session.sagemaker_config.config = {} + + session.sagemaker_config = {} return session diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index cfed30f076..c39a27495b 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -239,7 +239,7 @@ def sagemaker_session(): sms.upload_data = Mock(return_value=OUTPUT_PATH) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms @@ -366,13 +366,13 @@ def test_default_value_of_enable_network_isolation(sagemaker_session): assert framework.enable_network_isolation() is False -def test_framework_initialization_with_sagemaker_config_injection(sagemaker_config_session): +def test_framework_initialization_with_sagemaker_config_injection(sagemaker_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB framework = DummyFramework( entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_config_session, + sagemaker_session=sagemaker_session, instance_groups=[ InstanceGroup("group1", "ml.c4.xlarge", 1), InstanceGroup("group2", "ml.m4.xlarge", 2), @@ -409,9 +409,9 @@ def test_framework_initialization_with_sagemaker_config_injection(sagemaker_conf assert framework.subnets == expected_subnets -def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_config_session): +def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB estimator = Estimator( image_uri="some-image", @@ -419,7 +419,7 @@ def test_estimator_initialization_with_sagemaker_config_injection(sagemaker_conf InstanceGroup("group1", "ml.c4.xlarge", 1), InstanceGroup("group2", "ml.p3.16xlarge", 2), ], - sagemaker_session=sagemaker_config_session, + sagemaker_session=sagemaker_session, base_job_name="base_job_name", ) expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ @@ -1778,8 +1778,8 @@ def test_local_code_location(): local_mode=True, spec=sagemaker.local.LocalSession, ) - sms.sagemaker_config = Mock() - sms.sagemaker_config.config = {} + + sms.sagemaker_config = {} t = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -3742,13 +3742,13 @@ def test_register_under_pipeline_session(pipeline_session): def test_local_mode(session_class, local_session_class): local_session = Mock(spec=sagemaker.local.LocalSession) local_session.local_mode = True - local_session.sagemaker_config = Mock() - local_session.sagemaker_config.config = {} + + local_session.sagemaker_config = {} session = Mock() session.local_mode = False - session.sagemaker_config = Mock() - session.sagemaker_config.config = {} + + session.sagemaker_config = {} local_session_class.return_value = local_session session_class.return_value = session @@ -3774,8 +3774,8 @@ def test_local_mode_file_output_path(local_session_class): local_session = Mock(spec=sagemaker.local.LocalSession) local_session.local_mode = True local_session_class.return_value = local_session - local_session.sagemaker_config = Mock() - local_session.sagemaker_config.config = {} + + local_session.sagemaker_config = {} e = Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, "local", output_path="file:///tmp/model/") assert e.output_path == "file:///tmp/model/" @@ -4002,8 +4002,8 @@ def test_estimator_local_mode_error(sagemaker_session): def test_estimator_local_mode_ok(sagemaker_local_session): - sagemaker_local_session.sagemaker_config = Mock() - sagemaker_local_session.sagemaker_config.config = {} + + sagemaker_local_session.sagemaker_config = {} # When using instance local with a session which is not LocalSession we should error out Estimator( image_uri="some-image", diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index 0f1cca2d87..61f8079396 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -69,7 +69,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index 4197ca77da..5fe3882ed3 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -66,7 +66,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 5e0d7748bb..c151bc8174 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -81,7 +81,7 @@ def sagemaker_session(): ) mock_session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - mock_session.sagemaker_config.config = {} + mock_session.sagemaker_config = {} return mock_session diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index 0ac457a6a7..e966f4024c 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -63,7 +63,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 828bd73cd0..704d7a665f 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -69,7 +69,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index 8ef73047e5..f0adbc4d0a 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -58,7 +58,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index d055d20066..1d0d5d08dc 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -64,7 +64,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_multidatamodel.py b/tests/unit/test_multidatamodel.py index c0a6bb7ed3..45216f0a62 100644 --- a/tests/unit/test_multidatamodel.py +++ b/tests/unit/test_multidatamodel.py @@ -81,7 +81,7 @@ def sagemaker_session(): return_value=os.path.join(VALID_MULTI_MODEL_DATA_PREFIX, "mleap_model.tar.gz"), ) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} s3_mock = Mock() boto_mock.client("s3").return_value = s3_mock diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index d6dcb339fb..d5dd02ba05 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -102,7 +102,7 @@ def sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 93f9a13336..6b0339bf74 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -63,7 +63,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index 623a616c3f..9a5eac7931 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -71,7 +71,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index f12a527f37..4a00f1ea0d 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -63,7 +63,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 7f7ca3b00e..57a909a73d 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -75,7 +75,7 @@ def sagemaker_session(): ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms @@ -304,14 +304,14 @@ def test_pipeline_model_without_role(sagemaker_session): @patch("tarfile.open") @patch("time.strftime", return_value=TIMESTAMP) -def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_session): +def test_pipeline_model_with_config_injection(tfo, time, sagemaker_session): combined_config = copy.deepcopy(SAGEMAKER_CONFIG_MODEL) endpoint_config = copy.deepcopy(SAGEMAKER_CONFIG_ENDPOINT_CONFIG) merge_dicts(combined_config, endpoint_config) - sagemaker_config_session.sagemaker_config.config = combined_config + sagemaker_session.sagemaker_config = combined_config - sagemaker_config_session.create_model = Mock() - sagemaker_config_session.endpoint_from_production_variants = Mock() + sagemaker_session.create_model = Mock() + sagemaker_session.endpoint_from_production_variants = Mock() expected_role_arn = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["ExecutionRoleArn"] expected_enable_network_isolation = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"][ @@ -322,12 +322,12 @@ def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_sessio "KmsKeyId" ] - framework_model = DummyFrameworkModel(sagemaker_config_session) + framework_model = DummyFrameworkModel(sagemaker_session) sparkml_model = SparkMLModel( - model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_config_session + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session ) pipeline_model = PipelineModel( - [framework_model, sparkml_model], sagemaker_session=sagemaker_config_session + [framework_model, sparkml_model], sagemaker_session=sagemaker_session ) assert pipeline_model.role == expected_role_arn assert pipeline_model.vpc_config == expected_vpc_config @@ -335,14 +335,14 @@ def test_pipeline_model_with_config_injection(tfo, time, sagemaker_config_sessio pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) - sagemaker_config_session.create_model.assert_called_with( + sagemaker_session.create_model.assert_called_with( ANY, expected_role_arn, ANY, vpc_config=expected_vpc_config, enable_network_isolation=expected_enable_network_isolation, ) - sagemaker_config_session.endpoint_from_production_variants.assert_called_with( + sagemaker_session.endpoint_from_production_variants.assert_called_with( name="mi-1-2017-10-10-14-14-15", production_variants=[ { diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 0149e17fcd..d6265fa0ac 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -85,7 +85,7 @@ def sagemaker_session(): ) # For tests which doesn't verify config file injection, operate with empty config - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} return session_mock @@ -111,7 +111,7 @@ def pipeline_session(): session_mock.__class__ = PipelineSession # For tests which doesn't verify config file injection, operate with empty config - session_mock.sagemaker_config.config = {} + session_mock.sagemaker_config = {} return session_mock @@ -646,17 +646,17 @@ def test_script_processor_without_role(exists_mock, isfile_mock, sagemaker_sessi @patch("os.path.exists", return_value=True) @patch("os.path.isfile", return_value=True) def test_script_processor_with_sagemaker_config_injection( - exists_mock, isfile_mock, sagemaker_config_session + exists_mock, isfile_mock, sagemaker_session ): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_PROCESSING_JOB + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_PROCESSING_JOB - sagemaker_config_session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - sagemaker_config_session.upload_data = Mock(name="upload_data", return_value=MOCKED_S3_URI) - sagemaker_config_session.wait_for_processing_job = MagicMock( + sagemaker_session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sagemaker_session.upload_data = Mock(name="upload_data", return_value=MOCKED_S3_URI) + sagemaker_session.wait_for_processing_job = MagicMock( name="wait_for_processing_job", return_value=_get_describe_response_inputs_and_ouputs() ) - sagemaker_config_session.process = Mock() - sagemaker_config_session.expand_role = Mock(name="expand_role", side_effect=lambda a: a) + sagemaker_session.process = Mock() + sagemaker_session.expand_role = Mock(name="expand_role", side_effect=lambda a: a) processor = ScriptProcessor( image_uri=CUSTOM_IMAGE_URI, @@ -668,7 +668,7 @@ def test_script_processor_with_sagemaker_config_injection( base_job_name="my_sklearn_processor", env={"my_env_variable": "my_env_variable_value"}, tags=[{"Key": "my-tag", "Value": "my-tag-value"}], - sagemaker_session=sagemaker_config_session, + sagemaker_session=sagemaker_session, ) processor.run( code="/local/path/to/processing_code.py", @@ -707,7 +707,7 @@ def test_script_processor_with_sagemaker_config_injection( "EnableInterContainerTrafficEncryption" ] = expected_enable_inter_containter_traffic_encryption - sagemaker_config_session.process.assert_called_with(**expected_args) + sagemaker_session.process.assert_called_with(**expected_args) assert "my_job_name" in processor._current_job_name diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index d31ddcd587..e0c49ea328 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -84,7 +84,7 @@ def fixture_sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index e7662ce88a..daa6f8cacc 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -63,7 +63,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index c4b66d5490..2d74d3919e 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -77,7 +77,7 @@ def fixture_sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1ac8d2782f..81d92a47b4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -262,14 +262,14 @@ def test_process(boto_session): session.sagemaker_client.create_processing_job.assert_called_with(**expected_request) -def test_create_process_with_sagemaker_config_injection(sagemaker_config_session): +def test_create_process_with_sagemaker_config_injection(sagemaker_session): processing_job_config = copy.deepcopy(SAGEMAKER_CONFIG_PROCESSING_JOB) # deleting RedshiftDatasetDefinition. API can take either RedshiftDatasetDefinition or # AthenaDatasetDefinition del processing_job_config["SageMaker"]["ProcessingJob"]["ProcessingInputs"][0][ "DatasetDefinition" ]["RedshiftDatasetDefinition"] - sagemaker_config_session.sagemaker_config.config = processing_job_config + sagemaker_session.sagemaker_config = processing_job_config processing_inputs = [ { @@ -331,7 +331,7 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_config_session "environment": {"my_env_variable": 20}, "experiment_config": {"ExperimentName": "AnExperiment"}, } - sagemaker_config_session.process(**process_request_args) + sagemaker_session.process(**process_request_args) expected_volume_kms_key_id = SAGEMAKER_CONFIG_PROCESSING_JOB["SageMaker"]["ProcessingJob"][ "ProcessingResources" ]["ClusterConfig"]["VolumeKmsKeyId"] @@ -377,9 +377,7 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_config_session "VolumeKmsKeyId" ] = expected_volume_kms_key_id - sagemaker_config_session.sagemaker_client.create_processing_job.assert_called_with( - **expected_request - ) + sagemaker_session.sagemaker_client.create_processing_job.assert_called_with(**expected_request) def mock_exists(filepath_to_mock, exists_result): @@ -1539,8 +1537,8 @@ def test_stop_tuning_job_client_error(sagemaker_session): ) -def test_train_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRAINING_JOB +def test_train_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB in_config = [ { @@ -1573,7 +1571,7 @@ def test_train_with_sagemaker_config_injection(sagemaker_config_session): }, } - sagemaker_config_session.train( + sagemaker_session.train( image_uri=IMAGE, input_mode="File", input_config=in_config, @@ -1592,7 +1590,7 @@ def test_train_with_sagemaker_config_injection(sagemaker_config_session): training_image_config=TRAINING_IMAGE_CONFIG, ) - _, _, actual_train_args = sagemaker_config_session.sagemaker_client.method_calls[0] + _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ "ResourceConfig" @@ -1714,9 +1712,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): ) -def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_session): +def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session): # Config to test injection for - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRANSFORM_JOB + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRANSFORM_JOB model_name = "my-model" in_config = { @@ -1761,7 +1759,7 @@ def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_s assert "VolumeKmsKeyId" not in resource_config # injection should happen during this method - sagemaker_config_session.transform( + sagemaker_session.transform( job_name=JOB_NAME, model_name=model_name, strategy=None, @@ -1778,7 +1776,7 @@ def test_create_transform_job_with_sagemaker_config_injection(sagemaker_config_s batch_data_capture_config=data_capture_config, ) - _, _, actual_args = sagemaker_config_session.sagemaker_client.method_calls[0] + _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_args == expected_args @@ -2114,14 +2112,14 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_ } -def test_create_model_with_sagemaker_config_injection(sagemaker_config_session): +def test_create_model_with_sagemaker_config_injection(sagemaker_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MODEL + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MODEL - sagemaker_config_session.expand_role = Mock( + sagemaker_session.expand_role = Mock( name="expand_role", side_effect=lambda role_name: role_name ) - model = sagemaker_config_session.create_model( + model = sagemaker_session.create_model( MODEL_NAME, container_defs=PRIMARY_CONTAINER, ) @@ -2132,7 +2130,7 @@ def test_create_model_with_sagemaker_config_injection(sagemaker_config_session): expected_vpc_config = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["VpcConfig"] expected_tags = SAGEMAKER_CONFIG_MODEL["SageMaker"]["Model"]["Tags"] assert model == MODEL_NAME - sagemaker_config_session.sagemaker_client.create_model.assert_called_with( + sagemaker_session.sagemaker_client.create_model.assert_called_with( ExecutionRoleArn=expected_role_arn, ModelName=MODEL_NAME, PrimaryContainer=PRIMARY_CONTAINER, @@ -2310,12 +2308,12 @@ def test_create_model_from_job_with_tags(sagemaker_session): ) -def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB +def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB output_config = {"S3OutputLocation": S3_OUTPUT} - sagemaker_config_session.package_model_for_edge( + sagemaker_session.package_model_for_edge( output_config, ) expected_role_arn = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"][ @@ -2325,7 +2323,7 @@ def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_config_ "OutputConfig" ]["KmsKeyId"] expected_tags = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["Tags"] - sagemaker_config_session.sagemaker_client.create_edge_packaging_job.assert_called_with( + sagemaker_session.sagemaker_client.create_edge_packaging_job.assert_called_with( RoleArn=expected_role_arn, # provided from config OutputConfig={ "S3OutputLocation": S3_OUTPUT, # provided as param @@ -2339,12 +2337,12 @@ def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_config_ ) -def test_create_monitoring_schedule_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE +def test_create_monitoring_schedule_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE monitoring_output_config = {"MonitoringOutputs": [{"S3Output": {"S3Uri": S3_OUTPUT}}]} - sagemaker_config_session.create_monitoring_schedule( + sagemaker_session.create_monitoring_schedule( JOB_NAME, schedule_expression=None, statistics_s3_uri=None, @@ -2383,7 +2381,7 @@ def test_create_monitoring_schedule_with_sagemaker_config_injection(sagemaker_co ]["MonitoringSchedule"]["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["NetworkConfig"][ "EnableInterContainerTrafficEncryption" ] - sagemaker_config_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( MonitoringScheduleName=JOB_NAME, MonitoringScheduleConfig={ "MonitoringJobDefinition": { @@ -2419,10 +2417,10 @@ def test_create_monitoring_schedule_with_sagemaker_config_injection(sagemaker_co ) -def test_compile_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_COMPILATION_JOB +def test_compile_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_COMPILATION_JOB - sagemaker_config_session.compile_model( + sagemaker_session.compile_model( input_model_config={}, output_model_config={"S3OutputLocation": "s3://test"}, job_name="TestJob", @@ -2435,7 +2433,7 @@ def test_compile_with_sagemaker_config_injection(sagemaker_config_session): "VpcConfig" ] expected_tags = SAGEMAKER_CONFIG_COMPILATION_JOB["SageMaker"]["CompilationJob"]["Tags"] - sagemaker_config_session.sagemaker_client.create_compilation_job.assert_called_with( + sagemaker_session.sagemaker_client.create_compilation_job.assert_called_with( InputConfig={}, OutputConfig={"S3OutputLocation": "s3://test", "KmsKeyId": expected_kms_key_id}, RoleArn=expected_role_arn, @@ -2508,13 +2506,13 @@ def test_endpoint_from_production_variants(sagemaker_session): ) -def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG +def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG data_capture_config_dict = {"DestinationS3Uri": "s3://test"} # This method does not support ASYNC_INFERENCE_CONFIG or multiple PRODUCTION_VARIANTS - sagemaker_config_session.create_endpoint_config( + sagemaker_session.create_endpoint_config( "endpoint-test", "simple-model", 1, @@ -2529,7 +2527,7 @@ def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config ] expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] - sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName="endpoint-test", ProductionVariants=[ { @@ -2550,9 +2548,9 @@ def test_create_endpoint_config_with_sagemaker_config_injection(sagemaker_config def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( - sagemaker_config_session, + sagemaker_session, ): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG pvs = [ sagemaker.production_variant("A", "ml.p2.xlarge"), @@ -2564,14 +2562,14 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo" existing_endpoint_name = "foo" new_endpoint_name = "new-foo" - sagemaker_config_session.sagemaker_client.describe_endpoint_config.return_value = { + sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = { "ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")], "EndpointConfigArn": existing_endpoint_arn, "AsyncInferenceConfig": {}, } - sagemaker_config_session.sagemaker_client.list_tags.return_value = {"Tags": []} + sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []} - sagemaker_config_session.create_endpoint_config_from_existing( + sagemaker_session.create_endpoint_config_from_existing( existing_endpoint_name, new_endpoint_name, new_production_variants=pvs ) @@ -2586,7 +2584,7 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( ] expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] - sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName=new_endpoint_name, ProductionVariants=[ { @@ -2609,11 +2607,11 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( def test_endpoint_from_production_variants_with_sagemaker_config_injection( - sagemaker_config_session, + sagemaker_session, ): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG - sagemaker_config_session.sagemaker_client.describe_endpoint = Mock( + sagemaker_session.sagemaker_client.describe_endpoint = Mock( return_value={"EndpointStatus": "InService"} ) pvs = [ @@ -2623,7 +2621,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection( ] # Add DestinationS3Uri to only one production variant pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"} - sagemaker_config_session.endpoint_from_production_variants( + sagemaker_session.endpoint_from_production_variants( "some-endpoint", pvs, data_capture_config_dict={}, @@ -2655,7 +2653,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection( "DestinationS3Uri": "s3://test", "KmsKeyId": expected_production_variant_0_kms_key_id, } - sagemaker_config_session.sagemaker_client.create_endpoint_config.assert_called_with( + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName="some-endpoint", ProductionVariants=expected_pvs, Tags=expected_tags, # from config @@ -2663,7 +2661,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection( AsyncInferenceConfig=expected_async_inference_config_dict, DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id}, ) - sagemaker_config_session.sagemaker_client.create_endpoint.assert_called_with( + sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=expected_tags, # from config @@ -3163,8 +3161,8 @@ def test_auto_ml_pack_to_request(sagemaker_session): ) -def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_AUTO_ML +def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_AUTO_ML input_config = [ { @@ -3184,9 +3182,7 @@ def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_config_session } job_name = JOB_NAME - sagemaker_config_session.auto_ml( - input_config, output_config, auto_ml_job_config, job_name=job_name - ) + sagemaker_session.auto_ml(input_config, output_config, auto_ml_job_config, job_name=job_name) expected_call_args = copy.deepcopy(DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS) expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ "SecurityConfig" @@ -3211,7 +3207,7 @@ def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_config_session expected_call_args["AutoMLJobConfig"]["SecurityConfig"][ "VolumeKmsKeyId" ] = expected_volume_kms_key_id - sagemaker_config_session.sagemaker_client.create_auto_ml_job.assert_called_with( + sagemaker_session.sagemaker_client.create_auto_ml_job.assert_called_with( AutoMLJobName=expected_call_args["AutoMLJobName"], InputDataConfig=expected_call_args["InputDataConfig"], OutputDataConfig=expected_call_args["OutputDataConfig"], @@ -3419,8 +3415,8 @@ def test_create_model_package_from_containers_without_model_package_group_name( ) -def test_create_model_package_with_sagemaker_config_injection(sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_MODEL_PACKAGE +def test_create_model_package_with_sagemaker_config_injection(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MODEL_PACKAGE model_package_name = "sagemaker-model-package" containers = ["dummy-container"] @@ -3463,7 +3459,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_config_s domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" - sagemaker_config_session.create_model_package_from_containers( + sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, response_types=response_types, @@ -3522,9 +3518,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_config_s "TransformOutput" ]["KmsKeyId"] = expected_kms_key_id - sagemaker_config_session.sagemaker_client.create_model_package.assert_called_with( - **expected_args - ) + sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) def test_create_model_package_from_containers_all_args(sagemaker_session): @@ -3736,12 +3730,12 @@ def feature_group_dummy_definitions(): def test_feature_group_create_with_sagemaker_config_injection( - sagemaker_config_session, feature_group_dummy_definitions + sagemaker_session, feature_group_dummy_definitions ): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_FEATURE_GROUP + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_FEATURE_GROUP - sagemaker_config_session.create_feature_group( + sagemaker_session.create_feature_group( feature_group_name="MyFeatureGroup", record_identifier_name="feature1", event_time_feature_name="feature2", @@ -3771,9 +3765,7 @@ def test_feature_group_create_with_sagemaker_config_injection( }, "Tags": expected_tags, } - sagemaker_config_session.sagemaker_client.create_feature_group.assert_called_with( - **expected_request - ) + sagemaker_session.sagemaker_client.create_feature_group.assert_called_with(**expected_request) def test_feature_group_create(sagemaker_session, feature_group_dummy_definitions): @@ -4535,11 +4527,13 @@ def sort(tags): {"Key": "tagkey2", "Value": "tagvalue2"}, {"Key": "tagkey3", "Value": "tagvalue3"}, ] + sagemaker_session.sagemaker_config = {"SchemaVersion": "1.0"} + sagemaker_session.sagemaker_config.update( + {"SageMaker": {"ProcessingJob": {"Tags": config_tag_value}}} + ) + config_key_path = "SageMaker.ProcessingJob.Tags" - sagemaker_session.sagemaker_config = Mock() - sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": config_tag_value}}} - - base_case = sagemaker_session._append_sagemaker_config_tags(tags_base, "DUMMY.CONFIG.PATH") + base_case = sagemaker_session._append_sagemaker_config_tags(tags_base, config_key_path) assert sort(base_case) == sort( [ {"Key": "tagkey1", "Value": "tagvalue1"}, @@ -4551,7 +4545,7 @@ def sort(tags): ) duplicate_case = sagemaker_session._append_sagemaker_config_tags( - tags_duplicate, "DUMMY.CONFIG.PATH" + tags_duplicate, config_key_path ) assert sort(duplicate_case) == sort( [ @@ -4561,7 +4555,7 @@ def sort(tags): ] ) - none_case = sagemaker_session._append_sagemaker_config_tags(tags_none, "DUMMY.CONFIG.PATH") + none_case = sagemaker_session._append_sagemaker_config_tags(tags_none, config_key_path) assert sort(none_case) == sort( [ {"Key": "tagkey1", "Value": "tagvalue1"}, @@ -4570,7 +4564,7 @@ def sort(tags): ] ) - empty_case = sagemaker_session._append_sagemaker_config_tags(tags_empty, "DUMMY.CONFIG.PATH") + empty_case = sagemaker_session._append_sagemaker_config_tags(tags_empty, config_key_path) assert sort(empty_case) == sort( [ {"Key": "tagkey1", "Value": "tagvalue1"}, @@ -4579,12 +4573,10 @@ def sort(tags): ] ) - sagemaker_session.sagemaker_config.config = { - "DUMMY": {"CONFIG": {"OTHER_PATH": config_tag_value}} - } - config_tags_none = sagemaker_session._append_sagemaker_config_tags( - tags_base, "DUMMY.CONFIG.PATH" + sagemaker_session.sagemaker_config.update( + {"SageMaker": {"TrainingJob": {"Tags": config_tag_value}}} ) + config_tags_none = sagemaker_session._append_sagemaker_config_tags(tags_base, config_key_path) assert sort(config_tags_none) == sort( [ {"Key": "tagkey4", "Value": "000"}, @@ -4592,7 +4584,9 @@ def sort(tags): ] ) - sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": config_tag_value}}} + sagemaker_session.sagemaker_config.update( + {"SageMaker": {"ProcessingJob": {"Tags": config_tag_value}}} + ) config_tags_empty = sagemaker_session._append_sagemaker_config_tags( tags_base, "DUMMY.CONFIG.PATH" ) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index d7d9f45c48..6f0ce35319 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -79,7 +79,7 @@ def sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index f81ba3dde0..1e178dfc57 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -46,7 +46,7 @@ def sagemaker_session(): sms.boto_region_name = REGION sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py index 09d2669876..76ebeee53d 100644 --- a/tests/unit/test_timeout.py +++ b/tests/unit/test_timeout.py @@ -61,7 +61,7 @@ def session(): ) sms.default_bucket = Mock(name=DEFAULT_BUCKET_NAME, return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 240c894fa2..3611ebd4a2 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -67,7 +67,7 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session") session = Mock(name="sagemaker_session", boto_session=boto_mock, local_mode=False) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session @@ -103,15 +103,15 @@ def transformer(sagemaker_session): @patch("sagemaker.transformer._TransformJob.start_new") -def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_config_session): - sagemaker_config_session.sagemaker_config.config = SAGEMAKER_CONFIG_TRANSFORM_JOB +def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRANSFORM_JOB transformer = Transformer( MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_config_session, + sagemaker_session=sagemaker_session, ) # volume kms key and output kms key are inserted from the config assert ( diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index e6d50eeee9..74ad25aaea 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -67,7 +67,7 @@ def sagemaker_session(): sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) # For tests which doesn't verify config file injection, operate with empty config - sms.sagemaker_config.config = {} + sms.sagemaker_config = {} return sms diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 888adc3799..0a9be66d7a 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -107,7 +107,8 @@ def test_get_nested_value(): assert sagemaker.utils.get_nested_value(dictionary, []) is None -def test_update_list_of_dicts_with_values_from_config(): +@patch("jsonschema.validate") +def test_update_list_of_dicts_with_values_from_config(mock_json_schema_validation): input_list = [{"a": 1, "b": 2}] input_config_list = [ { @@ -117,8 +118,8 @@ def test_update_list_of_dicts_with_values_from_config(): ] # Using short form for sagemaker_session ss = MagicMock() - ss.sagemaker_config = Mock() - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} config_path = "DUMMY.CONFIG.PATH" # happy case - both inputs and config have same number of elements update_list_of_dicts_with_values_from_config(input_list, config_path, sagemaker_session=ss) @@ -134,7 +135,7 @@ def test_update_list_of_dicts_with_values_from_config(): "c": 3, } ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config(input_list, config_path, sagemaker_session=ss) assert input_list == [ {"a": 1, "b": 2, "c": 3}, @@ -149,7 +150,7 @@ def test_update_list_of_dicts_with_values_from_config(): }, {"a": 5, "b": 6}, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config(input_list, config_path, sagemaker_session=ss) assert input_list == [{"a": 1, "b": 2, "c": 3}] # Testing required parameters. If required parameters are not present, don't do the merge @@ -160,7 +161,7 @@ def test_update_list_of_dicts_with_values_from_config(): "c": 3, }, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config( input_list, config_path, required_key_paths=["d"], sagemaker_session=ss ) @@ -178,7 +179,7 @@ def test_update_list_of_dicts_with_values_from_config(): "b": 8, }, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config( input_list, config_path, required_key_paths=["c"], sagemaker_session=ss ) @@ -198,7 +199,7 @@ def test_update_list_of_dicts_with_values_from_config(): "d": 8, # c is present in the original list and d is present in this list. }, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config( input_list, config_path, union_key_paths=[["c", "d"]], sagemaker_session=ss ) @@ -218,7 +219,7 @@ def test_update_list_of_dicts_with_values_from_config(): "d": 8, }, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config( input_list, config_path, union_key_paths=[["c", "e"], ["d", "e"]], sagemaker_session=ss ) @@ -238,7 +239,7 @@ def test_update_list_of_dicts_with_values_from_config(): "d": 8, }, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config( input_list, config_path, union_key_paths=[["d", "e"], ["c", "e"]], sagemaker_session=ss ) @@ -259,7 +260,7 @@ def test_update_list_of_dicts_with_values_from_config(): "d": 8, }, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config( input_list, config_path, @@ -281,7 +282,7 @@ def test_update_list_of_dicts_with_values_from_config(): "g": 8, }, ] - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": input_config_list}}} update_list_of_dicts_with_values_from_config( input_list, config_path, @@ -1177,20 +1178,22 @@ def test_retry_with_backoff(patched_sleep): def test_resolve_value_from_config(): # using a shorter name for inside the test sagemaker_session = MagicMock() - sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": "CONFIG_VALUE"}}} + sagemaker_session.sagemaker_config = {"SchemaVersion": "1.0"} + config_key_path = "SageMaker.EndpointConfig.KmsKeyId" + sagemaker_session.sagemaker_config.update( + {"SageMaker": {"EndpointConfig": {"KmsKeyId": "CONFIG_VALUE"}}} + ) # direct_input should be respected assert ( - resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", "DEFAULT_VALUE", sagemaker_session) + resolve_value_from_config("INPUT", config_key_path, "DEFAULT_VALUE", sagemaker_session) == "INPUT" ) - assert ( - resolve_value_from_config("INPUT", "DUMMY.CONFIG.PATH", None, sagemaker_session) == "INPUT" - ) + assert resolve_value_from_config("INPUT", config_key_path, None, sagemaker_session) == "INPUT" assert ( - resolve_value_from_config("INPUT", "DUMMY.CONFIG.INVALID_PATH", None, sagemaker_session) + resolve_value_from_config("INPUT", "SageMaker.EndpointConfig.Tags", None, sagemaker_session) == "INPUT" ) @@ -1201,41 +1204,33 @@ def test_resolve_value_from_config(): assert ( resolve_value_from_config( - None, "DUMMY.CONFIG.INVALID_PATH", "DEFAULT_VALUE", sagemaker_session + None, "SageMaker.EndpointConfig.Tags", "DEFAULT_VALUE", sagemaker_session ) == "DEFAULT_VALUE" ) assert ( - resolve_value_from_config(None, "DUMMY.CONFIG.PATH", "DEFAULT_VALUE", sagemaker_session) + resolve_value_from_config(None, config_key_path, "DEFAULT_VALUE", sagemaker_session) == "CONFIG_VALUE" ) assert resolve_value_from_config(None, None, None, sagemaker_session) is None # Different falsy direct_inputs - assert resolve_value_from_config("", "DUMMY.CONFIG.PATH", None, sagemaker_session) == "" + assert resolve_value_from_config("", config_key_path, None, sagemaker_session) == "" - assert resolve_value_from_config([], "DUMMY.CONFIG.PATH", None, sagemaker_session) == [] + assert resolve_value_from_config([], config_key_path, None, sagemaker_session) == [] - assert resolve_value_from_config(False, "DUMMY.CONFIG.PATH", None, sagemaker_session) is False + assert resolve_value_from_config(False, config_key_path, None, sagemaker_session) is False - assert resolve_value_from_config({}, "DUMMY.CONFIG.PATH", None, sagemaker_session) == {} + assert resolve_value_from_config({}, config_key_path, None, sagemaker_session) == {} # Different falsy config_values - sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": ""}}} - assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) == "" - - sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": []}}} - assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) == [] - - sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": False}}} - assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) is False - - sagemaker_session.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": {}}}} - assert resolve_value_from_config(None, "DUMMY.CONFIG.PATH", None, sagemaker_session) == {} + sagemaker_session.sagemaker_config.update({"SageMaker": {"EndpointConfig": {"KmsKeyId": ""}}}) + assert resolve_value_from_config(None, config_key_path, None, sagemaker_session) == "" +@patch("jsonschema.validate") @pytest.mark.parametrize( "existing_value, config_value, default_value", [ @@ -1245,7 +1240,9 @@ def test_resolve_value_from_config(): (0, 1, 2), ], ) -def test_resolve_class_attribute_from_config(existing_value, config_value, default_value): +def test_resolve_class_attribute_from_config( + mock_validate, existing_value, config_value, default_value +): # using a shorter name for inside the test ss = MagicMock() @@ -1265,8 +1262,8 @@ def __eq__(self, other): dummy_config_path = "DUMMY.CONFIG.PATH" # with an existing config value - ss.sagemaker_config = Mock() - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": config_value}}} + + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": config_value}}} # instance exists and has value; config has value test_instance = TestClass(test_attribute=existing_value, extra="EXTRA_VALUE") @@ -1312,7 +1309,7 @@ def __eq__(self, other): ) # without an existing config value - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"SOMEOTHERPATH": config_value}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"SOMEOTHERPATH": config_value}}} # instance exists but doesnt have value; config doesnt have value test_instance = TestClass(extra="EXTRA_VALUE") assert resolve_class_attribute_from_config( @@ -1351,7 +1348,8 @@ def __eq__(self, other): ) == TestClass(test_attribute=default_value, extra=None) -def test_resolve_nested_dict_value_from_config(): +@patch("jsonschema.validate") +def test_resolve_nested_dict_value_from_config(mock_validate): # using a shorter name for inside the test ss = MagicMock() @@ -1373,8 +1371,8 @@ def test_resolve_nested_dict_value_from_config(): ) == {"local": {"region_name": "us-west-2", "port": "123"}} # happy case: return dict with config_value when it wasnt set in dict or was None - ss.sagemaker_config = Mock() - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"PATH": "CONFIG_VALUE"}}} + + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"PATH": "CONFIG_VALUE"}}} assert resolve_nested_dict_value_from_config( {"local": {"port": "123"}}, ["local", "region_name"], @@ -1437,7 +1435,7 @@ def test_resolve_nested_dict_value_from_config(): ) # without an existing config value - ss.sagemaker_config.config = {"DUMMY": {"CONFIG": {"ANOTHER_PATH": "CONFIG_VALUE"}}} + ss.sagemaker_config = {"DUMMY": {"CONFIG": {"ANOTHER_PATH": "CONFIG_VALUE"}}} # happy case: return dict with default_value when it wasnt set in dict and in config assert resolve_nested_dict_value_from_config( diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 76ad51896f..4476b3b5ff 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -81,7 +81,7 @@ def sagemaker_session(): session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config - session.sagemaker_config.config = {} + session.sagemaker_config = {} return session diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index 215cc80d26..b9fa4f2ff3 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -73,7 +73,7 @@ SAGEMAKER_SESSION = Mock() # For tests which doesn't verify config file injection, operate with empty config -SAGEMAKER_SESSION.sagemaker_config.config = {} +SAGEMAKER_SESSION.sagemaker_config = {} ESTIMATOR = Estimator( IMAGE_NAME, From d73bb9742f39f3f9eaa47c2eea195036ccc20034 Mon Sep 17 00:00:00 2001 From: Balaji Sankar Date: Mon, 27 Mar 2023 17:42:01 -0700 Subject: [PATCH 33/40] fix: fix broken unit tests due to refactoring --- src/sagemaker/local/local_session.py | 5 ++-- tests/conftest.py | 2 +- tests/unit/sagemaker/model/test_deploy.py | 6 +++-- .../sagemaker/model/test_framework_model.py | 4 ++++ tests/unit/sagemaker/model/test_model.py | 4 ++++ tests/unit/sagemaker/model/test_neo.py | 3 +++ .../monitor/test_model_monitoring.py | 6 +++++ tests/unit/sagemaker/spark/test_processing.py | 1 + .../unit/sagemaker/workflow/test_pipeline.py | 19 +++++++++++---- tests/unit/sagemaker/workflow/test_utils.py | 4 +++- tests/unit/test_algorithm.py | 24 ++++++++++++++++++- tests/unit/test_clarify.py | 1 + tests/unit/test_estimator.py | 14 +++++++++++ tests/unit/test_tuner.py | 2 ++ 14 files changed, 84 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index cb41548d04..ac34555327 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -682,8 +682,9 @@ def _initialize( self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url) self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url) self.sagemaker_config = ( - sagemaker_config if sagemaker_config else fetch_sagemaker_config( - s3_resource=self.s3_resource) + sagemaker_config + if sagemaker_config + else fetch_sagemaker_config(s3_resource=self.s3_resource) ) else: self.sagemaker_config = ( diff --git a/tests/conftest.py b/tests/conftest.py index 68761046bd..45d3cd3ff6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -166,7 +166,7 @@ def sagemaker_session( sagemaker_client=sagemaker_client, sagemaker_runtime_client=runtime_client, sagemaker_metrics_client=metrics_client, - sagemaker_config={"SchemaVersion": "1.0"}, + sagemaker_config={}, ) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 5c045dcae4..cecf1ecc62 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -168,6 +168,7 @@ def test_deploy_accelerator_type( @patch("sagemaker.model.Model._create_sagemaker_model", Mock()) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) def test_deploy_endpoint_name(sagemaker_session): + sagemaker_session.sagemaker_config = {} model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session) endpoint_name = "blah" @@ -371,6 +372,7 @@ def test_deploy_async_inference(production_variant, name_from_base, sagemaker_se @patch("sagemaker.model.Model._create_sagemaker_model") @patch("sagemaker.production_variant") def test_deploy_serverless_inference(production_variant, create_sagemaker_model, sagemaker_session): + sagemaker_session.sagemaker_config = {} model = Model( MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) @@ -439,8 +441,8 @@ def test_deploy_wrong_serverless_config(sagemaker_session): @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_deploy_creates_correct_session(local_session, session): - local_session.sagemaker_config = {} - session.sagemaker_config = {} + local_session.return_value.sagemaker_config = {} + session.return_value.sagemaker_config = {} # We expect a LocalSession when deploying to instance_type = 'local' model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index ca48e93187..613ebefd64 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -208,6 +208,7 @@ def test_git_support_repo_not_provided(sagemaker_session): ), ) def test_git_support_git_clone_fail(sagemaker_session): + sagemaker_session.sagemaker_config = {} entry_point = "source_dir/entry_point" git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH} with pytest.raises(subprocess.CalledProcessError) as error: @@ -257,6 +258,7 @@ def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): side_effect=ValueError("Entry point does not exist in the repo."), ) def test_git_support_entry_point_not_exist(sagemaker_session): + sagemaker_session.sagemaker_config = {} entry_point = "source_dir/entry_point" git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} with pytest.raises(ValueError) as error: @@ -272,6 +274,7 @@ def test_git_support_entry_point_not_exist(sagemaker_session): side_effect=ValueError("Source directory does not exist in the repo."), ) def test_git_support_source_dir_not_exist(sagemaker_session): + sagemaker_session.sagemaker_config = {} entry_point = "entry_point" source_dir = "source_dir_that_does_not_exist" git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} @@ -291,6 +294,7 @@ def test_git_support_source_dir_not_exist(sagemaker_session): side_effect=ValueError("Dependency no-such-dir does not exist in the repo."), ) def test_git_support_dependencies_not_exist(sagemaker_session): + sagemaker_session.sagemaker_config = {} entry_point = "entry_point" dependencies = ["foo", "no_such_dir"] git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index ee412e1399..daa9d46763 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -326,6 +326,8 @@ def test_create_sagemaker_model_generates_model_name_each_time( @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_create_sagemaker_model_creates_correct_session(local_session, session): + local_session.return_value.sagemaker_config = {} + session.return_value.sagemaker_config = {} model = Model(MODEL_IMAGE, MODEL_DATA) model._create_sagemaker_model("local") assert model.sagemaker_session == local_session.return_value @@ -433,6 +435,8 @@ def test_model_create_transformer_base_name(sagemaker_session): @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_transformer_creates_correct_session(local_session, session): + local_session.return_value.sagemaker_config = {} + session.return_value.sagemaker_config = {} model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=None) transformer = model.transformer(instance_count=1, instance_type="local") assert model.sagemaker_session == local_session.return_value diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index e6912476d5..5aa1468fd6 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -170,6 +170,7 @@ def test_compile_model_for_cloud_tflite(sagemaker_session): @patch("sagemaker.session.Session") def test_compile_creates_session(session): session.return_value.boto_region_name = REGION + session.return_value.sagemaker_config = {} model = _create_model() model.compile( @@ -313,6 +314,7 @@ def test_compile_with_framework_version_16(sagemaker_session): @patch("sagemaker.session.Session") def test_compile_with_pytorch_neo_in_ml_inf(session): session.return_value.boto_region_name = REGION + session.return_value.sagemaker_config = {} model = _create_model() model.compile( @@ -336,6 +338,7 @@ def test_compile_with_pytorch_neo_in_ml_inf(session): @patch("sagemaker.session.Session") def test_compile_with_tensorflow_neo_in_ml_inf(session): session.return_value.boto_region_name = REGION + session.return_value.sagemaker_config = {} model = _create_model() model.compile( diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index 04c67082fd..77eb8ae506 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -877,8 +877,14 @@ def test_data_quality_batch_transform_monitor_create_schedule_with_sagemaker_con data_quality_monitor, sagemaker_session, ): + from sagemaker.utils import get_config_value sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE + sagemaker_session._append_sagemaker_config_tags = Mock( + name="_append_sagemaker_config_tags", + side_effect=lambda tags, config_path_to_tags: tags + + get_config_value(config_path_to_tags, SAGEMAKER_CONFIG_MONITORING_SCHEDULE), + ) sagemaker_session.sagemaker_client.create_monitoring_schedule = Mock() data_quality_monitor.sagemaker_session = sagemaker_session diff --git a/tests/unit/sagemaker/spark/test_processing.py b/tests/unit/sagemaker/spark/test_processing.py index a079f477b4..16583d33ae 100644 --- a/tests/unit/sagemaker/spark/test_processing.py +++ b/tests/unit/sagemaker/spark/test_processing.py @@ -61,6 +61,7 @@ def sagemaker_session(): settings=SessionSettings(), ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session_mock.sagemaker_config = {} return session_mock diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index c92db39791..7308cf0b14 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -71,8 +71,10 @@ def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock): def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock): # For tests which doesn't verify config file injection, operate with empty config + pipeline_role_arn = "arn:aws:iam::111111111111:role/ConfigRole" sagemaker_session_mock.sagemaker_config = { - "SageMaker": {"Pipeline": {"RoleArn": "ConfigRoleArn"}} + "SchemaVersion": "1.0", + "SageMaker": {"Pipeline": {"RoleArn": pipeline_role_arn}}, } sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = { "PipelineArn": "pipeline-arn" @@ -85,15 +87,21 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock ) pipeline.create() sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with( - PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn" + PipelineName="MyPipeline", + PipelineDefinition=pipeline.definition(), + RoleArn=pipeline_role_arn, ) pipeline.update() sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( - PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn" + PipelineName="MyPipeline", + PipelineDefinition=pipeline.definition(), + RoleArn=pipeline_role_arn, ) pipeline.upsert() assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( - PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn" + PipelineName="MyPipeline", + PipelineDefinition=pipeline.definition(), + RoleArn=pipeline_role_arn, ) @@ -129,6 +137,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar @patch("sagemaker.s3.S3Uploader.upload_string_as_file_body") def test_large_pipeline_create(sagemaker_session_mock, role_arn): + sagemaker_session_mock.sagemaker_config = {} parameter = ParameterString("MyStr") pipeline = Pipeline( name="MyPipeline", @@ -151,6 +160,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn): def test_pipeline_update(sagemaker_session_mock, role_arn): + sagemaker_session_mock.sagemaker_config = {} pipeline = Pipeline( name="MyPipeline", parameters=[], @@ -201,6 +211,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar @patch("sagemaker.s3.S3Uploader.upload_string_as_file_body") def test_large_pipeline_update(sagemaker_session_mock, role_arn): + sagemaker_session_mock.sagemaker_config = {} parameter = ParameterString("MyStr") pipeline = Pipeline( name="MyPipeline", diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index d1b81f3148..48b1d762c3 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -226,9 +226,11 @@ def test_inject_repack_script_s3(estimator, tmp, fake_s3): model_data = Properties(step_name="MyStep", shape_name="DescribeModelOutput") entry_point = "inference.py" source_dir_path = "s3://fake/location" + session_mock = fake_s3.sagemaker_session + session_mock.sagemaker_config = {} step = _RepackModelStep( name="MyRepackModelStep", - sagemaker_session=fake_s3.sagemaker_session, + sagemaker_session=session_mock, role=estimator.role, image_uri="foo", model_data=model_data, diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index d4502a0890..0e15981b24 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -158,6 +158,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(session): # verify that the Estimator verifies the # input mode that an Algorithm supports. + session.sagemaker_config = {} file_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) file_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ @@ -259,6 +260,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(session): def test_algorithm_supported_input_mode_with_bad_input_types(session): # verify that the Estimator verifies raises exceptions when # attempting to train with an incorrect input type + session.sagemaker_config = {} file_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) file_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ @@ -329,6 +331,7 @@ def test_algorithm_supported_input_mode_with_bad_input_types(session): @patch("sagemaker.estimator.EstimatorBase.fit", Mock()) @patch("sagemaker.Session") def test_algorithm_trainining_channels_with_expected_channels(session): + session.sagemaker_config = {} training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels["TrainingSpecification"]["TrainingChannels"] = [ @@ -370,6 +373,7 @@ def test_algorithm_trainining_channels_with_expected_channels(session): @patch("sagemaker.estimator.EstimatorBase.fit", Mock()) @patch("sagemaker.Session") def test_algorithm_trainining_channels_with_invalid_channels(session): + session.sagemaker_config = {} training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels["TrainingSpecification"]["TrainingChannels"] = [ @@ -412,6 +416,7 @@ def test_algorithm_trainining_channels_with_invalid_channels(session): @patch("sagemaker.Session") def test_algorithm_train_instance_types_valid_instance_types(session): + session.sagemaker_config = {} describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) instance_types = ["ml.m4.xlarge", "ml.m5.2xlarge"] @@ -440,6 +445,7 @@ def test_algorithm_train_instance_types_valid_instance_types(session): @patch("sagemaker.Session") def test_algorithm_train_instance_types_invalid_instance_types(session): + session.sagemaker_config = {} describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) instance_types = ["ml.m4.xlarge", "ml.m5.2xlarge"] @@ -462,6 +468,7 @@ def test_algorithm_train_instance_types_invalid_instance_types(session): @patch("sagemaker.Session") def test_algorithm_distributed_training_validation(session): + session.sagemaker_config = {} distributed_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) distributed_algo["TrainingSpecification"]["SupportsDistributedTraining"] = True @@ -502,6 +509,7 @@ def test_algorithm_distributed_training_validation(session): @patch("sagemaker.Session") def test_algorithm_hyperparameter_integer_range_valid_range(session): + session.sagemaker_config = {} hyperparameters = [ { "Description": "Grow a tree with max_leaf_nodes in best-first fashion.", @@ -535,6 +543,7 @@ def test_algorithm_hyperparameter_integer_range_valid_range(session): @patch("sagemaker.Session") def test_algorithm_hyperparameter_integer_range_invalid_range(session): + session.sagemaker_config = {} hyperparameters = [ { "Description": "Grow a tree with max_leaf_nodes in best-first fashion.", @@ -571,6 +580,7 @@ def test_algorithm_hyperparameter_integer_range_invalid_range(session): @patch("sagemaker.Session") def test_algorithm_hyperparameter_continuous_range_valid_range(session): + session.sagemaker_config = {} hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -606,6 +616,7 @@ def test_algorithm_hyperparameter_continuous_range_valid_range(session): @patch("sagemaker.Session") def test_algorithm_hyperparameter_continuous_range_invalid_range(session): + session.sagemaker_config = {} hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -642,6 +653,7 @@ def test_algorithm_hyperparameter_continuous_range_invalid_range(session): @patch("sagemaker.Session") def test_algorithm_hyperparameter_categorical_range(session): + session.sagemaker_config = {} hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -679,6 +691,7 @@ def test_algorithm_hyperparameter_categorical_range(session): @patch("sagemaker.Session") def test_algorithm_required_hyperparameters_not_provided(session): + session.sagemaker_config = {} hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -723,6 +736,7 @@ def test_algorithm_required_hyperparameters_not_provided(session): @patch("sagemaker.Session") @patch("sagemaker.estimator.EstimatorBase.fit", Mock()) def test_algorithm_required_hyperparameters_are_provided(session): + session.sagemaker_config = {} hyperparameters = [ { "Description": "A categorical hyperparameter", @@ -767,6 +781,7 @@ def test_algorithm_required_hyperparameters_are_provided(session): @patch("sagemaker.Session") def test_algorithm_required_free_text_hyperparameter_not_provided(session): + session.sagemaker_config = {} hyperparameters = [ { "Name": "free_text_hp1", @@ -810,6 +825,7 @@ def test_algorithm_required_free_text_hyperparameter_not_provided(session): @patch("sagemaker.Session") @patch("sagemaker.algorithm.AlgorithmEstimator.create_model") def test_algorithm_create_transformer(create_model, session): + session.sagemaker_config = {} session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -834,6 +850,7 @@ def test_algorithm_create_transformer(create_model, session): @patch("sagemaker.Session") def test_algorithm_create_transformer_without_completed_training_job(session): + session.sagemaker_config = {} session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -852,6 +869,7 @@ def test_algorithm_create_transformer_without_completed_training_job(session): @patch("sagemaker.algorithm.AlgorithmEstimator.create_model") @patch("sagemaker.Session") def test_algorithm_create_transformer_with_product_id(create_model, session): + session.sagemaker_config = {} response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["ProductId"] = "some-product-id" session.sagemaker_client.describe_algorithm = Mock(return_value=response) @@ -875,6 +893,7 @@ def test_algorithm_create_transformer_with_product_id(create_model, session): @patch("sagemaker.Session") def test_algorithm_enable_network_isolation_no_product_id(session): + session.sagemaker_config = {} session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -891,6 +910,7 @@ def test_algorithm_enable_network_isolation_no_product_id(session): @patch("sagemaker.Session") def test_algorithm_enable_network_isolation_with_product_id(session): + session.sagemaker_config = {} response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["ProductId"] = "some-product-id" session.sagemaker_client.describe_algorithm = Mock(return_value=response) @@ -909,7 +929,7 @@ def test_algorithm_enable_network_isolation_with_product_id(session): @patch("sagemaker.Session") def test_algorithm_encrypt_inter_container_traffic(session): - + session.sagemaker_config = {} response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["encrypt_inter_container_traffic"] = True session.sagemaker_client.describe_algorithm = Mock(return_value=response) @@ -929,6 +949,7 @@ def test_algorithm_encrypt_inter_container_traffic(session): @patch("sagemaker.Session") def test_algorithm_no_required_hyperparameters(session): + session.sagemaker_config = {} some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) del some_algo["TrainingSpecification"]["SupportedHyperParameters"] @@ -1021,6 +1042,7 @@ def test_algorithm_attach_from_hyperparameter_tuning(): @patch("sagemaker.Session") def test_algorithm_supported_with_spot_instances(session): + session.sagemaker_config = {} session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) assert AlgorithmEstimator( diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 429fd91b64..f0e8ffad57 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -713,6 +713,7 @@ def sagemaker_session(): ) session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = "arn:aws:iam::012345678901:role/SageMakerRole" + session_mock.sagemaker_config = {} return session_mock diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index c39a27495b..4e5c26d64e 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -913,6 +913,7 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region): s3_resource=None, settings=SessionSettings(), ) + sms.sagemaker_config = {} f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -943,6 +944,7 @@ def test_framework_with_debugger_config_set_up_in_unsupported_region(region): s3_resource=None, settings=SessionSettings(), ) + sms.sagemaker_config = {} f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -970,6 +972,7 @@ def test_framework_enable_profiling_in_unsupported_region(region): s3_resource=None, settings=SessionSettings(), ) + sms.sagemaker_config = {} f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -997,6 +1000,7 @@ def test_framework_update_profiling_in_unsupported_region(region): s3_resource=None, settings=SessionSettings(), ) + sms.sagemaker_config = {} f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -1024,6 +1028,7 @@ def test_framework_disable_profiling_in_unsupported_region(region): s3_resource=None, settings=SessionSettings(), ) + sms.sagemaker_config = {} f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -4129,6 +4134,7 @@ def test_script_mode_estimator_same_calls_as_framework( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz" @@ -4197,6 +4203,7 @@ def test_script_mode_estimator_tags_jumpstart_estimators_and_models( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} instance_type = "ml.p2.xlarge" instance_count = 1 @@ -4269,6 +4276,7 @@ def test_script_mode_estimator_tags_jumpstart_models( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} instance_type = "ml.p2.xlarge" instance_count = 1 @@ -4328,6 +4336,7 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} instance_type = "ml.p2.xlarge" instance_count = 1 @@ -4385,6 +4394,7 @@ def test_all_framework_estimators_add_jumpstart_tags( patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session ): sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} sagemaker_session.sagemaker_client.describe_training_job.return_value = { "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} } @@ -4464,6 +4474,7 @@ def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} instance_type = "ml.p2.xlarge" instance_count = 1 @@ -4523,6 +4534,7 @@ def test_all_framework_estimators_add_jumpstart_base_name( patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session ): sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} sagemaker_session.sagemaker_client.describe_training_job.return_value = { "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} } @@ -4673,6 +4685,7 @@ def test_script_mode_estimator_escapes_hyperparameters_as_json( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} instance_type = "ml.p2.xlarge" instance_count = 1 @@ -4721,6 +4734,7 @@ def test_estimator_local_download_dir( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} local_download_dir = "some/download/dir" diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 74ad25aaea..8bd5127dd6 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -1722,6 +1722,7 @@ def test_tags_prefixes_jumpstart_models( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} sagemaker_session.sagemaker_client.describe_training_job.return_value = { "AlgorithmSpecification": { @@ -1850,6 +1851,7 @@ def test_no_tags_prefixes_non_jumpstart_models( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_config = {} sagemaker_session.sagemaker_client.describe_training_job.return_value = { "AlgorithmSpecification": { From fc18316cbaefa9a017e943f43142b62fd33aa51a Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Tue, 28 Mar 2023 09:04:34 -0700 Subject: [PATCH 34/40] fix: bug where a user-provided sagemaker_config wasnt set --- src/sagemaker/session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f580d67eae..a2ce33f7d2 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -268,6 +268,7 @@ def _initialize( self.local_mode = False if sagemaker_config: validate_sagemaker_config(sagemaker_config) + self.sagemaker_config = sagemaker_config else: if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) From d8e33ea76ab8f213f6f1606b99790137c1b9d59c Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Tue, 28 Mar 2023 10:35:17 -0700 Subject: [PATCH 35/40] change: rename fetch_sagemaker_config to load_sagemaker_config --- src/sagemaker/config/__init__.py | 2 +- src/sagemaker/config/config.py | 2 +- src/sagemaker/local/local_session.py | 8 ++--- src/sagemaker/session.py | 6 ++-- src/sagemaker/workflow/pipeline_context.py | 2 +- tests/integ/test_sagemaker_config.py | 6 ++-- tests/unit/sagemaker/config/test_config.py | 34 +++++++++++----------- 7 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index 002768911e..01b47f3cc5 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -13,7 +13,7 @@ """This module configures the default values for SageMaker Python SDK.""" from __future__ import absolute_import -from sagemaker.config.config import fetch_sagemaker_config, validate_sagemaker_config # noqa: F401 +from sagemaker.config.config import load_sagemaker_config, validate_sagemaker_config # noqa: F401 from sagemaker.config.config_schema import ( # noqa: F401 KEY, TRAINING_JOB, diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 5d6bac24ad..c964e2670c 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -51,7 +51,7 @@ S3_PREFIX = "s3://" -def fetch_sagemaker_config( +def load_sagemaker_config( additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE ) -> dict: """Helper method that loads config files and merges them. diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index ac34555327..b6a1c39807 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -21,7 +21,7 @@ import boto3 from botocore.exceptions import ClientError -from sagemaker.config import fetch_sagemaker_config, validate_sagemaker_config +from sagemaker.config import load_sagemaker_config, validate_sagemaker_config from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import get_docker_host from sagemaker.local.entities import ( @@ -626,7 +626,7 @@ def __init__( SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE), a new dictionary will be generated from those configuration files. Alternatively, this dictionary can be generated by calling - :func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the + :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the Session. """ self.s3_endpoint_url = s3_endpoint_url @@ -684,11 +684,11 @@ def _initialize( self.sagemaker_config = ( sagemaker_config if sagemaker_config - else fetch_sagemaker_config(s3_resource=self.s3_resource) + else load_sagemaker_config(s3_resource=self.s3_resource) ) else: self.sagemaker_config = ( - sagemaker_config if sagemaker_config else fetch_sagemaker_config() + sagemaker_config if sagemaker_config else load_sagemaker_config() ) sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index a2ce33f7d2..0117a8a1e9 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -33,7 +33,7 @@ import sagemaker.logs from sagemaker import vpc_utils from sagemaker._studio import _append_project_tags -from sagemaker.config import fetch_sagemaker_config, validate_sagemaker_config # noqa: F401 +from sagemaker.config import load_sagemaker_config, validate_sagemaker_config # noqa: F401 from sagemaker.config import ( KEY, TRAINING_JOB, @@ -197,7 +197,7 @@ def __init__( SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE), a new dictionary will be generated from those configuration files. Alternatively, this dictionary can be generated by calling - :func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the + :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the Session. """ self._default_bucket = None @@ -274,7 +274,7 @@ def _initialize( s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) else: s3 = self.s3_resource - self.sagemaker_config = fetch_sagemaker_config(s3_resource=s3) + self.sagemaker_config = load_sagemaker_config(s3_resource=s3) @property def boto_region_name(self): diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index cd1d07189d..2ca1e11484 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -140,7 +140,7 @@ def __init__( SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE), a new dictionary will be generated from those configuration files. Alternatively, this dictionary can be generated by calling - :func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the + :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the Session. """ super().__init__( diff --git a/tests/integ/test_sagemaker_config.py b/tests/integ/test_sagemaker_config.py index b0cf92841b..33627e8723 100644 --- a/tests/integ/test_sagemaker_config.py +++ b/tests/integ/test_sagemaker_config.py @@ -27,7 +27,7 @@ Predictor, Session, ) -from sagemaker.config import fetch_sagemaker_config +from sagemaker.config import load_sagemaker_config from sagemaker.model_monitor import DataCaptureConfig from sagemaker.s3 import S3Uploader from sagemaker.sparkml import SparkMLModel @@ -144,7 +144,7 @@ def sagemaker_session_with_dynamically_generated_sagemaker_config( sagemaker_client=sagemaker_client, sagemaker_runtime_client=runtime_client, sagemaker_metrics_client=metrics_client, - sagemaker_config=fetch_sagemaker_config([dynamic_sagemaker_config_yaml_path]), + sagemaker_config=load_sagemaker_config([dynamic_sagemaker_config_yaml_path]), ) return session @@ -173,7 +173,7 @@ def test_config_download_from_s3_and_merge( ) # The thing being tested. - sagemaker_config = fetch_sagemaker_config( + sagemaker_config = load_sagemaker_config( additional_config_paths=[s3_uri_config_1, config_file_2_local_path] ) diff --git a/tests/unit/sagemaker/config/test_config.py b/tests/unit/sagemaker/config/test_config.py index 00732e5320..e07606000c 100644 --- a/tests/unit/sagemaker/config/test_config.py +++ b/tests/unit/sagemaker/config/test_config.py @@ -17,7 +17,7 @@ import yaml from mock import Mock, MagicMock -from sagemaker.config.config import fetch_sagemaker_config +from sagemaker.config.config import load_sagemaker_config from jsonschema import exceptions from yaml.constructor import ConstructorError @@ -37,14 +37,14 @@ def expected_merged_config(get_data_dir): def test_config_when_default_config_file_and_user_config_file_is_not_found(): - assert fetch_sagemaker_config() == {} + assert load_sagemaker_config() == {} def test_config_when_overriden_default_config_file_is_not_found(get_data_dir): fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = fake_config_file_path with pytest.raises(ValueError): - fetch_sagemaker_config() + load_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -55,14 +55,14 @@ def test_invalid_config_file_which_has_python_code(get_data_dir): # PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using # yaml.safe_load internally with pytest.raises(ConstructorError) as exception_info: - fetch_sagemaker_config(additional_config_paths=[invalid_config_file_path]) + load_sagemaker_config(additional_config_paths=[invalid_config_file_path]) assert "python/object/apply:eval" in str(exception_info.value) def test_config_when_additional_config_file_path_is_not_found(get_data_dir): fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml") with pytest.raises(ValueError): - fetch_sagemaker_config(additional_config_paths=[fake_config_file_path]) + load_sagemaker_config(additional_config_paths=[fake_config_file_path]) def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir): @@ -71,7 +71,7 @@ def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir ) os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = fake_additional_override_config_file_path with pytest.raises(ValueError): - fetch_sagemaker_config() + load_sagemaker_config() del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] @@ -79,7 +79,7 @@ def test_default_config_file_with_invalid_schema(get_data_dir): config_file_path = os.path.join(get_data_dir, "invalid_config_file.yaml") os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = config_file_path with pytest.raises(exceptions.ValidationError): - fetch_sagemaker_config() + load_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -90,7 +90,7 @@ def test_default_config_file_when_directory_is_provided_as_the_path( expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = get_data_dir - assert expected_config == fetch_sagemaker_config() + assert expected_config == load_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -100,7 +100,7 @@ def test_additional_config_paths_when_directory_is_provided( # This will try to load config.yaml file from that directory if present. expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert expected_config == fetch_sagemaker_config(additional_config_paths=[get_data_dir]) + assert expected_config == load_sagemaker_config(additional_config_paths=[get_data_dir]) def test_default_config_file_when_path_is_provided_as_environment_variable( @@ -110,7 +110,7 @@ def test_default_config_file_when_path_is_provided_as_environment_variable( # This will try to load config.yaml file from that directory if present. expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert expected_config == fetch_sagemaker_config() + assert expected_config == load_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -123,7 +123,7 @@ def test_merge_behavior_when_additional_config_file_path_is_not_found( ) os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path with pytest.raises(ValueError): - fetch_sagemaker_config(additional_config_paths=[fake_additional_override_config_file_path]) + load_sagemaker_config(additional_config_paths=[fake_additional_override_config_file_path]) del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] @@ -133,11 +133,11 @@ def test_merge_behavior(get_data_dir, expected_merged_config): get_data_dir, "sample_additional_config_for_merge.yaml" ) os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path - assert expected_merged_config == fetch_sagemaker_config( + assert expected_merged_config == load_sagemaker_config( additional_config_paths=[additional_override_config_file_path] ) os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = additional_override_config_file_path - assert expected_merged_config == fetch_sagemaker_config() + assert expected_merged_config == load_sagemaker_config() del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] @@ -160,7 +160,7 @@ def test_s3_config_file( config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert expected_config == fetch_sagemaker_config( + assert expected_config == load_sagemaker_config( additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock ) @@ -174,7 +174,7 @@ def test_config_factory_when_default_s3_config_file_is_not_found(s3_resource_moc ).all.return_value = [] config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) with pytest.raises(ValueError): - fetch_sagemaker_config( + load_sagemaker_config( additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock ) @@ -204,7 +204,7 @@ def test_s3_config_file_when_uri_provided_corresponds_to_a_path( config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) expected_config = base_config_with_schema expected_config["SageMaker"] = valid_config_with_all_the_scopes - assert expected_config == fetch_sagemaker_config( + assert expected_config == load_sagemaker_config( additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock ) @@ -231,7 +231,7 @@ def test_merge_of_s3_default_config_file_and_regular_config_file( get_data_dir, "sample_additional_config_for_merge.yaml" ) os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = config_file_s3_uri - assert expected_merged_config == fetch_sagemaker_config( + assert expected_merged_config == load_sagemaker_config( additional_config_paths=[additional_override_config_file_path], s3_resource=s3_resource_mock, ) From 523d01d6f02fc55147a05e4aba0ecc429adc304d Mon Sep 17 00:00:00 2001 From: Ruban Hussain Date: Tue, 28 Mar 2023 12:37:43 -0700 Subject: [PATCH 36/40] fix: update Schema to match exactly with APIs --- src/sagemaker/config/__init__.py | 2 +- src/sagemaker/config/config_schema.py | 110 ++++++++++-------- src/sagemaker/session.py | 12 +- tests/data/config/config.yaml | 2 +- tests/unit/__init__.py | 4 +- tests/unit/sagemaker/automl/test_auto_ml.py | 14 +-- tests/unit/sagemaker/config/conftest.py | 2 +- .../sagemaker/config/test_config_schema.py | 2 +- tests/unit/test_session.py | 25 ++-- 9 files changed, 95 insertions(+), 78 deletions(-) diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index 01b47f3cc5..a7c685fa08 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -39,7 +39,7 @@ AUTO_ML_ROLE_ARN_PATH, AUTO_ML_OUTPUT_CONFIG_PATH, AUTO_ML_JOB_CONFIG_PATH, - AUTO_ML, + AUTO_ML_JOB, COMPILATION_JOB_ROLE_ARN_PATH, COMPILATION_JOB_OUTPUT_CONFIG_PATH, COMPILATION_JOB_VPC_CONFIG_PATH, diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index ae4104020f..cd1ce48baf 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -67,7 +67,7 @@ MODEL = "Model" MONITORING_SCHEDULE = "MonitoringSchedule" ENDPOINT_CONFIG = "EndpointConfig" -AUTO_ML = "AutoML" +AUTO_ML_JOB = "AutoMLJob" COMPILATION_JOB = "CompilationJob" CUSTOM_PARAMETERS = "CustomParameters" PIPELINE = "Pipeline" @@ -136,16 +136,16 @@ def _simple_path(*args: str): FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH = _simple_path( FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, SECURITY_CONFIG, KMS_KEY_ID ) -AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG) -AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML, OUTPUT_DATA_CONFIG, KMS_KEY_ID) +AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG) +AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG, KMS_KEY_ID) AUTO_ML_VOLUME_KMS_KEY_ID_PATH = _simple_path( - SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID + SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID ) -AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML, ROLE_ARN) +AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, ROLE_ARN) AUTO_ML_VPC_CONFIG_PATH = _simple_path( - SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG + SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG ) -AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML, AUTO_ML_JOB_CONFIG) +AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG) MONITORING_JOB_DEFINITION_PREFIX = _simple_path( SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION ) @@ -233,7 +233,7 @@ def _simple_path(*args: str): ) AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( SAGEMAKER, - AUTO_ML, + AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, @@ -256,9 +256,23 @@ def _simple_path(*args: str): # Schema for IAM Role. This includes a Regex validator. TYPE: "string", "pattern": r"^arn:aws[a-z\-]*:iam::\d{12}:role/?[a-zA-Z_0-9+=,.@\-_/]+$", + "minLength": 20, + "maxLength": 2048, + }, + "kmsKeyId": { + TYPE: "string", + "maxLength": 2048, + }, + "securityGroupId": { + TYPE: "string", + "pattern": r"[-0-9a-zA-Z]+", + "maxLength": 32, + }, + "subnet": { + TYPE: "string", + "pattern": r"[-0-9a-zA-Z]+", + "maxLength": 32, }, - "securityGroupId": {TYPE: "string", "pattern": r"[-0-9a-zA-Z]+"}, - "subnet": {TYPE: "string", "pattern": r"[-0-9a-zA-Z]+"}, "vpcConfig": { # Schema for VPC Configs. # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference @@ -269,8 +283,15 @@ def _simple_path(*args: str): SECURITY_GROUP_IDS: { TYPE: "array", "items": {"$ref": "#/definitions/securityGroupId"}, + "minItems": 1, + "maxItems": 5, + }, + SUBNETS: { + TYPE: "array", + "items": {"$ref": "#/definitions/subnet"}, + "minItems": 1, + "maxItems": 16, }, - SUBNETS: {TYPE: "array", "items": {"$ref": "#/definitions/subnet"}}, }, }, "productionVariant": { @@ -280,7 +301,7 @@ def _simple_path(*args: str): CORE_DUMP_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, } }, }, @@ -295,12 +316,12 @@ def _simple_path(*args: str): TRANSFORM_OUTPUT: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, TRANSFORM_RESOURCES: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, }, } @@ -317,19 +338,13 @@ def _simple_path(*args: str): ATHENA_DATASET_DEFINITION: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: { - KMS_KEY_ID: { - TYPE: "string", - } - }, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, REDSHIFT_DATASET_DEFINITION: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { - KMS_KEY_ID: { - TYPE: "string", - }, + KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}, CLUSTER_ROLE_ARN: {"$ref": "#/definitions/roleArn"}, }, }, @@ -362,7 +377,6 @@ def _simple_path(*args: str): "minItems": 0, "maxItems": 50, }, - SUBNETS: {TYPE: "array", "items": {"$ref": "#/definitions/subnet"}}, }, PROPERTIES: { SCHEMA_VERSION: { @@ -409,7 +423,7 @@ def _simple_path(*args: str): S3_STORAGE_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, } }, }, @@ -420,7 +434,7 @@ def _simple_path(*args: str): SECURITY_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, } }, }, @@ -445,7 +459,9 @@ def _simple_path(*args: str): MONITORING_OUTPUT_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: { + KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"} + }, }, MONITORING_RESOURCES: { TYPE: OBJECT, @@ -455,7 +471,9 @@ def _simple_path(*args: str): TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { - VOLUME_KMS_KEY_ID: {TYPE: "string"} + VOLUME_KMS_KEY_ID: { + "$ref": "#/definitions/kmsKeyId" + } }, } }, @@ -482,8 +500,6 @@ def _simple_path(*args: str): # Endpoint Config # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html # Note: there is a separate API for creating Endpoints. - # That will be added later to schema once we start - # supporting other parameters such as Tags ENDPOINT_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -495,16 +511,16 @@ def _simple_path(*args: str): OUTPUT_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, } }, }, DATA_CAPTURE_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, - KMS_KEY_ID: {TYPE: "string"}, + KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}, PRODUCTION_VARIANTS: { TYPE: "array", "items": {"$ref": "#/definitions/productionVariant"}, @@ -514,7 +530,7 @@ def _simple_path(*args: str): }, # Auto ML # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateAutoMLJob.html - AUTO_ML: { + AUTO_ML_JOB: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { @@ -529,9 +545,7 @@ def _simple_path(*args: str): ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: { TYPE: "boolean" }, - VOLUME_KMS_KEY_ID: { - TYPE: "string", - }, + VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, }, } @@ -540,7 +554,7 @@ def _simple_path(*args: str): OUTPUT_DATA_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, TAGS: {"$ref": "#/definitions/tags"}, @@ -555,17 +569,17 @@ def _simple_path(*args: str): DATA_CAPTURE_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, TRANSFORM_OUTPUT: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, TRANSFORM_RESOURCES: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, TAGS: {"$ref": "#/definitions/tags"}, }, @@ -580,7 +594,7 @@ def _simple_path(*args: str): OUTPUT_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, @@ -651,11 +665,13 @@ def _simple_path(*args: str): PROCESSING_INPUTS: { TYPE: "array", "items": {"$ref": "#/definitions/processingInput"}, + "minItems": 0, + "maxItems": 10, }, PROCESSING_OUTPUT_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, PROCESSING_RESOURCES: { TYPE: OBJECT, @@ -664,7 +680,9 @@ def _simple_path(*args: str): CLUSTER_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: { + VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"} + }, } }, }, @@ -683,12 +701,12 @@ def _simple_path(*args: str): OUTPUT_DATA_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, RESOURCE_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, @@ -704,7 +722,7 @@ def _simple_path(*args: str): OUTPUT_CONFIG: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, - PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}}, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, }, ROLE_ARN: {"$ref": "#/definitions/roleArn"}, TAGS: {"$ref": "#/definitions/tags"}, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 0117a8a1e9..15a5fc9b77 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -59,7 +59,7 @@ AUTO_ML_ROLE_ARN_PATH, AUTO_ML_OUTPUT_CONFIG_PATH, AUTO_ML_JOB_CONFIG_PATH, - AUTO_ML, + AUTO_ML_JOB, COMPILATION_JOB_ROLE_ARN_PATH, COMPILATION_JOB_OUTPUT_CONFIG_PATH, COMPILATION_JOB_VPC_CONFIG_PATH, @@ -1212,10 +1212,14 @@ def process( # Processing Input can either have AthenaDatasetDefinition or RedshiftDatasetDefinition # or neither, but not both union_key_paths_for_dataset_definition = [ + [ + "DatasetDefinition", + "S3Input", + ], [ "DatasetDefinition.AthenaDatasetDefinition", "DatasetDefinition.RedshiftDatasetDefinition", - ] + ], ] update_list_of_dicts_with_values_from_config( inputs, @@ -2193,7 +2197,9 @@ def _get_auto_ml_request( auto_ml_job_request["ProblemType"] = problem_type tags = _append_project_tags(tags) - tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML, TAGS)) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB, TAGS) + ) if tags is not None: auto_ml_job_request["Tags"] = tags diff --git a/tests/data/config/config.yaml b/tests/data/config/config.yaml index 46d55cf8f0..0abb47e70e 100644 --- a/tests/data/config/config.yaml +++ b/tests/data/config/config.yaml @@ -34,7 +34,7 @@ SageMaker: ProductionVariants: - CoreDumpConfig: KmsKeyId: 'kmskeyid4' - AutoML: + AutoMLJob: AutoMLJobConfig: SecurityConfig: VolumeKmsKeyId: 'volumekmskeyid1' diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 0f54dd9e0d..df465fd31f 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -40,7 +40,7 @@ ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG, PRODUCTION_VARIANTS, - AUTO_ML, + AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, OUTPUT_DATA_CONFIG, @@ -149,7 +149,7 @@ SAGEMAKER_CONFIG_AUTO_ML = { SCHEMA_VERSION: "1.0", SAGEMAKER: { - AUTO_ML: { + AUTO_ML_JOB: { AUTO_ML_JOB_CONFIG: { SECURITY_CONFIG: { ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: True, diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index 8ccaa609d0..df501e3166 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -305,18 +305,18 @@ def test_framework_initialization_with_sagemaker_config_injection(sagemaker_sess sagemaker_session=sagemaker_session, ) - expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ - "SecurityConfig" - ]["VolumeKmsKeyId"] - expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["RoleArn"] - expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["OutputDataConfig"][ + expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"][ + "AutoMLJobConfig" + ]["SecurityConfig"]["VolumeKmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["OutputDataConfig"][ "KmsKeyId" ] - expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ + expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["AutoMLJobConfig"][ "SecurityConfig" ]["VpcConfig"] expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"][ - "AutoML" + "AutoMLJob" ]["AutoMLJobConfig"]["SecurityConfig"]["EnableInterContainerTrafficEncryption"] assert auto_ml.role == expected_role_arn assert auto_ml.output_kms_key == expected_kms_key_id diff --git a/tests/unit/sagemaker/config/conftest.py b/tests/unit/sagemaker/config/conftest.py index 70a497e708..ef51538dc9 100644 --- a/tests/unit/sagemaker/config/conftest.py +++ b/tests/unit/sagemaker/config/conftest.py @@ -183,7 +183,7 @@ def valid_config_with_all_the_scopes( "FeatureGroup": valid_feature_group_config, "MonitoringSchedule": valid_monitoring_schedule_config, "EndpointConfig": valid_endpointconfig_config, - "AutoML": valid_automl_config, + "AutoMLJob": valid_automl_config, "TransformJob": valid_transform_job_config, "CompilationJob": valid_compilation_job_config, "Pipeline": valid_pipeline_config, diff --git a/tests/unit/sagemaker/config/test_config_schema.py b/tests/unit/sagemaker/config/test_config_schema.py index e008d670e9..f7170c3d3e 100644 --- a/tests/unit/sagemaker/config/test_config_schema.py +++ b/tests/unit/sagemaker/config/test_config_schema.py @@ -80,7 +80,7 @@ def test_valid_transform_job_schema(base_config_with_schema, valid_transform_job def test_valid_automl_schema(base_config_with_schema, valid_automl_config): - _validate_config(base_config_with_schema, {"AutoML": valid_automl_config}) + _validate_config(base_config_with_schema, {"AutoMLJob": valid_automl_config}) def test_valid_endpoint_config_schema(base_config_with_schema, valid_endpointconfig_config): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 81d92a47b4..0242ea6feb 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -274,14 +274,7 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_session): processing_inputs = [ { "InputName": "input-1", - "S3Input": { - "S3Uri": "mocked_s3_uri_from_upload_data", - "LocalPath": "/container/path/", - "S3DataType": "Archive", - "S3InputMode": "File", - "S3DataDistributionType": "FullyReplicated", - "S3CompressionType": "None", - }, + # No S3Input because the API expects only one of S3Input or DatasetDefinition "DatasetDefinition": { "AthenaDatasetDefinition": {}, }, @@ -3184,19 +3177,19 @@ def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_session): job_name = JOB_NAME sagemaker_session.auto_ml(input_config, output_config, auto_ml_job_config, job_name=job_name) expected_call_args = copy.deepcopy(DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS) - expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ - "SecurityConfig" - ]["VolumeKmsKeyId"] - expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["RoleArn"] - expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["OutputDataConfig"][ + expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"][ + "AutoMLJobConfig" + ]["SecurityConfig"]["VolumeKmsKeyId"] + expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["RoleArn"] + expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["OutputDataConfig"][ "KmsKeyId" ] - expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][ + expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["AutoMLJobConfig"][ "SecurityConfig" ]["VpcConfig"] - expected_tags = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["Tags"] + expected_tags = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["Tags"] expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"][ - "AutoML" + "AutoMLJob" ]["AutoMLJobConfig"]["SecurityConfig"]["EnableInterContainerTrafficEncryption"] expected_call_args["OutputDataConfig"]["KmsKeyId"] = expected_kms_key_id expected_call_args["RoleArn"] = expected_role_arn From 99660cab1ef18bf1eab268d7f4ca6590e3c63d0b Mon Sep 17 00:00:00 2001 From: Ivy Bazan Date: Tue, 28 Mar 2023 13:44:58 -0700 Subject: [PATCH 37/40] add documentation for default configuration support --- doc/api/utility/config.rst | 7 + doc/overview.rst | 572 +++++++++++++++++++++++++++++++++ src/sagemaker/config/config.py | 52 ++- 3 files changed, 603 insertions(+), 28 deletions(-) create mode 100644 doc/api/utility/config.rst diff --git a/doc/api/utility/config.rst b/doc/api/utility/config.rst new file mode 100644 index 0000000000..b88fd3c223 --- /dev/null +++ b/doc/api/utility/config.rst @@ -0,0 +1,7 @@ +Config +------- + +.. automodule:: sagemaker.config.config + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/overview.rst b/doc/overview.rst index 8cbac7632f..9cfcfbadaa 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -1836,6 +1836,578 @@ You can use Amazon SageMaker Processing with "Processors" to perform data proces amazon_sagemaker_processing +************************************************************ +Configuring and using defaults with the SageMaker Python SDK +************************************************************ + +The Amazon SageMaker Python SDK supports the setting of default +values for AWS infrastructure primitive types. After administrators configure +these defaults, they are automatically passed when SageMaker Python +SDK calls supported APIs. Amazon SageMaker APIs and primitives may +not have a direct correspondence to the SageMaker Python SDK +abstractions that you are using. The parameters you specify are +automatically passed when the SageMaker Python SDK makes calls to the +API on your behalf. With the use of defaults, developers can use the +SageMaker Python SDK without having to specify infrastructure +parameters. + +Configuration file structure +============================ + +The SageMaker Python SDK uses YAML configuration files to define the +default values that are automatically passed to APIs. Admins can +create these configuration files and populate them with default +values defined for their desired API parameters. Your configuration +file should adhere to the structure outlined in the following sample +config file. This config outlines some of the parameters that you can +set default values for. For the full schema, see ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA``. + +:: + + SchemaVersion: '1.0' + CustomParameters: +   AnyStringKey: 'AnyStringValue' + SageMaker: +   FeatureGroup: +     # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateFeatureGroup.html +     OnlineStoreConfig: +       SecurityConfig: +         KmsKeyId: 'kmskeyid1' +     OfflineStoreConfig: +       S3StorageConfig: +         KmsKeyId: 'kmskeyid2' +     RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   MonitoringSchedule: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateMonitoringSchedule.html +     MonitoringScheduleConfig: +       MonitoringJobDefinition: +         MonitoringOutputConfig: +           KmsKeyId: 'kmskeyid3' +         MonitoringResources: +           ClusterConfig: +             VolumeKmsKeyId: 'volumekmskeyid1' +         NetworkConfig: +           EnableNetworkIsolation: true +           VpcConfig: +             SecurityGroupIds: +               - 'sg123' +             Subnets: +               - 'subnet-1234' +         RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   EndpointConfig: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html +     AsyncInferenceConfig: +       OutputConfig: +         KmsKeyId: 'kmskeyid4' +     DataCaptureConfig: +       KmsKeyId: 'kmskeyid5' +     KmsKeyId: 'kmskeyid6' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   AutoMLJob: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateAutoMLJob.html +     AutoMLJobConfig: +       SecurityConfig: +         VolumeKmsKeyId: 'volumekmskeyid2' +         VpcConfig: +           SecurityGroupIds: +             - 'sg123' +           Subnets: +             - 'subnet-1234' +     OutputDataConfig: +       KmsKeyId: 'kmskeyid7' +     RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   TransformJob: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html +     DataCaptureConfig: +       KmsKeyId: 'kmskeyid8' +     TransformOutput: +       KmsKeyId: 'kmskeyid9' +     TransformResources: +       VolumeKmsKeyId: 'volumekmskeyid3' +   CompilationJob: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateCompilationJob.html +     OutputConfig: +       # Currently not supported by the SageMaker Python SDK +       KmsKeyId: 'kmskeyid10' +     RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     # Currently not supported by the SageMaker Python SDK +     VpcConfig: +       SecurityGroupIds: +         - 'sg123' +       Subnets: +         - 'subnet-1234' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   Pipeline: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreatePipeline.html +     RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   Model: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html +     EnableNetworkIsolation: true +     ExecutionRoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     VpcConfig: +       SecurityGroupIds: +         - 'sg123' +       Subnets: +         - 'subnet-1234' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   ProcessingJob: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html +     NetworkConfig: +       EnableNetworkIsolation: true +       VpcConfig: +         SecurityGroupIds: +           - 'sg123' +         Subnets: +           - 'subnet-1234' +     ProcessingInputs: +       - DatasetDefinition: +           AthenaDatasetDefinition: +             KmsKeyId: 'kmskeyid11' +           RedshiftDatasetDefinition: +             KmsKeyId: 'kmskeyid12' +             ClusterRoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     ProcessingOutputConfig: +       KmsKeyId: 'kmskeyid13' +     ProcessingResources: +       ClusterConfig: +         VolumeKmsKeyId: 'volumekmskeyid4' +     RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   TrainingJob: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html +     EnableNetworkIsolation: true +     OutputDataConfig: +       KmsKeyId: 'kmskeyid14' +     ResourceConfig: +       VolumeKmsKeyId: 'volumekmskeyid5' +     RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     VpcConfig: +       SecurityGroupIds: +         - 'sg123' +       Subnets: +         - 'subnet-1234' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' +   EdgePackagingJob: +   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEdgePackagingJob.html +     OutputConfig: +       KmsKeyId: 'kmskeyid15' +     RoleArn: 'arn:aws:iam::555555555555:role/IMRole' +     Tags: +     - Key: 'tag_key' +       Value: 'tag_value' + +Configuration file locations +============================ + +The SageMaker Python SDK searches for configuration files at two +locations based on the platform that you are using. You can also +modify the default locations by overriding them using environment +variables. The following sections give information about these +configuration file locations. + +Default configuration file location +----------------------------------- + +By default, the SageMaker Python SDK uses two configuration files. +One for the admin and one for the user. Using the admin config file, +admins can define a set of default values. Users can use the user +configuration file to override values set in the admin configuration +file, as well as set other default parameter values. Users can also +set additional configuration file locations. For more information +about setting additional configuration file locations, see `Specify additional configuration files`_. + +The location of your default configuration paths depends on the +platform that you’re using the SageMaker Python SDK on. These default +locations are relative to the environment that you are using the +SageMaker Python SDK on. + +The following code block returns the default locations of your admin +and user configuration files. These commands must be run from the +environment that you’re using the SageMaker Python SDK in. + +Note: The directories returned by these commands may not exist. In +that case, you must create these directories with the required +permissions. + +.. code:: python + + import os + from platformdirs import site_config_dir, user_config_dir + + #Prints the location of the admin config file + print(os.path.join(site_config_dir("sagemaker"), "config.yaml")) + + #Prints the location of the user config file + print(os.path.join(user_config_dir("sagemaker"), "config.yaml")) + +Default Notebook instances locations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following code sample lists the default locations of the +configuration files when using the SageMaker Python SDK on Amazon +SageMaker Notebook instances. + +.. code:: python + + #Location of the admin config file + /etc/xdg/sagemaker/config.yaml + + #Location of the user config file + /home/ec2-user/.config/sagemaker/config.yaml + +Default Studio notebook locations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following code sample lists the default locations of the +configuration files when using the SageMaker Python SDK on Amazon +SageMaker Studio notebooks. + +.. code:: python + + #Location of the admin config file + /etc/xdg/sagemaker/config.yaml + + #Location of the user config file + /root/.config/sagemaker/config.yaml + +Override the configuration file location +---------------------------------------- + +To change the default configuration file locations used by the +SageMaker Python SDK, set one or both of the following environment +variables from the environment where you are using the SageMaker +Python SDK. When you modify these environment variables, the +SageMaker Python SDK searches for configuration files in the +locations that you specify instead of the default configuration file +locations. + +- ``SAGEMAKER_ADMIN_CONFIG_OVERRIDE`` overrides the default + location where the SageMaker Python SDK searches for the admin + config. +- ``SAGEMAKER_USER_CONFIG_OVERRIDE`` overrides the default + location where the SageMaker Python SDK searches for the user + config. + +Using these environment variables, you can set the config location to +either a local config location or a config location in an Amazon S3 +bucket. If a directory is provided as the path, the SageMaker Python +SDK searches for the ``config.yaml`` file in that directory. +The SageMaker Python SDK does not do any recursive searches for the +file. + +The following options are available if the config is saved locally. + +- Local file path:   ``/config.yaml`` +- Path of the directory containing the config file : + ``/`` + +The following options are available if the config is saved on Amazon +S3. + +- S3 URI of the config file: ``s3:////config.yaml`` +- S3 URI of the directory containing the config file: ``s3:////`` + +For example, the following example sets the default user +configuration location to a local directory from within a Jupyter +notebook environment. + +.. code:: python + + import os + os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = "" + +If you’re using Studio or a Notebook instance, you can automatically +set this value for all instances with a lifecycle configuration +script. For more information about lifecycle configuration scripts, +see `Use Lifecycle Configurations with Amazon SageMaker +Studio `__. + +Supported APIs and parameters +============================= + +The following sections give information about the APIs and parameters +that the SageMaker Python SDK supports setting defaults for. To set +defaults for these parameters, create key/value pairs in your +configuration file as shown in `Configuration file structure`_. +For the full schema, see ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA``. + +List of parameters supported +---------------------------- + +In the supported APIs, only parameters for the following primitive +types support setting defaults with a configuration file. + +- AWS IAM Role ARNs +- Enable network isolation +- Amazon VPC subnets and security groups +- AWS KMS key IDs +- Tags +- Enable inter-container traffic encryption + +List of APIs supported +---------------------- + +Default values for the supported parameters of these APIs apply to +all create and update calls for that API. For example, if a supported +parameter is set for ``TrainingJob``, then the parameter is used for +all ``CreateTrainingJob`` and ``UpdateTrainingJob`` API calls. +The parameter is not used in any other API calls unless it is +specified for that API as well. However, the default value passed for +the ``TrainingJob`` API is present for any artifacts generated by that +API, so any subsequent calls that use these artifacts will also use +the value. + +The following groups of APIs support setting defaults with a +configuration file. + +- Feature Group: ``CreateFeatureGroup``, ``UpdateFeatureGroup`` +- Monitoring Schedule: ``CreateMonitoringSchedule``, ``UpdateMonitoringSchedule`` +- Endpoint Config: ``CreateEndpointConfig``, ``UpdateEndpointConfig`` +- Auto ML: ``CreateAutoMLJob``, ``UpdateAutoMLJob`` +- Transform Job: ``CreateTransformJob``, ``UpdateTransformJob`` +- Compilation Job: ``CreateCompilationJob``, ``UpdateCompilationJob`` +- Pipeline: ``CreatePipeline``, ``UpdatePipeline`` +- Model: ``CreateModel``, ``UpdateModel`` +- Model Package: ``CreateModelPackage``, ``UpdateModelPackage`` +- Processing Jobs: ``CreateProcessingJob``, ``UpdateProcessingJob`` +- Training Job: ``CreateTrainingJob``, ``UpdateTrainingJob`` +- Edge Packaging Job: ``CreateEdgePackagingJob``, ``UpdateEdgePackagingJob`` + +Hyperparameter Tuning Job: Supported indirectly via ``TrainingJob`` API. While this API is not directly supported, it includes the training job definition as a parameter. +If you provide defaults for this parameter as part of the ``TrainingJob`` API, these defaults are also used for Hyperparameter Tuning Job. + +Configuration file resolution +============================= + +To create a consistent experience when using defaults with multiple +configuration files, the SageMaker Python SDK merges all of the +configuration files into a single configuration dictionary that defines all +of the default values set in the environment. The configuration files +for defaults are loaded and merged during the initialization of +the ``Session`` object. To access the configuration files, the user +must have read access to any local paths set and read access to any +S3 URIs that are set. These permissions can be set using the IAM role +or other AWS credentials for the user. + +If a configuration dictionary is not specified during ``Session`` +initialization, the ``Session`` automatically calls ``load_sagemaker_config()`` to load, merge, and validate configuration files from +the default locations. + +If a configuration dictionary is specified, the ``Session`` uses the supplied dictionary. + +The following sections gives information about how the merging of +configuration files happens. + +Default configuration files +--------------------------- + +The ``load_sagemaker_config()`` method first checks the default location of the +admin config file. If one is found, it serves as the basis for +further merging. If a config file is not found, then the merged +config dictionary is empty. The ``load_sagemaker_config()`` method then checks the default +location for the user config file. If a config file is found, then +the values are merged on top of the existing configuration dictionary. This +means that the values specified in the user config override the +corresponding values specified in the admin config file. If there is not an +existing entry for a user config value in the existing configuration dictionary, then a new +entry is added. + +Specify additional configuration files +-------------------------------------- + +In addition to the default locations for your admin and user config +files, you can also specify other locations for configuration files. +To specify these additional config locations, pass a list of these additional locations as part of the +``load_sagemaker_config()`` call and pass the resulting dictionary +to any ``Session`` objects you create as shown in the +following code sample. These additional configuration file +locations are checked in the order specified in the ``load_sagemaker_config()`` call. When a configuration file is found, it is merged on top of +the existing configuration dictionary. All of the values specified in the +first additional config override the corresponding values in the +default configs. Subsequent additional configuration files are merged +on top of the existing configuration dictionary using the same method. + +If you are building a dictionary with custom configuration file locations, we recommend that you +use ``load_sagemaker_config()`` and ``validate_sagemaker_config()`` iteratively to verify the construction +of your dictionary before you pass it to a ``Session`` object. + +.. code:: python + + from sagemaker.session import Session + from sagemaker.config import load_sagemaker_config, validate_sagemaker_config + + # Create a configuration dictionary manually + custom_sagemaker_config = load_sagemaker_config( +     additional_config_paths=[ +         'path1', +         'path2', +         'path3' +     ] + ) + + # Then validate that the dictionary adheres to the configuration schema + validate_sagemaker_config(custom_sagemaker_config) + + # Then initialize the Session object with the configuration dictionary + sm_session = Session( +     sagemaker_config = custom_sagemaker_config + ) + +Tags +---- + +Any tags specified in the configuration dictionary are appended to the set +of tags set by the SageMaker Python SDK and specified by the user. +Each of the tags in the combined list must have a unique key or the +API call fails. If a user provides a tag with the same name as a tag +in the configuration dictionary, the user tag is used and the config tag is +skipped. This behavior applies to all config keys that follow +the ``SageMaker.*.Tags`` pattern. + +Object Arrays +------------- + +For the following keys, the configuration file may contain an array +of elements, where each element contains one or more values. When a +configuration file is merged with an existing configuration dictionary and +both contain a value for these keys, the elements in the array +defined in the existing configuration dictionary are overridden in order. +If there are more elements in the array being merged than in the +existing configuration dictionary, then the size of the array is increased. + +- ``SageMaker.EndpointConfig.ProductionVariants`` +- ``SageMaker.ModelPackage.ValidationSpecification.ValidationProfiles`` +- ``SageMaker.ProcessingJob.ProcessingInputs`` + +When a user passes values for these keys, the behavior depends on the +size of the array. If values are not explicitly defined inside the +user input array but are defined inside the config array, then those values from the config array are added to the user array. If the user input array +contains more elements than the config array, the extra elements of +the user input array are not substituted with values from the config. +Alternatively, if the config array contains more elements than the +user input array, the extra elements of the config array are not +used. + +View the merged configuration dict +---------------------------------- + +When the SageMaker Python SDK creates your ``Session``, it merges +together the config files found at the default locations and the +additional locations specified in the ``load_sagemaker_config()`` call. In this +process, a new config dictionary is created that aggregates the defaults in all +of the config files. To see the full merged config, inspect the +config of the session object as follows. + +.. code:: python + + session=Session() + session.sagemaker_config + +Inherit default values from the configuration file +================================================== + +After the ``Session`` is created, if a value for a supported parameter is +present in the merged configuration dictionary and the user does not pass +that parameter as part of a SageMaker Python SDK method, +the SageMaker Python SDK automatically passes the corresponding value +from the configuration dictionary as part of the API call. If a user +explicitly passes a value for a parameter that supports default +values, the SageMaker Python SDK overrides the value present in the +merged configuration dictionary and uses the value passed by the user +instead. + +Reference values from config +---------------------------- + +You can manually reference values from the merged configuration dictionary +using the corresponding key. This makes it possible to pass these +defaults values to an AWS SDK for Python (Boto3) request using the SageMaker Python +SDK. To reference a value from the configuration dictionary, pass the +corresponding key as follows. + +.. code:: python + + from sagemaker.session import Session + from sagemaker.utils import get_sagemaker_config_value + session=Session() + + # Option 1 + get_sagemaker_config_value(session, "key1.key2") + # Option 2 + session.sagemaker_config["key1"]["key2"] + +You can also specify custom parameters as part of the +``CustomParameters`` section in a configuration file by setting key and +value pairs as shown in the `Configuration file structure`_. +Values set in ``CustomParameters`` are not automatically used. You can +only use these values by manually referencing them with the +corresponding key. + +For example, the following code block references the ``VPCConfig`` parameter +specified as part of the ``Model`` API in the configuration file and sets +a variable with that value. It also references the ``JobName`` value +specified as part of ``CustomParameters``. + +.. code:: python + + from sagemaker.session import Session + from sagemaker.utils import get_sagemaker_config_value + session = Session() + + vpc_config_option_1 = get_sagemaker_config_value(session, "SageMaker.Model.VpcConfig") + vpc_config_option_2 = session.sagemaker_config["SageMaker"]["Model"]["VpcConfig"] + + custom_param_option_1 = get_sagemaker_config_value(session, "CustomParameters.JobName") + custom_param_option_2 = session.sagemaker_config["CustomParameters"]["JobName"] + +Debug default values +-------------------- + +When a default value is being used as part of an API call, the +SageMaker Python SDK prints out information about the default value, +the configuration dictionary that it came from, the keys that are being +looked at, and whether they are used or skipped. To get information, +enable Boto3 debug logging that shows the HTTP request info. + +To turn on this logging, run the following commands in the +environment where you’re using the SageMaker PySDK. + +.. code:: python + + import boto3 + import logging + boto3.set_stream_logger(name='botocore.endpoint', level=logging.DEBUG) + +The following log lines offer the most relevant information, +specifically the contents of ``'body': b'{...}`` . + +:: + + botocore.endpoint [DEBUG] Making request for OperationModel(name=) with params: {'url_path': ..., + 'query_string': ..., 'method': 'POST', 'headers': {...}, 'body': b'{...}', 'url': 'https://api.sagemaker.us-west-2.amazonaws.com/', + 'context': {...}} *** FAQ diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index c964e2670c..5a17ffb6a2 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -54,48 +54,47 @@ def load_sagemaker_config( additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE ) -> dict: - """Helper method that loads config files and merges them. + """Loads config files and merges them. By default, this method first searches for config files in the default locations defined by the SDK. Users can override the default admin and user config file paths using the - SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables, + ``SAGEMAKER_ADMIN_CONFIG_OVERRIDE`` and ``SAGEMAKER_USER_CONFIG_OVERRIDE`` environment variables, respectively. Additional config file paths can also be provided as a parameter. This method then: - * Loads each config file, whether it is Amazon S3 or the local file system. - * Validates the schema of the config files. - * Merges the files in the same order. + * Loads each config file, whether it is Amazon S3 or the local file system. + * Validates the schema of the config files. + * Merges the files in the same order. This method throws exceptions in the following cases: - * jsonschema.exceptions.ValidationError: Schema validation fails for one or more config - files. - * RuntimeError: The method is unable to retrieve the list of all S3 files with the - same prefix or is unable to retrieve the file. - * ValueError: There are no S3 files with the prefix when an S3 URI is provided. - * ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided. - * ValueError: A file doesn't exist in a path that was specified by the user as part of an - environment variable or additional configuration file path. This doesn't include the default - config file locations. + * ``jsonschema.exceptions.ValidationError``: Schema validation fails for one or more config files. + * ``RuntimeError``: The method is unable to retrieve the list of all S3 files with the same prefix or is unable to retrieve the file. + * ``ValueError``: There are no S3 files with the prefix when an S3 URI is provided. + * ``ValueError``: There is no config.yaml file in the S3 bucket when an S3 URI is provided. + * ``ValueError``: A file doesn't exist in a path that was specified by the user as part of an + environment variable or additional configuration file path. This doesn't include the default + config file locations. Args: additional_config_paths: List of config file paths. These paths can be one of the following. In the case of a directory, this method - searches for a config.yaml file in that directory. This method does not perform a + searches for a ``config.yaml`` file in that directory. This method does not perform a recursive search of folders in that directory. - * Local file path - * Local directory path - * S3 URI of the config file - * S3 URI of the directory containing the config file - Note: S3 URI follows the format s3:/// - s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch - config files from S3. If it is not provided, this method creates a default S3 resource - See :py:meth:boto3.session.Session.resource. This argument is not needed if the - config files are present in the local file system. + * Local file path + * Local directory path + * S3 URI of the config file + * S3 URI of the directory containing the config file + + Note: S3 URI follows the format ``s3:///`` + s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch + config files from S3. If it is not provided, this method creates a default S3 resource. + See `Boto3 Session documentation `__. + This argument is not needed if the config files are present in the local file system. """ default_config_path = os.getenv( ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH @@ -129,10 +128,7 @@ def load_sagemaker_config( def validate_sagemaker_config(sagemaker_config: dict = None): - """Helper method that validates whether the schema of a given dictionary. - - This method will validate whether the dictionary adheres to the schema - defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA` + """Validates whether a given dictionary adheres to the schema defined at ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`` Args: sagemaker_config: A dictionary containing default values for the From 17f290ba56e55db3cfea82b3044618cbc029af15 Mon Sep 17 00:00:00 2001 From: Ivy Bazan Date: Tue, 28 Mar 2023 14:17:14 -0700 Subject: [PATCH 38/40] fix linting errors --- src/sagemaker/config/config.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 5a17ffb6a2..f4836abe39 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -60,8 +60,8 @@ def load_sagemaker_config( defined by the SDK. Users can override the default admin and user config file paths using the - ``SAGEMAKER_ADMIN_CONFIG_OVERRIDE`` and ``SAGEMAKER_USER_CONFIG_OVERRIDE`` environment variables, - respectively. + ``SAGEMAKER_ADMIN_CONFIG_OVERRIDE`` and ``SAGEMAKER_USER_CONFIG_OVERRIDE`` environment + variables, respectively. Additional config file paths can also be provided as a parameter. @@ -71,13 +71,16 @@ def load_sagemaker_config( * Merges the files in the same order. This method throws exceptions in the following cases: - * ``jsonschema.exceptions.ValidationError``: Schema validation fails for one or more config files. - * ``RuntimeError``: The method is unable to retrieve the list of all S3 files with the same prefix or is unable to retrieve the file. + * ``jsonschema.exceptions.ValidationError``: Schema validation fails for one or more + config files. + * ``RuntimeError``: The method is unable to retrieve the list of all S3 files with the + same prefix or is unable to retrieve the file. * ``ValueError``: There are no S3 files with the prefix when an S3 URI is provided. - * ``ValueError``: There is no config.yaml file in the S3 bucket when an S3 URI is provided. - * ``ValueError``: A file doesn't exist in a path that was specified by the user as part of an - environment variable or additional configuration file path. This doesn't include the default - config file locations. + * ``ValueError``: There is no config.yaml file in the S3 bucket when an S3 URI is + provided. + * ``ValueError``: A file doesn't exist in a path that was specified by the user as + part of an environment variable or additional configuration file path. This doesn't + include the default config file locations. Args: additional_config_paths: List of config file paths. @@ -128,7 +131,8 @@ def load_sagemaker_config( def validate_sagemaker_config(sagemaker_config: dict = None): - """Validates whether a given dictionary adheres to the schema defined at ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`` + """Validates whether a given dictionary adheres to the schema defined at + ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`` Args: sagemaker_config: A dictionary containing default values for the From dc38ce58913afd69856c14f88ddf66f00e1752e9 Mon Sep 17 00:00:00 2001 From: Ivy Bazan Date: Tue, 28 Mar 2023 14:36:04 -0700 Subject: [PATCH 39/40] fix link lint --- src/sagemaker/config/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index f4836abe39..21eb13db90 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -96,7 +96,8 @@ def load_sagemaker_config( Note: S3 URI follows the format ``s3:///`` s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch config files from S3. If it is not provided, this method creates a default S3 resource. - See `Boto3 Session documentation `__. + See `Boto3 Session documentation `__. This argument is not needed if the config files are present in the local file system. """ default_config_path = os.getenv( @@ -131,7 +132,7 @@ def load_sagemaker_config( def validate_sagemaker_config(sagemaker_config: dict = None): - """Validates whether a given dictionary adheres to the schema defined at + """Validates whether a given dictionary adheres to the schema defined at ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`` Args: From ee93e3f241d137f1f99f51fc29669d3e96d90067 Mon Sep 17 00:00:00 2001 From: Ivy Bazan Date: Tue, 28 Mar 2023 14:45:13 -0700 Subject: [PATCH 40/40] fix lint --- src/sagemaker/config/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index 21eb13db90..730b11396b 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -132,8 +132,10 @@ def load_sagemaker_config( def validate_sagemaker_config(sagemaker_config: dict = None): - """Validates whether a given dictionary adheres to the schema defined at - ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`` + """Validates whether a given dictionary adheres to the schema. + + The schema is defined at + ``sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA``. Args: sagemaker_config: A dictionary containing default values for the