Skip to content

Commit 3cc40a3

Browse files
authored
Wait timeout exception handling (#5400)
1 parent 36f96be commit 3cc40a3

File tree

5 files changed

+28
-9
lines changed

5 files changed

+28
-9
lines changed

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _calculate_transition_duration(trans) -> Tuple[str, str]:
157157
def wait(
158158
training_job: TrainingJob,
159159
poll: int = 5,
160-
timeout: Optional[int] = None
160+
timeout: Optional[int] = 3000
161161
) -> None:
162162
"""Wait for training job to complete with progress tracking.
163163
@@ -192,8 +192,10 @@ def wait(
192192
iteration = 0
193193
while True:
194194
iteration += 1
195-
time.sleep(poll)
196-
training_job.refresh()
195+
time.sleep(1)
196+
if iteration == poll:
197+
training_job.refresh()
198+
iteration = 0
197199
clear_output(wait=True)
198200

199201
status = training_job.training_job_status
@@ -302,7 +304,7 @@ def wait(
302304
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)
303305

304306
if timeout and elapsed >= timeout:
305-
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
307+
raise TimeoutExceededError(resource_type="TrainingJob", status=status)
306308

307309
else:
308310
print(f"\nTrainingJob Name: {training_job.training_job_name}")
@@ -363,7 +365,7 @@ def wait(
363365
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)
364366

365367
if timeout and elapsed >= timeout:
366-
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
368+
raise TimeoutExceededError(resource_type="TrainingJob", status=status)
367369

368370

369371
except (FailedStatusError, TimeoutExceededError):

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,11 @@ def train(self,
261261

262262
if wait:
263263
from sagemaker.train.common_utils.trainer_wait import wait as _wait
264-
_wait(training_job)
264+
from sagemaker.core.utils.exceptions import TimeoutExceededError
265+
try :
266+
_wait(training_job)
267+
except TimeoutExceededError as e:
268+
logger.error("Error: %s", e)
265269

266270
self.latest_training_job = training_job
267271
return training_job

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
268268

269269
if wait:
270270
from sagemaker.train.common_utils.trainer_wait import wait as _wait
271-
_wait(training_job)
271+
from sagemaker.core.utils.exceptions import TimeoutExceededError
272+
try :
273+
_wait(training_job)
274+
except TimeoutExceededError as e:
275+
logger.error("Error: %s", e)
272276

273277
self.latest_training_job = training_job
274278
return training_job

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
268268

269269
if wait:
270270
from sagemaker.train.common_utils.trainer_wait import wait as _wait
271-
_wait(training_job)
271+
from sagemaker.core.utils.exceptions import TimeoutExceededError
272+
try:
273+
_wait(training_job)
274+
except TimeoutExceededError as e:
275+
logger.error("Error: %s", e)
272276

273277
self.latest_training_job = training_job
274278
return training_job

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from logging import exception
12
from typing import Optional, Union
23
import logging
34
from sagemaker.train.base_trainer import BaseTrainer
@@ -261,7 +262,11 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
261262

262263
if wait:
263264
from sagemaker.train.common_utils.trainer_wait import wait as _wait
264-
_wait(training_job)
265+
from sagemaker.core.utils.exceptions import TimeoutExceededError
266+
try :
267+
_wait(training_job)
268+
except TimeoutExceededError as e:
269+
logger.error("Error: %s", e)
265270

266271
self.latest_training_job = training_job
267272
return training_job

0 commit comments

Comments
 (0)