Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def read_version():
install_requires=required_packages,
extras_require={
"test": [
"tox==3.13.1",
"tox==4.6.4",
"pytest==4.4.1",
"pytest-cov",
"mock",
"sagemaker[local]<2",
"sagemaker[local]>=2.172.0,<3",
"black==22.3.0 ; python_version >= '3.7'",
]
},
Expand Down
61 changes: 61 additions & 0 deletions src/sagemaker_training/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@

import importlib
import os
import re
import shlex
import sys
import textwrap

import boto3
import six

from sagemaker_training import environment, errors, files, logging_config, process

logger = logging_config.get_logger()

DEFAULT_MODULE_NAME = "default_user_module_name"
CA_REPOSITORY_ARN_ENV = "CA_REPOSITORY_ARN"


def exists(name): # type: (str) -> bool
Expand Down Expand Up @@ -121,6 +124,9 @@ def install(path, capture_error=False): # type: (str, bool) -> None

if has_requirements(path):
cmd += "-r requirements.txt"
if os.getenv(CA_REPOSITORY_ARN_ENV):
index = _get_codeartifact_index()
cmd += " -i {}".format(index)

logger.info("Installing module with the following command:\n%s", cmd)

Expand All @@ -138,6 +144,9 @@ def install_requirements(path, capture_error=False): # type: (str, bool) -> Non
stderr, and appends it to the returned Exception message in case of errors.
"""
cmd = "{} -m pip install -r requirements.txt".format(process.python_executable())
if os.getenv(CA_REPOSITORY_ARN_ENV):
index = _get_codeartifact_index()
cmd += " -i {}".format(index)

logger.info("Installing dependencies from requirements.txt:\n{}".format(cmd))

Expand Down Expand Up @@ -172,3 +181,55 @@ def import_module(uri, name=DEFAULT_MODULE_NAME): # type: (str, str) -> module
return module
except Exception as e: # pylint: disable=broad-except
six.reraise(errors.ImportModuleError, errors.ImportModuleError(e), sys.exc_info()[2])


def _get_codeartifact_index():
"""
Build the authenticated codeartifact index url based on the arn provided
via CA_REPOSITORY_ARN environment variable following the form
`arn:${Partition}:codeartifact:${Region}:${Account}:repository/${DomainName}/${RepositoryName}`
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
:return: authenticated codeartifact index url
"""
repository_arn = os.getenv(CA_REPOSITORY_ARN_ENV)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a repository_arn sample in comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second documentation link already has the format of the arn.

arn:${Partition}:codeartifact:${Region}:${Account}:repository/${DomainName}/${RepositoryName}

Also, keep in mind, this is a private method, but am happy to add this above format for clarity if u think that the link is not enough

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, pls add it. Thanks

arn_regex = (
"arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
)
m = re.match(arn_regex, repository_arn)
if not m:
raise Exception("invalid CodeArtifact repository arn {}".format(repository_arn))
domain = m.group("domain")
owner = m.group("account")
repository = m.group("repository")
region = m.group("region")

logger.info(
"configuring pip to use codeartifact "
"(domain: %s, domain owner: %s, repository: %s, region: %s)",
domain,
owner,
repository,
region,
)
try:
client = boto3.client("codeartifact", region_name=region)
auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner)
token = auth_token_response["authorizationToken"]
endpoint_response = client.get_repository_endpoint(
domain=domain, domainOwner=owner, repository=repository, format="pypi"
)
unauthenticated_index = endpoint_response["repositoryEndpoint"]
return re.sub(
"https://",
"https://aws:{}@".format(token),
re.sub(
"{}/?$".format(repository),
"{}/simple/".format(repository),
unauthenticated_index,
),
)
except Exception:
logger.error("failed to configure pip to use codeartifact")
raise Exception("failed to configure pip to use codeartifact")
31 changes: 31 additions & 0 deletions test/integration/local/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,34 @@ def test_install_requirements(capsys):
assert "Installing collected packages: pyfiglet" in stdout
assert "Successfully installed pyfiglet-0.8.post1" in stdout
assert "Reporting training SUCCESS" in stdout


# def test_install_requirements_from_codeartifact(capsys):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you enable this test later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, for this to really be enabled, it requires that the build infrastructure for this package to actually have a codeartifact repo which I believe it doesn't.
I've written the integ test, that can be enabled if the build infra has such access.

So, at the moment, the answer is no, I don't think I can enable this on my own.

# # TODO: fill in details for CA
# ca_domain = "..."
# ca_domain_owner = "..."
# ca_repository = "..."
# ca_region = "..."
# ca_repository_arn = "..."
#
# estimator = Estimator(
# image_uri="sagemaker-training-toolkit-test:dummy",
# # TODO: Grant the role permissions to access CodeArtifact repo (repo resource policy + role policy)
# # https://docs.aws.amazon.com/codeartifact/latest/ug/security-iam.html
# # https://docs.aws.amazon.com/codeartifact/latest/ug/repo-policies.html
# role="SageMakerRole",
# instance_count=1,
# instance_type="local",
# environment={
# "CA_REPOSITORY_ARN": ca_repository_arn,
# }
# )
#
# estimator.fit()
#
# stdout = capsys.readouterr().out
#
# assert "{}-{}.d.codeartifact.{}.amazonaws.com/pypi/{}/simple/".format(ca_domain, ca_domain_owner, ca_region, ca_repository) in stdout
# assert "Installing collected packages: pyfiglet" in stdout
# assert "Successfully installed pyfiglet-0.8.post1" in stdout
# assert "Reporting training SUCCESS" in stdout
54 changes: 53 additions & 1 deletion test/unit/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sys
import textwrap

from mock import call, Mock, mock_open, patch
from mock import call, MagicMock, Mock, mock_open, patch
import pytest

from sagemaker_training import environment, errors, files, modules, params
Expand Down Expand Up @@ -90,6 +90,58 @@ def test_install_requirements(check_error):
)


def mock_codeartifact_client(client_name, **kwargs):
endpoint = "https://domain-012345678900.d.codeartifact.my-region.amazonaws.com/pypi/my_repo/"
client = MagicMock()
client.get_authorization_token = MagicMock(
return_value={"authorizationToken": "the-auth-token"}
)
client.get_repository_endpoint = MagicMock(return_value={"repositoryEndpoint": endpoint})
return client


@patch("sagemaker_training.process.check_error", autospec=True)
@patch("boto3.client", new=mock_codeartifact_client)
def test_install_requirements_codeartifact(check_error):
path = "c://sagemaker-pytorch-container"
cmd = [
sys.executable,
"-m",
"pip",
"install",
"-r",
"requirements.txt",
"-i",
"https://aws:the-auth-token@domain-012345678900.d.codeartifact.my-region.amazonaws.com/pypi/my_repo/simple/",
]

with patch.dict(
os.environ,
{
"CA_REPOSITORY_ARN": "arn:aws:codeartifact:my-region:012345678900:repository/my_domain/my_repo"
},
clear=True,
):
with patch("os.path.exists", return_value=True):
modules.install_requirements(path)

check_error.assert_called_with(
cmd, errors.InstallRequirementsError, 1, cwd=path, capture_error=False
)


@patch.dict(os.environ, {"CA_REPOSITORY_ARN": "invalid_arn"}, clear=True)
@patch("sagemaker_training.process.check_error", autospec=True)
def test_install_requirements_codeartifact_missing_environment_variables(check_error):
path = "c://sagemaker-pytorch-container"

with patch("os.path.exists", return_value=True):
with pytest.raises(Exception) as e:
modules.install_requirements(path)

assert "invalid CodeArtifact repository arn invalid_arn" in str(e.value)


@patch("sagemaker_training.process.check_error", autospec=True)
def test_install_fails(check_error):
check_error.side_effect = errors.ClientError()
Expand Down