Skip to content

Commit 8566a33

Browse files
authored
fix: cleanup old testset generator (explodinggradients#500)
1 parent 8969173 commit 8566a33

15 files changed

+94
-628
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dependencies = [
66
"tiktoken",
77
"langchain",
88
"langchain-core",
9+
"langchain-community",
910
"langchain_openai",
1011
"openai>1",
1112
"pysbd>=0.3.4",

src/ragas/evaluation.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def evaluate(
154154
[m.init_model() for m in metrics]
155155

156156
executor = Executor(
157-
is_async=is_async, max_workers=max_workers, raise_exceptions=raise_exceptions
157+
desc="Evaluating",
158+
keep_progress_bar=True,
159+
is_async=is_async,
160+
max_workers=max_workers,
161+
raise_exceptions=raise_exceptions,
158162
)
159163
# new evaluation chain
160164
row_run_managers = []

src/ragas/executor.py

+5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
@dataclass
1111
class Executor:
1212
desc: str = "Evaluating"
13+
keep_progress_bar: bool = True
1314
is_async: bool = True
1415
max_workers: t.Optional[int] = None
1516
futures: t.List[t.Any] = field(default_factory=list, repr=False)
@@ -74,6 +75,8 @@ async def _aresults(self) -> t.List[t.Any]:
7475
asyncio.as_completed(self.futures),
7576
desc=self.desc,
7677
total=len(self.futures),
78+
# whether you want to keep the progress bar after completion
79+
leave=self.keep_progress_bar,
7780
):
7881
r = (-1, np.nan)
7982
try:
@@ -109,6 +112,8 @@ def results(self) -> t.List[t.Any]:
109112
as_completed(self.futures),
110113
desc=self.desc,
111114
total=len(self.futures),
115+
# whether you want to keep the progress bar after completion
116+
leave=self.keep_progress_bar,
112117
):
113118
r = (-1, np.nan)
114119
try:

src/ragas/llms/base.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
from abc import ABC, abstractmethod
55
from dataclasses import dataclass
66

7-
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI
8-
from langchain.llms import AzureOpenAI, OpenAI, VertexAI
7+
from langchain_community.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI
8+
from langchain_community.llms import AzureOpenAI, OpenAI, VertexAI
99
from langchain_core.language_models import BaseLanguageModel
1010
from langchain_core.outputs import LLMResult
1111

1212
if t.TYPE_CHECKING:
1313
from langchain_core.callbacks import Callbacks
14-
from langchain_core.prompts import ChatPromptTemplate
1514

1615
from ragas.llms.prompt import PromptValue
1716

@@ -62,20 +61,6 @@ async def agenerate_text(
6261
) -> LLMResult:
6362
...
6463

65-
# TODO: remove after testset generator is refactored
66-
def generate_text_with_hmpt(
67-
self,
68-
prompts: t.List[ChatPromptTemplate],
69-
n: int = 1,
70-
temperature: float = 1e-8,
71-
stop: t.Optional[t.List[str]] = None,
72-
callbacks: Callbacks = [],
73-
) -> LLMResult:
74-
from ragas.llms.prompt import PromptValue
75-
76-
prompt = PromptValue(prompt_str=prompts[0].format())
77-
return self.generate_text(prompt, n, temperature, stop, callbacks)
78-
7964

8065
@dataclass
8166
class LangchainLLMWrapper(BaseRagasLLM):

src/ragas/metrics/_answer_similarity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ragas.metrics.base import EvaluationMode, MetricWithEmbeddings, MetricWithLLM
1111

1212
if t.TYPE_CHECKING:
13-
from langchain.callbacks.base import Callbacks
13+
from langchain_core.callbacks.base import Callbacks
1414

1515

1616
logger = logging.getLogger(__name__)

src/ragas/metrics/_context_relevancy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1212

1313
if t.TYPE_CHECKING:
14-
from langchain.callbacks.base import Callbacks
14+
from langchain_core.callbacks.base import Callbacks
1515

1616
logger = logging.getLogger(__name__)
1717

src/ragas/metrics/critique.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1313

1414
if t.TYPE_CHECKING:
15-
from langchain.callbacks.base import Callbacks
15+
from langchain_core.callbacks.base import Callbacks
1616

1717
from ragas.llms import BaseRagasLLM
1818

src/ragas/testset/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from ragas.testset.testset_generator import TestsetGenerator
1+
from ragas.testset.generator import TestsetGenerator
22

33
__all__ = ["TestsetGenerator"]

src/ragas/testset/docstore.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from langchain_core.pydantic_v1 import Field
1515
from llama_index.readers.schema import Document as LlamaindexDocument
1616

17-
from ragas.async_utils import run_async_tasks
1817
from ragas.embeddings.base import BaseRagasEmbeddings
18+
from ragas.executor import Executor
1919

2020
Embedding = t.Union[t.List[float], npt.NDArray[np.float64]]
2121
logger = logging.getLogger(__name__)
@@ -204,22 +204,29 @@ def add_nodes(
204204
assert self.embeddings is not None, "Embeddings must be set"
205205

206206
# NOTE: Adds everything in async mode for now.
207-
embed_tasks = []
208-
docs_to_embed = []
207+
nodes_to_embed = []
209208
# get embeddings for the docs
210-
for n in nodes:
209+
executor = Executor(
210+
desc="embedding nodes",
211+
keep_progress_bar=False,
212+
is_async=True,
213+
raise_exceptions=True,
214+
)
215+
for i, n in enumerate(nodes):
211216
if n.embedding is None:
212-
embed_tasks.append(self.embeddings.aembed_query(n.page_content))
213-
docs_to_embed.append(n)
217+
nodes_to_embed.append(n)
218+
executor.submit(
219+
self.embeddings.aembed_query,
220+
n.page_content,
221+
name=f"embed_node_task[{i}]",
222+
)
214223
else:
215224
self.nodes.append(n)
216225
self.node_map[n.doc_id] = n
217226
self.node_embeddings_list.append(n.embedding)
218227

219-
embeddings = run_async_tasks(
220-
embed_tasks, show_progress=show_progress, progress_bar_desc=desc
221-
)
222-
for n, embedding in zip(docs_to_embed, embeddings):
228+
embeddings = executor.results()
229+
for n, embedding in zip(nodes_to_embed, embeddings):
223230
n.embedding = embedding
224231
self.nodes.append(n)
225232
self.node_map[n.doc_id] = n

src/ragas/testset/evolutions.py

+41-32
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import logging
44
import typing as t
55
from abc import abstractmethod
6-
from collections import namedtuple
76
from dataclasses import dataclass, field
87

98
from fsspec.exceptions import asyncio
9+
from langchain_core.pydantic_v1 import BaseModel
1010
from numpy.random import default_rng
1111

12+
from ragas.llms import BaseRagasLLM
13+
from ragas.llms.prompt import Prompt
1214
from ragas.testset.docstore import Direction, DocumentStore, Node
1315
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter
1416
from ragas.testset.prompts import (
@@ -22,32 +24,27 @@
2224
rng = default_rng()
2325
logger = logging.getLogger(__name__)
2426

25-
if t.TYPE_CHECKING:
26-
from ragas.llms import BaseRagasLLM
27-
from ragas.llms.prompt import Prompt
28-
2927

3028
@dataclass
3129
class CurrentNodes:
3230
root_node: Node
3331
nodes: t.List[Node] = field(default_factory=list)
3432

3533

36-
DataRow = namedtuple(
37-
"DataRow",
38-
[
39-
"question",
40-
"context",
41-
"answer",
42-
"question_type",
43-
"evolution_elimination",
44-
],
45-
)
34+
# (question, current_nodes, evolution_type)
35+
EvolutionOutput = t.Tuple[str, CurrentNodes, str]
36+
37+
38+
class DataRow(BaseModel):
39+
question: str
40+
context: str
41+
answer: str
42+
evolution_type: str
4643

4744

4845
@dataclass
4946
class Evolution:
50-
generator_llm: t.Optional[BaseRagasLLM] = None
47+
generator_llm: BaseRagasLLM = t.cast(BaseRagasLLM, None)
5148
docstore: t.Optional[DocumentStore] = None
5249
node_filter: t.Optional[NodeFilter] = None
5350
question_filter: t.Optional[QuestionFilter] = None
@@ -61,7 +58,7 @@ def merge_nodes(nodes: CurrentNodes) -> Node:
6158

6259
async def aretry_evolve(
6360
self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True
64-
) -> str:
61+
) -> EvolutionOutput:
6562
if update_count:
6663
current_tries += 1
6764
logger.info("retrying evolution: %s times", current_tries)
@@ -112,22 +109,29 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow:
112109
async def aevolve(self, current_nodes: CurrentNodes) -> DataRow:
113110
# init tries with 0 when first called
114111
current_tries = 0
115-
evolved_question = await self._aevolve(current_tries, current_nodes)
112+
(
113+
evolved_question,
114+
current_nodes,
115+
evolution_type,
116+
) = await self._aevolve(current_tries, current_nodes)
117+
116118
return self.generate_datarow(
117119
question=evolved_question,
118120
current_nodes=current_nodes,
121+
evolution_type=evolution_type,
119122
)
120123

121124
@abstractmethod
122-
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
125+
async def _aevolve(
126+
self, current_tries: int, current_nodes: CurrentNodes
127+
) -> EvolutionOutput:
123128
...
124129

125130
def generate_datarow(
126131
self,
127132
question: str,
128133
current_nodes: CurrentNodes,
129-
question_type: str = "",
130-
evolution_elimination: bool = False,
134+
evolution_type: str,
131135
):
132136
assert self.generator_llm is not None, "generator_llm cannot be None"
133137

@@ -146,15 +150,16 @@ def generate_datarow(
146150
return DataRow(
147151
question=question,
148152
context=merged_nodes.page_content,
149-
answer=answer,
150-
question_type=question_type,
151-
evolution_elimination=evolution_elimination,
153+
answer="" if answer is None else answer,
154+
evolution_type=evolution_type,
152155
)
153156

154157

155158
@dataclass
156159
class SimpleEvolution(Evolution):
157-
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
160+
async def _aevolve(
161+
self, current_tries: int, current_nodes: CurrentNodes
162+
) -> EvolutionOutput:
158163
assert self.docstore is not None, "docstore cannot be None"
159164
assert self.node_filter is not None, "node filter cannot be None"
160165
assert self.generator_llm is not None, "generator_llm cannot be None"
@@ -183,7 +188,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
183188
return await self.aretry_evolve(current_tries, current_nodes)
184189
else:
185190
# if valid question
186-
return seed_question
191+
return seed_question, current_nodes, "simple"
187192

188193
def __hash__(self):
189194
return hash(self.__class__.__name__)
@@ -209,13 +214,15 @@ def init_evolution(self):
209214

210215
@dataclass
211216
class MultiContextEvolution(ComplexEvolution):
212-
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
217+
async def _aevolve(
218+
self, current_tries: int, current_nodes: CurrentNodes
219+
) -> EvolutionOutput:
213220
assert self.docstore is not None, "docstore cannot be None"
214221
assert self.generator_llm is not None, "generator_llm cannot be None"
215222
assert self.question_filter is not None, "question_filter cannot be None"
216223
assert self.se is not None, "simple evolution cannot be None"
217224

218-
simple_question = await self.se._aevolve(current_tries, current_nodes)
225+
simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
219226
logger.debug(
220227
"[MultiContextEvolution] simple question generated: %s", simple_question
221228
)
@@ -254,20 +261,22 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
254261
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
255262
return await self.aretry_evolve(current_tries, current_nodes)
256263

257-
return compressed_question
264+
return compressed_question, current_nodes, "multi_context"
258265

259266
def __hash__(self):
260267
return hash(self.__class__.__name__)
261268

262269

263270
@dataclass
264271
class ReasoningEvolution(ComplexEvolution):
265-
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
272+
async def _aevolve(
273+
self, current_tries: int, current_nodes: CurrentNodes
274+
) -> EvolutionOutput:
266275
assert self.generator_llm is not None, "generator_llm cannot be None"
267276
assert self.question_filter is not None, "question_filter cannot be None"
268277
assert self.se is not None, "simple evolution cannot be None"
269278

270-
simple_question = await self.se._aevolve(current_tries, current_nodes)
279+
simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
271280
logger.debug(
272281
"[ReasoningEvolution] simple question generated: %s", simple_question
273282
)
@@ -304,7 +313,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
304313
)
305314
return await self.aretry_evolve(current_tries, current_nodes)
306315

307-
return reasoning_question
316+
return reasoning_question, current_nodes, "reasoning"
308317

309318
def __hash__(self):
310319
return hash(self.__class__.__name__)

0 commit comments

Comments
 (0)