Skip to content

Commit d2486f1

Browse files
Allow Metric.score to work within an existing asyncio loop (explodinggradients#1161)
I got errors when running metrics within an existing asyncio loop, the `asyncio.run` part makes the code fail in those cases. Submitting a bugfix. --------- Co-authored-by: jjmachan <[email protected]>
1 parent 9061638 commit d2486f1

File tree

5 files changed

+68
-26
lines changed

5 files changed

+68
-26
lines changed

.github/workflows/ci.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
OPTS=(--dist loadfile -n auto)
9898
fi
9999
# Now run the unit tests
100-
pytest tests/unit "${OPTS[@]}"
100+
pytest --nbmake tests/unit "${OPTS[@]}"
101101
env:
102102
__RAGAS_DEBUG_TRACKING: true
103103
RAGAS_DO_NOT_TRACK: true

src/ragas/metrics/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def score(self: t.Self, row: t.Dict, callbacks: Callbacks = None) -> float:
9696
callbacks = callbacks or []
9797
rm, group_cm = new_group(self.name, inputs=row, callbacks=callbacks)
9898
try:
99-
score = asyncio.run(self._ascore(row=row, callbacks=group_cm))
99+
loop = asyncio.get_event_loop()
100+
score = loop.run_until_complete(self._ascore(row=row, callbacks=group_cm))
100101
except Exception as e:
101102
if not group_cm.ended:
102103
rm.on_chain_error(e)

tests/e2e/test_evaluation_in_jupyter.ipynb

-7
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,6 @@
103103
"\n",
104104
"result"
105105
]
106-
},
107-
{
108-
"cell_type": "code",
109-
"execution_count": null,
110-
"metadata": {},
111-
"outputs": [],
112-
"source": []
113106
}
114107
],
115108
"metadata": {

tests/unit/test_executor_in_jupyter.ipynb

+48-17
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
{
22
"cells": [
33
{
4-
"cell_type": "code",
5-
"execution_count": 1,
4+
"cell_type": "markdown",
65
"metadata": {},
7-
"outputs": [],
86
"source": [
9-
"%load_ext autoreload\n",
10-
"%autoreload 2"
7+
"# Test Executor "
118
]
129
},
1310
{
1411
"cell_type": "code",
15-
"execution_count": 2,
12+
"execution_count": 14,
1613
"metadata": {},
1714
"outputs": [
1815
{
1916
"data": {
2017
"application/vnd.jupyter.widget-view+json": {
21-
"model_id": "b78a418208b84895b03c93c54f1d1d61",
18+
"model_id": "ebb0705d6a05459a89f4ae87cbbbfd84",
2219
"version_major": 2,
2320
"version_minor": 0
2421
},
@@ -36,14 +33,14 @@
3633
"\n",
3734
"exec = Executor(raise_exceptions=True)\n",
3835
"for i in range(10):\n",
39-
" exec.submit(sleep, i)\n",
36+
" exec.submit(sleep, i/10)\n",
4037
"\n",
4138
"assert exec.results(), \"didn't get anything from results\""
4239
]
4340
},
4441
{
4542
"cell_type": "code",
46-
"execution_count": 3,
43+
"execution_count": 15,
4744
"metadata": {},
4845
"outputs": [],
4946
"source": [
@@ -54,7 +51,7 @@
5451
},
5552
{
5653
"cell_type": "code",
57-
"execution_count": 5,
54+
"execution_count": 16,
5855
"metadata": {},
5956
"outputs": [],
6057
"source": [
@@ -83,13 +80,13 @@
8380
},
8481
{
8582
"cell_type": "code",
86-
"execution_count": 7,
83+
"execution_count": 17,
8784
"metadata": {},
8885
"outputs": [
8986
{
9087
"data": {
9188
"application/vnd.jupyter.widget-view+json": {
92-
"model_id": "9bb608f8b2de42628fb525581d496d3a",
89+
"model_id": "985b8a189c9047c29d6ccebf7c5a938b",
9390
"version_major": 2,
9491
"version_minor": 0
9592
},
@@ -103,7 +100,7 @@
103100
{
104101
"data": {
105102
"application/vnd.jupyter.widget-view+json": {
106-
"model_id": "128ca1d600b3457c863ddf376d24c44e",
103+
"model_id": "ff3097e24dc249fbab6e610e59ccc9b6",
107104
"version_major": 2,
108105
"version_minor": 0
109106
},
@@ -118,22 +115,56 @@
118115
"source": [
119116
"exec = Executor(raise_exceptions=True)\n",
120117
"for i in range(1000):\n",
121-
" exec.submit(sleep, 1)\n",
118+
" exec.submit(sleep, 0.01)\n",
122119
"\n",
123120
"assert exec.results(), \"didn't get anything from results\"\n",
124121
"\n",
125122
"for i in range(1000):\n",
126-
" exec.submit(sleep, 1)\n",
123+
" exec.submit(sleep, 0.01)\n",
127124
"\n",
128125
"assert exec.results(), \"didn't get anything from results\""
129126
]
130127
},
128+
{
129+
"cell_type": "markdown",
130+
"metadata": {},
131+
"source": [
132+
"# Test Metric"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": 23,
138+
"metadata": {},
139+
"outputs": [],
140+
"source": [
141+
"from ragas.metrics.base import Metric, EvaluationMode\n",
142+
"\n",
143+
"class FakeMetric(Metric):\n",
144+
" name = \"fake_metric\"\n",
145+
" evaluation_mode = EvaluationMode.qa\n",
146+
"\n",
147+
" def init(self):\n",
148+
" pass\n",
149+
"\n",
150+
" async def _ascore(self, row, callbacks)->float:\n",
151+
" return 0\n",
152+
"\n",
153+
"fm = FakeMetric()"
154+
]
155+
},
131156
{
132157
"cell_type": "code",
133-
"execution_count": null,
158+
"execution_count": 24,
134159
"metadata": {},
135160
"outputs": [],
136-
"source": []
161+
"source": [
162+
"score = fm.score(\n",
163+
" row={\"question\": [\"q\"], \"answer\": [\"a\"]},\n",
164+
" callbacks=None,\n",
165+
")\n",
166+
"assert score == 0"
167+
]
137168
}
138169
],
139170
"metadata": {

tests/unit/test_metric.py

+17
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,20 @@ def test_get_available_metrics():
2727
for metric in get_available_metrics(ds)
2828
]
2929
), "All metrics should have evaluation mode qa"
30+
31+
32+
def test_metric():
33+
from ragas.metrics.base import Metric
34+
35+
class FakeMetric(Metric):
36+
name = "fake_metric" # type: ignore
37+
evaluation_mode = EvaluationMode.qa # type: ignore
38+
39+
def init(self, run_config):
40+
pass
41+
42+
async def _ascore(self, row, callbacks) -> float:
43+
return 0
44+
45+
fm = FakeMetric()
46+
assert fm.score({"question": "a", "answer": "b"}) == 0

0 commit comments

Comments
 (0)