Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix, repack=True)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
346 changes: 232 additions & 114 deletions src/sagemaker/model.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix, self._is_mms_version())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix, repack=self._is_mms_version())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
]

deploy_env = dict(model.env)
deploy_env.update(model._framework_env_vars())
deploy_env.update(model._script_mode_env_vars())

try:
if model.model_server_workers:
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
99 changes: 99 additions & 0 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import


import pytest
from mock import MagicMock, Mock, patch
from sagemaker.model import FrameworkModel, Model


ENTRY_POINT_INFERENCE = "inference.py"
REGION = "us-west-2"
TIMESTAMP = "2017-11-06-14:14:15.671"
BUCKET_NAME = "mybucket"
INSTANCE_COUNT = 1
INSTANCE_TYPE = "ml.p2.xlarge"
ROLE = "DummyRole"
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
MODEL_DATA = "s3://someprefix2/models/model.tar.gz"


class DummyFrameworkModel(FrameworkModel):
def __init__(self, **kwargs):
super(DummyFrameworkModel, self).__init__(
**kwargs,
)


@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
sms = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=REGION,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

return sms


@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
@patch("sagemaker.utils.repack_model")
def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_session):
t = Model(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=SCRIPT_URI,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert len(sagemaker_session.create_model.call_args_list) == 1
assert len(sagemaker_session.endpoint_from_production_variants.call_args_list) == 1
assert len(repack_model.call_args_list) == 1

generic_model_create_model_args = sagemaker_session.create_model.call_args_list
generic_model_endpoint_from_production_variants_args = (
sagemaker_session.endpoint_from_production_variants.call_args_list
)
generic_model_repack_model_args = repack_model.call_args_list

sagemaker_session.create_model.reset_mock()
sagemaker_session.endpoint_from_production_variants.reset_mock()
repack_model.reset_mock()

t = DummyFrameworkModel(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=SCRIPT_URI,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert generic_model_create_model_args == sagemaker_session.create_model.call_args_list
assert (
generic_model_endpoint_from_production_variants_args
== sagemaker_session.endpoint_from_production_variants.call_args_list
)
assert generic_model_repack_model_args == repack_model.call_args_list