Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 44 additions & 42 deletions src/lightspeed_evaluation/pipeline/evaluation/amender.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Optional

from lightspeed_evaluation.core.api import APIClient
from lightspeed_evaluation.core.models import EvaluationData
from lightspeed_evaluation.core.models import EvaluationData, TurnData
from lightspeed_evaluation.core.system.exceptions import APIError

logger = logging.getLogger(__name__)
Expand All @@ -17,54 +17,56 @@ def __init__(self, api_client: Optional[APIClient]):
"""Initialize with API client."""
self.api_client = api_client

def amend_conversation_data(self, conv_data: EvaluationData) -> Optional[str]:
"""Amend conversation data with API responses.
def amend_single_turn(
self, turn_data: TurnData, conversation_id: Optional[str] = None
) -> tuple[Optional[str], Optional[str]]:
"""Amend single turn data with API response.

Args:
turn_data: The turn data to amend
conversation_id: Optional conversation ID from previous turns

Returns:
Optional[str]: Error message if any API error occurred, None if successful
tuple: (error_message, updated_conversation_id)
- error_message: None if successful, error string if failed
- updated_conversation_id: The conversation ID for next turns
"""
if not self.api_client:
return None

# Track conversation_id across turns
conversation_id: Optional[str] = None

for turn_data in conv_data.turns:
logger.debug("Amending turn %s with API data", turn_data.turn_id)

try:
api_response = self.api_client.query(
query=turn_data.query,
conversation_id=conversation_id,
attachments=turn_data.attachments,
return None, conversation_id

logger.debug("Amending turn %s with API data", turn_data.turn_id)

try:
api_response = self.api_client.query(
query=turn_data.query,
conversation_id=conversation_id,
attachments=turn_data.attachments,
)

# AMEND EVALUATION DATA: This modifies the loaded TurnData object in-place
# Update response from API
turn_data.response = api_response.response
turn_data.conversation_id = api_response.conversation_id

# Update contexts from API output
if api_response.contexts:
turn_data.contexts = api_response.contexts

# Update tool calls from API output
if api_response.tool_calls:
logger.debug(
"Tool calls provided: %d sequences",
len(api_response.tool_calls),
)
conversation_id = api_response.conversation_id # Track for next turns

# AMEND EVALUATION DATA: This modifies the loaded TurnData object in-place
# Update response from API
turn_data.response = api_response.response
turn_data.conversation_id = api_response.conversation_id

# Update contexts from API output
if api_response.contexts:
turn_data.contexts = api_response.contexts

# Update tool calls from API output
if api_response.tool_calls:
logger.debug(
"Tool calls provided: %d sequences",
len(api_response.tool_calls),
)
turn_data.tool_calls = api_response.tool_calls

logger.debug("Data amended for turn %s", turn_data.turn_id)
turn_data.tool_calls = api_response.tool_calls

except APIError as e:
error_msg = f"API Error for turn {turn_data.turn_id}: {e}"
logger.error(error_msg)
return error_msg
logger.debug("Data amended for turn %s", turn_data.turn_id)
return None, api_response.conversation_id

return None # No errors occurred
except APIError as e:
error_msg = f"API Error for turn {turn_data.turn_id}: {e}"
logger.error(error_msg)
return error_msg, conversation_id

def get_amendment_summary(self, conv_data: EvaluationData) -> dict[str, Any]:
"""Get summary of what would be amended for a conversation."""
Expand Down
118 changes: 117 additions & 1 deletion src/lightspeed_evaluation/pipeline/evaluation/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging

from lightspeed_evaluation.core.models import EvaluationData, EvaluationResult
from lightspeed_evaluation.core.models import EvaluationData, EvaluationResult, TurnData

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,6 +76,122 @@ def mark_all_metrics_as_error(
self.results.extend(error_results)
return error_results

def mark_turn_metrics_as_error( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
conv_data: EvaluationData,
turn_idx: int,
turn_data: TurnData,
turn_metrics: list[str],
error_reason: str,
) -> list[EvaluationResult]:
"""Mark all metrics for a single turn as ERROR.

Args:
conv_data: Conversation data
turn_idx: Index of the turn
turn_data: Turn data
turn_metrics: Metrics for this turn
error_reason: Reason for error

Returns:
list[EvaluationResult]: ERROR results for this turn's metrics
"""
logger.warning(
"Marking turn %d metrics as ERROR for conversation %s: %s",
turn_idx,
conv_data.conversation_group_id,
error_reason,
)
error_results = []

# Mark all turn-level metrics as ERROR
for metric_identifier in turn_metrics:
error_result = EvaluationResult(
conversation_group_id=conv_data.conversation_group_id,
turn_id=turn_data.turn_id,
metric_identifier=metric_identifier,
result="ERROR",
score=None,
threshold=None,
reason=error_reason,
query=turn_data.query,
response="",
execution_time=0.0,
)
error_results.append(error_result)

# Store results internally for summary tracking
self.results.extend(error_results)
return error_results

def mark_cascade_failure( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
conv_data: EvaluationData,
failed_turn_idx: int,
resolved_turn_metrics: list[list[str]],
resolved_conversation_metrics: list[str],
error_reason: str,
) -> list[EvaluationResult]:
"""Mark remaining turns and conversation metrics as ERROR (cascade failure).

Args:
conv_data: Conversation data
failed_turn_idx: Index of the turn that failed
resolved_turn_metrics: Resolved metrics for all turns
resolved_conversation_metrics: Resolved conversation metrics
error_reason: Reason for error

Returns:
list[EvaluationResult]: ERROR results for remaining turns and conversation
"""
logger.warning(
"Marking remaining turns (%d onwards) and conversation metrics as ERROR for %s: %s",
failed_turn_idx + 1,
conv_data.conversation_group_id,
error_reason,
)
error_results = []

# Mark remaining turns as ERROR (from failed_turn_idx + 1 onwards)
for turn_idx in range(failed_turn_idx + 1, len(conv_data.turns)):
turn_data = conv_data.turns[turn_idx]
turn_metrics = resolved_turn_metrics[turn_idx]

for metric_identifier in turn_metrics:
error_result = EvaluationResult(
conversation_group_id=conv_data.conversation_group_id,
turn_id=turn_data.turn_id,
metric_identifier=metric_identifier,
result="ERROR",
score=None,
threshold=None,
reason=error_reason,
query=turn_data.query,
response="",
execution_time=0.0,
)
error_results.append(error_result)

# Mark conversation-level metrics as ERROR
for metric_identifier in resolved_conversation_metrics:
error_result = EvaluationResult(
conversation_group_id=conv_data.conversation_group_id,
turn_id=None, # Conversation-level
metric_identifier=metric_identifier,
result="ERROR",
score=None,
threshold=None,
reason=error_reason,
query="",
response="",
execution_time=0.0,
)
error_results.append(error_result)

# Store results internally for summary tracking
self.results.extend(error_results)
return error_results

def get_error_summary(self) -> dict[str, int]:
"""Get summary of error results collected."""
return {
Expand Down
81 changes: 60 additions & 21 deletions src/lightspeed_evaluation/pipeline/evaluation/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def __init__(self, config_loader: ConfigLoader, components: ProcessorComponents)
self.config = config_loader.system_config
self.components = components

def process_conversation(self, conv_data: EvaluationData) -> list[EvaluationResult]:
def process_conversation( # pylint: disable=too-many-locals
self, conv_data: EvaluationData
) -> list[EvaluationResult]:
"""Process single conversation - handle turn and conversation level metrics.

Returns:
Expand Down Expand Up @@ -88,31 +90,68 @@ def process_conversation(self, conv_data: EvaluationData) -> list[EvaluationResu
return error_results

try:
# Step 2: Amend with API data if enabled
if self.config is None:
raise ValueError("SystemConfig must be loaded")
api_error_message = None
if self.config.api.enabled:
logger.debug("Amending data via API")
api_error_message = self.components.api_amender.amend_conversation_data(
conv_data
)

# If API error occurred, mark all metrics as ERROR and skip evaluation
if api_error_message:
logger.error("API error detected - marking all metrics as ERROR")
error_results = self.components.error_handler.mark_all_metrics_as_error(
conv_data,
api_error_message,
resolved_turn_metrics=resolved_turn_metrics,
resolved_conversation_metrics=resolved_conversation_metrics,
)
return error_results
# Step 2: Process each turn individually (API call + evaluation)
conversation_id: Optional[str] = None

# Step 3: Process turn-level metrics for each turn
for turn_idx, (turn_data, turn_metrics) in enumerate(
zip(conv_data.turns, resolved_turn_metrics)
):
# Step 2a: Amend with API data if enabled (per turn)
if self.config.api.enabled:
logger.debug("Processing turn %d: %s", turn_idx, turn_data.turn_id)
api_error_message, conversation_id = (
self.components.api_amender.amend_single_turn(
turn_data, conversation_id
)
)
logger.debug(
"✅ API Call completed for turn %d: %s",
turn_idx,
turn_data.turn_id,
)

# If API error occurred, mark current turn + remaining + conversation as ERROR
if api_error_message:
logger.error(
"API error for turn %d - marking current turn, "
"remaining turns, and conversation as ERROR",
turn_idx,
)
# Mark current turn as ERROR
current_turn_errors = (
self.components.error_handler.mark_turn_metrics_as_error(
conv_data,
turn_idx,
turn_data,
turn_metrics,
api_error_message,
)
)
results.extend(current_turn_errors)

# Mark remaining turns and conversation metrics as ERROR
cascade_error_reason = (
f"Cascade failure from turn {turn_idx + 1} API error: "
f"{api_error_message}"
)
remaining_errors = (
self.components.error_handler.mark_cascade_failure(
conv_data,
turn_idx,
resolved_turn_metrics,
resolved_conversation_metrics,
cascade_error_reason,
)
)
results.extend(remaining_errors)

# Stop processing - API failure cascades to all remaining
return results

# Step 2b: Process turn-level metrics for this turn
if turn_metrics:
logger.debug(
"Processing turn %d metrics: %s", turn_idx, turn_metrics
Expand All @@ -122,7 +161,7 @@ def process_conversation(self, conv_data: EvaluationData) -> list[EvaluationResu
)
results.extend(turn_results)

# Step 4: Process conversation-level metrics
# Step 3: Process conversation-level metrics
if resolved_conversation_metrics:
logger.debug(
"Processing conversation-level metrics: %s",
Expand All @@ -136,7 +175,7 @@ def process_conversation(self, conv_data: EvaluationData) -> list[EvaluationResu
return results

finally:
# Step 5: Always run cleanup script (if provided) regardless of results
# Step 4: Always run cleanup script (if provided) regardless of results
self._run_cleanup_script(conv_data)

def _evaluate_turn(
Expand Down
Loading
Loading