diff --git a/bench/hallucination/README.md b/bench/hallucination/README.md new file mode 100644 index 0000000000..bffd5427e4 --- /dev/null +++ b/bench/hallucination/README.md @@ -0,0 +1,137 @@ +# Hallucination Detection Benchmark + +E2E evaluation of hallucination detection through the semantic router. + +## Quick Start + +```bash +# 1. Start vLLM (if not running) +docker run -d --gpus all -p 8083:8000 vllm/vllm-openai:latest \ + vllm serve --model Qwen/Qwen2.5-14B-Instruct-AWQ + +# 2. Start semantic router with hallucination config +cd /path/to/semantic-router +export LD_LIBRARY_PATH=$PWD/candle-binding/target/release +./bin/router -config=bench/hallucination/config-7b.yaml + +# 3. Start Envoy +make run-envoy + +# 4. Run benchmark +python3 -m bench.hallucination.evaluate \ + --endpoint http://localhost:8801 \ + --dataset halueval \ + --max-samples 50 + +# Or use the Makefile target: +make bench-hallucination MAX_SAMPLES=50 +``` + +## Using the Large Model + +The large model (`lettucedect-large-modernbert-en-v1`, 395M params) provides better detection accuracy than the base model. + +### Step 1: Download the Large Model + +```bash +cd /path/to/semantic-router + +# Download from HuggingFace +hf download KRLabsOrg/lettucedect-large-modernbert-en-v1 \ + --local-dir models/lettucedect-large-modernbert-en-v1 +``` + +### Step 2: Update Config + +Edit `bench/hallucination/config-7b.yaml`: + +```yaml +hallucination_mitigation: + enabled: true + + hallucination_model: + model_id: "models/lettucedect-large-modernbert-en-v1" # Use large model + threshold: 0.5 + use_cpu: true # Set to false for GPU +``` + +### Step 3: Restart Router + +```bash +# Kill existing router +pkill -f "router.*config" + +# Start with updated config +export LD_LIBRARY_PATH=$PWD/candle-binding/target/release +./bin/router -config=bench/hallucination/config-7b.yaml +``` + +## Supported Models + +| Model | Params | HuggingFace ID | +|-------|--------|----------------| +| Base | 149M | `KRLabsOrg/lettucedect-base-modernbert-en-v1` | +| Large | 395M | `KRLabsOrg/lettucedect-large-modernbert-en-v1` | + +Both use `ModernBertForTokenClassification` architecture supported by candle-binding. + +## Config Reference + +Key settings in `config-7b.yaml`: + +```yaml +# vLLM endpoint +vllm_endpoints: + - name: "vllm-general" + address: "127.0.0.1" + port: 8083 + +# Hallucination detection +hallucination_mitigation: + enabled: true + hallucination_model: + model_id: "models/lettucedect-large-modernbert-en-v1" + threshold: 0.5 + use_cpu: true + on_hallucination_detected: "warn" # or "block" +``` + +## Datasets + +| Dataset | Command | +|---------|---------| +| HaluEval | `--dataset halueval` | +| Custom | `--dataset /path/to/data.jsonl` | + +## Output + +Results saved to `bench/hallucination/results/` with: + +- Precision, Recall, F1 (when ground truth available) +- Latency metrics (avg, p50, p99) +- Per-sample detection results + +### Two-Stage Pipeline Efficiency Metrics + +The benchmark tracks the computational savings from the two-stage detection pipeline: + +``` +⚔ Two-Stage Pipeline Efficiency: +---------------------------------------- + Fact-check needed: 65/100 queries + Detection skipped: 35/100 queries + Avg context length: 4500 chars + Estimated detect time: 6500.00 ms (if all ran) + Actual detect time: 4225.00 ms + Time saved: 2275.00 ms + Efficiency gain: 35.0% + + šŸ’” Pre-filtering skipped 35.0% of requests, + saving 2275ms of detection compute. +``` + +This demonstrates the value of the HaluGate Sentinel pre-classifier: + +- **O(1) filtering** before **O(n) detection** (n = context length) +- Non-factual queries (creative, opinion, brainstorming) skip expensive token classification +- Critical for RAG applications with large contexts (8K+ tokens) diff --git a/bench/hallucination/__init__.py b/bench/hallucination/__init__.py new file mode 100644 index 0000000000..b8e3d07a13 --- /dev/null +++ b/bench/hallucination/__init__.py @@ -0,0 +1,10 @@ +"""Hallucination Detection Benchmark for Semantic Router. + +This package provides end-to-end evaluation of the hallucination detection pipeline +through the router + Envoy stack. +""" + +from .evaluate import HallucinationBenchmark +from .datasets import HaluEvalDataset, CustomDataset, get_dataset + +__all__ = ["HallucinationBenchmark", "HaluEvalDataset", "CustomDataset", "get_dataset"] diff --git a/bench/hallucination/config-7b.yaml b/bench/hallucination/config-7b.yaml new file mode 100644 index 0000000000..88f46031cf --- /dev/null +++ b/bench/hallucination/config-7b.yaml @@ -0,0 +1,182 @@ +# Configuration for hallucination detection benchmark +# Connects to real vLLM server at port 8083 + +bert_model: + model_id: models/all-MiniLM-L12-v2 + threshold: 0.6 + use_cpu: true + +semantic_cache: + enabled: false # Disable cache for benchmarking + +# Classifier configuration +classifier: + category_model: + model_id: "models/category_classifier_modernbert-base_model" + use_modernbert: true + threshold: 0.6 + use_cpu: true + category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + pii_model: + model_id: "models/pii_classifier_modernbert-base_presidio_token_model" + use_modernbert: true + threshold: 0.7 + use_cpu: true + pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" + +# Hallucination mitigation configuration +hallucination_mitigation: + enabled: true + + # Fact-check classifier: determines if a prompt needs fact verification + fact_check_model: + model_id: "models/halugate-sentinel" + threshold: 0.6 + use_cpu: true + mapping_path: "models/halugate-sentinel/fact_check_mapping.json" + + # Hallucination detector: verifies if LLM response is grounded in context + # Using large model (395M params) for better accuracy + hallucination_model: + model_id: "models/lettucedect-large-modernbert-en-v1" + threshold: 0.5 + use_cpu: true + + # NLI model: provides explanations for hallucinated spans + nli_model: + model_id: "models/ModernBERT-base-nli" + threshold: 0.7 + use_cpu: true + + # Action when hallucination detected: "warn" adds headers, "block" returns error + on_hallucination_detected: "warn" + +# Fact-check rules for signal classification +# The classifier outputs one of these signals that can be referenced in decision conditions +fact_check_rules: + - name: needs_fact_check + description: "Query contains factual claims that should be verified against context" + - name: no_fact_check_needed + description: "Query is creative, code-related, or opinion-based - no fact verification needed" + +# Prompt guard +prompt_guard: + enabled: true + use_modernbert: true + model_id: "models/jailbreak_classifier_modernbert-base_model" + threshold: 0.7 + use_cpu: true + jailbreak_mapping_path: "models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json" + +# vLLM endpoint - real vLLM server +vllm_endpoints: + - name: "vllm-general" + address: "127.0.0.1" + port: 8083 + weight: 1 + health_check_path: "/health" + +# Model configuration - use the actual model from vLLM +model_config: + "Qwen/Qwen2.5-14B-Instruct-AWQ": + reasoning_family: "qwen3" + preferred_endpoints: ["vllm-general"] + +# Categories for routing +categories: + - name: general + description: "General questions" + mmlu_categories: ["other"] + - name: math + description: "Mathematics and quantitative reasoning" + mmlu_categories: ["math"] + - name: science + description: "Science questions" + mmlu_categories: ["physics", "chemistry", "biology"] + +strategy: "priority" + +decisions: + - name: "math_decision" + description: "Mathematics and quantitative reasoning" + priority: 100 + rules: + operator: "AND" + conditions: + - type: "domain" + name: "math" + modelRefs: + - model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + use_reasoning: true + plugins: + - type: "pii" + configuration: + enabled: true + pii_types_allowed: [] + - type: "hallucination" + configuration: + enabled: true + use_nli: true + hallucination_action: "header" + unverified_factual_action: "header" + include_hallucination_details: false + + - name: "science_decision" + description: "Science questions" + priority: 100 + rules: + operator: "AND" + conditions: + - type: "domain" + name: "science" + modelRefs: + - model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + use_reasoning: true + plugins: + - type: "pii" + configuration: + enabled: true + pii_types_allowed: [] + - type: "hallucination" + configuration: + enabled: true + use_nli: true + hallucination_action: "header" + unverified_factual_action: "header" + include_hallucination_details: false + + - name: "general_decision" + description: "General questions" + priority: 50 + rules: + operator: "AND" + conditions: + - type: "domain" + name: "general" + modelRefs: + - model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + use_reasoning: false + plugins: + - type: "pii" + configuration: + enabled: true + pii_types_allowed: [] + - type: "hallucination" + configuration: + enabled: true + use_nli: true + hallucination_action: "header" + unverified_factual_action: "header" + include_hallucination_details: false + +default_model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + +# API Configuration +api: + batch_classification: + metrics: + enabled: true + detailed_goroutine_tracking: true + high_resolution_timing: false + sample_rate: 1.0 + diff --git a/bench/hallucination/config.yaml b/bench/hallucination/config.yaml new file mode 100644 index 0000000000..29f8edfa61 --- /dev/null +++ b/bench/hallucination/config.yaml @@ -0,0 +1,182 @@ +# Configuration for hallucination detection benchmark +# Connects to real vLLM server at port 8083 + +bert_model: + model_id: models/all-MiniLM-L12-v2 + threshold: 0.6 + use_cpu: true + +semantic_cache: + enabled: false # Disable cache for benchmarking + +# Classifier configuration +classifier: + category_model: + model_id: "models/category_classifier_modernbert-base_model" + use_modernbert: true + threshold: 0.6 + use_cpu: true + category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + pii_model: + model_id: "models/pii_classifier_modernbert-base_presidio_token_model" + use_modernbert: true + threshold: 0.7 + use_cpu: true + pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" + +# Hallucination mitigation configuration +hallucination_mitigation: + enabled: true + + # Fact-check classifier: determines if a prompt needs fact verification + fact_check_model: + model_id: "models/halugate-sentinel" + threshold: 0.6 + use_cpu: true + mapping_path: "models/halugate-sentinel/fact_check_mapping.json" + + # Hallucination detector: verifies if LLM response is grounded in context + # Using large model (395M params) for better accuracy + hallucination_model: + model_id: "models/lettucedect-large-modernbert-en-v1" + threshold: 0.5 + use_cpu: true + + # NLI model: provides explanations for hallucinated spans + nli_model: + model_id: "models/ModernBERT-base-nli" + threshold: 0.7 + use_cpu: true + + # Action when hallucination detected: "warn" adds headers, "block" returns error + on_hallucination_detected: "warn" + +# Fact-check rules for signal classification +# The classifier outputs one of these signals that can be referenced in decision conditions +fact_check_rules: + - name: needs_fact_check + description: "Query contains factual claims that should be verified against context" + - name: no_fact_check_needed + description: "Query is creative, code-related, or opinion-based - no fact verification needed" + +# Prompt guard +prompt_guard: + enabled: true + use_modernbert: true + model_id: "models/jailbreak_classifier_modernbert-base_model" + threshold: 0.7 + use_cpu: true + jailbreak_mapping_path: "models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json" + +# vLLM endpoint - real vLLM server +vllm_endpoints: + - name: "vllm-qwen" + address: "127.0.0.1" + port: 8083 + weight: 1 + health_check_path: "/health" + +# Model configuration - use the actual model from vLLM +model_config: + "Qwen/Qwen2.5-14B-Instruct-AWQ": + reasoning_family: "qwen3" + preferred_endpoints: ["vllm-qwen"] + +# Categories for routing +categories: + - name: general + description: "General questions" + mmlu_categories: ["other"] + - name: math + description: "Mathematics and quantitative reasoning" + mmlu_categories: ["math"] + - name: science + description: "Science questions" + mmlu_categories: ["physics", "chemistry", "biology"] + +strategy: "priority" + +decisions: + - name: "math_decision" + description: "Mathematics and quantitative reasoning" + priority: 100 + rules: + operator: "AND" + conditions: + - type: "domain" + name: "math" + modelRefs: + - model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + use_reasoning: true + plugins: + - type: "pii" + configuration: + enabled: true + pii_types_allowed: [] + - type: "hallucination" + configuration: + enabled: true + use_nli: true + hallucination_action: "header" + unverified_factual_action: "header" + include_hallucination_details: false + + - name: "science_decision" + description: "Science questions" + priority: 100 + rules: + operator: "AND" + conditions: + - type: "domain" + name: "science" + modelRefs: + - model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + use_reasoning: true + plugins: + - type: "pii" + configuration: + enabled: true + pii_types_allowed: [] + - type: "hallucination" + configuration: + enabled: true + use_nli: true + hallucination_action: "header" + unverified_factual_action: "header" + include_hallucination_details: false + + - name: "general_decision" + description: "General questions" + priority: 50 + rules: + operator: "AND" + conditions: + - type: "domain" + name: "general" + modelRefs: + - model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + use_reasoning: false + plugins: + - type: "pii" + configuration: + enabled: true + pii_types_allowed: [] + - type: "hallucination" + configuration: + enabled: true + use_nli: true + hallucination_action: "header" + unverified_factual_action: "header" + include_hallucination_details: false + +default_model: "Qwen/Qwen2.5-14B-Instruct-AWQ" + +# API Configuration +api: + batch_classification: + metrics: + enabled: true + detailed_goroutine_tracking: true + high_resolution_timing: false + sample_rate: 1.0 + diff --git a/bench/hallucination/datasets.py b/bench/hallucination/datasets.py new file mode 100644 index 0000000000..7f015fa3bf --- /dev/null +++ b/bench/hallucination/datasets.py @@ -0,0 +1,146 @@ +"""Dataset loaders for hallucination detection benchmarks. + +Supports: +- HaluEval: General hallucination evaluation dataset +- Custom datasets in JSONL format +""" + +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Iterator + +try: + from datasets import load_dataset + + HAS_DATASETS = True +except ImportError: + HAS_DATASETS = False + + +@dataclass +class HallucinationSample: + """A single sample for hallucination detection evaluation.""" + + id: str + context: str # Ground truth context (tool results / RAG context) + question: str # User question + gold_answer: str # Correct answer from context + llm_response: Optional[str] = None # Generated LLM response + hallucination_spans: Optional[List[dict]] = None # Annotated hallucination spans + is_faithful: Optional[bool] = None # Whether response is faithful to context + metadata: Optional[dict] = None # Additional metadata + + +class DatasetInterface(ABC): + """Abstract interface for hallucination datasets.""" + + @abstractmethod + def load(self, max_samples: Optional[int] = None) -> List[HallucinationSample]: + """Load the dataset.""" + pass + + @abstractmethod + def name(self) -> str: + """Dataset name.""" + pass + + +class HaluEvalDataset(DatasetInterface): + """HaluEval dataset loader - general hallucination evaluation.""" + + def name(self) -> str: + return "halueval" + + def load(self, max_samples: Optional[int] = None) -> List[HallucinationSample]: + """Load HaluEval dataset from HuggingFace.""" + if not HAS_DATASETS: + raise ImportError( + "datasets package not installed. Run: pip install datasets" + ) + + print(f"Loading HaluEval dataset...") + + try: + dataset = load_dataset("pminervini/HaluEval", "qa_samples", split="data") + except Exception as e: + print(f"Error loading HaluEval: {e}") + raise + + samples = [] + for i, item in enumerate(dataset): + if max_samples and i >= max_samples: + break + + samples.append( + HallucinationSample( + id=f"halueval_{i}", + context=item.get("knowledge", ""), + question=item.get("question", ""), + gold_answer=item.get("right_answer", ""), + llm_response=item.get( + "hallucinated_answer", item.get("answer", "") + ), + is_faithful=item.get("hallucination", "no") == "no", + metadata={ + "dataset": "halueval", + }, + ) + ) + + print(f"Loaded {len(samples)} samples from HaluEval") + return samples + + +class CustomDataset(DatasetInterface): + """Load custom dataset from JSONL file.""" + + def __init__(self, file_path: str): + self.file_path = Path(file_path) + self._name = self.file_path.stem + + def name(self) -> str: + return self._name + + def load(self, max_samples: Optional[int] = None) -> List[HallucinationSample]: + """Load dataset from JSONL file.""" + if not self.file_path.exists(): + raise FileNotFoundError(f"Dataset file not found: {self.file_path}") + + samples = [] + with open(self.file_path) as f: + for i, line in enumerate(f): + if max_samples and i >= max_samples: + break + + item = json.loads(line.strip()) + samples.append( + HallucinationSample( + id=item.get("id", f"custom_{i}"), + context=item.get("context", item.get("knowledge", "")), + question=item.get("question", item.get("query", "")), + gold_answer=item.get("gold_answer", item.get("answer", "")), + llm_response=item.get("llm_response", item.get("response", "")), + hallucination_spans=item.get("hallucination_spans", []), + is_faithful=item.get("is_faithful"), + metadata=item.get("metadata", {}), + ) + ) + + print(f"Loaded {len(samples)} samples from {self.file_path}") + return samples + + +def get_dataset(name: str, **kwargs) -> DatasetInterface: + """Factory function to get a dataset by name.""" + datasets = { + "halueval": HaluEvalDataset, + } + + if name in datasets: + return datasets[name]() + elif Path(name).exists(): + return CustomDataset(name) + else: + raise ValueError(f"Unknown dataset: {name}. Available: {list(datasets.keys())}") diff --git a/bench/hallucination/evaluate.py b/bench/hallucination/evaluate.py new file mode 100644 index 0000000000..3a0bfbf16c --- /dev/null +++ b/bench/hallucination/evaluate.py @@ -0,0 +1,624 @@ +"""End-to-end hallucination detection benchmark. + +This module evaluates the hallucination detection pipeline by sending requests +to an already-running semantic router + vLLM setup. + +The benchmark: +1. Discovers available models from /v1/models endpoint +2. Sends requests with tool context through the router +3. Collects hallucination detection results from response headers +4. Calculates precision/recall metrics + +Usage: + # Discover models and run quick test + python -m bench.hallucination_bench.evaluate --max-samples 10 + + # Full evaluation with HaluEval dataset + python -m bench.hallucination_bench.evaluate --dataset halueval --max-samples 100 + + # Specify custom endpoint + python -m bench.hallucination_bench.evaluate --endpoint http://localhost:8801 +""" + +import argparse +import json +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, List, Optional + +import requests + +from .datasets import HallucinationSample, get_dataset + + +@dataclass +class DetectionResult: + """Result from hallucination detection.""" + + sample_id: str + question: str + context_length: int + response_length: int + model_used: str + + # Hallucination detection results (from response headers) + hallucination_detected: Optional[bool] = None + hallucination_spans: Optional[str] = None + + # Efficiency metrics (from response headers) + fact_check_needed: Optional[bool] = None # Was fact-check classification triggered? + detection_skipped: bool = False # Was detection skipped due to non-factual query? + + # Ground truth (if available) + gt_is_faithful: Optional[bool] = None + + # Metrics + latency_ms: Optional[float] = None + error: Optional[str] = None + + +@dataclass +class BenchmarkResults: + """Aggregated benchmark results.""" + + dataset_name: str + endpoint: str + model: str + total_samples: int + successful_requests: int + failed_requests: int + + # Detection metrics + total_hallucinations_detected: int + + # Accuracy (when ground truth available) + true_positives: int # Correctly detected hallucinations + false_positives: int # Incorrectly flagged as hallucination + true_negatives: int # Correctly identified faithful responses + false_negatives: int # Missed hallucinations + + precision: float + recall: float + f1_score: float + accuracy: float + + # Latency + avg_latency_ms: float + p50_latency_ms: float + p99_latency_ms: float + + # Individual results + results: List[dict] = None # type: ignore # Initialized in __post_init__ + + # Efficiency metrics (two-stage pipeline savings) + fact_check_needed_count: int = 0 # Queries that needed fact-checking + detection_skipped_count: int = 0 # Queries where detection was skipped + avg_context_length: float = 0.0 # Average context size in chars + estimated_detection_time_ms: float = 0.0 # Estimated time if all ran detection + actual_detection_time_ms: float = 0.0 # Actual detection time (skipped = 0) + time_saved_ms: float = 0.0 # Time saved by pre-filtering + efficiency_gain_percent: float = 0.0 # Percentage improvement + + def __post_init__(self): + if self.results is None: + self.results = [] + + +class HallucinationBenchmark: + """End-to-end hallucination detection benchmark.""" + + def __init__( + self, + endpoint: str = "http://localhost:8801", + timeout: int = 120, + ): + self.endpoint = endpoint.rstrip("/") + self.timeout = timeout + self.results: List[DetectionResult] = [] + self.available_models: List[str] = [] + + def discover_models(self) -> List[str]: + """Discover available models from /v1/models endpoint.""" + try: + resp = requests.get(f"{self.endpoint}/v1/models", timeout=10) + if resp.status_code != 200: + print(f"āŒ Failed to get models: HTTP {resp.status_code}") + return [] + + data = resp.json() + models = [ + m.get("id", m.get("name", "unknown")) for m in data.get("data", []) + ] + self.available_models = models + + print(f"āœ“ Discovered {len(models)} models:") + for m in models: + print(f" - {m}") + + return models + except requests.exceptions.RequestException as e: + print(f"āŒ Failed to discover models: {e}") + return [] + + def check_health(self) -> bool: + """Check if the endpoint is healthy.""" + try: + # Try /v1/models as health check + resp = requests.get(f"{self.endpoint}/v1/models", timeout=10) + if resp.status_code == 200: + print(f"āœ“ Endpoint {self.endpoint} is healthy") + return True + else: + print(f"āŒ Endpoint returned HTTP {resp.status_code}") + return False + except requests.exceptions.RequestException as e: + print(f"āŒ Health check failed: {e}") + return False + + def send_request( + self, + sample: HallucinationSample, + model: str, + include_context: bool = True, + ) -> DetectionResult: + """Send a request and collect hallucination detection results.""" + start_time = time.time() + + result = DetectionResult( + sample_id=sample.id, + question=sample.question[:100], + context_length=len(sample.context), + response_length=0, + model_used=model, + gt_is_faithful=sample.is_faithful, + ) + + try: + # Build messages - include tool context if available + messages = [{"role": "user", "content": sample.question}] + + if include_context and sample.context: + # Add tool result with context (enables hallucination detection) + messages.append( + { + "role": "tool", + "tool_call_id": f"ctx_{sample.id}", + "content": sample.context, + } + ) + + # Send request + response = requests.post( + f"{self.endpoint}/v1/chat/completions", + json={ + "model": model, + "messages": messages, + "max_tokens": 512, + "temperature": 0.7, + }, + timeout=self.timeout, + ) + + latency_ms = (time.time() - start_time) * 1000 + result.latency_ms = latency_ms + + if response.status_code != 200: + result.error = f"HTTP {response.status_code}: {response.text[:200]}" + return result + + # Parse response + data = response.json() + if "choices" in data and data["choices"]: + content = data["choices"][0].get("message", {}).get("content", "") + result.response_length = len(content) + + # Extract hallucination detection headers + headers = response.headers + result.hallucination_detected = ( + headers.get("x-vsr-hallucination-detected", "").lower() == "true" + ) + result.hallucination_spans = headers.get("x-vsr-hallucination-spans", "") + + # Extract efficiency-related headers + result.fact_check_needed = ( + headers.get("x-vsr-fact-check-needed", "").lower() == "true" + ) + # Detection is skipped if fact-check says not needed, or if unverified (no context) + unverified = ( + headers.get("x-vsr-unverified-factual-response", "").lower() == "true" + ) + result.detection_skipped = not result.fact_check_needed or unverified + + except requests.exceptions.Timeout: + result.error = "Request timeout" + except requests.exceptions.RequestException as e: + result.error = str(e) + except Exception as e: + result.error = f"Unexpected error: {e}" + + return result + + def run_benchmark( + self, + samples: List[HallucinationSample], + model: str, + include_context: bool = True, + verbose: bool = True, + ) -> BenchmarkResults: + """Run the benchmark on a list of samples.""" + from tqdm import tqdm + + print(f"\nRunning hallucination detection benchmark...") + print(f" Endpoint: {self.endpoint}") + print(f" Model: {model}") + print(f" Samples: {len(samples)}") + print(f" Context: {'included' if include_context else 'excluded'}") + print() + + self.results = [] + detected_count = 0 + + # Use tqdm with leave=True to keep progress bar visible + pbar = tqdm( + samples, + desc="Evaluating", + unit="sample", + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}] Hallucinations: {postfix}", + ) + pbar.set_postfix_str("0") + + for sample in pbar: + result = self.send_request(sample, model, include_context=include_context) + self.results.append(result) + + # Print hallucination detection immediately when found + if result.hallucination_detected and verbose: + detected_count += 1 + pbar.set_postfix_str(str(detected_count)) + + # Print detection details below progress bar + tqdm.write("") + tqdm.write(f"🚨 HALLUCINATION DETECTED [{detected_count}]") + tqdm.write(f" Question: {result.question[:80]}...") + if result.hallucination_spans: + spans_preview = result.hallucination_spans[:150] + tqdm.write( + f" Spans: {spans_preview}{'...' if len(result.hallucination_spans) > 150 else ''}" + ) + if result.gt_is_faithful is not None: + gt_label = ( + "āœ… Correct" + if not result.gt_is_faithful + else "āŒ False Positive" + ) + tqdm.write(f" Ground Truth: {gt_label}") + tqdm.write(f" Latency: {result.latency_ms:.0f}ms") + tqdm.write("") + + # Also print errors immediately + elif result.error and verbose: + tqdm.write(f"āŒ ERROR: {result.error[:100]}") + + dataset_name = ( + samples[0].metadata.get("dataset", "unknown") if samples else "unknown" + ) + return self.calculate_metrics(dataset_name, model) + + def calculate_metrics(self, dataset_name: str, model: str) -> BenchmarkResults: + """Calculate benchmark metrics from results.""" + successful = [r for r in self.results if r.error is None] + failed = [r for r in self.results if r.error is not None] + + # Detection counts + hallucinations_detected = sum(1 for r in successful if r.hallucination_detected) + + # Accuracy metrics (when ground truth available) + tp = fp = tn = fn = 0 + for r in successful: + if r.gt_is_faithful is not None: + detected = r.hallucination_detected or False + is_hallucination = not r.gt_is_faithful + + if detected and is_hallucination: + tp += 1 + elif detected and not is_hallucination: + fp += 1 + elif not detected and not is_hallucination: + tn += 1 + elif not detected and is_hallucination: + fn += 1 + + total_with_gt = tp + fp + tn + fn + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + accuracy = (tp + tn) / total_with_gt if total_with_gt > 0 else 0.0 + + # Latency metrics + latencies = [r.latency_ms for r in successful if r.latency_ms is not None] + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + sorted_latencies = sorted(latencies) + p50 = sorted_latencies[len(sorted_latencies) // 2] if sorted_latencies else 0.0 + p99_idx = min(int(len(sorted_latencies) * 0.99), len(sorted_latencies) - 1) + p99 = sorted_latencies[p99_idx] if sorted_latencies else 0.0 + + # Efficiency metrics - calculate savings from two-stage pipeline + fact_check_needed_count = sum(1 for r in successful if r.fact_check_needed) + detection_skipped_count = sum(1 for r in successful if r.detection_skipped) + + # Average context length + context_lengths = [r.context_length for r in successful if r.context_length > 0] + avg_context_length = ( + sum(context_lengths) / len(context_lengths) if context_lengths else 0.0 + ) + + # Estimate detection time based on context length + # ModernBERT-large (395M params) timing estimates: + # - Fact-check classifier (ModernBERT-base): ~12ms (prompt only) + # - LettuceDetect-large: ~45ms base + ~0.02ms per character + # These are empirical measurements on CPU; GPU would be faster + CLASSIFIER_TIME_MS = ( + 12.0 # Fact-check classifier time (ModernBERT-base, prompt only) + ) + DETECTION_BASE_MS = 45.0 # Base detection time (ModernBERT-large) + DETECTION_PER_CHAR_MS = 0.02 # Additional time per character (large model) + + estimated_detection_times = [] + actual_detection_times = [] + for r in successful: + # Estimated time if we ran detection on everything + est_time = DETECTION_BASE_MS + (r.context_length * DETECTION_PER_CHAR_MS) + estimated_detection_times.append(est_time) + + # Actual time (0 if skipped, otherwise estimated) + if r.detection_skipped: + actual_detection_times.append(0.0) + else: + actual_detection_times.append(est_time) + + estimated_total = sum(estimated_detection_times) + actual_total = sum(actual_detection_times) + time_saved = estimated_total - actual_total + efficiency_gain = ( + (time_saved / estimated_total * 100) if estimated_total > 0 else 0.0 + ) + + return BenchmarkResults( + dataset_name=dataset_name, + endpoint=self.endpoint, + model=model, + total_samples=len(self.results), + successful_requests=len(successful), + failed_requests=len(failed), + total_hallucinations_detected=hallucinations_detected, + true_positives=tp, + false_positives=fp, + true_negatives=tn, + false_negatives=fn, + precision=precision, + recall=recall, + f1_score=f1, + accuracy=accuracy, + avg_latency_ms=avg_latency, + p50_latency_ms=p50, + p99_latency_ms=p99, + # Efficiency metrics + fact_check_needed_count=fact_check_needed_count, + detection_skipped_count=detection_skipped_count, + avg_context_length=avg_context_length, + estimated_detection_time_ms=estimated_total, + actual_detection_time_ms=actual_total, + time_saved_ms=time_saved, + efficiency_gain_percent=efficiency_gain, + results=[asdict(r) for r in self.results], + ) + + +def print_results(results: BenchmarkResults): + """Print benchmark results.""" + print("\n" + "=" * 60) + print(f"HALLUCINATION DETECTION BENCHMARK") + print("=" * 60) + + print(f"\nšŸ“Œ Configuration:") + print("-" * 40) + print(f" Dataset: {results.dataset_name}") + print(f" Endpoint: {results.endpoint}") + print(f" Model: {results.model}") + + print(f"\nšŸ“Š Request Statistics:") + print("-" * 40) + print(f" Total samples: {results.total_samples}") + print(f" Successful: {results.successful_requests}") + print(f" Failed: {results.failed_requests}") + + print(f"\nšŸ“‹ Detection Results:") + print("-" * 40) + print(f" Hallucinations detected: {results.total_hallucinations_detected}") + + total_gt = ( + results.true_positives + + results.false_positives + + results.true_negatives + + results.false_negatives + ) + if total_gt > 0: + print(f"\nšŸŽÆ Accuracy Metrics (vs ground truth):") + print("-" * 40) + print(f" True Positives: {results.true_positives}") + print(f" False Positives: {results.false_positives}") + print(f" True Negatives: {results.true_negatives}") + print(f" False Negatives: {results.false_negatives}") + print(f" Precision: {results.precision:.4f}") + print(f" Recall: {results.recall:.4f}") + print(f" F1 Score: {results.f1_score:.4f}") + print(f" Accuracy: {results.accuracy:.4f}") + else: + print(f"\nāš ļø No ground truth labels available for accuracy metrics") + + print(f"\nā±ļø Latency:") + print("-" * 40) + print(f" Average: {results.avg_latency_ms:.2f} ms") + print(f" P50: {results.p50_latency_ms:.2f} ms") + print(f" P99: {results.p99_latency_ms:.2f} ms") + + # Efficiency metrics - two-stage pipeline savings + print(f"\n⚔ Two-Stage Pipeline Efficiency:") + print("-" * 40) + print( + f" Fact-check needed: {results.fact_check_needed_count}/{results.successful_requests} queries" + ) + print( + f" Detection skipped: {results.detection_skipped_count}/{results.successful_requests} queries" + ) + print(f" Avg context length: {results.avg_context_length:.0f} chars") + print( + f" Estimated detect time: {results.estimated_detection_time_ms:.2f} ms (if all ran)" + ) + print(f" Actual detect time: {results.actual_detection_time_ms:.2f} ms") + print(f" Time saved: {results.time_saved_ms:.2f} ms") + print(f" Efficiency gain: {results.efficiency_gain_percent:.1f}%") + + if results.detection_skipped_count > 0: + skip_rate = results.detection_skipped_count / results.successful_requests * 100 + print(f"\n šŸ’” Pre-filtering skipped {skip_rate:.1f}% of requests,") + print(f" saving {results.time_saved_ms:.0f}ms of detection compute.") + + # Show sample results + print(f"\nšŸ” Sample Results (first 5):") + print("-" * 60) + for r in results.results[:5]: + detected = "🚨" if r.get("hallucination_detected") else "āœ…" + gt = ( + "HAL" + if r.get("gt_is_faithful") is False + else ("OK" if r.get("gt_is_faithful") else "?") + ) + print(f" {detected} GT:{gt} | {r.get('question', '')[:50]}...") + + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="Hallucination Detection Benchmark") + + # Endpoint options + parser.add_argument( + "--endpoint", + type=str, + default="http://localhost:8801", + help="Router/Envoy endpoint URL", + ) + + # Dataset options + parser.add_argument( + "--dataset", + type=str, + default="halueval", + help="Dataset to use (halueval, or path to JSONL)", + ) + parser.add_argument( + "--max-samples", type=int, default=50, help="Maximum samples to evaluate" + ) + + # Model options + parser.add_argument( + "--model", + type=str, + default=None, + help="Model to use (auto-discovered if not specified)", + ) + + # Test modes + parser.add_argument( + "--no-context", + action="store_true", + help="Don't include tool context (disables hallucination detection)", + ) + parser.add_argument( + "--list-models", action="store_true", help="Just list available models and exit" + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Don't print hallucinations as they're detected", + ) + + # Output options + parser.add_argument( + "--output-dir", + type=str, + default="bench/hallucination/results", + help="Output directory for results", + ) + + args = parser.parse_args() + + # Create benchmark + benchmark = HallucinationBenchmark(endpoint=args.endpoint) + + # Check health and discover models + if not benchmark.check_health(): + print(f"\nāŒ Cannot connect to {args.endpoint}") + print(" Make sure the semantic router and vLLM are running.") + sys.exit(1) + + models = benchmark.discover_models() + if not models: + print("\nāŒ No models available") + sys.exit(1) + + if args.list_models: + print("\nAvailable models:") + for m in models: + print(f" - {m}") + sys.exit(0) + + # Select model + model = args.model or models[0] + if model not in models and model != "auto": + print(f"\nāš ļø Model '{model}' not in available models, using anyway...") + + print(f"\nšŸ“¦ Using model: {model}") + + # Load dataset + try: + dataset = get_dataset(args.dataset) + samples = dataset.load(max_samples=args.max_samples) + except Exception as e: + print(f"\nāŒ Failed to load dataset: {e}") + sys.exit(1) + + # Run benchmark + results = benchmark.run_benchmark( + samples, + model=model, + include_context=not args.no_context, + verbose=not args.quiet, + ) + + # Print results + print_results(results) + + # Save results + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + timestamp = time.strftime("%Y%m%d_%H%M%S") + results_file = output_dir / f"results_{args.dataset}_{timestamp}.json" + with open(results_file, "w") as f: + # Don't include individual results in JSON to keep file small + summary = asdict(results) + summary["results"] = summary["results"][:10] # Only first 10 for summary + json.dump(summary, f, indent=2) + print(f"\nšŸ“ Results saved to {results_file}") + + +if __name__ == "__main__": + main() diff --git a/tools/make/build-run-test.mk b/tools/make/build-run-test.mk index bc99392af1..94fcf19805 100644 --- a/tools/make/build-run-test.mk +++ b/tools/make/build-run-test.mk @@ -144,4 +144,29 @@ test-e2e-vllm: @echo "āš ļø Note: Make sure LLM Katan servers are running with 'make start-llm-katan'" @python3 e2e-tests/run_all_tests.py +# Run hallucination detection benchmark +# Requires: router running with hallucination config, vLLM endpoint, envoy proxy +bench-hallucination: ## Run hallucination detection benchmark (requires router + vLLM + envoy running) +bench-hallucination: + @echo "Running hallucination detection benchmark..." + @echo "āš ļø Prerequisites:" + @echo " 1. vLLM server running (e.g., docker container on port 8083)" + @echo " 2. Router: make run-router CONFIG_FILE=bench/hallucination/config-7b.yaml" + @echo " 3. Envoy: make run-envoy" + @echo "" + python3 -m bench.hallucination.evaluate \ + --endpoint http://localhost:8801 \ + --dataset halueval \ + --max-samples $${MAX_SAMPLES:-50} + +# Run hallucination benchmark with custom parameters +# Usage: make bench-hallucination-full MAX_SAMPLES=200 DATASET=halueval +bench-hallucination-full: ## Run full hallucination benchmark with more samples +bench-hallucination-full: + @echo "Running full hallucination detection benchmark..." + python3 -m bench.hallucination.evaluate \ + --endpoint http://localhost:8801 \ + --dataset $${DATASET:-halueval} \ + --max-samples $${MAX_SAMPLES:-200} + # Note: Use the manual workflow: make start-llm-katan in one terminal, then run tests in another