diff --git a/setup.py b/setup.py index 8610a97..6b174cd 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ def read_version(): packages = setuptools.find_packages(where="src", exclude=("test",)) -required_packages = ["numpy", "six", "psutil", "retrying>=1.3.3,<1.4", "scipy"] +required_packages = ["boto3", "numpy", "six", "psutil", "retrying>=1.3.3,<1.4", "scipy"] # enum is introduced in Python 3.4. Installing enum back port if sys.version_info < (3, 4): diff --git a/src/sagemaker_inference/model_server.py b/src/sagemaker_inference/model_server.py index e529749..b34d5d5 100644 --- a/src/sagemaker_inference/model_server.py +++ b/src/sagemaker_inference/model_server.py @@ -15,10 +15,12 @@ from __future__ import absolute_import import os +import re import signal import subprocess import sys +import boto3 import pkg_resources import psutil from retrying import retry @@ -199,7 +201,10 @@ def _terminate(signo, frame): # pylint: disable=unused-argument def _install_requirements(): logger.info("installing packages from requirements.txt...") pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH] - + if os.getenv("CA_REPOSITORY_ARN"): + index = _get_codeartifact_index() + pip_install_cmd.append("-i") + pip_install_cmd.append(index) try: subprocess.check_call(pip_install_cmd) except subprocess.CalledProcessError: @@ -207,6 +212,56 @@ def _install_requirements(): raise ValueError("failed to install required packages") +def _get_codeartifact_index(): + """ + Build the authenticated codeartifact index url + https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html + https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies + :return: authenticated codeartifact index url + """ + repository_arn = os.getenv("CA_REPOSITORY_ARN") + arn_regex = ( + "arn:(?P[^:]+):codeartifact:(?P[^:]+):(?P[^:]+)" + ":repository/(?P[^/]+)/(?P.+)" + ) + m = re.match(arn_regex, repository_arn) + if not m: + raise Exception("invalid CodeArtifact repository arn {}".format(repository_arn)) + domain = m.group("domain") + owner = m.group("account") + repository = m.group("repository") + region = m.group("region") + + logger.info( + "configuring pip to use codeartifact " + "(domain: %s, domain owner: %s, repository: %s, region: %s)", + domain, + owner, + repository, + region, + ) + try: + client = boto3.client("codeartifact", region_name=region) + auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner) + token = auth_token_response["authorizationToken"] + endpoint_response = client.get_repository_endpoint( + domain=domain, domainOwner=owner, repository=repository, format="pypi" + ) + unauthenticated_index = endpoint_response["repositoryEndpoint"] + return re.sub( + "https://", + "https://aws:{}@".format(token), + re.sub( + "{}/?$".format(repository), + "{}/simple/".format(repository), + unauthenticated_index, + ), + ) + except Exception: + logger.error("failed to configure pip to use codeartifact") + raise Exception("failed to configure pip to use codeartifact") + + def _retry_retrieve_mms_server_process(startup_timeout): retrieve_mms_server_process = retry(wait_fixed=1000, stop_max_delay=startup_timeout * 1000)( _retrieve_mms_server_process diff --git a/test/unit/test_model_server.py b/test/unit/test_model_server.py index 76247ed..37301e5 100644 --- a/test/unit/test_model_server.py +++ b/test/unit/test_model_server.py @@ -13,9 +13,12 @@ import os import signal import subprocess +import sys import types -from mock import ANY, Mock, patch +import botocore.session +from botocore.stub import Stubber +from mock import ANY, MagicMock, Mock, patch import pytest from sagemaker_inference import environment, model_server @@ -224,6 +227,16 @@ def test_add_sigterm_handler(signal_call): def test_install_requirements(check_call): model_server._install_requirements() + install_cmd = [ + sys.executable, + "-m", + "pip", + "install", + "-r", + "/opt/ml/model/code/requirements.txt", + ] + check_call.assert_called_once_with(install_cmd) + @patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd")) def test_install_requirements_installation_failed(check_call): @@ -233,6 +246,49 @@ def test_install_requirements_installation_failed(check_call): assert "failed to install required packages" in str(e.value) +@patch.dict(os.environ, {"CA_REPOSITORY_ARN": "invalid_arn"}, clear=True) +def test_install_requirements_codeartifact_invalid_arn_installation_failed(): + with pytest.raises(Exception) as e: + model_server._install_requirements() + + assert "invalid CodeArtifact repository arn invalid_arn" in str(e.value) + + +@patch("subprocess.check_call") +@patch.dict( + os.environ, + { + "CA_REPOSITORY_ARN": "arn:aws:codeartifact:my_region:012345678900:repository/my_domain/my_repo" + }, + clear=True, +) +def test_install_requirements_codeartifact(check_call): + # mock/stub codeartifact client and its responses + endpoint = "https://domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/" + codeartifact = botocore.session.get_session().create_client( + "codeartifact", region_name="myregion" + ) + stubber = Stubber(codeartifact) + stubber.add_response("get_authorization_token", {"authorizationToken": "the-auth-token"}) + stubber.add_response("get_repository_endpoint", {"repositoryEndpoint": endpoint}) + stubber.activate() + + with patch("boto3.client", MagicMock(return_value=codeartifact)): + model_server._install_requirements() + + install_cmd = [ + sys.executable, + "-m", + "pip", + "install", + "-r", + "/opt/ml/model/code/requirements.txt", + "-i", + "https://aws:the-auth-token@domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/simple/", + ] + check_call.assert_called_once_with(install_cmd) + + @patch("psutil.process_iter") def test_retrieve_mms_server_process(process_iter): server = Mock()