@@ -98,6 +98,7 @@ def __init__(
9898 rules = None ,
9999 debugger_hook_config = None ,
100100 tensorboard_output_config = None ,
101+ enable_sagemaker_metrics = None ,
101102 ):
102103 """Initialize an ``EstimatorBase`` instance.
103104
@@ -195,6 +196,10 @@ def __init__(
195196 started. If the path is unset then SageMaker assumes the
196197 checkpoints will be provided under `/opt/ml/checkpoints/`.
197198 (default: ``None``).
199+ enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
200+ Series. For more information see:
201+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
202+ (default: ``None``).
198203 """
199204 self .role = role
200205 self .train_instance_count = train_instance_count
@@ -250,6 +255,8 @@ def __init__(
250255 self .debugger_rule_configs = None
251256 self .collection_configs = None
252257
258+ self .enable_sagemaker_metrics = enable_sagemaker_metrics
259+
253260 @abstractmethod
254261 def train_image (self ):
255262 """Return the Docker image to use for training.
@@ -958,6 +965,9 @@ def start_new(cls, estimator, inputs):
958965
959966 cls ._add_spot_checkpoint_args (local_mode , estimator , train_args )
960967
968+ if estimator .enable_sagemaker_metrics is not None :
969+ train_args ["enable_sagemaker_metrics" ] = estimator .enable_sagemaker_metrics
970+
961971 estimator .sagemaker_session .train (** train_args )
962972
963973 return cls (estimator .sagemaker_session , estimator ._current_job_name )
@@ -1060,6 +1070,7 @@ def __init__(
10601070 rules = None ,
10611071 debugger_hook_config = None ,
10621072 tensorboard_output_config = None ,
1073+ enable_sagemaker_metrics = None ,
10631074 ):
10641075 """Initialize an ``Estimator`` instance.
10651076
@@ -1171,6 +1182,10 @@ def __init__(
11711182 user entry script for training. The user entry script, files in
11721183 source_dir (if specified), and dependencies will be uploaded in
11731184 a tar to S3. Also known as internet-free mode (default: ``False``).
1185+ enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
1186+ Series. For more information see:
1187+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
1188+ (default: ``None``).
11741189 """
11751190 self .image_name = image_name
11761191 self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
@@ -1201,6 +1216,7 @@ def __init__(
12011216 rules = rules ,
12021217 debugger_hook_config = debugger_hook_config ,
12031218 tensorboard_output_config = tensorboard_output_config ,
1219+ enable_sagemaker_metrics = enable_sagemaker_metrics ,
12041220 )
12051221
12061222 def enable_network_isolation (self ):
@@ -1354,6 +1370,7 @@ def __init__(
13541370 git_config = None ,
13551371 checkpoint_s3_uri = None ,
13561372 checkpoint_local_path = None ,
1373+ enable_sagemaker_metrics = None ,
13571374 ** kwargs
13581375 ):
13591376 """Base class initializer. Subclasses which override ``__init__`` should
@@ -1500,6 +1517,10 @@ def __init__(
15001517 started. If the path is unset then SageMaker assumes the
15011518 checkpoints will be provided under `/opt/ml/checkpoints/`.
15021519 (default: ``None``).
1520+ enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
1521+ Series. For more information see:
1522+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
1523+ (default: ``None``).
15031524 **kwargs: Additional kwargs passed to the ``EstimatorBase``
15041525 constructor.
15051526 """
@@ -1530,6 +1551,7 @@ def __init__(
15301551 self ._hyperparameters = hyperparameters or {}
15311552 self .checkpoint_s3_uri = checkpoint_s3_uri
15321553 self .checkpoint_local_path = checkpoint_local_path
1554+ self .enable_sagemaker_metrics = enable_sagemaker_metrics
15331555
15341556 def enable_network_isolation (self ):
15351557 """Return True if this Estimator can use network isolation to run.
0 commit comments