Skip to content
15 changes: 11 additions & 4 deletions src/sagemaker/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):

CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs'

def __init__(self, training_job_name, metric_names=None, sagemaker_session=None):
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None,
start_time=None, end_time=None, period=None):
"""Initialize a ``TrainingJobAnalytics`` instance.

Args:
Expand All @@ -216,6 +217,10 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None)
self._sage_client = sagemaker_session.sagemaker_client
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
self._training_job_name = training_job_name
self._start_time = start_time
self._end_time = end_time
self._period = period if period else 60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make the default a constant?


if metric_names:
self._metric_names = metric_names
else:
Expand Down Expand Up @@ -245,13 +250,15 @@ def _determine_timeinterval(self):
covering the interval of the training job
"""
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
start_time = description[u'TrainingStartTime'] # datetime object
start_time = self._start_time if self._start_time else description[u'TrainingStartTime'] # datetime object
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
# This results in logs being searched in the time range in which the correct log line was not present.
# Example - Log time - 2018-10-22 08:25:55
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log.
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
end_time = self._end_time if self._end_time else description.get(
u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)

return {
'start_time': start_time,
'end_time': end_time,
Expand All @@ -276,7 +283,7 @@ def _fetch_metric(self, metric_name):
],
'StartTime': self._time_interval['start_time'],
'EndTime': self._time_interval['end_time'],
'Period': 60,
'Period': self._period,
'Statistics': ['Average'],
}
raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints']
Expand Down
4 changes: 4 additions & 0 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_mnist(sagemaker_session, instance_type):
sagemaker_session=sagemaker_session,
py_version='py3',
framework_version=TensorFlow.LATEST_VERSION,
metric_definitions=[{'Name': 'train:global_steps', 'Regex': 'global_step\/sec:\s(.*)'}],
base_job_name='test-tf-sm-mnist')
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(RESOURCE_PATH, 'data'),
Expand All @@ -56,6 +57,9 @@ def test_mnist(sagemaker_session, instance_type):
estimator.fit(inputs)
_assert_s3_files_exist(estimator.model_dir,
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
df = estimator.training_job_analytics.dataframe()
print(df)
assert df.size > 0


def test_server_side_encryption(sagemaker_session):
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,20 @@ def test_trainer_dataframe():
trainer.export_csv(tmp_name)
assert os.path.isfile(tmp_name)
os.unlink(tmp_name)


def test_start_time_end_time_and_period_specified():
describe_training_result = {
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
}
session = create_sagemaker_session(describe_training_result)
start_time = datetime.datetime(2018, 5, 16, 1, 3, 4)
end_time = datetime.datetime(2018, 5, 16, 5, 1, 1)
period = 300
trainer = TrainingJobAnalytics("my-training-job", ["metric"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: single quotes

sagemaker_session=session, start_time=start_time, end_time=end_time, period=period)

assert trainer._time_interval['start_time'] == start_time
assert trainer._time_interval['end_time'] == end_time
assert trainer._period == period