2121import os
2222import pathlib
2323import logging
24+ from textwrap import dedent
2425from typing import Dict , List , Optional , Tuple
2526import attr
2627
@@ -1217,18 +1218,7 @@ class FeatureStoreOutput(ApiObject):
12171218class FrameworkProcessor (ScriptProcessor ):
12181219 """Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
12191220
1220- runproc_sh = """#!/bin/bash
1221-
1222- cd /opt/ml/processing/input/code/
1223- tar -xzf sourcedir.tar.gz
1224-
1225- # Exit on any error. SageMaker uses error code to mark failed job.
1226- set -e
1227-
1228- [[ -f 'requirements.txt' ]] && pip install -r requirements.txt
1229-
1230- python {entry_point} "$@"
1231- """
1221+ framework_entrypoint_command = ["/bin/bash" ]
12321222
12331223 # Added new (kw)args for estimator. The rest are from ScriptProcessor with same defaults.
12341224 def __init__ (
@@ -1240,6 +1230,7 @@ def __init__(
12401230 instance_type ,
12411231 py_version = "py3" , # New kwarg
12421232 image_uri = None ,
1233+ command = ["python" ],
12431234 volume_size_in_gb = 30 ,
12441235 volume_kms_key = None ,
12451236 output_kms_key = None ,
@@ -1272,6 +1263,8 @@ def __init__(
12721263 is ignored when ``image_uri`` is provided.
12731264 image_uri (str): The URI of the Docker image to use for the
12741265 processing jobs (default: None).
1266+ command ([str]): The command to run, along with any command-line flags
1267+ to *precede* the ```entry_point script``` (default: ['python']).
12751268 volume_size_in_gb (int): Size in GB of the EBS volume
12761269 to use for storing data during processing (default: 30).
12771270 volume_kms_key (str): A KMS key for the processing volume (default: None).
@@ -1312,7 +1305,7 @@ def __init__(
13121305 super ().__init__ (
13131306 role = role ,
13141307 image_uri = image_uri ,
1315- command = [ "/bin/bash" ] ,
1308+ command = command ,
13161309 instance_count = instance_count ,
13171310 instance_type = instance_type ,
13181311 volume_size_in_gb = volume_size_in_gb ,
@@ -1493,7 +1486,7 @@ def run( # type: ignore[override]
14931486 )
14941487 script = estimator .uploaded_code .script_name
14951488 s3_runproc_sh = S3Uploader .upload_string_as_file_body (
1496- self .runproc_sh . format ( entry_point = script ),
1489+ self ._generate_framework_script ( script ),
14971490 desired_s3_uri = entrypoint_s3_uri ,
14981491 sagemaker_session = self .sagemaker_session ,
14991492 )
@@ -1512,6 +1505,35 @@ def run( # type: ignore[override]
15121505 kms_key = kms_key ,
15131506 )
15141507
1508+ def _generate_framework_script (self , user_script : str ) -> str :
1509+ """Generate the framework entrypoint file (as text) for a processing job.
1510+
1511+ This script implements the "framework" functionality for setting up your code:
1512+ Untar-ing the sourcedir bundle in the ```code``` input; installing extra
1513+ runtime dependencies if specified; and then invoking the ```command``` and
1514+ ```entry_point``` configured for the job.
1515+
1516+ Args:
1517+ user_script (str): Relative path to ```entry_point``` in the source bundle
1518+ - e.g. 'process.py'.
1519+ """
1520+ return dedent ("""\
1521+ #!/bin/bash
1522+
1523+ cd /opt/ml/processing/input/code/
1524+ tar -xzf sourcedir.tar.gz
1525+
1526+ # Exit on any error. SageMaker uses error code to mark failed job.
1527+ set -e
1528+
1529+ [[ -f 'requirements.txt' ]] && pip install -r requirements.txt
1530+
1531+ {entry_point_command} {entry_point} "$@"
1532+ """ ).format (
1533+ entry_point_command = " " .join (self .command ),
1534+ entry_point = user_script ,
1535+ )
1536+
15151537 def _upload_payload (
15161538 self ,
15171539 entry_point : str ,
@@ -1575,3 +1597,18 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
15751597 )
15761598 )
15771599 return inputs
1600+
1601+ def _set_entrypoint (self , command , user_script_name ):
1602+ """FrameworkProcessor override for setting processing job entrypoint.
1603+
1604+ Args:
1605+ command ([str]): Ignored in favor of self.framework_entrypoint_command
1606+ user_script_name (str): A filename with an extension.
1607+ """
1608+
1609+ user_script_location = str (
1610+ pathlib .PurePosixPath (
1611+ self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME , user_script_name
1612+ )
1613+ )
1614+ self .entrypoint = self .framework_entrypoint_command + [user_script_location ]
0 commit comments