Skip to content

Commit 8e3513c

Browse files
committed
use common custom llm call
1 parent 11874b9 commit 8e3513c

File tree

2 files changed

+41
-58
lines changed

2 files changed

+41
-58
lines changed

src/lightspeed_evaluation/core/llm/ragas.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22

33
from typing import Any, Optional
44

5-
import litellm
65
from ragas.llms.base import BaseRagasLLM, Generation, LLMResult
76
from ragas.metrics import answer_relevancy, faithfulness
87

8+
from lightspeed_evaluation.core.llm.custom import BaseCustomLLM
9+
from lightspeed_evaluation.core.system.exceptions import LLMError
910

10-
class RagasCustomLLM(BaseRagasLLM):
11+
12+
class RagasCustomLLM(BaseRagasLLM, BaseCustomLLM):
1113
"""Custom LLM for Ragas using LiteLLM parameters."""
1214

1315
def __init__(self, model_name: str, litellm_params: dict[str, Any]):
1416
"""Initialize Ragas custom LLM with model name and LiteLLM parameters."""
15-
super().__init__()
16-
self.model_name = model_name
17-
self.litellm_params = litellm_params
17+
BaseRagasLLM.__init__(self)
18+
BaseCustomLLM.__init__(self, model_name, litellm_params)
1819
print(f"✅ Ragas Custom LLM: {self.model_name}")
1920

2021
def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arguments
@@ -36,31 +37,27 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
3637
)
3738

3839
try:
39-
response = litellm.completion(
40-
model=self.model_name,
41-
messages=[{"role": "user", "content": prompt_text}],
42-
n=n,
43-
temperature=temp,
44-
max_tokens=self.litellm_params.get("max_tokens"),
45-
timeout=self.litellm_params.get("timeout"),
46-
num_retries=self.litellm_params.get("num_retries"),
40+
# Use inherited BaseCustomLLM functionality
41+
call_kwargs = {}
42+
if stop is not None:
43+
call_kwargs["stop"] = stop
44+
45+
responses = self.call(
46+
prompt_text, n=n, temperature=temp, return_single=False, **call_kwargs
4747
)
4848

4949
# Convert to Ragas format
5050
generations = []
51-
for choice in response.choices: # type: ignore
52-
content = choice.message.content # type: ignore
53-
if content is None:
54-
content = ""
55-
gen = Generation(text=content.strip())
51+
for response_text in responses:
52+
gen = Generation(text=response_text)
5653
generations.append(gen)
5754

5855
result = LLMResult(generations=[generations])
5956
return result
6057

6158
except Exception as e:
6259
print(f"❌ Ragas LLM failed: {e}")
63-
raise RuntimeError(f"Ragas LLM evaluation failed: {str(e)}") from e
60+
raise LLMError(f"Ragas LLM evaluation failed: {str(e)}") from e
6461

6562
async def agenerate_text( # pylint: disable=too-many-arguments,too-many-positional-arguments
6663
self,

src/lightspeed_evaluation/core/metrics/custom.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import re
44
from typing import Any, Optional
55

6-
import litellm
76
from pydantic import BaseModel, Field
87

8+
from lightspeed_evaluation.core.llm.custom import BaseCustomLLM
99
from lightspeed_evaluation.core.llm.manager import LLMManager
1010
from lightspeed_evaluation.core.metrics.tool_eval import evaluate_tool_calls
1111
from lightspeed_evaluation.core.models import EvaluationScope, TurnData
12+
from lightspeed_evaluation.core.system.exceptions import LLMError
1213

1314

1415
class EvaluationPromptParams(BaseModel):
@@ -35,15 +36,16 @@ def __init__(self, llm_manager: LLMManager):
3536
Args:
3637
llm_manager: Pre-configured LLMManager with validated parameters
3738
"""
38-
self.model_name = llm_manager.get_model_name()
39-
self.litellm_params = llm_manager.get_litellm_params()
39+
self.llm = BaseCustomLLM(
40+
llm_manager.get_model_name(), llm_manager.get_litellm_params()
41+
)
4042

4143
self.supported_metrics = {
4244
"answer_correctness": self._evaluate_answer_correctness,
4345
"tool_eval": self._evaluate_tool_calls,
4446
}
4547

46-
print(f"✅ Custom Metrics initialized: {self.model_name}")
48+
print(f"✅ Custom Metrics initialized: {self.llm.model_name}")
4749

4850
def evaluate(
4951
self,
@@ -62,31 +64,12 @@ def evaluate(
6264
except (ValueError, AttributeError, KeyError) as e:
6365
return None, f"Custom {metric_name} evaluation failed: {str(e)}"
6466

65-
def _call_llm(self, prompt: str, system_prompt: Optional[str] = None) -> str:
66-
"""Make a LiteLLM call with the configured parameters."""
67-
# Prepare messages
68-
messages = []
69-
if system_prompt:
70-
messages.append({"role": "system", "content": system_prompt})
71-
messages.append({"role": "user", "content": prompt})
72-
73-
try:
74-
response = litellm.completion(
75-
model=self.model_name,
76-
messages=messages,
77-
temperature=self.litellm_params.get("temperature", 0.0),
78-
max_tokens=self.litellm_params.get("max_tokens"),
79-
timeout=self.litellm_params.get("timeout"),
80-
num_retries=self.litellm_params.get("num_retries", 3),
81-
)
82-
83-
content = response.choices[0].message.content # type: ignore
84-
if content is None:
85-
raise RuntimeError("LLM returned empty response")
86-
return content.strip()
87-
88-
except Exception as e:
89-
raise RuntimeError(f"LiteLLM call failed: {str(e)}") from e
67+
def _call_llm(self, prompt: str) -> str:
68+
"""Make an LLM call with the configured parameters."""
69+
result = self.llm.call(prompt, return_single=True)
70+
if isinstance(result, list):
71+
return result[0] if result else ""
72+
return result
9073

9174
def _parse_score_response(self, response: str) -> tuple[Optional[float], str]:
9275
r"""Parse LLM response to extract score and reason.
@@ -232,16 +215,19 @@ def _evaluate_answer_correctness(
232215
prompt += "- Absence of contradictory information"
233216

234217
# Make LLM call and parse response
235-
llm_response = self._call_llm(prompt)
236-
score, reason = self._parse_score_response(llm_response)
237-
238-
if score is None:
239-
return (
240-
None,
241-
f"Could not parse score from LLM response: {llm_response[:100]}...",
242-
)
243-
244-
return score, f"Custom answer correctness: {score:.2f} - {reason}"
218+
try:
219+
llm_response = self._call_llm(prompt)
220+
score, reason = self._parse_score_response(llm_response)
221+
222+
if score is None:
223+
return (
224+
None,
225+
f"Could not parse score from LLM response: {llm_response[:100]}...",
226+
)
227+
228+
return score, f"Custom answer correctness: {score:.2f} - {reason}"
229+
except LLMError as e:
230+
return None, f"Answer correctness evaluation failed: {str(e)}"
245231

246232
def _evaluate_tool_calls(
247233
self,

0 commit comments

Comments
 (0)