Skip to content

Commit a1c716c

Browse files
committed
mark remaining metrics as error
1 parent 6785231 commit a1c716c

File tree

3 files changed

+163
-8
lines changed

3 files changed

+163
-8
lines changed

src/lightspeed_evaluation/pipeline/evaluation/errors.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,74 @@ def mark_turn_metrics_as_error( # pylint: disable=too-many-arguments,too-many-p
124124
self.results.extend(error_results)
125125
return error_results
126126

127+
def mark_remaining_turns_and_conversation_as_error( # pylint: disable=too-many-arguments,too-many-positional-arguments
128+
self,
129+
conv_data: EvaluationData,
130+
failed_turn_idx: int,
131+
resolved_turn_metrics: list[list[str]],
132+
resolved_conversation_metrics: list[str],
133+
error_reason: str,
134+
) -> list[EvaluationResult]:
135+
"""Mark all remaining turns and conversation metrics as ERROR after API failure.
136+
137+
Args:
138+
conv_data: Conversation data
139+
failed_turn_idx: Index of the turn that failed
140+
resolved_turn_metrics: Resolved metrics for all turns
141+
resolved_conversation_metrics: Resolved conversation metrics
142+
error_reason: Reason for error
143+
144+
Returns:
145+
list[EvaluationResult]: ERROR results for remaining turns and conversation
146+
"""
147+
logger.warning(
148+
"Marking remaining turns (%d onwards) and conversation metrics as ERROR for %s: %s",
149+
failed_turn_idx + 1,
150+
conv_data.conversation_group_id,
151+
error_reason,
152+
)
153+
error_results = []
154+
155+
# Mark remaining turns as ERROR (from failed_turn_idx + 1 onwards)
156+
for turn_idx in range(failed_turn_idx + 1, len(conv_data.turns)):
157+
turn_data = conv_data.turns[turn_idx]
158+
turn_metrics = resolved_turn_metrics[turn_idx]
159+
160+
for metric_identifier in turn_metrics:
161+
error_result = EvaluationResult(
162+
conversation_group_id=conv_data.conversation_group_id,
163+
turn_id=turn_data.turn_id,
164+
metric_identifier=metric_identifier,
165+
result="ERROR",
166+
score=None,
167+
threshold=None,
168+
reason=error_reason,
169+
query=turn_data.query,
170+
response="",
171+
execution_time=0.0,
172+
)
173+
error_results.append(error_result)
174+
175+
# Mark conversation-level metrics as ERROR
176+
for metric_identifier in resolved_conversation_metrics:
177+
error_result = EvaluationResult(
178+
conversation_group_id=conv_data.conversation_group_id,
179+
turn_id=None, # Conversation-level
180+
metric_identifier=metric_identifier,
181+
result="ERROR",
182+
score=None,
183+
threshold=None,
184+
reason=error_reason,
185+
query="",
186+
response="",
187+
execution_time=0.0,
188+
)
189+
error_results.append(error_result)
190+
191+
# Store results internally for summary tracking
192+
self.results.extend(error_results)
193+
return error_results
194+
127195
def get_error_summary(self) -> dict[str, int]:
128196
"""Get summary of error results collected."""
129197
return {

src/lightspeed_evaluation/pipeline/evaluation/processor.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def process_conversation( # pylint: disable=too-many-locals
101101
):
102102
# Step 2a: Amend with API data if enabled (per turn)
103103
if self.config.api.enabled:
104-
logger.debug(
105-
"Processing turn %d: %s", turn_idx, turn_data.turn_id
106-
)
104+
logger.debug("Processing turn %d: %s", turn_idx, turn_data.turn_id)
107105
api_error_message, conversation_id = (
108106
self.components.api_amender.amend_single_turn(
109107
turn_data, conversation_id
@@ -115,13 +113,15 @@ def process_conversation( # pylint: disable=too-many-locals
115113
turn_data.turn_id,
116114
)
117115

118-
# If API error occurred for this turn, mark its metrics as ERROR
116+
# If API error occurred, mark current turn + remaining + conversation as ERROR
119117
if api_error_message:
120118
logger.error(
121-
"API error for turn %d - marking turn metrics as ERROR",
119+
"API error for turn %d - marking current turn, "
120+
"remaining turns, and conversation as ERROR",
122121
turn_idx,
123122
)
124-
error_results = (
123+
# Mark current turn as ERROR
124+
current_turn_errors = (
125125
self.components.error_handler.mark_turn_metrics_as_error(
126126
conv_data,
127127
turn_idx,
@@ -130,8 +130,25 @@ def process_conversation( # pylint: disable=too-many-locals
130130
api_error_message,
131131
)
132132
)
133-
results.extend(error_results)
134-
continue # Skip to next turn
133+
results.extend(current_turn_errors)
134+
135+
# Mark remaining turns and conversation metrics as ERROR
136+
cascade_error_reason = (
137+
f"Cascade failure from turn {turn_idx + 1} API error: "
138+
f"{api_error_message}"
139+
)
140+
error_handler = self.components.error_handler
141+
remaining_errors = error_handler.mark_remaining_turns_and_conversation_as_error(
142+
conv_data,
143+
turn_idx,
144+
resolved_turn_metrics,
145+
resolved_conversation_metrics,
146+
cascade_error_reason,
147+
)
148+
results.extend(remaining_errors)
149+
150+
# Stop processing - API failure cascades to all remaining
151+
return results
135152

136153
# Step 2b: Process turn-level metrics for this turn
137154
if turn_metrics:

tests/unit/pipeline/evaluation/test_errors.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,73 @@ def test_mark_turn_metrics_as_error(self):
222222
assert summary["total_errors"] == 2
223223
assert summary["turn_errors"] == 2
224224
assert summary["conversation_errors"] == 0
225+
226+
def test_mark_remaining_turns_and_conversation_as_error(self):
227+
"""Test marking remaining turns and conversation metrics as error after API failure."""
228+
handler = EvaluationErrorHandler()
229+
230+
# Setup conversation with 3 turns
231+
turn1 = TurnData(turn_id="turn1", query="Query 1", response="Response 1")
232+
turn2 = TurnData(turn_id="turn2", query="Query 2", response="Response 2")
233+
turn3 = TurnData(turn_id="turn3", query="Query 3", response="Response 3")
234+
conv_data = EvaluationData(
235+
conversation_group_id="test_conv", turns=[turn1, turn2, turn3]
236+
)
237+
238+
# Resolved metrics for all turns
239+
resolved_turn_metrics = [
240+
["ragas:faithfulness"], # turn1
241+
["custom:answer_correctness"], # turn2
242+
["ragas:response_relevancy"], # turn3
243+
]
244+
resolved_conversation_metrics = [
245+
"deepeval:conversation_completeness",
246+
"deepeval:conversation_relevancy",
247+
]
248+
249+
# API failure happens at turn 0 (first turn)
250+
failed_turn_idx = 0
251+
error_reason = "Cascade failure from turn 1 API error: Connection timeout"
252+
253+
results = handler.mark_remaining_turns_and_conversation_as_error(
254+
conv_data,
255+
failed_turn_idx,
256+
resolved_turn_metrics,
257+
resolved_conversation_metrics,
258+
error_reason,
259+
)
260+
261+
# Should have errors for:
262+
# - Turn 2 (1 metric) + Turn 3 (1 metric) + Conversation (2 metrics) = 4 total
263+
assert len(results) == 4
264+
265+
# Check turn 2 error
266+
turn2_result = results[0]
267+
assert turn2_result.conversation_group_id == "test_conv"
268+
assert turn2_result.turn_id == "turn2"
269+
assert turn2_result.metric_identifier == "custom:answer_correctness"
270+
assert turn2_result.result == "ERROR"
271+
assert turn2_result.reason == error_reason
272+
273+
# Check turn 3 error
274+
turn3_result = results[1]
275+
assert turn3_result.turn_id == "turn3"
276+
assert turn3_result.metric_identifier == "ragas:response_relevancy"
277+
assert turn3_result.result == "ERROR"
278+
279+
# Check conversation-level errors
280+
conv_result1 = results[2]
281+
assert conv_result1.turn_id is None # Conversation-level
282+
assert conv_result1.metric_identifier == "deepeval:conversation_completeness"
283+
assert conv_result1.result == "ERROR"
284+
285+
conv_result2 = results[3]
286+
assert conv_result2.turn_id is None # Conversation-level
287+
assert conv_result2.metric_identifier == "deepeval:conversation_relevancy"
288+
assert conv_result2.result == "ERROR"
289+
290+
# Verify summary
291+
summary = handler.get_error_summary()
292+
assert summary["total_errors"] == 4
293+
assert summary["turn_errors"] == 2 # turn2 + turn3
294+
assert summary["conversation_errors"] == 2

0 commit comments

Comments
 (0)