Skip to content
This repository was archived by the owner on Nov 20, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def read_version():

packages = setuptools.find_packages(where="src", exclude=("test",))

required_packages = ["numpy", "six", "psutil", "retrying>=1.3.3,<1.4", "scipy"]
required_packages = ["boto3", "numpy", "six", "psutil", "retrying>=1.3.3,<1.4", "scipy"]

# enum is introduced in Python 3.4. Installing enum back port
if sys.version_info < (3, 4):
Expand Down
57 changes: 56 additions & 1 deletion src/sagemaker_inference/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from __future__ import absolute_import

import os
import re
import signal
import subprocess
import sys

import boto3
import pkg_resources
import psutil
from retrying import retry
Expand Down Expand Up @@ -199,14 +201,67 @@ def _terminate(signo, frame): # pylint: disable=unused-argument
def _install_requirements():
logger.info("installing packages from requirements.txt...")
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]

if os.getenv("CA_REPOSITORY_ARN"):
index = _get_codeartifact_index()
pip_install_cmd.append("-i")
pip_install_cmd.append(index)
try:
subprocess.check_call(pip_install_cmd)
except subprocess.CalledProcessError:
logger.error("failed to install required packages, exiting")
raise ValueError("failed to install required packages")


def _get_codeartifact_index():
"""
Build the authenticated codeartifact index url
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
:return: authenticated codeartifact index url
"""
repository_arn = os.getenv("CA_REPOSITORY_ARN")
arn_regex = (
"arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
)
m = re.match(arn_regex, repository_arn)
if not m:
raise Exception("invalid CodeArtifact repository arn {}".format(repository_arn))
domain = m.group("domain")
owner = m.group("account")
repository = m.group("repository")
region = m.group("region")

logger.info(
"configuring pip to use codeartifact "
"(domain: %s, domain owner: %s, repository: %s, region: %s)",
domain,
owner,
repository,
region,
)
try:
client = boto3.client("codeartifact", region_name=region)
auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner)
token = auth_token_response["authorizationToken"]
endpoint_response = client.get_repository_endpoint(
domain=domain, domainOwner=owner, repository=repository, format="pypi"
)
unauthenticated_index = endpoint_response["repositoryEndpoint"]
return re.sub(
"https://",
"https://aws:{}@".format(token),
re.sub(
"{}/?$".format(repository),
"{}/simple/".format(repository),
unauthenticated_index,
),
)
except Exception:
logger.error("failed to configure pip to use codeartifact")
raise Exception("failed to configure pip to use codeartifact")


def _retry_retrieve_mms_server_process(startup_timeout):
retrieve_mms_server_process = retry(wait_fixed=1000, stop_max_delay=startup_timeout * 1000)(
_retrieve_mms_server_process
Expand Down
58 changes: 57 additions & 1 deletion test/unit/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
import os
import signal
import subprocess
import sys
import types

from mock import ANY, Mock, patch
import botocore.session
from botocore.stub import Stubber
from mock import ANY, MagicMock, Mock, patch
import pytest

from sagemaker_inference import environment, model_server
Expand Down Expand Up @@ -224,6 +227,16 @@ def test_add_sigterm_handler(signal_call):
def test_install_requirements(check_call):
model_server._install_requirements()

install_cmd = [
sys.executable,
"-m",
"pip",
"install",
"-r",
"/opt/ml/model/code/requirements.txt",
]
check_call.assert_called_once_with(install_cmd)


@patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd"))
def test_install_requirements_installation_failed(check_call):
Expand All @@ -233,6 +246,49 @@ def test_install_requirements_installation_failed(check_call):
assert "failed to install required packages" in str(e.value)


@patch.dict(os.environ, {"CA_REPOSITORY_ARN": "invalid_arn"}, clear=True)
def test_install_requirements_codeartifact_invalid_arn_installation_failed():
with pytest.raises(Exception) as e:
model_server._install_requirements()

assert "invalid CodeArtifact repository arn invalid_arn" in str(e.value)


@patch("subprocess.check_call")
@patch.dict(
os.environ,
{
"CA_REPOSITORY_ARN": "arn:aws:codeartifact:my_region:012345678900:repository/my_domain/my_repo"
},
clear=True,
)
def test_install_requirements_codeartifact(check_call):
# mock/stub codeartifact client and its responses
endpoint = "https://domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/"
codeartifact = botocore.session.get_session().create_client(
"codeartifact", region_name="myregion"
)
stubber = Stubber(codeartifact)
stubber.add_response("get_authorization_token", {"authorizationToken": "the-auth-token"})
stubber.add_response("get_repository_endpoint", {"repositoryEndpoint": endpoint})
stubber.activate()

with patch("boto3.client", MagicMock(return_value=codeartifact)):
model_server._install_requirements()

install_cmd = [
sys.executable,
"-m",
"pip",
"install",
"-r",
"/opt/ml/model/code/requirements.txt",
"-i",
"https://aws:the-auth-token@domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/simple/",
]
check_call.assert_called_once_with(install_cmd)


@patch("psutil.process_iter")
def test_retrieve_mms_server_process(process_iter):
server = Mock()
Expand Down