diff --git a/src/sagemaker/analytics.py b/src/sagemaker/analytics.py index 085f5d8375..d2a4fb78d2 100644 --- a/src/sagemaker/analytics.py +++ b/src/sagemaker/analytics.py @@ -29,6 +29,8 @@ # Any subsequent attempt to use pandas will raise the ImportError pd = DeferredError(e) +METRICS_PERIOD_DEFAULT = 60 # seconds + class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)): """Base class for tuning job or training job analytics classes. @@ -201,7 +203,8 @@ class TrainingJobAnalytics(AnalyticsMetricsBase): CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs' - def __init__(self, training_job_name, metric_names=None, sagemaker_session=None): + def __init__(self, training_job_name, metric_names=None, sagemaker_session=None, + start_time=None, end_time=None, period=None): """Initialize a ``TrainingJobAnalytics`` instance. Args: @@ -216,6 +219,10 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None) self._sage_client = sagemaker_session.sagemaker_client self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch') self._training_job_name = training_job_name + self._start_time = start_time + self._end_time = end_time + self._period = period or METRICS_PERIOD_DEFAULT + if metric_names: self._metric_names = metric_names else: @@ -245,13 +252,15 @@ def _determine_timeinterval(self): covering the interval of the training job """ description = self._sage_client.describe_training_job(TrainingJobName=self.name) - start_time = description[u'TrainingStartTime'] # datetime object + start_time = self._start_time or description[u'TrainingStartTime'] # datetime object # Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs. # This results in logs being searched in the time range in which the correct log line was not present. # Example - Log time - 2018-10-22 08:25:55 # Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition) # CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log. - end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1) + end_time = self._end_time or description.get( + u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1) + return { 'start_time': start_time, 'end_time': end_time, @@ -276,7 +285,7 @@ def _fetch_metric(self, metric_name): ], 'StartTime': self._time_interval['start_time'], 'EndTime': self._time_interval['end_time'], - 'Period': 60, + 'Period': self._period, 'Statistics': ['Average'], } raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints'] diff --git a/tests/integ/test_tf_script_mode.py b/tests/integ/test_tf_script_mode.py index 924dc3e308..32c99681f4 100644 --- a/tests/integ/test_tf_script_mode.py +++ b/tests/integ/test_tf_script_mode.py @@ -47,6 +47,7 @@ def test_mnist(sagemaker_session, instance_type): sagemaker_session=sagemaker_session, py_version='py3', framework_version=TensorFlow.LATEST_VERSION, + metric_definitions=[{'Name': 'train:global_steps', 'Regex': r'global_step\/sec:\s(.*)'}], base_job_name='test-tf-sm-mnist') inputs = estimator.sagemaker_session.upload_data( path=os.path.join(RESOURCE_PATH, 'data'), @@ -56,6 +57,9 @@ def test_mnist(sagemaker_session, instance_type): estimator.fit(inputs) _assert_s3_files_exist(estimator.model_dir, ['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta']) + df = estimator.training_job_analytics.dataframe() + print(df) + assert df.size > 0 def test_server_side_encryption(sagemaker_session): diff --git a/tests/unit/test_analytics.py b/tests/unit/test_analytics.py index 05eab2089c..df432a9dd5 100644 --- a/tests/unit/test_analytics.py +++ b/tests/unit/test_analytics.py @@ -245,3 +245,20 @@ def test_trainer_dataframe(): trainer.export_csv(tmp_name) assert os.path.isfile(tmp_name) os.unlink(tmp_name) + + +def test_start_time_end_time_and_period_specified(): + describe_training_result = { + 'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3), + 'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7), + } + session = create_sagemaker_session(describe_training_result) + start_time = datetime.datetime(2018, 5, 16, 1, 3, 4) + end_time = datetime.datetime(2018, 5, 16, 5, 1, 1) + period = 300 + trainer = TrainingJobAnalytics('my-training-job', ['metric'], + sagemaker_session=session, start_time=start_time, end_time=end_time, period=period) + + assert trainer._time_interval['start_time'] == start_time + assert trainer._time_interval['end_time'] == end_time + assert trainer._period == period