Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ CHANGELOG
* enhancement: Frameworks: update warning for not setting framework_version as we aren't planning a breaking change anymore
* enhancement: Session: remove hardcoded 'training' from job status error message
* bug-fix: Updated Cloudwatch namespace for metrics in TrainingJobsAnalytics

* bug-fix: Changes to use correct s3 bucket and time range for dataframes in TrainingJobAnalytics.
* enhancement: Remove MetricDefinition lookup via tuning job in TrainingJobAnalytics

1.14.1
======
Expand Down
15 changes: 3 additions & 12 deletions src/sagemaker/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,18 +310,9 @@ def _add_single_metric(self, timestamp, metric_name, value):
def _metric_names_for_training_job(self):
"""Helper method to discover the metrics defined for a training job.
"""
# First look up the tuning job
training_description = self._sage_client.describe_training_job(TrainingJobName=self._training_job_name)
tuning_job_arn = training_description.get('TuningJobArn', None)
if not tuning_job_arn:
raise ValueError(
"No metrics available. Training Job Analytics only available through Hyperparameter Tuning Jobs"
)
tuning_job_name = extract_name_from_job_arn(tuning_job_arn)
tuning_job_description = self._sage_client.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=tuning_job_name
)
training_job_definition = tuning_job_description['TrainingJobDefinition']
metric_definitions = training_job_definition['AlgorithmSpecification']['MetricDefinitions']

metric_definitions = training_description['AlgorithmSpecification']['MetricDefinitions']
metric_names = [md['Name'] for md in metric_definitions]

return metric_names
35 changes: 14 additions & 21 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,28 +824,21 @@ def test_generic_training_job_analytics(sagemaker_session):
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value={
'TuningJobArn': 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner',
'TrainingStartTime': 1530562991.299,
})
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
name='describe_hyper_parameter_tuning_job',
return_value={
'TrainingJobDefinition': {
"AlgorithmSpecification": {
"TrainingImage": "some-image-url",
"TrainingInputMode": "File",
"MetricDefinitions": [
{
"Name": "train:loss",
"Regex": "train_loss=([0-9]+\\.[0-9]+)"
},
{
"Name": "validation:loss",
"Regex": "valid_loss=([0-9]+\\.[0-9]+)"
}
]
"AlgorithmSpecification": {
"TrainingImage": "some-image-url",
"TrainingInputMode": "File",
"MetricDefinitions": [
{
"Name": "train:loss",
"Regex": "train_loss=([0-9]+\\.[0-9]+)"
},
{
"Name": "validation:loss",
"Regex": "valid_loss=([0-9]+\\.[0-9]+)"
}
}
}
)
]
},
})

e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
sagemaker_session=sagemaker_session)
Expand Down