Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
job_name=init_params['base_job_name'])
estimator._current_job_name = estimator.latest_training_job.name
estimator.latest_training_job.wait()
return estimator

Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def test_attach_framework(sagemaker_session):
return_value=returned_job_description)

framework_estimator = DummyFramework.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
assert framework_estimator._current_job_name == 'neo'
assert framework_estimator.latest_training_job.job_name == 'neo'
assert framework_estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
assert framework_estimator.train_instance_count == 1
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def test_attach(sagemaker_session, sklearn_version):
return_value=returned_job_description)

estimator = SKLearn.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
assert estimator._current_job_name == 'neo'
assert estimator.latest_training_job.job_name == 'neo'
assert estimator.py_version == PYTHON_VERSION
assert estimator.framework_version == sklearn_version
Expand Down