3
3
import logging
4
4
import typing as t
5
5
from abc import abstractmethod
6
- from collections import namedtuple
7
6
from dataclasses import dataclass , field
8
7
9
8
from fsspec .exceptions import asyncio
9
+ from langchain_core .pydantic_v1 import BaseModel
10
10
from numpy .random import default_rng
11
11
12
+ from ragas .llms import BaseRagasLLM
13
+ from ragas .llms .prompt import Prompt
12
14
from ragas .testset .docstore import Direction , DocumentStore , Node
13
15
from ragas .testset .filters import EvolutionFilter , NodeFilter , QuestionFilter
14
16
from ragas .testset .prompts import (
22
24
rng = default_rng ()
23
25
logger = logging .getLogger (__name__ )
24
26
25
- if t .TYPE_CHECKING :
26
- from ragas .llms import BaseRagasLLM
27
- from ragas .llms .prompt import Prompt
28
-
29
27
30
28
@dataclass
31
29
class CurrentNodes :
32
30
root_node : Node
33
31
nodes : t .List [Node ] = field (default_factory = list )
34
32
35
33
36
- DataRow = namedtuple (
37
- "DataRow" ,
38
- [
39
- "question" ,
40
- "context" ,
41
- "answer" ,
42
- "question_type" ,
43
- "evolution_elimination" ,
44
- ],
45
- )
34
+ # (question, current_nodes, evolution_type)
35
+ EvolutionOutput = t .Tuple [str , CurrentNodes , str ]
36
+
37
+
38
+ class DataRow (BaseModel ):
39
+ question : str
40
+ context : str
41
+ answer : str
42
+ evolution_type : str
46
43
47
44
48
45
@dataclass
49
46
class Evolution :
50
- generator_llm : t . Optional [ BaseRagasLLM ] = None
47
+ generator_llm : BaseRagasLLM = t . cast ( BaseRagasLLM , None )
51
48
docstore : t .Optional [DocumentStore ] = None
52
49
node_filter : t .Optional [NodeFilter ] = None
53
50
question_filter : t .Optional [QuestionFilter ] = None
@@ -61,7 +58,7 @@ def merge_nodes(nodes: CurrentNodes) -> Node:
61
58
62
59
async def aretry_evolve (
63
60
self , current_tries : int , current_nodes : CurrentNodes , update_count : bool = True
64
- ) -> str :
61
+ ) -> EvolutionOutput :
65
62
if update_count :
66
63
current_tries += 1
67
64
logger .info ("retrying evolution: %s times" , current_tries )
@@ -112,22 +109,29 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow:
112
109
async def aevolve (self , current_nodes : CurrentNodes ) -> DataRow :
113
110
# init tries with 0 when first called
114
111
current_tries = 0
115
- evolved_question = await self ._aevolve (current_tries , current_nodes )
112
+ (
113
+ evolved_question ,
114
+ current_nodes ,
115
+ evolution_type ,
116
+ ) = await self ._aevolve (current_tries , current_nodes )
117
+
116
118
return self .generate_datarow (
117
119
question = evolved_question ,
118
120
current_nodes = current_nodes ,
121
+ evolution_type = evolution_type ,
119
122
)
120
123
121
124
@abstractmethod
122
- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
125
+ async def _aevolve (
126
+ self , current_tries : int , current_nodes : CurrentNodes
127
+ ) -> EvolutionOutput :
123
128
...
124
129
125
130
def generate_datarow (
126
131
self ,
127
132
question : str ,
128
133
current_nodes : CurrentNodes ,
129
- question_type : str = "" ,
130
- evolution_elimination : bool = False ,
134
+ evolution_type : str ,
131
135
):
132
136
assert self .generator_llm is not None , "generator_llm cannot be None"
133
137
@@ -146,15 +150,16 @@ def generate_datarow(
146
150
return DataRow (
147
151
question = question ,
148
152
context = merged_nodes .page_content ,
149
- answer = answer ,
150
- question_type = question_type ,
151
- evolution_elimination = evolution_elimination ,
153
+ answer = "" if answer is None else answer ,
154
+ evolution_type = evolution_type ,
152
155
)
153
156
154
157
155
158
@dataclass
156
159
class SimpleEvolution (Evolution ):
157
- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
160
+ async def _aevolve (
161
+ self , current_tries : int , current_nodes : CurrentNodes
162
+ ) -> EvolutionOutput :
158
163
assert self .docstore is not None , "docstore cannot be None"
159
164
assert self .node_filter is not None , "node filter cannot be None"
160
165
assert self .generator_llm is not None , "generator_llm cannot be None"
@@ -183,7 +188,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
183
188
return await self .aretry_evolve (current_tries , current_nodes )
184
189
else :
185
190
# if valid question
186
- return seed_question
191
+ return seed_question , current_nodes , "simple"
187
192
188
193
def __hash__ (self ):
189
194
return hash (self .__class__ .__name__ )
@@ -209,13 +214,15 @@ def init_evolution(self):
209
214
210
215
@dataclass
211
216
class MultiContextEvolution (ComplexEvolution ):
212
- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
217
+ async def _aevolve (
218
+ self , current_tries : int , current_nodes : CurrentNodes
219
+ ) -> EvolutionOutput :
213
220
assert self .docstore is not None , "docstore cannot be None"
214
221
assert self .generator_llm is not None , "generator_llm cannot be None"
215
222
assert self .question_filter is not None , "question_filter cannot be None"
216
223
assert self .se is not None , "simple evolution cannot be None"
217
224
218
- simple_question = await self .se ._aevolve (current_tries , current_nodes )
225
+ simple_question , _ , _ = await self .se ._aevolve (current_tries , current_nodes )
219
226
logger .debug (
220
227
"[MultiContextEvolution] simple question generated: %s" , simple_question
221
228
)
@@ -254,20 +261,22 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
254
261
current_nodes = self .se ._get_more_adjacent_nodes (current_nodes )
255
262
return await self .aretry_evolve (current_tries , current_nodes )
256
263
257
- return compressed_question
264
+ return compressed_question , current_nodes , "multi_context"
258
265
259
266
def __hash__ (self ):
260
267
return hash (self .__class__ .__name__ )
261
268
262
269
263
270
@dataclass
264
271
class ReasoningEvolution (ComplexEvolution ):
265
- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
272
+ async def _aevolve (
273
+ self , current_tries : int , current_nodes : CurrentNodes
274
+ ) -> EvolutionOutput :
266
275
assert self .generator_llm is not None , "generator_llm cannot be None"
267
276
assert self .question_filter is not None , "question_filter cannot be None"
268
277
assert self .se is not None , "simple evolution cannot be None"
269
278
270
- simple_question = await self .se ._aevolve (current_tries , current_nodes )
279
+ simple_question , _ , _ = await self .se ._aevolve (current_tries , current_nodes )
271
280
logger .debug (
272
281
"[ReasoningEvolution] simple question generated: %s" , simple_question
273
282
)
@@ -304,7 +313,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
304
313
)
305
314
return await self .aretry_evolve (current_tries , current_nodes )
306
315
307
- return reasoning_question
316
+ return reasoning_question , current_nodes , "reasoning"
308
317
309
318
def __hash__ (self ):
310
319
return hash (self .__class__ .__name__ )
0 commit comments