-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdispatcher.py
304 lines (278 loc) · 13.4 KB
/
dispatcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import logging
import multiprocessing
import numbers
import os
import time
from queue import Queue
from queue import LifoQueue
from ramp_database.tools.submission import get_submissions
from ramp_database.tools.submission import get_submission_by_id
from ramp_database.tools.submission import get_submission_state
from ramp_database.tools.submission import set_bagged_scores
# from ramp_database.tools.submission import set_predictions
from ramp_database.tools.submission import set_time
from ramp_database.tools.submission import set_scores
from ramp_database.tools.submission import set_submission_error_msg
from ramp_database.tools.submission import set_submission_state
from ramp_database.tools.leaderboard import update_all_user_leaderboards
from ramp_database.tools.leaderboard import update_leaderboards
from ramp_database.tools.leaderboard import update_user_leaderboards
from ramp_database.utils import session_scope
from ramp_utils import generate_ramp_config
from ramp_utils import generate_worker_config
from ramp_utils import read_config
from .local import CondaEnvWorker
logger = logging.getLogger('RAMP-DISPATCHER')
log_file = 'dispatcher.log'
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s') # noqa
fileHandler = logging.FileHandler(log_file, mode='a')
fileHandler.setFormatter(formatter)
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
logger.setLevel(logging.DEBUG)
logger.addHandler(fileHandler)
logger.addHandler(streamHandler)
class Dispatcher:
"""Dispatcher which schedule workers and communicate with the database.
The dispatcher uses two queues: a queue containing containing the workers
which should be launched and a queue containing the workers which are being
processed. The latter queue has a limited size defined by ``n_workers``.
Note that these workers can run simultaneously.
Parameters
----------
config : dict or str
A configuration YAML file containing the information about the
database.
event_config : dict or str
A RAMP configuration YAML file with information regarding the worker
and the ramp event.
worker : Worker, default=CondaEnvWorker
The type of worker to launch. By default, we launch local worker which
uses ``conda``.
n_workers : int, default=1
Maximum number of workers which can run submissions simultaneously.
n_threads : None or int
The number of threads that each worker can use. By default, there is no
limit imposed.
hunger_policy : {None, 'sleep', 'exit'}
Policy to apply in case that there is no anymore workers to be
processed:
* if None: the dispatcher will work without interruption;
* if 'sleep': the dispatcher will sleep for 5 seconds before to check
for new submission;
* if 'exit': the dispatcher will stop after collecting the results of
the last submissions.
time_between_collection : int, default=1
The amount of time in seconds to wait before checking if we can
collect results from worker.
.. note::
This parameter is important when using a cloud platform to run
submissions, as the check for collection will be done through SSH.
Thus, if the time between checks is too small, the repetitive
SSH requests may be potentially blocked by the cloud provider.
"""
def __init__(self, config, event_config, worker=None, n_workers=1,
n_threads=None, hunger_policy=None,
time_between_collection=1):
self.worker = CondaEnvWorker if worker is None else worker
self.n_workers = (max(multiprocessing.cpu_count() + 1 + n_workers, 1)
if n_workers < 0 else n_workers)
self.hunger_policy = hunger_policy
self.time_between_collection = time_between_collection
# init the poison pill to kill the dispatcher
self._poison_pill = False
# create the different dispatcher queues
self._awaiting_worker_queue = Queue()
self._processing_worker_queue = LifoQueue(maxsize=self.n_workers)
self._processed_submission_queue = Queue()
# split the different configuration required
if (isinstance(config, str) and
isinstance(event_config, str)):
self._database_config = read_config(config,
filter_section='sqlalchemy')
self._ramp_config = generate_ramp_config(event_config, config)
else:
self._database_config = config['sqlalchemy']
self._ramp_config = event_config['ramp']
self._worker_config = generate_worker_config(event_config, config)
# set the number of threads for openmp, openblas, and mkl
self.n_threads = n_threads
if self.n_threads is not None:
if not isinstance(self.n_threads, numbers.Integral):
raise TypeError(
"The parameter 'n_threads' should be a positive integer. "
"Got {} instead.".format(repr(self.n_threads))
)
for lib in ('OMP', 'MKL', 'OPENBLAS'):
os.environ[lib + '_NUM_THREADS'] = str(self.n_threads)
self._logger = logger.getChild(self._ramp_config['event_name'])
def fetch_from_db(self, session):
"""Fetch the submission from the database and create the workers."""
submissions = get_submissions(session,
self._ramp_config['event_name'],
state='new')
if not submissions:
return
for submission_id, submission_name, _ in submissions:
# do not train the sandbox submission
submission = get_submission_by_id(session, submission_id)
if not submission.is_not_sandbox:
continue
# create the worker
worker = self.worker(self._worker_config, submission_name)
set_submission_state(session, submission_id, 'sent_to_training')
update_user_leaderboards(
session, self._ramp_config['event_name'],
submission .team.name, new_only=True,
)
self._awaiting_worker_queue.put_nowait((worker, (submission_id,
submission_name)))
self._logger.info(
f'Submission {submission_name} added to the queue of '
'submission to be processed'
)
def launch_workers(self, session):
"""Launch the awaiting workers if possible."""
while (not self._processing_worker_queue.full() and
not self._awaiting_worker_queue.empty()):
worker, (submission_id, submission_name) = \
self._awaiting_worker_queue.get()
self._logger.info(f'Starting worker: {worker}')
try:
worker.setup()
if worker.status != "error":
worker.launch_submission()
except Exception as e:
self._logger.error(
f'Worker finished with unhandled exception:\n {e}'
)
worker.status = 'error'
if worker.status == 'error':
set_submission_state(session, submission_id, 'checking_error')
continue
set_submission_state(session, submission_id, 'training')
submission = get_submission_by_id(session, submission_id)
update_user_leaderboards(
session, self._ramp_config['event_name'],
submission.team.name, new_only=True,
)
self._processing_worker_queue.put_nowait(
(worker, (submission_id, submission_name)))
self._logger.info(
f'Store the worker {worker} into the processing queue'
)
def collect_result(self, session):
"""Collect result from processed workers."""
try:
workers, submissions = zip(
*[self._processing_worker_queue.get()
for _ in range(self._processing_worker_queue.qsize())]
)
except ValueError:
if self.hunger_policy == 'sleep':
time.sleep(5)
elif self.hunger_policy == 'exit':
self._poison_pill = True
return
for worker, (submission_id, submission_name) in zip(workers,
submissions):
dt = worker.time_since_last_status_check()
if dt is not None and dt < self.time_between_collection:
self._processing_worker_queue.put_nowait(
(worker, (submission_id, submission_name)))
time.sleep(0)
continue
elif worker.status == 'running':
self._processing_worker_queue.put_nowait(
(worker, (submission_id, submission_name)))
time.sleep(0)
elif worker.status == 'retry':
set_submission_state(session, submission_id, 'new')
self._logger.info(
f'Submission: {submission_id} has been interrupted. '
'It will be added to queue again and retried.'
)
worker.teardown()
else:
self._logger.info(f'Collecting results from worker {worker}')
returncode, stderr = worker.collect_results()
if returncode:
if returncode == 124:
self._logger.info(
f'Worker {worker} killed due to timeout.'
)
else:
self._logger.info(
f'Worker {worker} killed due to an error '
f'during training: {stderr}'
)
submission_status = 'training_error'
else:
submission_status = 'tested'
set_submission_state(
session, submission_id, submission_status
)
set_submission_error_msg(session, submission_id, stderr)
self._processed_submission_queue.put_nowait(
(submission_id, submission_name))
worker.teardown()
def update_database_results(self, session):
"""Update the database with the results of ramp_test_submission."""
make_update_leaderboard = False
while not self._processed_submission_queue.empty():
make_update_leaderboard = True
submission_id, submission_name = \
self._processed_submission_queue.get_nowait()
if 'error' in get_submission_state(session, submission_id):
continue
self._logger.info(
f'Write info in database for submission {submission_name}'
)
path_predictions = os.path.join(
self._worker_config['predictions_dir'], submission_name
)
# NOTE: In the past we were adding the predictions into the
# database. Since they require too much space, we stop to store
# them in the database and instead, keep it onto the disk.
# set_predictions(session, submission_id, path_predictions)
set_time(session, submission_id, path_predictions)
set_scores(session, submission_id, path_predictions)
set_bagged_scores(session, submission_id, path_predictions)
set_submission_state(session, submission_id, 'scored')
if make_update_leaderboard:
self._logger.info('Update all leaderboards')
update_leaderboards(session, self._ramp_config['event_name'])
update_all_user_leaderboards(session,
self._ramp_config['event_name'])
self._logger.info('Leaderboards updated')
@staticmethod
def _reset_submission_after_failure(session, even_name):
submissions = get_submissions(session, even_name, state=None)
for submission_id, _, _ in submissions:
submission_state = get_submission_state(session, submission_id)
if submission_state in ('training', 'sent_to_training'):
set_submission_state(session, submission_id, 'new')
def launch(self):
"""Launch the dispatcher."""
self._logger.info('Starting the RAMP dispatcher')
with session_scope(self._database_config) as session:
self._logger.info('Open a session to the database')
self._logger.info(
'Reset unfinished trained submission from previous session'
)
self._reset_submission_after_failure(
session, self._ramp_config['event_name']
)
try:
while not self._poison_pill:
self.fetch_from_db(session)
self.launch_workers(session)
self.collect_result(session)
self.update_database_results(session)
finally:
# reset the submissions to 'new' in case of error or unfinished
# training
self._reset_submission_after_failure(
session, self._ramp_config['event_name']
)
self._logger.info('Dispatcher killed by the poison pill')