2323from sagemaker .estimator import Framework
2424import sagemaker .fw_utils as fw
2525from sagemaker .tensorflow import defaults
26- from sagemaker .tensorflow .model import TensorFlowModel
2726from sagemaker .tensorflow .serving import Model
2827from sagemaker .transformer import Transformer
2928from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
@@ -252,10 +251,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
252251
253252 def create_model (
254253 self ,
255- model_server_workers = None ,
256254 role = None ,
257255 vpc_config_override = VPC_CONFIG_DEFAULT ,
258- endpoint_type = None ,
259256 entry_point = None ,
260257 source_dir = None ,
261258 dependencies = None ,
@@ -266,43 +263,25 @@ def create_model(
266263
267264 Args:
268265 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
269- used during transform jobs. If not specified, the role from the Estimator will be
270- used.
271- model_server_workers (int): Optional. The number of worker processes used by the
272- inference server. If None, server will use one worker per vCPU.
266+ used during transform jobs. If not specified, the role from the Estimator is used.
273267 vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
274- model.
275- Default: use subnets and security groups from this Estimator.
268+ model. Default: use subnets and security groups from this Estimator.
269+
276270 * 'Subnets' (list[str]): List of subnet ids.
277271 * 'SecurityGroupIds' (list[str]): List of security group ids.
278- endpoint_type (str): Optional. Selects the software stack used by the inference server.
279- If not specified, the model will be configured to use the default
280- SageMaker model server. If 'tensorflow-serving', the model will be configured to
281- use the SageMaker Tensorflow Serving container.
272+
282273 entry_point (str): Path (absolute or relative) to the local Python source file which
283- should be executed as the entry point to training. If not specified and
284- ``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
285- ``endpoint_type`` is also ``None``, then the training entry point is used.
274+ should be executed as the entry point to training (default: None).
286275 source_dir (str): Path (absolute or relative) to a directory with any other serving
287- source code dependencies aside from the entry point file. If not specified and
288- ``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
289- ``endpoint_type`` is also ``None``, then the model source directory from training
290- is used.
276+ source code dependencies aside from the entry point file (default: None).
291277 dependencies (list[str]): A list of paths to directories (absolute or relative) with
292- any additional libraries that will be exported to the container.
293- If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
294- set to ``None``.
295- If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
296- **kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`
297- and :class:`~sagemaker.tensorflow.model.TensorFlowModel` constructors.
278+ any additional libraries that will be exported to the container (default: None).
279+ **kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.
298280
299281 Returns:
300- sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
301- ``Model`` object. See :class:`~sagemaker.tensorflow.serving.Model` or
302- :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
282+ sagemaker.tensorflow.serving.Model: A ``Model`` object.
283+ See :class:`~sagemaker.tensorflow.serving.Model` for full details.
303284 """
304- role = role or self .role
305-
306285 if "image" not in kwargs :
307286 kwargs ["image" ] = self .image_name
308287
@@ -312,41 +291,11 @@ def create_model(
312291 if "enable_network_isolation" not in kwargs :
313292 kwargs ["enable_network_isolation" ] = self .enable_network_isolation ()
314293
315- if endpoint_type == "tensorflow-serving" or self ._script_mode_enabled :
316- return self ._create_tfs_model (
317- role = role ,
318- vpc_config_override = vpc_config_override ,
319- entry_point = entry_point ,
320- source_dir = source_dir ,
321- dependencies = dependencies ,
322- ** kwargs
323- )
324-
325- return self ._create_default_model (
326- model_server_workers = model_server_workers ,
327- role = role ,
328- vpc_config_override = vpc_config_override ,
329- entry_point = entry_point ,
330- source_dir = source_dir ,
331- dependencies = dependencies ,
332- ** kwargs
333- )
334-
335- def _create_tfs_model (
336- self ,
337- role = None ,
338- vpc_config_override = VPC_CONFIG_DEFAULT ,
339- entry_point = None ,
340- source_dir = None ,
341- dependencies = None ,
342- ** kwargs
343- ):
344- """Placeholder docstring"""
345294 return Model (
346295 model_data = self .model_data ,
347- role = role ,
296+ role = role or self . role ,
348297 container_log_level = self .container_log_level ,
349- framework_version = utils . get_short_version ( self .framework_version ) ,
298+ framework_version = self .framework_version ,
350299 sagemaker_session = self .sagemaker_session ,
351300 vpc_config = self .get_vpc_config (vpc_config_override ),
352301 entry_point = entry_point ,
@@ -355,34 +304,6 @@ def _create_tfs_model(
355304 ** kwargs
356305 )
357306
358- def _create_default_model (
359- self ,
360- model_server_workers ,
361- role ,
362- vpc_config_override ,
363- entry_point = None ,
364- source_dir = None ,
365- dependencies = None ,
366- ** kwargs
367- ):
368- """Placeholder docstring"""
369- return TensorFlowModel (
370- self .model_data ,
371- role ,
372- entry_point or self .entry_point ,
373- source_dir = source_dir or self ._model_source_dir (),
374- enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
375- container_log_level = self .container_log_level ,
376- code_location = self .code_location ,
377- py_version = self .py_version ,
378- framework_version = self .framework_version ,
379- model_server_workers = model_server_workers ,
380- sagemaker_session = self .sagemaker_session ,
381- vpc_config = self .get_vpc_config (vpc_config_override ),
382- dependencies = dependencies or self .dependencies ,
383- ** kwargs
384- )
385-
386307 def hyperparameters (self ):
387308 """Return hyperparameters used by your custom TensorFlow code during model training."""
388309 hyperparameters = super (TensorFlow , self ).hyperparameters ()
@@ -479,9 +400,7 @@ def transformer(
479400 max_payload = None ,
480401 tags = None ,
481402 role = None ,
482- model_server_workers = None ,
483403 volume_kms_key = None ,
484- endpoint_type = None ,
485404 entry_point = None ,
486405 vpc_config_override = VPC_CONFIG_DEFAULT ,
487406 enable_network_isolation = None ,
@@ -515,15 +434,8 @@ def transformer(
515434 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
516435 used during transform jobs. If not specified, the role from the Estimator will be
517436 used.
518- model_server_workers (int): Optional. The number of worker processes used by the
519- inference server. If None, server will use one worker per vCPU.
520437 volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
521438 compute instance (default: None).
522- endpoint_type (str): Optional. Selects the software stack used by the inference server.
523- If not specified, the model will be configured to use the default
524- SageMaker model server.
525- If 'tensorflow-serving', the model will be configured to
526- use the SageMaker Tensorflow Serving container.
527439 entry_point (str): Path (absolute or relative) to the local Python source file which
528440 should be executed as the entry point to training. If not specified and
529441 ``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
@@ -575,10 +487,8 @@ def transformer(
575487 enable_network_isolation = self .enable_network_isolation ()
576488
577489 model = self .create_model (
578- model_server_workers = model_server_workers ,
579490 role = role ,
580491 vpc_config_override = vpc_config_override ,
581- endpoint_type = endpoint_type ,
582492 entry_point = entry_point ,
583493 enable_network_isolation = enable_network_isolation ,
584494 name = model_name ,
0 commit comments