Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class constructor
init_params[attribute] = init_params["hyperparameters"][value.name]

del init_params["hyperparameters"]
del init_params["image"]
del init_params["image_uri"]
return init_params

def prepare_workflow_for_training(self, records=None, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/factorization_machines.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(FactorizationMachines.repo_name, FactorizationMachines.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(FactorizationMachinesModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=FactorizationMachinesPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(IPInsights.repo_name, IPInsights.repo_version)
image = "{}/{}".format(
image_uri = "{}/{}".format(
registry(sagemaker_session.boto_session.region_name, IPInsights.repo_name), repo
)

super(IPInsightsModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=IPInsightsPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(KMeans.repo_name, KMeans.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(KMeansModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=KMeansPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(LDA.repo_name, LDA.repo_version)
image = "{}/{}".format(
image_uri = "{}/{}".format(
registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo
)
super(LDAModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=LDAPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/linear_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(LinearLearner.repo_name, LinearLearner.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(LinearLearnerModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=LinearLearnerPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/ntm.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(NTM.repo_name, NTM.repo_version)
image = "{}/{}".format(
image_uri = "{}/{}".format(
registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo
)
super(NTMModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=NTMPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/object2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(Object2Vec.repo_name, Object2Vec.repo_version)
image = "{}/{}".format(
image_uri = "{}/{}".format(
registry(sagemaker_session.boto_session.region_name, Object2Vec.repo_name), repo
)
super(Object2VecModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=Predictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(PCA.repo_name, PCA.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(PCAModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=PCAPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/randomcutforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
"""
sagemaker_session = sagemaker_session or Session()
repo = "{}:{}".format(RandomCutForest.repo_name, RandomCutForest.repo_version)
image = "{}/{}".format(
image_uri = "{}/{}".format(
registry(sagemaker_session.boto_session.region_name, RandomCutForest.repo_name), repo
)
super(RandomCutForestModel, self).__init__(
image,
image_uri,
model_data,
role,
predictor_cls=RandomCutForestPredictor,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ def create_model(
models = []

for container in inference_containers:
image = container["Image"]
image_uri = container["Image"]
model_data = container["ModelDataUrl"]
env = container["Environment"]

model = Model(
image=image,
image_uri=image_uri,
model_data=model_data,
role=self.role,
env=env,
Expand Down
37 changes: 19 additions & 18 deletions src/sagemaker/automl/candidate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,25 @@ def _get_train_args(
Returns (dcit): a dictionary that can be used as args of
sagemaker_session.train method.
"""
train_args = {}
train_args["input_config"] = inputs
train_args["job_name"] = name
train_args["input_mode"] = desc["AlgorithmSpecification"]["TrainingInputMode"]
train_args["role"] = desc["RoleArn"]
train_args["output_config"] = desc["OutputDataConfig"]
train_args["resource_config"] = desc["ResourceConfig"]
train_args["image"] = desc["AlgorithmSpecification"]["TrainingImage"]
train_args["enable_network_isolation"] = desc["EnableNetworkIsolation"]
train_args["encrypt_inter_container_traffic"] = encrypt_inter_container_traffic
train_args["train_use_spot_instances"] = desc["EnableManagedSpotTraining"]
train_args["hyperparameters"] = {}
train_args["stop_condition"] = {}
train_args["metric_definitions"] = None
train_args["checkpoint_s3_uri"] = None
train_args["checkpoint_local_path"] = None
train_args["tags"] = []
train_args["vpc_config"] = None
train_args = {
"input_config": inputs,
"job_name": name,
"input_mode": desc["AlgorithmSpecification"]["TrainingInputMode"],
"role": desc["RoleArn"],
"output_config": desc["OutputDataConfig"],
"resource_config": desc["ResourceConfig"],
"image_uri": desc["AlgorithmSpecification"]["TrainingImage"],
"enable_network_isolation": desc["EnableNetworkIsolation"],
"encrypt_inter_container_traffic": encrypt_inter_container_traffic,
"train_use_spot_instances": desc["EnableManagedSpotTraining"],
"hyperparameters": {},
"stop_condition": {},
"metric_definitions": None,
"checkpoint_s3_uri": None,
"checkpoint_local_path": None,
"tags": [],
"vpc_config": None,
}

if volume_kms_key is not None:
train_args["resource_config"]["VolumeKmsKeyId"] = volume_kms_key
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def create_model(
"""
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

if "image" not in kwargs:
kwargs["image"] = self.image_uri
if "image_uri" not in kwargs:
kwargs["image_uri"] = self.image_uri

return ChainerModel(
self.model_data,
Expand Down Expand Up @@ -257,7 +257,7 @@ class constructor
if value:
init_params[argument[len("sagemaker_") :]] = value

image_uri = init_params.pop("image")
image_uri = init_params.pop("image_uri")
framework, py_version, tag, _ = framework_name_from_image(image_uri)

if tag is None:
Expand Down
16 changes: 8 additions & 8 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
model_data,
role,
entry_point,
image=None,
image_uri=None,
framework_version=None,
py_version=None,
predictor_cls=ChainerPredictor,
Expand All @@ -85,16 +85,16 @@ def __init__(
file which should be executed as the entry point to model
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
image (str): A Docker image URI (default: None). If not specified, a
image_uri (str): A Docker image URI (default: None). If not specified, a
default image for Chainer will be used. If ``framework_version``
or ``py_version`` are ``None``, then ``image`` is required. If
or ``py_version`` are ``None``, then ``image_uri`` is required. If
also ``None``, then a ``ValueError`` will be raised.
framework_version (str): Chainer version you want to use for
executing your model training code. Defaults to ``None``. Required
unless ``image`` is provided.
unless ``image_uri`` is provided.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image`` is provided.
``image_uri`` is provided.
predictor_cls (callable[str, sagemaker.session.Session]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
Expand All @@ -111,7 +111,7 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
validate_version_or_image_args(framework_version, py_version, image)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
Expand All @@ -120,7 +120,7 @@ def __init__(
self.py_version = py_version

super(ChainerModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

self.model_server_workers = model_server_workers
Expand All @@ -140,7 +140,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
dict[str, str]: A container definition object usable with the
CreateModel API.
"""
deploy_image = self.image
deploy_image = self.image_uri
if not deploy_image:
if instance_type is None:
raise ValueError(
Expand Down
31 changes: 5 additions & 26 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ class constructor
if "AlgorithmName" in job_details["AlgorithmSpecification"]:
init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
elif "TrainingImage" in job_details["AlgorithmSpecification"]:
init_params["image"] = job_details["AlgorithmSpecification"]["TrainingImage"]
init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"]
else:
raise RuntimeError(
"Invalid AlgorithmSpecification. Either TrainingImage or "
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def start_new(cls, estimator, inputs, experiment_config):
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
train_args["algorithm_arn"] = estimator.algorithm_arn
else:
train_args["image"] = estimator.train_image()
train_args["image_uri"] = estimator.train_image()

if estimator.debugger_rule_configs:
train_args["debugger_rule_configs"] = estimator.debugger_rule_configs
Expand Down Expand Up @@ -1331,7 +1331,7 @@ def hyperparameters(self):
def create_model(
self,
role=None,
image=None,
image_uri=None,
predictor_cls=None,
serializer=None,
deserializer=None,
Expand All @@ -1350,7 +1350,7 @@ def create_model(
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
which is also used during transform jobs. If not specified, the
role from the Estimator will be used.
image (str): An container image to use for deploying the model.
image_uri (str): A Docker image URI to use for deploying the model.
Defaults to the image used for training.
predictor_cls (Predictor): The predictor class to use when
deploying the model.
Expand Down Expand Up @@ -1393,7 +1393,7 @@ def predict_wrapper(endpoint, session):
kwargs["enable_network_isolation"] = self.enable_network_isolation()

return Model(
image or self.train_image(),
image_uri or self.train_image(),
self.model_data,
role,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand All @@ -1402,27 +1402,6 @@ def predict_wrapper(endpoint, session):
**kwargs
)

@classmethod
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
"""Convert the job description to init params that can be handled by the
class constructor

Args:
job_details: the returned job details from a describe_training_job
API call.
model_channel_name (str): Name of the channel where pre-trained
model data will be downloaded

Returns:
dictionary: The transformed init_params
"""
init_params = super(Estimator, cls)._prepare_init_params_from_job_description(
job_details, model_channel_name
)

init_params["image_uri"] = init_params.pop("image")
return init_params


class Framework(EstimatorBase):
"""Base class that cannot be instantiated directly.
Expand Down
Loading