Skip to content

Commit 2ededad

Browse files
authored
Merge pull request #70 from asamal4/common-llm-call
add common custom llm
2 parents 0c80be6 + 0871015 commit 2ededad

File tree

9 files changed

+152
-92
lines changed

9 files changed

+152
-92
lines changed

src/lightspeed_evaluation/core/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""LLM management for Evaluation Framework."""
22

3+
from lightspeed_evaluation.core.llm.custom import BaseCustomLLM
34
from lightspeed_evaluation.core.llm.deepeval import DeepEvalLLMManager
45
from lightspeed_evaluation.core.llm.manager import LLMManager
56
from lightspeed_evaluation.core.llm.ragas import RagasLLMManager
@@ -11,6 +12,7 @@
1112
"LLMConfig",
1213
"LLMError",
1314
"LLMManager",
15+
"BaseCustomLLM",
1416
"DeepEvalLLMManager",
1517
"RagasLLMManager",
1618
"validate_provider_env",
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Base Custom LLM class for evaluation framework."""
2+
3+
from typing import Any, Optional, Union
4+
5+
import litellm
6+
7+
from lightspeed_evaluation.core.system.exceptions import LLMError
8+
9+
10+
class BaseCustomLLM: # pylint: disable=too-few-public-methods
11+
"""Base LLM class with core calling functionality."""
12+
13+
def __init__(self, model_name: str, llm_params: dict[str, Any]):
14+
"""Initialize with model configuration."""
15+
self.model_name = model_name
16+
self.llm_params = llm_params
17+
18+
def call(
19+
self,
20+
prompt: str,
21+
n: int = 1,
22+
temperature: Optional[float] = None,
23+
return_single: bool = True,
24+
**kwargs: Any,
25+
) -> Union[str, list[str]]:
26+
"""Make LLM call and return response(s).
27+
28+
Args:
29+
prompt: Text prompt to send
30+
n: Number of responses to generate (default 1)
31+
temperature: Override temperature (uses config default if None)
32+
return_single: If True and n=1, return single string. If False, always return list.
33+
**kwargs: Additional LLM parameters
34+
35+
Returns:
36+
Single string if return_single=True and n=1, otherwise list of strings
37+
"""
38+
temp = (
39+
temperature
40+
if temperature is not None
41+
else self.llm_params.get("temperature", 0.0)
42+
)
43+
44+
call_params = {
45+
"model": self.model_name,
46+
"messages": [{"role": "user", "content": prompt}],
47+
"temperature": temp,
48+
"n": n,
49+
"max_tokens": self.llm_params.get("max_tokens"),
50+
"timeout": self.llm_params.get("timeout"),
51+
"num_retries": self.llm_params.get("num_retries", 3),
52+
**kwargs,
53+
}
54+
55+
try:
56+
response = litellm.completion(**call_params)
57+
58+
# Extract content from all choices
59+
results = []
60+
for choice in response.choices: # type: ignore
61+
content = choice.message.content # type: ignore
62+
if content is None:
63+
content = ""
64+
results.append(content.strip())
65+
66+
# Return format based on parameters
67+
if return_single and n == 1:
68+
if not results:
69+
raise LLMError("LLM returned empty response")
70+
return results[0]
71+
72+
return results
73+
74+
except Exception as e:
75+
raise LLMError(f"LLM call failed: {str(e)}") from e
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""DeepEval LLM Manager - DeepEval-specific LLM wrapper that takes LiteLLM parameters."""
1+
"""DeepEval LLM Manager - DeepEval-specific LLM wrapper."""
22

33
from typing import Any
44

@@ -11,32 +11,32 @@ class DeepEvalLLMManager:
1111
This manager focuses solely on DeepEval-specific LLM integration.
1212
"""
1313

14-
def __init__(self, model_name: str, litellm_params: dict[str, Any]):
14+
def __init__(self, model_name: str, llm_params: dict[str, Any]):
1515
"""Initialize with LLM parameters from LLMManager."""
1616
self.model_name = model_name
17-
self.litellm_params = litellm_params
17+
self.llm_params = llm_params
1818

19-
# Create DeepEval's LiteLLMModel with provided parameters
19+
# Create DeepEval's LLM model with provided parameters
2020
self.llm_model = LiteLLMModel(
2121
model=self.model_name,
22-
temperature=litellm_params.get("temperature", 0.0),
23-
max_tokens=litellm_params.get("max_tokens"),
24-
timeout=litellm_params.get("timeout"),
25-
num_retries=litellm_params.get("num_retries", 3),
22+
temperature=llm_params.get("temperature", 0.0),
23+
max_tokens=llm_params.get("max_tokens"),
24+
timeout=llm_params.get("timeout"),
25+
num_retries=llm_params.get("num_retries", 3),
2626
)
2727

2828
print(f"✅ DeepEval LLM Manager: {self.model_name}")
2929

3030
def get_llm(self) -> LiteLLMModel:
31-
"""Get the configured DeepEval LiteLLM model."""
31+
"""Get the configured DeepEval LLM model."""
3232
return self.llm_model
3333

3434
def get_model_info(self) -> dict[str, Any]:
3535
"""Get information about the configured model."""
3636
return {
3737
"model_name": self.model_name,
38-
"temperature": self.litellm_params.get("temperature", 0.0),
39-
"max_tokens": self.litellm_params.get("max_tokens"),
40-
"timeout": self.litellm_params.get("timeout"),
41-
"num_retries": self.litellm_params.get("num_retries", 3),
38+
"temperature": self.llm_params.get("temperature", 0.0),
39+
"max_tokens": self.llm_params.get("max_tokens"),
40+
"timeout": self.llm_params.get("timeout"),
41+
"num_retries": self.llm_params.get("num_retries", 3),
4242
}

src/lightspeed_evaluation/core/llm/manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class LLMManager:
1313
Responsibilities:
1414
- Environment validation for multiple providers
1515
- Model name construction
16-
- Provides LiteLLM parameters for consumption by framework-specific managers
16+
- Provides LLM parameters for consumption by framework-specific managers
1717
"""
1818

1919
def __init__(self, config: LLMConfig):
@@ -25,7 +25,7 @@ def __init__(self, config: LLMConfig):
2525
)
2626

2727
def _construct_model_name_and_validate(self) -> str:
28-
"""Construct model name for LiteLLM and validate required environment variables."""
28+
"""Construct model name and validate required environment variables."""
2929
provider = self.config.provider.lower()
3030

3131
# Provider-specific validation and model name construction
@@ -89,11 +89,11 @@ def _handle_ollama_provider(self) -> str:
8989
return f"ollama/{self.config.model}"
9090

9191
def get_model_name(self) -> str:
92-
"""Get the constructed LiteLLM model name."""
92+
"""Get the constructed model name."""
9393
return self.model_name
9494

95-
def get_litellm_params(self) -> dict[str, Any]:
96-
"""Get parameters for LiteLLM completion calls."""
95+
def get_llm_params(self) -> dict[str, Any]:
96+
"""Get parameters for LLM completion calls."""
9797
return {
9898
"model": self.model_name,
9999
"temperature": self.config.temperature,

src/lightspeed_evaluation/core/llm/ragas.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
"""Ragas LLM Manager - Ragas-specific LLM wrapper that takes LiteLLM parameters."""
1+
"""Ragas LLM Manager - Ragas-specific LLM wrapper."""
22

33
from typing import Any, Optional
44

5-
import litellm
65
from ragas.llms.base import BaseRagasLLM, Generation, LLMResult
76
from ragas.metrics import answer_relevancy, faithfulness
87

8+
from lightspeed_evaluation.core.llm.custom import BaseCustomLLM
9+
from lightspeed_evaluation.core.system.exceptions import LLMError
910

10-
class RagasCustomLLM(BaseRagasLLM):
11-
"""Custom LLM for Ragas using LiteLLM parameters."""
1211

13-
def __init__(self, model_name: str, litellm_params: dict[str, Any]):
14-
"""Initialize Ragas custom LLM with model name and LiteLLM parameters."""
15-
super().__init__()
16-
self.model_name = model_name
17-
self.litellm_params = litellm_params
12+
class RagasCustomLLM(BaseRagasLLM, BaseCustomLLM):
13+
"""Custom LLM for Ragas."""
14+
15+
def __init__(self, model_name: str, llm_params: dict[str, Any]):
16+
"""Initialize Ragas custom LLM with model name and LLM parameters."""
17+
BaseRagasLLM.__init__(self)
18+
BaseCustomLLM.__init__(self, model_name, llm_params)
1819
print(f"✅ Ragas Custom LLM: {self.model_name}")
1920

2021
def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arguments
@@ -25,42 +26,38 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
2526
stop: Optional[list[str]] = None,
2627
callbacks: Optional[Any] = None,
2728
) -> LLMResult:
28-
"""Generate text using LiteLLM with provided parameters."""
29+
"""Generate text using LLM with provided parameters."""
2930
prompt_text = str(prompt)
3031

3132
# Use temperature from params unless explicitly overridden
3233
temp = (
3334
temperature
3435
if temperature != 1e-08
35-
else self.litellm_params.get("temperature", 0.0)
36+
else self.llm_params.get("temperature", 0.0)
3637
)
3738

3839
try:
39-
response = litellm.completion(
40-
model=self.model_name,
41-
messages=[{"role": "user", "content": prompt_text}],
42-
n=n,
43-
temperature=temp,
44-
max_tokens=self.litellm_params.get("max_tokens"),
45-
timeout=self.litellm_params.get("timeout"),
46-
num_retries=self.litellm_params.get("num_retries"),
40+
# Use inherited BaseCustomLLM functionality
41+
call_kwargs = {}
42+
if stop is not None:
43+
call_kwargs["stop"] = stop
44+
45+
responses = self.call(
46+
prompt_text, n=n, temperature=temp, return_single=False, **call_kwargs
4747
)
4848

4949
# Convert to Ragas format
5050
generations = []
51-
for choice in response.choices: # type: ignore
52-
content = choice.message.content # type: ignore
53-
if content is None:
54-
content = ""
55-
gen = Generation(text=content.strip())
51+
for response_text in responses:
52+
gen = Generation(text=response_text)
5653
generations.append(gen)
5754

5855
result = LLMResult(generations=[generations])
5956
return result
6057

6158
except Exception as e:
6259
print(f"❌ Ragas LLM failed: {e}")
63-
raise RuntimeError(f"Ragas LLM evaluation failed: {str(e)}") from e
60+
raise LLMError(f"Ragas LLM evaluation failed: {str(e)}") from e
6461

6562
async def agenerate_text( # pylint: disable=too-many-arguments,too-many-positional-arguments
6663
self,
@@ -87,11 +84,11 @@ class RagasLLMManager:
8784
This manager focuses solely on Ragas-specific LLM integration.
8885
"""
8986

90-
def __init__(self, model_name: str, litellm_params: dict[str, Any]):
87+
def __init__(self, model_name: str, llm_params: dict[str, Any]):
9188
"""Initialize with LLM parameters from LLMManager."""
9289
self.model_name = model_name
93-
self.litellm_params = litellm_params
94-
self.custom_llm = RagasCustomLLM(model_name, litellm_params)
90+
self.llm_params = llm_params
91+
self.custom_llm = RagasCustomLLM(model_name, llm_params)
9592

9693
# Configure Ragas metrics to use our custom LLM
9794
answer_relevancy.llm = self.custom_llm
@@ -107,5 +104,5 @@ def get_model_info(self) -> dict[str, Any]:
107104
"""Get information about the configured model."""
108105
return {
109106
"model_name": self.model_name,
110-
"temperature": self.litellm_params.get("temperature", 0.0),
107+
"temperature": self.llm_params.get("temperature", 0.0),
111108
}

src/lightspeed_evaluation/core/metrics/custom.py

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import re
44
from typing import Any, Optional
55

6-
import litellm
76
from pydantic import BaseModel, Field
87

8+
from lightspeed_evaluation.core.llm.custom import BaseCustomLLM
99
from lightspeed_evaluation.core.llm.manager import LLMManager
1010
from lightspeed_evaluation.core.metrics.tool_eval import evaluate_tool_calls
1111
from lightspeed_evaluation.core.models import EvaluationScope, TurnData
12+
from lightspeed_evaluation.core.system.exceptions import LLMError
1213

1314

1415
class EvaluationPromptParams(BaseModel):
@@ -27,23 +28,24 @@ class EvaluationPromptParams(BaseModel):
2728

2829

2930
class CustomMetrics: # pylint: disable=too-few-public-methods
30-
"""Handles custom metrics using LLMManager for direct LiteLLM calls."""
31+
"""Handles custom metrics using LLMManager for direct LLM calls."""
3132

3233
def __init__(self, llm_manager: LLMManager):
3334
"""Initialize with LLM Manager.
3435
3536
Args:
3637
llm_manager: Pre-configured LLMManager with validated parameters
3738
"""
38-
self.model_name = llm_manager.get_model_name()
39-
self.litellm_params = llm_manager.get_litellm_params()
39+
self.llm = BaseCustomLLM(
40+
llm_manager.get_model_name(), llm_manager.get_llm_params()
41+
)
4042

4143
self.supported_metrics = {
4244
"answer_correctness": self._evaluate_answer_correctness,
4345
"tool_eval": self._evaluate_tool_calls,
4446
}
4547

46-
print(f"✅ Custom Metrics initialized: {self.model_name}")
48+
print(f"✅ Custom Metrics initialized: {self.llm.model_name}")
4749

4850
def evaluate(
4951
self,
@@ -62,31 +64,12 @@ def evaluate(
6264
except (ValueError, AttributeError, KeyError) as e:
6365
return None, f"Custom {metric_name} evaluation failed: {str(e)}"
6466

65-
def _call_llm(self, prompt: str, system_prompt: Optional[str] = None) -> str:
66-
"""Make a LiteLLM call with the configured parameters."""
67-
# Prepare messages
68-
messages = []
69-
if system_prompt:
70-
messages.append({"role": "system", "content": system_prompt})
71-
messages.append({"role": "user", "content": prompt})
72-
73-
try:
74-
response = litellm.completion(
75-
model=self.model_name,
76-
messages=messages,
77-
temperature=self.litellm_params.get("temperature", 0.0),
78-
max_tokens=self.litellm_params.get("max_tokens"),
79-
timeout=self.litellm_params.get("timeout"),
80-
num_retries=self.litellm_params.get("num_retries", 3),
81-
)
82-
83-
content = response.choices[0].message.content # type: ignore
84-
if content is None:
85-
raise RuntimeError("LLM returned empty response")
86-
return content.strip()
87-
88-
except Exception as e:
89-
raise RuntimeError(f"LiteLLM call failed: {str(e)}") from e
67+
def _call_llm(self, prompt: str) -> str:
68+
"""Make an LLM call with the configured parameters."""
69+
result = self.llm.call(prompt, return_single=True)
70+
if isinstance(result, list):
71+
return result[0] if result else ""
72+
return result
9073

9174
def _parse_score_response(self, response: str) -> tuple[Optional[float], str]:
9275
r"""Parse LLM response to extract score and reason.
@@ -232,16 +215,19 @@ def _evaluate_answer_correctness(
232215
prompt += "- Absence of contradictory information"
233216

234217
# Make LLM call and parse response
235-
llm_response = self._call_llm(prompt)
236-
score, reason = self._parse_score_response(llm_response)
237-
238-
if score is None:
239-
return (
240-
None,
241-
f"Could not parse score from LLM response: {llm_response[:100]}...",
242-
)
243-
244-
return score, f"Custom answer correctness: {score:.2f} - {reason}"
218+
try:
219+
llm_response = self._call_llm(prompt)
220+
score, reason = self._parse_score_response(llm_response)
221+
222+
if score is None:
223+
return (
224+
None,
225+
f"Could not parse score from LLM response: {llm_response[:100]}...",
226+
)
227+
228+
return score, f"Custom answer correctness: {score:.2f} - {reason}"
229+
except LLMError as e:
230+
return None, f"Answer correctness evaluation failed: {str(e)}"
245231

246232
def _evaluate_tool_calls(
247233
self,

0 commit comments

Comments
 (0)