Skip to content

Commit f2470c5

Browse files
authored
fix(executor): simplify executor by removing threading (explodinggradients#1093)
ragas `0.1.10` broke how the executor was functioning and there were a few hard to reproduce bugs overall. This fixes all those. - uses `nest_asyncio` in jupyter notebooks - added tests in jupyter notebook with `nbmake` - removed the Runner. this means all the `Tasks` will be run in the same event loop as main. for jupyter notebooks we will use `nest_async` lib takes inspiration from explodinggradients#689 fixes: explodinggradients#681
1 parent 0172bda commit f2470c5

File tree

5 files changed

+205
-89
lines changed

5 files changed

+205
-89
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ clean: ## Clean all generated files
2727
run-ci: format lint type ## Running all CI checks
2828
test: ## Run tests
2929
@echo "Running tests..."
30-
@pytest tests/unit $(shell if [ -n "$(k)" ]; then echo "-k $(k)"; fi)
30+
@pytest --nbmake tests/unit $(shell if [ -n "$(k)" ]; then echo "-k $(k)"; fi)
3131
test-e2e: ## Run end2end tests
3232
echo "running end2end tests..."
3333
@pytest tests/e2e -s

requirements/test.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pytest
22
pytest-xdist[psutil]
33
pytest-asyncio
44
llama_index
5+
nbmake

src/ragas/executor.py

+47-86
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import asyncio
44
import logging
5-
import sys
6-
import threading
75
import typing as t
86
from dataclasses import dataclass, field
97

@@ -16,77 +14,28 @@
1614
logger = logging.getLogger(__name__)
1715

1816

19-
def runner_exception_hook(args: threading.ExceptHookArgs):
20-
raise args.exc_type
17+
def is_event_loop_running() -> bool:
18+
try:
19+
loop = asyncio.get_running_loop()
20+
except RuntimeError:
21+
return False
22+
else:
23+
return loop.is_running()
2124

2225

23-
# set a custom exception hook
24-
# threading.excepthook = runner_exception_hook
25-
26-
27-
def as_completed(loop, coros, max_workers):
28-
loop_arg_dict = {"loop": loop} if sys.version_info[:2] < (3, 10) else {}
26+
def as_completed(coros, max_workers):
2927
if max_workers == -1:
30-
return asyncio.as_completed(coros, **loop_arg_dict)
28+
return asyncio.as_completed(coros)
3129

32-
# loop argument is removed since Python 3.10
33-
semaphore = asyncio.Semaphore(max_workers, **loop_arg_dict)
30+
semaphore = asyncio.Semaphore(max_workers)
3431

3532
async def sema_coro(coro):
3633
async with semaphore:
3734
return await coro
3835

3936
sema_coros = [sema_coro(c) for c in coros]
40-
return asyncio.as_completed(sema_coros, **loop_arg_dict)
41-
42-
43-
class Runner(threading.Thread):
44-
def __init__(
45-
self,
46-
jobs: t.List[t.Tuple[t.Coroutine, str]],
47-
desc: str,
48-
keep_progress_bar: bool = True,
49-
raise_exceptions: bool = True,
50-
run_config: t.Optional[RunConfig] = None,
51-
):
52-
super().__init__()
53-
self.jobs = jobs
54-
self.desc = desc
55-
self.keep_progress_bar = keep_progress_bar
56-
self.raise_exceptions = raise_exceptions
57-
self.run_config = run_config or RunConfig()
58-
59-
# create task
60-
try:
61-
self.loop = asyncio.get_event_loop()
62-
except RuntimeError:
63-
self.loop = asyncio.new_event_loop()
64-
self.futures = as_completed(
65-
loop=self.loop,
66-
coros=[coro for coro, _ in self.jobs],
67-
max_workers=self.run_config.max_workers,
68-
)
69-
70-
async def _aresults(self) -> t.List[t.Any]:
71-
results = []
72-
for future in tqdm(
73-
self.futures,
74-
desc=self.desc,
75-
total=len(self.jobs),
76-
# whether you want to keep the progress bar after completion
77-
leave=self.keep_progress_bar,
78-
):
79-
r = await future
80-
results.append(r)
8137

82-
return results
83-
84-
def run(self):
85-
results = []
86-
try:
87-
results = self.loop.run_until_complete(self._aresults())
88-
finally:
89-
self.results = results
38+
return asyncio.as_completed(sema_coros)
9039

9140

9241
@dataclass
@@ -95,21 +44,22 @@ class Executor:
9544
keep_progress_bar: bool = True
9645
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
9746
raise_exceptions: bool = False
98-
run_config: t.Optional[RunConfig] = field(default_factory=RunConfig, repr=False)
47+
run_config: t.Optional[RunConfig] = field(default=None, repr=False)
9948

10049
def wrap_callable_with_index(self, callable: t.Callable, counter):
10150
async def wrapped_callable_async(*args, **kwargs):
10251
result = np.nan
10352
try:
10453
result = await callable(*args, **kwargs)
10554
except MaxRetriesExceeded as e:
55+
# this only for testset generation v2
10656
logger.warning(f"max retries exceeded for {e.evolution}")
10757
except Exception as e:
10858
if self.raise_exceptions:
10959
raise e
11060
else:
11161
logger.error(
112-
"Runner in Executor raised an exception", exc_info=True
62+
"Runner in Executor raised an exception", exc_info=False
11363
)
11464

11565
return counter, result
@@ -120,29 +70,40 @@ def submit(
12070
self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs
12171
):
12272
callable_with_index = self.wrap_callable_with_index(callable, len(self.jobs))
123-
self.jobs.append((callable_with_index(*args, **kwargs), name))
73+
self.jobs.append((callable_with_index, args, kwargs, name))
12474

12575
def results(self) -> t.List[t.Any]:
126-
executor_job = Runner(
127-
jobs=self.jobs,
128-
desc=self.desc,
129-
keep_progress_bar=self.keep_progress_bar,
130-
raise_exceptions=self.raise_exceptions,
131-
run_config=self.run_config,
132-
)
133-
executor_job.start()
134-
try:
135-
executor_job.join()
136-
finally:
137-
...
138-
139-
if executor_job.results is None:
140-
if self.raise_exceptions:
141-
raise RuntimeError(
142-
"Executor failed to complete. Please check logs above for full info."
76+
if is_event_loop_running():
77+
# an event loop is running so call nested_asyncio to fix this
78+
try:
79+
import nest_asyncio
80+
except ImportError:
81+
raise ImportError(
82+
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
14383
)
144-
else:
145-
logger.error("Executor failed to complete. Please check logs above.")
146-
return []
147-
sorted_results = sorted(executor_job.results, key=lambda x: x[0])
84+
85+
nest_asyncio.apply()
86+
87+
# create a generator for which returns tasks as they finish
88+
futures_as_they_finish = as_completed(
89+
coros=[afunc(*args, **kwargs) for afunc, args, kwargs, _ in self.jobs],
90+
max_workers=(self.run_config or RunConfig()).max_workers,
91+
)
92+
93+
async def _aresults() -> t.List[t.Any]:
94+
results = []
95+
for future in tqdm(
96+
futures_as_they_finish,
97+
desc=self.desc,
98+
total=len(self.jobs),
99+
# whether you want to keep the progress bar after completion
100+
leave=self.keep_progress_bar,
101+
):
102+
r = await future
103+
results.append(r)
104+
105+
return results
106+
107+
results = asyncio.run(_aresults())
108+
sorted_results = sorted(results, key=lambda x: x[0])
148109
return [r[1] for r in sorted_results]

tests/unit/test_executor.py

+76-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
def test_order_of_execution():
1+
import asyncio
2+
import pytest
3+
4+
5+
@pytest.mark.asyncio
6+
async def test_order_of_execution():
27
from ragas.executor import Executor
38

4-
async def echo_order(index):
9+
async def echo_order(index: int):
510
return index
611

712
# Arrange
@@ -14,3 +19,72 @@ async def echo_order(index):
1419
results = executor.results()
1520
# Assert
1621
assert results == list(range(1, 11))
22+
23+
24+
@pytest.mark.asyncio
25+
async def test_executor_in_script():
26+
from ragas.executor import Executor
27+
28+
async def echo_order(index: int):
29+
await asyncio.sleep(0.1)
30+
return index
31+
32+
# Arrange
33+
executor = Executor()
34+
35+
# Act
36+
# add 10 jobs to the executor
37+
for i in range(1, 4):
38+
executor.submit(echo_order, i, name=f"echo_order_{i}")
39+
results = executor.results()
40+
# Assert
41+
assert results == list(range(1, 4))
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_executor_with_running_loop():
46+
import asyncio
47+
from ragas.executor import Executor
48+
49+
loop = asyncio.new_event_loop()
50+
loop.run_until_complete(asyncio.sleep(0.1))
51+
52+
async def echo_order(index: int):
53+
await asyncio.sleep(0.1)
54+
return index
55+
56+
# Arrange
57+
executor = Executor()
58+
59+
# Act
60+
# add 10 jobs to the executor
61+
for i in range(1, 4):
62+
executor.submit(echo_order, i, name=f"echo_order_{i}")
63+
results = executor.results()
64+
# Assert
65+
assert results == list(range(1, 4))
66+
67+
68+
def test_is_event_loop_running_in_script():
69+
from ragas.executor import is_event_loop_running
70+
71+
assert is_event_loop_running() is False
72+
73+
74+
def test_as_completed_in_script():
75+
from ragas.executor import as_completed
76+
77+
async def echo_order(index: int):
78+
await asyncio.sleep(index)
79+
return index
80+
81+
async def _run():
82+
results = []
83+
for t in as_completed([echo_order(1), echo_order(2), echo_order(3)], 3):
84+
r = await t
85+
results.append(r)
86+
return results
87+
88+
results = asyncio.run(_run())
89+
90+
assert results == [1, 2, 3]
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%load_ext autoreload\n",
10+
"%autoreload 2"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 16,
16+
"metadata": {},
17+
"outputs": [
18+
{
19+
"data": {
20+
"application/vnd.jupyter.widget-view+json": {
21+
"model_id": "e89bde88d0054ede891c3659b439a02f",
22+
"version_major": 2,
23+
"version_minor": 0
24+
},
25+
"text/plain": [
26+
"Evaluating: 0%| | 0/10 [00:00<?, ?it/s]"
27+
]
28+
},
29+
"metadata": {},
30+
"output_type": "display_data"
31+
}
32+
],
33+
"source": [
34+
"# from ragas.executor import Executor\n",
35+
"# from asyncio import sleep\n",
36+
"\n",
37+
"# exec = Executor(raise_exceptions=True)\n",
38+
"# for i in range(10):\n",
39+
"# exec.submit(sleep, i)\n",
40+
" \n",
41+
"# try:\n",
42+
"# exec.results()\n",
43+
"# except Exception:\n",
44+
"# print(\"error\")"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": null,
50+
"metadata": {},
51+
"outputs": [],
52+
"source": [
53+
"from ragas.executor import is_event_loop_running\n",
54+
"\n",
55+
"assert is_event_loop_running() is True"
56+
]
57+
}
58+
],
59+
"metadata": {
60+
"kernelspec": {
61+
"display_name": "ragas",
62+
"language": "python",
63+
"name": "python3"
64+
},
65+
"language_info": {
66+
"codemirror_mode": {
67+
"name": "ipython",
68+
"version": 3
69+
},
70+
"file_extension": ".py",
71+
"mimetype": "text/x-python",
72+
"name": "python",
73+
"nbconvert_exporter": "python",
74+
"pygments_lexer": "ipython3",
75+
"version": "3.10.12"
76+
}
77+
},
78+
"nbformat": 4,
79+
"nbformat_minor": 2
80+
}

0 commit comments

Comments
 (0)