2424from sagemaker .job import _Job
2525from sagemaker .utils import base_name_from_image , name_from_base
2626from sagemaker .session import Session
27- from sagemaker .s3 import (
28- S3CompressionType ,
29- S3DataDistributionType ,
30- S3DataType ,
31- S3DownloadMode ,
32- S3InputMode ,
33- S3UploadMode ,
34- S3Uploader ,
35- )
27+ from sagemaker .s3 import S3Uploader
3628from sagemaker .network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3729
3830
@@ -46,7 +38,6 @@ def __init__(
4638 instance_count ,
4739 instance_type ,
4840 entrypoint = None ,
49- arguments = None ,
5041 volume_size_in_gb = 30 ,
5142 volume_kms_key = None ,
5243 max_runtime_in_seconds = 24 * 60 * 60 ,
@@ -72,8 +63,6 @@ def __init__(
7263 instance_type (str): Type of EC2 instance to use for
7364 processing, for example, 'ml.c4.xlarge'.
7465 entrypoint (str): The entrypoint for the processing job.
75- arguments ([str]): A list of string arguments to be passed to a
76- processing job.
7766 volume_size_in_gb (int): Size in GB of the EBS volume
7867 to use for storing data during processing (default: 30).
7968 volume_kms_key (str): A KMS key for the processing
@@ -99,7 +88,6 @@ def __init__(
9988 self .instance_count = instance_count
10089 self .instance_type = instance_type
10190 self .entrypoint = entrypoint
102- self .arguments = arguments
10391 self .volume_size_in_gb = volume_size_in_gb
10492 self .volume_kms_key = volume_kms_key
10593 self .max_runtime_in_seconds = max_runtime_in_seconds
@@ -112,8 +100,9 @@ def __init__(
112100 self .jobs = []
113101 self .latest_job = None
114102 self ._current_job_name = None
103+ self .arguments = None
115104
116- def run (self , inputs = None , outputs = None , wait = True , logs = True , job_name = None ):
105+ def run (self , inputs = None , outputs = None , arguments = None , wait = True , logs = True , job_name = None ):
117106 """Run a processing job.
118107
119108 Args:
@@ -122,6 +111,8 @@ def run(self, inputs=None, outputs=None, wait=True, logs=True, job_name=None):
122111 outputs ([sagemaker.processor.ProcessingOutput]): Outputs for the processing
123112 job. These can be specified as either a path string or a ProcessingOutput
124113 object.
114+ arguments ([str]): A list of string arguments to be passed to a
115+ processing job.
125116 wait (bool): Whether the call should wait until the job completes (default: True).
126117 logs (bool): Whether to show the logs produced by the job.
127118 Only meaningful when wait is True (default: True).
@@ -138,6 +129,7 @@ def run(self, inputs=None, outputs=None, wait=True, logs=True, job_name=None):
138129
139130 normalized_inputs = self ._normalize_inputs (inputs )
140131 normalized_outputs = self ._normalize_outputs (outputs )
132+ self .arguments = arguments
141133
142134 self .latest_job = ProcessingJob .start_new (self , normalized_inputs , normalized_outputs )
143135 self .jobs .append (self .latest_job )
@@ -243,7 +235,7 @@ def _normalize_outputs(self, outputs=None):
243235 return normalized_outputs
244236
245237
246- class ScriptModeProcessor (Processor ):
238+ class ScriptProcessor (Processor ):
247239 """Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
248240
249241 def __init__ (
@@ -252,8 +244,6 @@ def __init__(
252244 image_uri ,
253245 instance_count ,
254246 instance_type ,
255- py_version = "py3" ,
256- arguments = None ,
257247 volume_size_in_gb = 30 ,
258248 volume_kms_key = None ,
259249 max_runtime_in_seconds = 24 * 60 * 60 ,
@@ -263,7 +253,7 @@ def __init__(
263253 tags = None ,
264254 network_config = None ,
265255 ):
266- """Initialize a ``ScriptModeProcessor `` instance. The ScriptModeProcessor
256+ """Initialize a ``ScriptProcessor `` instance. The ScriptProcessor
267257 handles Amazon SageMaker processing tasks for jobs using script mode.
268258
269259 Args:
@@ -279,8 +269,6 @@ def __init__(
279269 instance_type (str): Type of EC2 instance to use for
280270 processing, for example, 'ml.c4.xlarge'.
281271 py_version (str): The python version to use, for example, 'py3'.
282- arguments ([str]): A list of string arguments to be passed to a
283- processing job.
284272 volume_size_in_gb (int): Size in GB of the EBS volume
285273 to use for storing data during processing (default: 30).
286274 volume_kms_key (str): A KMS key for the processing
@@ -301,16 +289,14 @@ def __init__(
301289 object that configures network isolation, encryption of
302290 inter-container traffic, security group IDs, and subnets.
303291 """
304- self .py_version = py_version
305292 self ._CODE_CONTAINER_BASE_PATH = "/input/"
306293 self ._CODE_CONTAINER_INPUT_NAME = "code"
307294
308- super (ScriptModeProcessor , self ).__init__ (
295+ super (ScriptProcessor , self ).__init__ (
309296 role = role ,
310297 image_uri = image_uri ,
311298 instance_count = instance_count ,
312299 instance_type = instance_type ,
313- arguments = arguments ,
314300 volume_size_in_gb = volume_size_in_gb ,
315301 volume_kms_key = volume_kms_key ,
316302 max_runtime_in_seconds = max_runtime_in_seconds ,
@@ -322,11 +308,22 @@ def __init__(
322308 )
323309
324310 def run (
325- self , code , script_name = None , inputs = None , outputs = None , wait = True , logs = True , job_name = None
311+ self ,
312+ command ,
313+ code ,
314+ script_name = None ,
315+ inputs = None ,
316+ outputs = None ,
317+ arguments = None ,
318+ wait = True ,
319+ logs = True ,
320+ job_name = None ,
326321 ):
327322 """Run a processing job with Script Mode.
328323
329324 Args:
325+ command([str]): This is a list of strings that includes the executable, along
326+ with any command-line flags. For example: ["python3", "-v"]
330327 code (str): This can be an S3 uri or a local path to either
331328 a directory or a file with the user's script to run.
332329 script_name (str): If the user provides a directory for source,
@@ -337,6 +334,8 @@ def run(
337334 outputs ([str or sagemaker.processor.ProcessingOutput]): Outputs for the processing
338335 job. These can be specified as either a path string or a ProcessingOutput
339336 object.
337+ arguments ([str]): A list of string arguments to be passed to a
338+ processing job.
340339 wait (bool): Whether the call should wait until the job completes (default: True).
341340 logs (bool): Whether to show the logs produced by the job.
342341 Only meaningful when wait is True (default: True).
@@ -349,10 +348,15 @@ def run(
349348 customer_code_s3_uri = self ._upload_code (code )
350349 inputs_with_code = self ._convert_code_and_add_to_inputs (inputs , customer_code_s3_uri )
351350
352- self ._set_entrypoint (customer_script_name )
351+ self ._set_entrypoint (command , customer_script_name )
353352
354- super (ScriptModeProcessor , self ).run (
355- inputs = inputs_with_code , outputs = outputs , wait = wait , logs = logs , job_name = job_name
353+ super (ScriptProcessor , self ).run (
354+ inputs = inputs_with_code ,
355+ outputs = outputs ,
356+ arguments = arguments ,
357+ wait = wait ,
358+ logs = logs ,
359+ job_name = job_name ,
356360 )
357361
358362 def _get_customer_script_name (self , code , script_name ):
@@ -418,43 +422,16 @@ def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
418422 the ProcessingInput object created from s3_uri appended to the list.
419423
420424 """
421- input_list = inputs
422425 code_file_input = ProcessingInput (
423426 source = s3_uri ,
424427 destination = os .path .join (
425428 self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME
426429 ),
427430 input_name = self ._CODE_CONTAINER_INPUT_NAME ,
428431 )
429- input_list .append (code_file_input )
430- return input_list
431-
432- def _get_execution_program (self , script_name ):
433- """Determine which executable to run the user's script with
434- based on the file extension.
432+ return inputs + [code_file_input ]
435433
436- Args:
437- script_name (str): A filename with an extension.
438-
439- Returns:
440- str: A name of an executable to run the user's script with.
441- """
442- file_extension = os .path .splitext (script_name )[1 ]
443- if file_extension == ".py" :
444- if self .py_version == "py3" :
445- return "python3"
446- if self .py_version == "py2" :
447- return "python2"
448- return "python"
449- if file_extension == ".sh" :
450- return "bash"
451- raise ValueError (
452- """Script Mode supports Python or Bash scripts.
453- To use a custom entrypoint, please use Processor.
454- """
455- )
456-
457- def _set_entrypoint (self , customer_script_name ):
434+ def _set_entrypoint (self , command , customer_script_name ):
458435 """Sets the entrypoint based on the customer's script and corresponding executable.
459436
460437 Args:
@@ -463,8 +440,7 @@ def _set_entrypoint(self, customer_script_name):
463440 customer_script_location = os .path .join (
464441 self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME , customer_script_name
465442 )
466- execution_program = self ._get_execution_program (customer_script_name )
467- self .entrypoint = [execution_program , customer_script_location ]
443+ self .entrypoint = command + [customer_script_location ]
468444
469445
470446class ProcessingJob (_Job ):
@@ -564,11 +540,11 @@ def __init__(
564540 source ,
565541 destination ,
566542 input_name = None ,
567- s3_data_type = S3DataType . MANIFEST_FILE ,
568- s3_input_mode = S3InputMode . FILE ,
569- s3_download_mode = S3DownloadMode . CONTINUOUS ,
570- s3_data_distribution_type = S3DataDistributionType . FULLY_REPLICATED ,
571- s3_compression_type = S3CompressionType . NONE ,
543+ s3_data_type = "ManifestFile" ,
544+ s3_input_mode = "File" ,
545+ s3_download_mode = "Continuous" ,
546+ s3_data_distribution_type = "FullyReplicated" ,
547+ s3_compression_type = "None" ,
572548 ):
573549 """Initialize a ``ProcessingInput`` instance. ProcessingInput accepts parameters
574550 that specify an S3 input for a processing job and provides a method
@@ -579,11 +555,12 @@ def __init__(
579555 destination (str): The destination of the input.
580556 input_name (str): The user-provided name for the input. If a name
581557 is not provided, one will be generated.
582- s3_data_type (sagemaker.s3.S3DataType):
583- s3_input_mode (sagemaker.s3.S3InputMode):
584- s3_download_mode (sagemaker.s3.S3DownloadMode):
585- s3_data_distribution_type (sagemaker.s3.S3DataDistributionType):
586- s3_compression_type (sagemaker.s3.S3CompressionType):
558+ s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
559+ s3_input_mode (str): Valid options are "Pipe" or "File".
560+ s3_download_mode (str): Valid options are "StartOfJob" or "Continuous".
561+ s3_data_distribution_type (str): Valid options are "FullyReplicated"
562+ or "ShardedByS3Key".
563+ s3_compression_type (str): Valid options are "None" or "Gzip".
587564 """
588565 self .source = source
589566 self .destination = destination
@@ -602,21 +579,18 @@ def to_request_dict(self):
602579 "S3Input" : {
603580 "S3Uri" : self .source ,
604581 "LocalPath" : self .destination ,
605- "S3DataType" : self .s3_data_type . value ,
606- "S3InputMode" : self .s3_input_mode . value ,
607- "S3DownloadMode" : self .s3_download_mode . value ,
608- "S3DataDistributionType" : self .s3_data_distribution_type . value ,
582+ "S3DataType" : self .s3_data_type ,
583+ "S3InputMode" : self .s3_input_mode ,
584+ "S3DownloadMode" : self .s3_download_mode ,
585+ "S3DataDistributionType" : self .s3_data_distribution_type ,
609586 },
610587 }
611588
612589 # Check the compression type, then add it to the dictionary.
613- if (
614- self .s3_compression_type == S3CompressionType .GZIP
615- and self .s3_input_mode != S3InputMode .PIPE
616- ):
590+ if self .s3_compression_type == "Gzip" and self .s3_input_mode != "Pipe" :
617591 raise ValueError ("Data can only be gzipped when the input mode is Pipe." )
618592 if self .s3_compression_type is not None :
619- s3_input_request ["S3Input" ]["S3CompressionType" ] = self .s3_compression_type . value
593+ s3_input_request ["S3Input" ]["S3CompressionType" ] = self .s3_compression_type
620594
621595 # Return the request dictionary.
622596 return s3_input_request
@@ -627,12 +601,7 @@ class ProcessingOutput(object):
627601 a method to turn those parameters into a dictionary."""
628602
629603 def __init__ (
630- self ,
631- source ,
632- destination ,
633- output_name = None ,
634- kms_key_id = None ,
635- s3_upload_mode = S3UploadMode .CONTINUOUS ,
604+ self , source , destination , output_name = None , kms_key_id = None , s3_upload_mode = "Continuous"
636605 ):
637606 """Initialize a ``ProcessingOutput`` instance. ProcessingOutput accepts parameters that
638607 specify an S3 output for a processing job and provides a method to turn
@@ -643,7 +612,7 @@ def __init__(
643612 destination (str): The destination of the output.
644613 output_name (str): The name of the output.
645614 kms_key_id (str): The KMS key id for the output.
646- s3_upload_mode (sagemaker.s3.S3UploadMode):
615+ s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous".
647616 """
648617 self .source = source
649618 self .destination = destination
@@ -659,7 +628,7 @@ def to_request_dict(self):
659628 "S3Output" : {
660629 "S3Uri" : self .destination ,
661630 "LocalPath" : self .source ,
662- "S3UploadMode" : self .s3_upload_mode . value ,
631+ "S3UploadMode" : self .s3_upload_mode ,
663632 },
664633 }
665634
0 commit comments