Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show training error msg #495

Merged
merged 3 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ramp-engine/ramp_engine/aws/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def collect_results(self):
except Exception as e:
logger.error("Error occurred when downloading the logs"
f" from the submission: {e}")
exit_status = 1
exit_status = 2
error_msg = str(e)
self.status = 'error'
if exit_status == 0:
Expand All @@ -189,7 +189,7 @@ def collect_results(self):
error_msg = _get_traceback(
aws._get_log_content(self.config, self.submission))
self.status = 'collected'
exit_status, error_msg = 1, ""
exit_status = 1
logger.info(repr(self))
return exit_status, error_msg

Expand Down
10 changes: 7 additions & 3 deletions ramp-engine/ramp_engine/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def collect_result(self, session):
for worker, (submission_id, submission_name) in zip(workers,
submissions):
dt = worker.time_since_last_status_check()
if dt is not None and dt < self.time_between_collection:
if (dt is not None) and (dt < self.time_between_collection):
self._processing_worker_queue.put_nowait(
(worker, (submission_id, submission_name)))
time.sleep(0)
Expand All @@ -231,20 +231,24 @@ def collect_result(self, session):
else:
self._logger.info(f'Collecting results from worker {worker}')
returncode, stderr = worker.collect_results()

if returncode:
if returncode == 124:
self._logger.info(
f'Worker {worker} killed due to timeout.'
)
submission_status = 'checking_error'
elif returncode == 2:
# Error occurred when downloading the logs
submission_status = 'checking_error'
else:
self._logger.info(
f'Worker {worker} killed due to an error '
f'during training: {stderr}'
)
submission_status = 'training_error'
submission_status = 'training_error'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so basically the issue was that all checking_error were set as training_error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. that was actually another issue. The issue was that the error message was set to ''

else:
submission_status = 'tested'

set_submission_state(
session, submission_id, submission_status
)
Expand Down
3 changes: 2 additions & 1 deletion ramp-engine/ramp_engine/tests/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ class DummyInstance:
exit_status, error_msg = worker.collect_results()
assert 'Error occurred when downloading the logs' in caplog.text
assert 'Trying to download the log once again' in caplog.text
assert exit_status == 1
assert exit_status == 2
assert 'test' in error_msg
assert worker.status == 'error'


@mock.patch('ramp_engine.aws.api._rsync')
Expand Down
82 changes: 78 additions & 4 deletions ramp-engine/ramp_engine/tests/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import shutil
import os

import pytest
import shutil
from unittest import mock

from ramp_utils import read_config
from ramp_utils.testing import database_config_template
Expand Down Expand Up @@ -228,7 +228,6 @@ def test_dispatcher_worker_retry(session_toy):

while not dispatcher._processing_worker_queue.empty():
dispatcher.collect_result(session_toy)

submissions = get_submissions(session_toy, 'iris_test', 'new')
assert submission_name in [sub[1] for sub in submissions]

Expand All @@ -253,7 +252,82 @@ def test_dispatcher_aws_not_launching(session_toy_aws, caplog):
assert 'training' not in caplog.text
num_running_workers = dispatcher._processing_worker_queue.qsize()
assert num_running_workers == 0

submissions2 = get_submissions(session_toy_aws, 'iris_aws_test', 'new')
# assert that all the submissions are still in the 'new' state
assert len(submissions) == len(submissions2)


@mock.patch('ramp_engine.aws.api.download_log')
@mock.patch('ramp_engine.aws.api.check_instance_status')
@mock.patch('ramp_engine.aws.api._get_log_content')
@mock.patch('ramp_engine.aws.api._training_successful')
@mock.patch('ramp_engine.aws.api._training_finished')
@mock.patch('ramp_engine.aws.api.is_spot_terminated')
@mock.patch('ramp_engine.aws.api.launch_train')
@mock.patch('ramp_engine.aws.api.upload_submission')
@mock.patch('ramp_engine.aws.api.launch_ec2_instances')
def test_info_on_training_error(test_launch_ec2_instances, upload_submission,
launch_train,
is_spot_terminated, training_finished,
training_successful,
get_log_content, check_instance_status,
download_log,
session_toy_aws,
caplog):
# make sure that the Python error from the solution is passed to the
# dispatcher
# everything shoud be mocked as correct output from AWS instances
# on setting up the instance and loading the submission
# mock dummy AWS instance
class DummyInstance:
id = 1
test_launch_ec2_instances.return_value = (DummyInstance(),), 0
upload_submission.return_value = 0
launch_train.return_value = 0
is_spot_terminated.return_value = 0
training_finished.return_value = False
download_log.return_value = 0

config = read_config(database_config_template())
event_config = read_config(ramp_aws_config_template())

dispatcher = Dispatcher(config=config,
event_config=event_config,
worker=AWSWorker, n_workers=10,
hunger_policy='exit')
dispatcher.fetch_from_db(session_toy_aws)
dispatcher.launch_workers(session_toy_aws)
num_running_workers = dispatcher._processing_worker_queue.qsize()
# worker, (submission_id, submission_name) = \
# dispatcher._processing_worker_queue.get()
# assert worker.status == 'running'
submissions = get_submissions(session_toy_aws,
'iris_aws_test',
'training')
ids = [submissions[idx][0] for idx in range(len(submissions))]
assert len(submissions) > 1
assert num_running_workers == len(ids)

dispatcher.time_between_collection = 0
training_successful.return_value = False

# now we will end the submission with training error
training_finished.return_value = True
training_error_msg = 'Python error here'
get_log_content.return_value = training_error_msg
check_instance_status.return_value = 'finished'

dispatcher.collect_result(session_toy_aws)

# the worker which we were using should have been teared down
num_running_workers = dispatcher._processing_worker_queue.qsize()

assert num_running_workers == 0

submissions = get_submissions(session_toy_aws,
'iris_aws_test',
'training_error')
assert len(submissions) == len(ids)

submission = get_submission_by_id(session_toy_aws, submissions[0][0])
assert training_error_msg in submission.error_msg