2
2
3
3
import asyncio
4
4
import logging
5
- import sys
6
- import threading
7
5
import typing as t
8
6
from dataclasses import dataclass , field
9
7
16
14
logger = logging .getLogger (__name__ )
17
15
18
16
19
- def runner_exception_hook (args : threading .ExceptHookArgs ):
20
- raise args .exc_type
17
+ def is_event_loop_running () -> bool :
18
+ try :
19
+ loop = asyncio .get_running_loop ()
20
+ except RuntimeError :
21
+ return False
22
+ else :
23
+ return loop .is_running ()
21
24
22
25
23
- # set a custom exception hook
24
- # threading.excepthook = runner_exception_hook
25
-
26
-
27
- def as_completed (loop , coros , max_workers ):
28
- loop_arg_dict = {"loop" : loop } if sys .version_info [:2 ] < (3 , 10 ) else {}
26
+ def as_completed (coros , max_workers ):
29
27
if max_workers == - 1 :
30
- return asyncio .as_completed (coros , ** loop_arg_dict )
28
+ return asyncio .as_completed (coros )
31
29
32
- # loop argument is removed since Python 3.10
33
- semaphore = asyncio .Semaphore (max_workers , ** loop_arg_dict )
30
+ semaphore = asyncio .Semaphore (max_workers )
34
31
35
32
async def sema_coro (coro ):
36
33
async with semaphore :
37
34
return await coro
38
35
39
36
sema_coros = [sema_coro (c ) for c in coros ]
40
- return asyncio .as_completed (sema_coros , ** loop_arg_dict )
41
-
42
-
43
- class Runner (threading .Thread ):
44
- def __init__ (
45
- self ,
46
- jobs : t .List [t .Tuple [t .Coroutine , str ]],
47
- desc : str ,
48
- keep_progress_bar : bool = True ,
49
- raise_exceptions : bool = True ,
50
- run_config : t .Optional [RunConfig ] = None ,
51
- ):
52
- super ().__init__ ()
53
- self .jobs = jobs
54
- self .desc = desc
55
- self .keep_progress_bar = keep_progress_bar
56
- self .raise_exceptions = raise_exceptions
57
- self .run_config = run_config or RunConfig ()
58
-
59
- # create task
60
- try :
61
- self .loop = asyncio .get_event_loop ()
62
- except RuntimeError :
63
- self .loop = asyncio .new_event_loop ()
64
- self .futures = as_completed (
65
- loop = self .loop ,
66
- coros = [coro for coro , _ in self .jobs ],
67
- max_workers = self .run_config .max_workers ,
68
- )
69
-
70
- async def _aresults (self ) -> t .List [t .Any ]:
71
- results = []
72
- for future in tqdm (
73
- self .futures ,
74
- desc = self .desc ,
75
- total = len (self .jobs ),
76
- # whether you want to keep the progress bar after completion
77
- leave = self .keep_progress_bar ,
78
- ):
79
- r = await future
80
- results .append (r )
81
37
82
- return results
83
-
84
- def run (self ):
85
- results = []
86
- try :
87
- results = self .loop .run_until_complete (self ._aresults ())
88
- finally :
89
- self .results = results
38
+ return asyncio .as_completed (sema_coros )
90
39
91
40
92
41
@dataclass
@@ -95,21 +44,22 @@ class Executor:
95
44
keep_progress_bar : bool = True
96
45
jobs : t .List [t .Any ] = field (default_factory = list , repr = False )
97
46
raise_exceptions : bool = False
98
- run_config : t .Optional [RunConfig ] = field (default_factory = RunConfig , repr = False )
47
+ run_config : t .Optional [RunConfig ] = field (default = None , repr = False )
99
48
100
49
def wrap_callable_with_index (self , callable : t .Callable , counter ):
101
50
async def wrapped_callable_async (* args , ** kwargs ):
102
51
result = np .nan
103
52
try :
104
53
result = await callable (* args , ** kwargs )
105
54
except MaxRetriesExceeded as e :
55
+ # this only for testset generation v2
106
56
logger .warning (f"max retries exceeded for { e .evolution } " )
107
57
except Exception as e :
108
58
if self .raise_exceptions :
109
59
raise e
110
60
else :
111
61
logger .error (
112
- "Runner in Executor raised an exception" , exc_info = True
62
+ "Runner in Executor raised an exception" , exc_info = False
113
63
)
114
64
115
65
return counter , result
@@ -120,29 +70,40 @@ def submit(
120
70
self , callable : t .Callable , * args , name : t .Optional [str ] = None , ** kwargs
121
71
):
122
72
callable_with_index = self .wrap_callable_with_index (callable , len (self .jobs ))
123
- self .jobs .append ((callable_with_index ( * args , ** kwargs ) , name ))
73
+ self .jobs .append ((callable_with_index , args , kwargs , name ))
124
74
125
75
def results (self ) -> t .List [t .Any ]:
126
- executor_job = Runner (
127
- jobs = self .jobs ,
128
- desc = self .desc ,
129
- keep_progress_bar = self .keep_progress_bar ,
130
- raise_exceptions = self .raise_exceptions ,
131
- run_config = self .run_config ,
132
- )
133
- executor_job .start ()
134
- try :
135
- executor_job .join ()
136
- finally :
137
- ...
138
-
139
- if executor_job .results is None :
140
- if self .raise_exceptions :
141
- raise RuntimeError (
142
- "Executor failed to complete. Please check logs above for full info."
76
+ if is_event_loop_running ():
77
+ # an event loop is running so call nested_asyncio to fix this
78
+ try :
79
+ import nest_asyncio
80
+ except ImportError :
81
+ raise ImportError (
82
+ "It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
143
83
)
144
- else :
145
- logger .error ("Executor failed to complete. Please check logs above." )
146
- return []
147
- sorted_results = sorted (executor_job .results , key = lambda x : x [0 ])
84
+
85
+ nest_asyncio .apply ()
86
+
87
+ # create a generator for which returns tasks as they finish
88
+ futures_as_they_finish = as_completed (
89
+ coros = [afunc (* args , ** kwargs ) for afunc , args , kwargs , _ in self .jobs ],
90
+ max_workers = (self .run_config or RunConfig ()).max_workers ,
91
+ )
92
+
93
+ async def _aresults () -> t .List [t .Any ]:
94
+ results = []
95
+ for future in tqdm (
96
+ futures_as_they_finish ,
97
+ desc = self .desc ,
98
+ total = len (self .jobs ),
99
+ # whether you want to keep the progress bar after completion
100
+ leave = self .keep_progress_bar ,
101
+ ):
102
+ r = await future
103
+ results .append (r )
104
+
105
+ return results
106
+
107
+ results = asyncio .run (_aresults ())
108
+ sorted_results = sorted (results , key = lambda x : x [0 ])
148
109
return [r [1 ] for r in sorted_results ]
0 commit comments