diff --git a/setup.py b/setup.py index 2ed74101..76ebedc3 100644 --- a/setup.py +++ b/setup.py @@ -80,11 +80,11 @@ def read_version(): install_requires=required_packages, extras_require={ "test": [ - "tox==3.13.1", + "tox==4.6.4", "pytest==4.4.1", "pytest-cov", "mock", - "sagemaker[local]<2", + "sagemaker[local]>=2.172.0,<3", "black==22.3.0 ; python_version >= '3.7'", ] }, diff --git a/src/sagemaker_training/modules.py b/src/sagemaker_training/modules.py index 48245874..c791f4e5 100644 --- a/src/sagemaker_training/modules.py +++ b/src/sagemaker_training/modules.py @@ -17,10 +17,12 @@ import importlib import os +import re import shlex import sys import textwrap +import boto3 import six from sagemaker_training import environment, errors, files, logging_config, process @@ -28,6 +30,7 @@ logger = logging_config.get_logger() DEFAULT_MODULE_NAME = "default_user_module_name" +CA_REPOSITORY_ARN_ENV = "CA_REPOSITORY_ARN" def exists(name): # type: (str) -> bool @@ -121,6 +124,9 @@ def install(path, capture_error=False): # type: (str, bool) -> None if has_requirements(path): cmd += "-r requirements.txt" + if os.getenv(CA_REPOSITORY_ARN_ENV): + index = _get_codeartifact_index() + cmd += " -i {}".format(index) logger.info("Installing module with the following command:\n%s", cmd) @@ -138,6 +144,9 @@ def install_requirements(path, capture_error=False): # type: (str, bool) -> Non stderr, and appends it to the returned Exception message in case of errors. """ cmd = "{} -m pip install -r requirements.txt".format(process.python_executable()) + if os.getenv(CA_REPOSITORY_ARN_ENV): + index = _get_codeartifact_index() + cmd += " -i {}".format(index) logger.info("Installing dependencies from requirements.txt:\n{}".format(cmd)) @@ -172,3 +181,55 @@ def import_module(uri, name=DEFAULT_MODULE_NAME): # type: (str, str) -> module return module except Exception as e: # pylint: disable=broad-except six.reraise(errors.ImportModuleError, errors.ImportModuleError(e), sys.exc_info()[2]) + + +def _get_codeartifact_index(): + """ + Build the authenticated codeartifact index url based on the arn provided + via CA_REPOSITORY_ARN environment variable following the form + `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${DomainName}/${RepositoryName}` + https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html + https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies + :return: authenticated codeartifact index url + """ + repository_arn = os.getenv(CA_REPOSITORY_ARN_ENV) + arn_regex = ( + "arn:(?P[^:]+):codeartifact:(?P[^:]+):(?P[^:]+)" + ":repository/(?P[^/]+)/(?P.+)" + ) + m = re.match(arn_regex, repository_arn) + if not m: + raise Exception("invalid CodeArtifact repository arn {}".format(repository_arn)) + domain = m.group("domain") + owner = m.group("account") + repository = m.group("repository") + region = m.group("region") + + logger.info( + "configuring pip to use codeartifact " + "(domain: %s, domain owner: %s, repository: %s, region: %s)", + domain, + owner, + repository, + region, + ) + try: + client = boto3.client("codeartifact", region_name=region) + auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner) + token = auth_token_response["authorizationToken"] + endpoint_response = client.get_repository_endpoint( + domain=domain, domainOwner=owner, repository=repository, format="pypi" + ) + unauthenticated_index = endpoint_response["repositoryEndpoint"] + return re.sub( + "https://", + "https://aws:{}@".format(token), + re.sub( + "{}/?$".format(repository), + "{}/simple/".format(repository), + unauthenticated_index, + ), + ) + except Exception: + logger.error("failed to configure pip to use codeartifact") + raise Exception("failed to configure pip to use codeartifact") diff --git a/test/integration/local/test_dummy.py b/test/integration/local/test_dummy.py index f4db94e9..596bc29f 100644 --- a/test/integration/local/test_dummy.py +++ b/test/integration/local/test_dummy.py @@ -50,3 +50,34 @@ def test_install_requirements(capsys): assert "Installing collected packages: pyfiglet" in stdout assert "Successfully installed pyfiglet-0.8.post1" in stdout assert "Reporting training SUCCESS" in stdout + + +# def test_install_requirements_from_codeartifact(capsys): +# # TODO: fill in details for CA +# ca_domain = "..." +# ca_domain_owner = "..." +# ca_repository = "..." +# ca_region = "..." +# ca_repository_arn = "..." +# +# estimator = Estimator( +# image_uri="sagemaker-training-toolkit-test:dummy", +# # TODO: Grant the role permissions to access CodeArtifact repo (repo resource policy + role policy) +# # https://docs.aws.amazon.com/codeartifact/latest/ug/security-iam.html +# # https://docs.aws.amazon.com/codeartifact/latest/ug/repo-policies.html +# role="SageMakerRole", +# instance_count=1, +# instance_type="local", +# environment={ +# "CA_REPOSITORY_ARN": ca_repository_arn, +# } +# ) +# +# estimator.fit() +# +# stdout = capsys.readouterr().out +# +# assert "{}-{}.d.codeartifact.{}.amazonaws.com/pypi/{}/simple/".format(ca_domain, ca_domain_owner, ca_region, ca_repository) in stdout +# assert "Installing collected packages: pyfiglet" in stdout +# assert "Successfully installed pyfiglet-0.8.post1" in stdout +# assert "Reporting training SUCCESS" in stdout diff --git a/test/unit/test_modules.py b/test/unit/test_modules.py index 6c2f88fd..8f52d69c 100644 --- a/test/unit/test_modules.py +++ b/test/unit/test_modules.py @@ -17,7 +17,7 @@ import sys import textwrap -from mock import call, Mock, mock_open, patch +from mock import call, MagicMock, Mock, mock_open, patch import pytest from sagemaker_training import environment, errors, files, modules, params @@ -90,6 +90,58 @@ def test_install_requirements(check_error): ) +def mock_codeartifact_client(client_name, **kwargs): + endpoint = "https://domain-012345678900.d.codeartifact.my-region.amazonaws.com/pypi/my_repo/" + client = MagicMock() + client.get_authorization_token = MagicMock( + return_value={"authorizationToken": "the-auth-token"} + ) + client.get_repository_endpoint = MagicMock(return_value={"repositoryEndpoint": endpoint}) + return client + + +@patch("sagemaker_training.process.check_error", autospec=True) +@patch("boto3.client", new=mock_codeartifact_client) +def test_install_requirements_codeartifact(check_error): + path = "c://sagemaker-pytorch-container" + cmd = [ + sys.executable, + "-m", + "pip", + "install", + "-r", + "requirements.txt", + "-i", + "https://aws:the-auth-token@domain-012345678900.d.codeartifact.my-region.amazonaws.com/pypi/my_repo/simple/", + ] + + with patch.dict( + os.environ, + { + "CA_REPOSITORY_ARN": "arn:aws:codeartifact:my-region:012345678900:repository/my_domain/my_repo" + }, + clear=True, + ): + with patch("os.path.exists", return_value=True): + modules.install_requirements(path) + + check_error.assert_called_with( + cmd, errors.InstallRequirementsError, 1, cwd=path, capture_error=False + ) + + +@patch.dict(os.environ, {"CA_REPOSITORY_ARN": "invalid_arn"}, clear=True) +@patch("sagemaker_training.process.check_error", autospec=True) +def test_install_requirements_codeartifact_missing_environment_variables(check_error): + path = "c://sagemaker-pytorch-container" + + with patch("os.path.exists", return_value=True): + with pytest.raises(Exception) as e: + modules.install_requirements(path) + + assert "invalid CodeArtifact repository arn invalid_arn" in str(e.value) + + @patch("sagemaker_training.process.check_error", autospec=True) def test_install_fails(check_error): check_error.side_effect = errors.ClientError()