From e5a4135001f2b88fb024c2b6ead73ba84705b94b Mon Sep 17 00:00:00 2001 From: asaadbalum Date: Fri, 16 Jan 2026 10:28:57 +0200 Subject: [PATCH] feat(selection): implement advanced model selection methods Add pluggable model selection algorithms for intelligent routing: - Elo rating system with Bradley-Terry model for preference-based selection - RouterDC for query-to-model embedding matching - AutoMix for POMDP-based cost-quality optimization - Hybrid selector combining multiple methods with configurable weights - Static selector for backwards compatibility Integration: - OpenAIRouter initializes selection registry on startup - req_filter_classification uses configured selector instead of hardcoded first model - Prometheus metrics for selection tracking Signed-off-by: asaadbalum --- .../in-tree/model_selection_demo.yaml | 129 ++++ .../examples/selection/main.go | 270 ++++++++ src/semantic-router/pkg/config/config.go | 133 +++- .../pkg/extproc/req_filter_classification.go | 139 +++- src/semantic-router/pkg/extproc/router.go | 65 ++ src/semantic-router/pkg/selection/automix.go | 482 ++++++++++++++ src/semantic-router/pkg/selection/elo.go | 494 ++++++++++++++ src/semantic-router/pkg/selection/factory.go | 216 ++++++ src/semantic-router/pkg/selection/hybrid.go | 466 +++++++++++++ src/semantic-router/pkg/selection/metrics.go | 174 +++++ .../pkg/selection/router_dc.go | 364 ++++++++++ src/semantic-router/pkg/selection/selector.go | 194 ++++++ .../pkg/selection/selector_test.go | 626 ++++++++++++++++++ src/semantic-router/pkg/selection/static.go | 171 +++++ 14 files changed, 3910 insertions(+), 13 deletions(-) create mode 100644 config/intelligent-routing/in-tree/model_selection_demo.yaml create mode 100644 src/semantic-router/examples/selection/main.go create mode 100644 src/semantic-router/pkg/selection/automix.go create mode 100644 src/semantic-router/pkg/selection/elo.go create mode 100644 src/semantic-router/pkg/selection/factory.go create mode 100644 src/semantic-router/pkg/selection/hybrid.go create mode 100644 src/semantic-router/pkg/selection/metrics.go create mode 100644 src/semantic-router/pkg/selection/router_dc.go create mode 100644 src/semantic-router/pkg/selection/selector.go create mode 100644 src/semantic-router/pkg/selection/selector_test.go create mode 100644 src/semantic-router/pkg/selection/static.go diff --git a/config/intelligent-routing/in-tree/model_selection_demo.yaml b/config/intelligent-routing/in-tree/model_selection_demo.yaml new file mode 100644 index 0000000000..3687c0a6ec --- /dev/null +++ b/config/intelligent-routing/in-tree/model_selection_demo.yaml @@ -0,0 +1,129 @@ +--- +# Demo: Advanced Model Selection Methods +# Algorithms: Elo, RouterDC, AutoMix, Hybrid +# +# Reference papers: +# - Elo: RouteLLM (arXiv:2406.18665) +# - RouterDC: arXiv:2409.19886 +# - AutoMix: arXiv:2310.12963 +# - Hybrid: arXiv:2404.14618 + +bert_model: + model_id: sentence-transformers/all-MiniLM-L12-v2 + threshold: 0.6 + use_cpu: true + +classifier: + category_model: + model_id: "models/mom-domain-classifier" + use_modernbert: true + threshold: 0.6 + use_cpu: true + category_mapping_path: "models/mom-domain-classifier/category_mapping.json" + +# Backend models with pricing info for cost-aware selection +backend_models: + model_config: + "llama3.2:3b": + pricing: + prompt_per_1m: 0.05 + completion_per_1m: 0.10 + "llama3.2:8b": + pricing: + prompt_per_1m: 0.15 + completion_per_1m: 0.30 + "phi4": + pricing: + prompt_per_1m: 0.10 + completion_per_1m: 0.20 + "gemma3:27b": + pricing: + prompt_per_1m: 0.50 + completion_per_1m: 1.00 + "mistral-small3.1": + pricing: + prompt_per_1m: 0.25 + completion_per_1m: 0.50 + +# Categories for domain classification +categories: + - name: tech + mmlu_categories: ["computer science", "engineering"] + - name: finance + mmlu_categories: ["economics"] + - name: general + +# Decisions with PER-DECISION algorithm (aligned with looper pattern) +# Each decision specifies its own algorithm. No global model_selection needed. +decisions: + - name: tech + description: "Tech queries using Elo ratings" + priority: 10 + rules: + operator: "OR" + conditions: + - type: "domain" + name: "tech" + modelRefs: + - model: "llama3.2:3b" + use_reasoning: false + - model: "phi4" + use_reasoning: true + - model: "gemma3:27b" + use_reasoning: true + algorithm: + type: "elo" + elo: + k_factor: 32 + category_weighted: true + cost_scaling_factor: 0.2 + + - name: finance + description: "Finance queries using AutoMix" + priority: 10 + rules: + operator: "OR" + conditions: + - type: "domain" + name: "finance" + modelRefs: + - model: "llama3.2:8b" + use_reasoning: false + - model: "mistral-small3.1" + use_reasoning: true + - model: "gemma3:27b" + use_reasoning: true + algorithm: + type: "automix" + automix: + cost_quality_tradeoff: 0.4 + cost_aware_routing: true + + - name: general + description: "General queries using hybrid approach" + priority: 5 + rules: + operator: "OR" + conditions: + - type: "domain" + name: "general" + modelRefs: + - model: "llama3.2:3b" + use_reasoning: false + - model: "llama3.2:8b" + use_reasoning: false + - model: "mistral-small3.1" + use_reasoning: true + algorithm: + type: "hybrid" + hybrid: + elo_weight: 0.3 + router_dc_weight: 0.3 + automix_weight: 0.2 + cost_weight: 0.2 + +default_model: llama3.2:3b + +metrics: + enabled: true + path: /metrics diff --git a/src/semantic-router/examples/selection/main.go b/src/semantic-router/examples/selection/main.go new file mode 100644 index 0000000000..e9acb41e71 --- /dev/null +++ b/src/semantic-router/examples/selection/main.go @@ -0,0 +1,270 @@ +/* +Selection Demo - Demonstrates advanced model selection methods + +This demo exercises the actual selection package code. +Run with: cd src/semantic-router && go run ./examples/selection/main.go + +Logs are printed to show the decision-making process. +*/ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection" +) + +func init() { + // Initialize logging to see selection decisions + os.Setenv("LOG_LEVEL", "info") + _, _ = logging.InitLoggerFromEnv() +} + +func main() { + fmt.Println("════════════════════════════════════════════════════════════════════════════════") + fmt.Println("Demo: Advanced Model Selection Methods") + fmt.Println("════════════════════════════════════════════════════════════════════════════════") + fmt.Println() + fmt.Println("Command: cd src/semantic-router && go run ./examples/selection/main.go") + fmt.Println() + fmt.Println("This demo calls the actual selection package code.") + fmt.Println("Log lines starting with [EloSelector], [AutoMix], [RouterDC], [HybridSelector]") + fmt.Println("show the real decision-making process.") + fmt.Println() + + // Define test models (these are model IDs, not requiring actual models to be running) + candidates := []config.ModelRef{ + {Model: "llama3.2:3b"}, + {Model: "phi4"}, + {Model: "gemma3:27b"}, + } + + fmt.Println("Candidate Models (for demonstration):") + fmt.Println(" - llama3.2:3b (small, cheap)") + fmt.Println(" - phi4 (medium)") + fmt.Println(" - gemma3:27b (large, expensive)") + fmt.Println() + + // Demo 1: Static Selection + fmt.Println("┌────────────────────────────────────────────────────────────────────────────────┐") + fmt.Println("│ DEMO 1: Static Selection (Baseline - BEFORE) │") + fmt.Println("│ Always picks first model or highest configured score │") + fmt.Println("└────────────────────────────────────────────────────────────────────────────────┘") + staticSelector := selection.NewStaticSelector(&selection.StaticConfig{}) + demoSelector(staticSelector, candidates, "How do I fix a memory leak?") + + // Demo 2: Elo Selection + fmt.Println() + fmt.Println("┌────────────────────────────────────────────────────────────────────────────────┐") + fmt.Println("│ DEMO 2: Elo Rating Selection │") + fmt.Println("│ Models have ratings based on user preference feedback │") + fmt.Println("└────────────────────────────────────────────────────────────────────────────────┘") + eloSelector := selection.NewEloSelector(&selection.EloConfig{ + InitialRating: 1500, + KFactor: 32, + CategoryWeighted: true, + }) + + // Simulate some feedback to adjust ratings + fmt.Println("\nSimulating user feedback to adjust Elo ratings...") + fmt.Println(" - gemma3:27b beats llama3.2:3b") + fmt.Println(" - gemma3:27b beats phi4") + fmt.Println(" - phi4 beats llama3.2:3b") + ctx := context.Background() + _ = eloSelector.UpdateFeedback(ctx, &selection.Feedback{ + WinnerModel: "gemma3:27b", + LoserModel: "llama3.2:3b", + DecisionName: "tech", + }) + _ = eloSelector.UpdateFeedback(ctx, &selection.Feedback{ + WinnerModel: "gemma3:27b", + LoserModel: "phi4", + DecisionName: "tech", + }) + _ = eloSelector.UpdateFeedback(ctx, &selection.Feedback{ + WinnerModel: "phi4", + LoserModel: "llama3.2:3b", + DecisionName: "tech", + }) + + // Show current ratings + leaderboard := eloSelector.GetLeaderboard("tech") + fmt.Println("\nCurrent Elo Ratings (after feedback):") + for _, entry := range leaderboard { + fmt.Printf(" %s: %.0f (W:%d L:%d)\n", entry.Model, entry.Rating, entry.Wins, entry.Losses) + } + + demoSelector(eloSelector, candidates, "How do I fix a memory leak?") + + // Demo 3: AutoMix Selection - Cost vs Quality + fmt.Println() + fmt.Println("┌────────────────────────────────────────────────────────────────────────────────┐") + fmt.Println("│ DEMO 3: AutoMix Selection - Cost-Quality Tradeoff │") + fmt.Println("│ Shows how different cost_quality_tradeoff values affect selection │") + fmt.Println("└────────────────────────────────────────────────────────────────────────────────┘") + + // Set model capabilities for AutoMix + fmt.Println("\nModel capabilities (for cost-quality tradeoff):") + fmt.Println(" llama3.2:3b: cost=$0.05/1M, quality=0.70 (small, cheap)") + fmt.Println(" phi4: cost=$0.15/1M, quality=0.85 (medium)") + fmt.Println(" gemma3:27b: cost=$0.50/1M, quality=0.95 (large, expensive)") + + // Helper to set up capabilities + setupAutoMix := func(tradeoff float64) *selection.AutoMixSelector { + am := selection.NewAutoMixSelector(&selection.AutoMixConfig{ + CostQualityTradeoff: tradeoff, + CostAwareRouting: true, + }) + am.SetCapability("llama3.2:3b", &selection.ModelCapability{ + Model: "llama3.2:3b", ParamSize: 3.0, Cost: 0.05, AvgQuality: 0.70, + }) + am.SetCapability("phi4", &selection.ModelCapability{ + Model: "phi4", ParamSize: 14.0, Cost: 0.15, AvgQuality: 0.85, + }) + am.SetCapability("gemma3:27b", &selection.ModelCapability{ + Model: "gemma3:27b", ParamSize: 27.0, Cost: 0.50, AvgQuality: 0.95, + }) + return am + } + + // Low cost weight (prefer quality) + fmt.Println("\n>>> Config: cost_quality_tradeoff = 0.2 (PREFER QUALITY)") + autoMixQuality := setupAutoMix(0.2) + demoSelector(autoMixQuality, candidates, "Explain quantum computing in detail") + + // High cost weight (prefer cost) + fmt.Println("\n>>> Config: cost_quality_tradeoff = 0.8 (PREFER COST)") + autoMixCost := setupAutoMix(0.8) + demoSelector(autoMixCost, candidates, "What is 2+2?") + + // Demo 4: RouterDC Selection + fmt.Println() + fmt.Println("┌────────────────────────────────────────────────────────────────────────────────┐") + fmt.Println("│ DEMO 4: RouterDC Selection - Query-to-Model Matching │") + fmt.Println("│ Matches query embeddings to model capability embeddings │") + fmt.Println("└────────────────────────────────────────────────────────────────────────────────┘") + routerDC := selection.NewRouterDCSelector(&selection.RouterDCConfig{ + Temperature: 0.07, + MinSimilarity: 0.3, + }) + + // Set model embeddings (normally learned, here we configure them) + fmt.Println("\nSetting model capability embeddings:") + fmt.Println(" phi4 → optimized for [code, debugging, technical] queries") + fmt.Println(" gemma3:27b → optimized for [reasoning, analysis, complex] queries") + fmt.Println(" llama3.2:3b → general purpose (balanced)") + routerDC.SetModelEmbedding("phi4", []float32{0.9, 0.85, 0.88, 0.5, 0.4}) // Code-focused + routerDC.SetModelEmbedding("gemma3:27b", []float32{0.5, 0.4, 0.3, 0.95, 0.92}) // Reasoning-focused + routerDC.SetModelEmbedding("llama3.2:3b", []float32{0.6, 0.6, 0.6, 0.6, 0.6}) // Balanced + + // Test with code query (embedding similar to code domain) + fmt.Println("\n>>> Query: Code/Debugging (embedding: [0.85, 0.9, 0.88, 0.4, 0.3])") + codeQuery := "Debug this Go function that has a nil pointer dereference" + demoSelectorWithEmbedding(routerDC, candidates, codeQuery, []float32{0.85, 0.9, 0.88, 0.4, 0.3}) + + // Test with reasoning query (embedding similar to reasoning domain) + fmt.Println("\n>>> Query: Reasoning/Analysis (embedding: [0.3, 0.4, 0.2, 0.92, 0.88])") + reasoningQuery := "Analyze the philosophical implications of AI consciousness" + demoSelectorWithEmbedding(routerDC, candidates, reasoningQuery, []float32{0.3, 0.4, 0.2, 0.92, 0.88}) + + fmt.Println() + fmt.Println("⚠️ RouterDC LIMITATION NOTE:") + fmt.Println(" Demo uses simple 5-dimension embeddings for illustration.") + fmt.Println(" For production, model embeddings should be:") + fmt.Println(" - Pre-computed from benchmark results / model capabilities") + fmt.Println(" - Or learned via dual-contrastive training (see RouterDC paper)") + fmt.Println(" The mechanism works - production needs real embeddings.") + + // Demo 5: Hybrid Selection + fmt.Println() + fmt.Println("┌────────────────────────────────────────────────────────────────────────────────┐") + fmt.Println("│ DEMO 5: Hybrid Selection - Combines All Methods │") + fmt.Println("│ Weights: elo=0.3, routerdc=0.3, automix=0.2, cost=0.2 │") + fmt.Println("└────────────────────────────────────────────────────────────────────────────────┘") + + // Create hybrid with all component selectors + hybridSelector := selection.NewHybridSelectorWithComponents(&selection.HybridConfig{ + EloWeight: 0.3, + RouterDCWeight: 0.3, + AutoMixWeight: 0.2, + CostWeight: 0.2, + }, eloSelector, routerDC, autoMixQuality) + + fmt.Println("\nCombining scores from:") + fmt.Println(" - Elo ratings (gemma3:27b has highest)") + fmt.Println(" - RouterDC similarity (depends on query)") + fmt.Println(" - AutoMix cost-quality (balanced)") + + demoSelectorWithEmbedding(hybridSelector, candidates, "Write an efficient sorting algorithm", []float32{0.8, 0.85, 0.9, 0.5, 0.4}) + + fmt.Println() + fmt.Println("════════════════════════════════════════════════════════════════════════════════") + fmt.Println("✅ DEMO COMPLETE - All selection methods demonstrated with REAL code execution") + fmt.Println("════════════════════════════════════════════════════════════════════════════════") +} + +func demoSelector(selector selection.Selector, candidates []config.ModelRef, query string) { + ctx := context.Background() + selCtx := &selection.SelectionContext{ + Query: query, + CandidateModels: candidates, + DecisionName: "tech", + } + + result, err := selector.Select(ctx, selCtx) + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + + printResult(query, result) +} + +func demoSelectorWithEmbedding(selector selection.Selector, candidates []config.ModelRef, query string, embedding []float32) { + ctx := context.Background() + selCtx := &selection.SelectionContext{ + Query: query, + QueryEmbedding: embedding, + CandidateModels: candidates, + DecisionName: "tech", + } + + result, err := selector.Select(ctx, selCtx) + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + + printResult(query, result) +} + +func printResult(query string, result *selection.SelectionResult) { + fmt.Println() + fmt.Printf("Query: \"%s\"\n", truncate(query, 60)) + fmt.Println() + fmt.Println("Selection Result:") + fmt.Printf(" ✅ SELECTED MODEL: %s\n", result.SelectedModel) + fmt.Printf(" 📊 Score: %.4f\n", result.Score) + fmt.Printf(" 🎯 Confidence: %.4f\n", result.Confidence) + fmt.Printf(" 🔧 Method: %s\n", result.Method) + fmt.Printf(" 💭 Reasoning: %s\n", result.Reasoning) + + if len(result.AllScores) > 0 { + fmt.Println() + fmt.Println(" All Candidate Scores:") + scoresJSON, _ := json.MarshalIndent(result.AllScores, " ", " ") + fmt.Printf(" %s\n", scoresJSON) + } +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 016b432994..5dadcb0443 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -151,10 +151,120 @@ type IntelligentRouting struct { // "confidence" - select decision with highest confidence score Strategy string `yaml:"strategy,omitempty"` + // ModelSelection configures the algorithm used for model selection + // Supported methods: "static", "elo", "router_dc", "automix", "hybrid" + ModelSelection ModelSelectionConfig `yaml:"model_selection,omitempty"` + // Reasoning mode configuration ReasoningConfig `yaml:",inline"` } +// ModelSelectionConfig represents configuration for advanced model selection algorithms +// Reference papers: +// - Elo: RouteLLM (arXiv:2406.18665) - Weighted Elo using Bradley-Terry model +// - RouterDC: Query-Based Router by Dual Contrastive Learning (arXiv:2409.19886) +// - AutoMix: Automatically Mixing Language Models (arXiv:2310.12963) +// - Hybrid: Cost-Efficient Quality-Aware Query Routing (arXiv:2404.14618) +type ModelSelectionConfig struct { + // Method specifies the selection algorithm to use + // Options: "static", "elo", "router_dc", "automix", "hybrid" + // Default: "static" (uses static scores from configuration) + Method string `yaml:"method,omitempty"` + + // Elo configuration for Elo rating-based selection + Elo EloSelectionConfig `yaml:"elo,omitempty"` + + // RouterDC configuration for dual-contrastive learning selection + RouterDC RouterDCSelectionConfig `yaml:"router_dc,omitempty"` + + // AutoMix configuration for POMDP-based cascaded routing + AutoMix AutoMixSelectionConfig `yaml:"automix,omitempty"` + + // Hybrid configuration for combined selection methods + Hybrid HybridSelectionConfig `yaml:"hybrid,omitempty"` +} + +// EloSelectionConfig configures Elo rating-based model selection +type EloSelectionConfig struct { + // InitialRating is the starting Elo rating for new models (default: 1500) + InitialRating float64 `yaml:"initial_rating,omitempty"` + + // KFactor controls rating volatility (default: 32) + KFactor float64 `yaml:"k_factor,omitempty"` + + // CategoryWeighted enables per-category Elo ratings (default: true) + CategoryWeighted bool `yaml:"category_weighted,omitempty"` + + // DecayFactor applies time decay to old comparisons (0-1, default: 0) + DecayFactor float64 `yaml:"decay_factor,omitempty"` + + // MinComparisons before rating is considered stable (default: 5) + MinComparisons int `yaml:"min_comparisons,omitempty"` + + // CostScalingFactor scales cost consideration (0 = ignore cost) + CostScalingFactor float64 `yaml:"cost_scaling_factor,omitempty"` +} + +// RouterDCSelectionConfig configures dual-contrastive learning selection +type RouterDCSelectionConfig struct { + // Temperature for softmax scaling (default: 0.07) + Temperature float64 `yaml:"temperature,omitempty"` + + // DimensionSize for embeddings (default: 768) + DimensionSize int `yaml:"dimension_size,omitempty"` + + // MinSimilarity threshold for valid matches (default: 0.3) + MinSimilarity float64 `yaml:"min_similarity,omitempty"` + + // UseQueryContrastive enables query-side contrastive learning + UseQueryContrastive bool `yaml:"use_query_contrastive,omitempty"` + + // UseModelContrastive enables model-side contrastive learning + UseModelContrastive bool `yaml:"use_model_contrastive,omitempty"` +} + +// AutoMixSelectionConfig configures POMDP-based cascaded routing +type AutoMixSelectionConfig struct { + // VerificationThreshold for self-verification (default: 0.7) + VerificationThreshold float64 `yaml:"verification_threshold,omitempty"` + + // MaxEscalations limits escalation count (default: 2) + MaxEscalations int `yaml:"max_escalations,omitempty"` + + // CostAwareRouting enables cost-quality tradeoff (default: true) + CostAwareRouting bool `yaml:"cost_aware_routing,omitempty"` + + // CostQualityTradeoff balance (0 = quality, 1 = cost, default: 0.3) + CostQualityTradeoff float64 `yaml:"cost_quality_tradeoff,omitempty"` + + // DiscountFactor for POMDP value iteration (default: 0.95) + DiscountFactor float64 `yaml:"discount_factor,omitempty"` + + // UseLogprobVerification uses logprobs for confidence (default: true) + UseLogprobVerification bool `yaml:"use_logprob_verification,omitempty"` +} + +// HybridSelectionConfig configures combined selection methods +type HybridSelectionConfig struct { + // EloWeight for Elo rating contribution (0-1, default: 0.3) + EloWeight float64 `yaml:"elo_weight,omitempty"` + + // RouterDCWeight for embedding similarity (0-1, default: 0.3) + RouterDCWeight float64 `yaml:"router_dc_weight,omitempty"` + + // AutoMixWeight for POMDP value (0-1, default: 0.2) + AutoMixWeight float64 `yaml:"automix_weight,omitempty"` + + // CostWeight for cost consideration (0-1, default: 0.2) + CostWeight float64 `yaml:"cost_weight,omitempty"` + + // QualityGapThreshold triggers escalation (default: 0.1) + QualityGapThreshold float64 `yaml:"quality_gap_threshold,omitempty"` + + // NormalizeScores before combination (default: true) + NormalizeScores bool `yaml:"normalize_scores,omitempty"` +} + type Signals struct { // Keyword-based classification rules KeywordRules []KeywordRule `yaml:"keyword_rules,omitempty"` @@ -935,14 +1045,33 @@ type Decision struct { // AlgorithmConfig defines how multiple models should be executed and aggregated type AlgorithmConfig struct { - // Type specifies the algorithm type: "confidence", "ratings" + // Type specifies the algorithm type: + // Looper algorithms (multi-model execution): // - "confidence": Try smaller models first, escalate to larger models if confidence is low // - "ratings": Execute all models concurrently and return multiple choices for comparison + // Selection algorithms (single model selection from candidates): + // - "static": Use static scores from configuration (default) + // - "elo": Use Elo rating system with Bradley-Terry model + // - "router_dc": Use dual-contrastive learning for query-model matching + // - "automix": Use POMDP-based cost-quality optimization + // - "hybrid": Combine multiple selection methods with configurable weights Type string `yaml:"type"` - // Algorithm-specific configurations (only one should be set based on Type) + // Looper algorithm configurations (for multi-model execution) Confidence *ConfidenceAlgorithmConfig `yaml:"confidence,omitempty"` Ratings *RatingsAlgorithmConfig `yaml:"ratings,omitempty"` + + // Selection algorithm configurations (for single model selection) + // These align with the global ModelSelectionConfig but can be overridden per-decision + Elo *EloSelectionConfig `yaml:"elo,omitempty"` + RouterDC *RouterDCSelectionConfig `yaml:"router_dc,omitempty"` + AutoMix *AutoMixSelectionConfig `yaml:"automix,omitempty"` + Hybrid *HybridSelectionConfig `yaml:"hybrid,omitempty"` + + // OnError defines behavior when algorithm fails: "skip" or "fail" + // - "skip": Skip and use fallback (default) + // - "fail": Return error immediately + OnError string `yaml:"on_error,omitempty"` } // ConfidenceAlgorithmConfig configures the confidence algorithm diff --git a/src/semantic-router/pkg/extproc/req_filter_classification.go b/src/semantic-router/pkg/extproc/req_filter_classification.go index 9b3028dcdf..87cff9e723 100644 --- a/src/semantic-router/pkg/extproc/req_filter_classification.go +++ b/src/semantic-router/pkg/extproc/req_filter_classification.go @@ -1,9 +1,12 @@ package extproc import ( + "context" "strings" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/entropy" ) @@ -120,23 +123,27 @@ func (r *OpenAIRouter) performDecisionEvaluation(originalModel string, userConte return decisionName, evaluationConfidence, reasoningDecision, "" } - // Select best model from the decision's ModelRefs (only for auto models) + // Select best model from the decision's ModelRefs using configured selection algorithm if len(result.Decision.ModelRefs) > 0 { - modelRef := result.Decision.ModelRefs[0] + // Use advanced model selection (Elo, RouterDC, AutoMix, Hybrid, or Static) + // Pass decision's algorithm config for per-decision algorithm override + selectedModelRef, usedMethod := r.selectModelFromCandidates(result.Decision.ModelRefs, decisionName, userContent, result.Decision.Algorithm) + // Use LoRA name if specified, otherwise use the base model name - selectedModel = modelRef.Model - if modelRef.LoRAName != "" { - selectedModel = modelRef.LoRAName - logging.Infof("Selected model from decision %s: %s (LoRA adapter for base model %s)", - decisionName, selectedModel, modelRef.Model) + selectedModel = selectedModelRef.Model + if selectedModelRef.LoRAName != "" { + selectedModel = selectedModelRef.LoRAName + logging.Infof("Selected model from decision %s: %s (LoRA adapter for base model %s) using %s selection", + decisionName, selectedModel, selectedModelRef.Model, usedMethod) } else { - logging.Infof("Selected model from decision %s: %s", decisionName, selectedModel) + logging.Infof("Selected model from decision %s: %s using %s selection", + decisionName, selectedModel, usedMethod) } ctx.VSRSelectedModel = selectedModel - // Determine reasoning mode from the best model's configuration - if result.Decision.ModelRefs[0].UseReasoning != nil { - useReasoning := *result.Decision.ModelRefs[0].UseReasoning + // Determine reasoning mode from the selected model's configuration + if selectedModelRef.UseReasoning != nil { + useReasoning := *selectedModelRef.UseReasoning reasoningDecision = entropy.ReasoningDecision{ UseReasoning: useReasoning, Confidence: evaluationConfidence, @@ -164,3 +171,113 @@ func (r *OpenAIRouter) performDecisionEvaluation(originalModel string, userConte return decisionName, evaluationConfidence, reasoningDecision, selectedModel } + +// selectModelFromCandidates uses the configured selection algorithm to choose the best model +// from the decision's candidate models. Falls back to first model if selection fails. +// The algorithm parameter allows per-decision algorithm override (aligned with looper pattern). +// Returns the selected model and the method name used for logging. +func (r *OpenAIRouter) selectModelFromCandidates(modelRefs []config.ModelRef, decisionName string, query string, algorithm *config.AlgorithmConfig) (*config.ModelRef, string) { + if len(modelRefs) == 0 { + return nil, "" + } + + // If only one model, no need for selection algorithm + if len(modelRefs) == 1 { + return &modelRefs[0], "single" + } + + // Determine selection method: per-decision algorithm takes precedence over global config + method := r.getSelectionMethod(algorithm) + + // Get selector from registry + var selector selection.Selector + if r.ModelSelector != nil { + selector, _ = r.ModelSelector.Get(method) + } + + // Fallback to first model if no selector available + if selector == nil { + logging.Warnf("[ModelSelection] No selector available for method %s, using first model", method) + return &modelRefs[0], string(method) + } + + // Build selection context with cost/quality weights + costWeight, qualityWeight := r.getSelectionWeights(algorithm) + + selCtx := &selection.SelectionContext{ + Query: query, + DecisionName: decisionName, + CandidateModels: modelRefs, + CostWeight: costWeight, + QualityWeight: qualityWeight, + } + + // Perform selection + result, err := selector.Select(context.Background(), selCtx) + if err != nil { + logging.Warnf("[ModelSelection] Selection failed: %v, using first model", err) + return &modelRefs[0], string(method) + } + + // Find the selected model in the candidates + for i := range modelRefs { + if modelRefs[i].Model == result.SelectedModel || + modelRefs[i].LoRAName == result.SelectedModel { + logging.Infof("[ModelSelection] Selected %s (method=%s, score=%.4f, confidence=%.2f): %s", + result.SelectedModel, method, result.Score, result.Confidence, result.Reasoning) + // Record selection metrics + selection.RecordSelection(string(method), decisionName, result.SelectedModel, result.Score) + return &modelRefs[i], string(method) + } + } + + // Fallback if selected model not found in candidates (shouldn't happen) + logging.Warnf("[ModelSelection] Selected model %s not found in candidates, using first model", result.SelectedModel) + return &modelRefs[0], string(method) +} + +// getSelectionMethod determines which selection algorithm to use. +// Per-decision algorithm is the primary configuration (aligned with looper pattern). +// Defaults to static selection if no algorithm is specified. +func (r *OpenAIRouter) getSelectionMethod(algorithm *config.AlgorithmConfig) selection.SelectionMethod { + // Check per-decision algorithm (aligned with looper pattern) + if algorithm != nil && algorithm.Type != "" { + switch algorithm.Type { + case "elo": + return selection.MethodElo + case "router_dc": + return selection.MethodRouterDC + case "automix": + return selection.MethodAutoMix + case "hybrid": + return selection.MethodHybrid + case "static": + return selection.MethodStatic + case "confidence", "ratings": + // These are looper algorithms, not selection algorithms + // Fall through to default + } + } + + // Default to static selection (use first model) + return selection.MethodStatic +} + +// getSelectionWeights returns cost and quality weights based on algorithm config. +// Uses per-decision config only (aligned with looper pattern). +func (r *OpenAIRouter) getSelectionWeights(algorithm *config.AlgorithmConfig) (float64, float64) { + // Check per-decision algorithm config + if algorithm != nil { + if algorithm.AutoMix != nil && algorithm.AutoMix.CostQualityTradeoff > 0 { + cost := algorithm.AutoMix.CostQualityTradeoff + return cost, 1.0 - cost + } + if algorithm.Hybrid != nil && algorithm.Hybrid.CostWeight > 0 { + cost := algorithm.Hybrid.CostWeight + return cost, 1.0 - cost + } + } + + // Default: equal weighting (0.5 cost, 0.5 quality) + return 0.5, 0.5 +} diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index 31feff8463..84e9e51943 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -16,6 +16,7 @@ import ( "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/responsestore" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/routerreplay" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/selection" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/services" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/tools" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/pii" @@ -31,6 +32,9 @@ type OpenAIRouter struct { ToolsDatabase *tools.ToolsDatabase ResponseAPIFilter *ResponseAPIFilter ReplayRecorder *routerreplay.Recorder + // ModelSelector is the registry of advanced model selection algorithms + // Initialized from config.IntelligentRouting.ModelSelection + ModelSelector *selection.Registry } // Ensure OpenAIRouter implements the ext_proc calls @@ -196,6 +200,66 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { } replayRecorder := routerreplay.NewRecorder(replayMax) + // Initialize model selection registry with default configs + // Actual selection method is determined per-decision via algorithm config (aligned with looper) + modelSelectionCfg := &selection.ModelSelectionConfig{ + Method: "static", // Default; per-decision algorithm overrides this + } + // Copy Elo config from config package to selection package format + eloCfg := cfg.IntelligentRouting.ModelSelection.Elo + modelSelectionCfg.Elo = &selection.EloConfig{ + InitialRating: eloCfg.InitialRating, + KFactor: eloCfg.KFactor, + CategoryWeighted: eloCfg.CategoryWeighted, + DecayFactor: eloCfg.DecayFactor, + MinComparisons: eloCfg.MinComparisons, + CostScalingFactor: eloCfg.CostScalingFactor, + } + + // Copy RouterDC config + routerDCCfg := cfg.IntelligentRouting.ModelSelection.RouterDC + modelSelectionCfg.RouterDC = &selection.RouterDCConfig{ + Temperature: routerDCCfg.Temperature, + DimensionSize: routerDCCfg.DimensionSize, + MinSimilarity: routerDCCfg.MinSimilarity, + UseQueryContrastive: routerDCCfg.UseQueryContrastive, + UseModelContrastive: routerDCCfg.UseModelContrastive, + } + + // Copy AutoMix config + autoMixCfg := cfg.IntelligentRouting.ModelSelection.AutoMix + modelSelectionCfg.AutoMix = &selection.AutoMixConfig{ + VerificationThreshold: autoMixCfg.VerificationThreshold, + MaxEscalations: autoMixCfg.MaxEscalations, + CostAwareRouting: autoMixCfg.CostAwareRouting, + CostQualityTradeoff: autoMixCfg.CostQualityTradeoff, + DiscountFactor: autoMixCfg.DiscountFactor, + UseLogprobVerification: autoMixCfg.UseLogprobVerification, + } + + // Copy Hybrid config + hybridCfg := cfg.IntelligentRouting.ModelSelection.Hybrid + modelSelectionCfg.Hybrid = &selection.HybridConfig{ + EloWeight: hybridCfg.EloWeight, + RouterDCWeight: hybridCfg.RouterDCWeight, + AutoMixWeight: hybridCfg.AutoMixWeight, + CostWeight: hybridCfg.CostWeight, + QualityGapThreshold: hybridCfg.QualityGapThreshold, + NormalizeScores: hybridCfg.NormalizeScores, + } + + // Create selection factory and initialize all selectors + selectionFactory := selection.NewFactory(modelSelectionCfg) + if cfg.BackendModels.ModelConfig != nil { + selectionFactory = selectionFactory.WithModelConfig(cfg.BackendModels.ModelConfig) + } + if len(cfg.Categories) > 0 { + selectionFactory = selectionFactory.WithCategories(cfg.Categories) + } + modelSelectorRegistry := selectionFactory.CreateAll() + + logging.Infof("[Router] Initialized model selection registry (per-decision algorithm config)") + router := &OpenAIRouter{ Config: cfg, CategoryDescriptions: categoryDescriptions, @@ -205,6 +269,7 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { ToolsDatabase: toolsDatabase, ResponseAPIFilter: responseAPIFilter, ReplayRecorder: replayRecorder, + ModelSelector: modelSelectorRegistry, } return router, nil diff --git a/src/semantic-router/pkg/selection/automix.go b/src/semantic-router/pkg/selection/automix.go new file mode 100644 index 0000000000..3ce620d068 --- /dev/null +++ b/src/semantic-router/pkg/selection/automix.go @@ -0,0 +1,482 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "context" + "fmt" + "math" + "sort" + "strings" + "sync" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" +) + +// AutoMixConfig configures the AutoMix POMDP-based selector +// Based on arXiv:2310.12963 - Automatically Mixing Language Models +// +// NOTE: This is a PRE-SELECTION implementation of AutoMix concepts. +// The original paper describes a CASCADED EXECUTION approach where: +// 1. Start with the smallest/cheapest model +// 2. Execute the query and perform self-verification +// 3. If confidence is below threshold, escalate to a larger model +// 4. Repeat until confidence is acceptable or max escalations reached +// +// Our implementation applies AutoMix PRINCIPLES to pre-selection: +// - We estimate which model is most likely to succeed based on learned capabilities +// - We optimize the cost-quality tradeoff using POMDP value functions +// - Feedback updates improve the selection over time +// +// For true cascaded execution with self-verification, the looper package +// would need to be extended to support multi-stage inference with confidence +// checks between stages. This is planned for a future enhancement. +type AutoMixConfig struct { + // VerificationThreshold is the confidence threshold for self-verification + // Responses below this threshold trigger escalation (default: 0.7) + VerificationThreshold float64 `yaml:"verification_threshold"` + + // MaxEscalations limits how many times to escalate (default: 2) + MaxEscalations int `yaml:"max_escalations"` + + // CostAwareRouting enables cost-quality tradeoff optimization + CostAwareRouting bool `yaml:"cost_aware_routing"` + + // CostQualityTradeoff controls balance (0 = pure quality, 1 = pure cost) + CostQualityTradeoff float64 `yaml:"cost_quality_tradeoff"` + + // DiscountFactor for POMDP value iteration (gamma, default: 0.95) + DiscountFactor float64 `yaml:"discount_factor"` + + // UseLogprobVerification uses logprobs for confidence estimation + UseLogprobVerification bool `yaml:"use_logprob_verification"` +} + +// DefaultAutoMixConfig returns the default AutoMix configuration +func DefaultAutoMixConfig() *AutoMixConfig { + return &AutoMixConfig{ + VerificationThreshold: 0.7, + MaxEscalations: 2, + CostAwareRouting: true, + CostQualityTradeoff: 0.3, + DiscountFactor: 0.95, + UseLogprobVerification: true, + } +} + +// ModelCapability stores learned model capabilities for POMDP states +type ModelCapability struct { + Model string `json:"model"` + ParamSize float64 `json:"param_size"` // Model size in billions of parameters + Cost float64 `json:"cost"` // Cost per 1M tokens + AvgQuality float64 `json:"avg_quality"` // Learned average quality score + VerificationProb float64 `json:"verification_prob"` // Probability of passing self-verification + EscalationReward float64 `json:"escalation_reward"` // Expected reward from escalation + QuerySuccessCount int `json:"query_success_count"` // Successful queries + QueryTotalCount int `json:"query_total_count"` // Total queries +} + +// AutoMixSelector implements POMDP-based cascaded model selection +// The algorithm routes to smaller models first and escalates based on +// self-verification confidence, optimizing the cost-quality tradeoff. +type AutoMixSelector struct { + config *AutoMixConfig + + // Model capabilities indexed by model name + capabilities map[string]*ModelCapability + capMu sync.RWMutex + + // POMDP value function V(s) for each model + valueFunction map[string]float64 + valueMu sync.RWMutex + + // Transition probabilities P(s'|s,a) for escalation decisions + transitionProbs map[string]map[string]float64 +} + +// NewAutoMixSelector creates a new AutoMix-based selector +func NewAutoMixSelector(cfg *AutoMixConfig) *AutoMixSelector { + if cfg == nil { + cfg = DefaultAutoMixConfig() + } + return &AutoMixSelector{ + config: cfg, + capabilities: make(map[string]*ModelCapability), + valueFunction: make(map[string]float64), + transitionProbs: make(map[string]map[string]float64), + } +} + +// Method returns the selection method type +func (a *AutoMixSelector) Method() SelectionMethod { + return MethodAutoMix +} + +// InitializeFromConfig sets up model capabilities from configuration +func (a *AutoMixSelector) InitializeFromConfig(modelConfig map[string]config.ModelParams) { + a.capMu.Lock() + defer a.capMu.Unlock() + + for model, params := range modelConfig { + cap := &ModelCapability{ + Model: model, + Cost: params.Pricing.PromptPer1M, + AvgQuality: 0.8, // Default quality estimate + VerificationProb: 0.7, // Default verification probability + ParamSize: a.estimateParamSize(model), // Estimate from model name + } + a.capabilities[model] = cap + + // Initialize value function (higher for larger/better models) + a.valueMu.Lock() + a.valueFunction[model] = cap.ParamSize / 100.0 // Normalize + a.valueMu.Unlock() + } + + logging.Infof("[AutoMix] Initialized capabilities for %d models", len(a.capabilities)) +} + +// Select chooses the best model using POMDP-based cost-quality optimization +func (a *AutoMixSelector) Select(ctx context.Context, selCtx *SelectionContext) (*SelectionResult, error) { + if len(selCtx.CandidateModels) == 0 { + return nil, fmt.Errorf("no candidate models provided") + } + + // Sort candidates by cost (cheaper first for cascaded routing) + sortedCandidates := a.sortByCost(selCtx.CandidateModels) + + // Calculate expected value for each model using POMDP + allScores := make(map[string]float64) + a.capMu.RLock() + a.valueMu.RLock() + defer a.capMu.RUnlock() + defer a.valueMu.RUnlock() + + logging.Infof("[AutoMix] Evaluating %d candidates (tradeoff=%.2f):", + len(sortedCandidates), a.config.CostQualityTradeoff) + for _, model := range sortedCandidates { + modelName := model.Model + score := a.computeExpectedValue(modelName, selCtx) + allScores[modelName] = score + if cap, ok := a.capabilities[modelName]; ok { + logging.Infof("[AutoMix] %s: cost=$%.2f, quality=%.2f, value=%.4f", + modelName, cap.Cost, cap.AvgQuality, score) + } else { + logging.Infof("[AutoMix] %s: value=%.4f (no capability data)", modelName, score) + } + } + + // Find optimal starting model (not necessarily the best, but best value) + var selectedModel *config.ModelRef + var selectedScore float64 + var reasoning string + + if a.config.CostAwareRouting { + // Cost-aware: select model with best value considering cost + selectedModel, selectedScore, reasoning = a.selectCostAware(sortedCandidates, allScores, selCtx) + } else { + // Quality-only: select model with highest expected quality + selectedModel, selectedScore, reasoning = a.selectQualityOnly(sortedCandidates, allScores) + } + + if selectedModel == nil { + return nil, fmt.Errorf("could not select a model") + } + + // Calculate confidence based on verification probability + confidence := a.getVerificationProbability(selectedModel.Model) + + logging.Infof("[AutoMix] Selected model %s (score=%.4f, confidence=%.2f, cost-aware=%v)", + selectedModel.Model, selectedScore, confidence, a.config.CostAwareRouting) + + return &SelectionResult{ + SelectedModel: selectedModel.Model, + LoRAName: selectedModel.LoRAName, + Score: selectedScore, + Confidence: confidence, + Method: MethodAutoMix, + Reasoning: reasoning, + AllScores: allScores, + }, nil +} + +// UpdateFeedback updates POMDP model based on verification outcomes +func (a *AutoMixSelector) UpdateFeedback(ctx context.Context, feedback *Feedback) error { + if feedback.WinnerModel == "" { + return fmt.Errorf("winner model is required") + } + + a.capMu.Lock() + defer a.capMu.Unlock() + + // Update winner model capabilities + if cap, ok := a.capabilities[feedback.WinnerModel]; ok { + cap.QuerySuccessCount++ + cap.QueryTotalCount++ + + // Update verification probability with exponential moving average + alpha := 0.1 // Learning rate + cap.VerificationProb = cap.VerificationProb*(1-alpha) + 1.0*alpha + cap.AvgQuality = cap.AvgQuality*(1-alpha) + 1.0*alpha + + logging.Debugf("[AutoMix] Updated winner %s: verification_prob=%.3f, quality=%.3f", + feedback.WinnerModel, cap.VerificationProb, cap.AvgQuality) + } + + // Update loser model capabilities (if this was a comparison) + if feedback.LoserModel != "" && !feedback.Tie { + if cap, ok := a.capabilities[feedback.LoserModel]; ok { + cap.QueryTotalCount++ + + alpha := 0.1 + cap.VerificationProb = cap.VerificationProb*(1-alpha) + 0.0*alpha + cap.AvgQuality = cap.AvgQuality*(1-alpha) + 0.0*alpha + + logging.Debugf("[AutoMix] Updated loser %s: verification_prob=%.3f, quality=%.3f", + feedback.LoserModel, cap.VerificationProb, cap.AvgQuality) + } + } + + // Run value iteration to update POMDP values + a.updateValueFunction() + + return nil +} + +// computeExpectedValue calculates the expected value of using a model +// V(model) = R(model) + γ * E[V(s') | escalation possible] +func (a *AutoMixSelector) computeExpectedValue(model string, selCtx *SelectionContext) float64 { + cap := a.capabilities[model] + if cap == nil { + return 0.5 // Default value for unknown models + } + + // Immediate reward: quality + quality := cap.AvgQuality + + // Cost penalty (normalized) + costPenalty := 0.0 + if a.config.CostAwareRouting && cap.Cost > 0 { + // Normalize cost to 0-1 range (assuming max cost is ~$10/1M tokens) + normalizedCost := cap.Cost / 10.0 + costPenalty = normalizedCost * a.config.CostQualityTradeoff + } + + // Expected value from potential escalation + verificationProb := cap.VerificationProb + escalationValue := 0.0 + + if verificationProb < a.config.VerificationThreshold { + // Model likely needs escalation - consider value of larger models + escalationValue = a.config.DiscountFactor * cap.EscalationReward + } + + // Combine: value = quality - cost_penalty + escalation_value + value := quality - costPenalty + escalationValue*(1-verificationProb) + + return value +} + +// selectCostAware selects model optimizing cost-quality tradeoff +func (a *AutoMixSelector) selectCostAware(candidates []config.ModelRef, scores map[string]float64, selCtx *SelectionContext) (*config.ModelRef, float64, string) { + var bestModel *config.ModelRef + bestValue := math.Inf(-1) + + for i := range candidates { + model := &candidates[i] + score := scores[model.Model] + + cap := a.capabilities[model.Model] + if cap == nil { + continue + } + + // Calculate cost-adjusted value + costFactor := 1.0 + if cap.Cost > 0 { + // Prefer cheaper models when cost weight is high + costFactor = 1.0 / (1.0 + cap.Cost*selCtx.CostWeight) + } + + value := score * costFactor + + // Prefer models above verification threshold + if cap.VerificationProb >= a.config.VerificationThreshold { + value *= 1.1 // 10% bonus for likely-to-succeed models + } + + if value > bestValue { + bestValue = value + bestModel = model + } + } + + if bestModel == nil && len(candidates) > 0 { + bestModel = &candidates[0] + bestValue = scores[bestModel.Model] + } + + reasoning := fmt.Sprintf("Cost-aware POMDP selection (tradeoff=%.2f, discount=%.2f)", + a.config.CostQualityTradeoff, a.config.DiscountFactor) + + return bestModel, bestValue, reasoning +} + +// selectQualityOnly selects the highest quality model regardless of cost +func (a *AutoMixSelector) selectQualityOnly(candidates []config.ModelRef, scores map[string]float64) (*config.ModelRef, float64, string) { + var bestModel *config.ModelRef + var bestScore float64 + + for i := range candidates { + model := &candidates[i] + score := scores[model.Model] + + if score > bestScore || bestModel == nil { + bestScore = score + bestModel = model + } + } + + reasoning := fmt.Sprintf("Quality-only POMDP selection (threshold=%.2f)", + a.config.VerificationThreshold) + + return bestModel, bestScore, reasoning +} + +// sortByCost sorts models by cost (ascending) +func (a *AutoMixSelector) sortByCost(models []config.ModelRef) []config.ModelRef { + sorted := make([]config.ModelRef, len(models)) + copy(sorted, models) + + a.capMu.RLock() + defer a.capMu.RUnlock() + + sort.Slice(sorted, func(i, j int) bool { + capI := a.capabilities[sorted[i].Model] + capJ := a.capabilities[sorted[j].Model] + + costI := 0.0 + costJ := 0.0 + if capI != nil { + costI = capI.Cost + } + if capJ != nil { + costJ = capJ.Cost + } + + return costI < costJ + }) + + return sorted +} + +// getVerificationProbability returns the learned verification probability +func (a *AutoMixSelector) getVerificationProbability(model string) float64 { + a.capMu.RLock() + defer a.capMu.RUnlock() + + if cap, ok := a.capabilities[model]; ok { + return cap.VerificationProb + } + return 0.7 // Default +} + +// updateValueFunction performs one iteration of POMDP value update +func (a *AutoMixSelector) updateValueFunction() { + a.capMu.RLock() + defer a.capMu.RUnlock() + a.valueMu.Lock() + defer a.valueMu.Unlock() + + // Simple value iteration: V(s) = R(s) + γ * max_a E[V(s')] + for model, cap := range a.capabilities { + // Current reward + reward := cap.AvgQuality + + // Expected future value (from escalation) + futureValue := 0.0 + if cap.VerificationProb < a.config.VerificationThreshold { + // Calculate expected value of escalation + for otherModel, otherCap := range a.capabilities { + if otherCap.ParamSize > cap.ParamSize { + // Larger model could be escalation target + transitionProb := (1 - cap.VerificationProb) * 0.5 // Simplified + futureValue += transitionProb * a.valueFunction[otherModel] + } + } + } + + // Update value + a.valueFunction[model] = reward + a.config.DiscountFactor*futureValue + + // Update escalation reward for capability + cap.EscalationReward = futureValue + } +} + +// estimateParamSize estimates model size from name +func (a *AutoMixSelector) estimateParamSize(model string) float64 { + // Extract size from common naming patterns (7b, 13b, 70b, etc.) + sizes := []struct { + pattern string + size float64 + }{ + {"405b", 405.0}, + {"70b", 70.0}, + {"72b", 72.0}, + {"34b", 34.0}, + {"32b", 32.0}, + {"14b", 14.0}, + {"13b", 13.0}, + {"8b", 8.0}, + {"7b", 7.0}, + {"3b", 3.0}, + {"1.8b", 1.8}, + {"1.5b", 1.5}, + {"0.5b", 0.5}, + } + + modelLower := strings.ToLower(model) + for _, s := range sizes { + if strings.Contains(modelLower, s.pattern) { + return s.size + } + } + + return 7.0 // Default assumption +} + +// GetCapabilities returns all model capabilities (for debugging) +func (a *AutoMixSelector) GetCapabilities() map[string]*ModelCapability { + a.capMu.RLock() + defer a.capMu.RUnlock() + + result := make(map[string]*ModelCapability) + for k, v := range a.capabilities { + capCopy := *v + result[k] = &capCopy + } + return result +} + +// SetCapability directly sets a model's capability +func (a *AutoMixSelector) SetCapability(model string, cap *ModelCapability) { + a.capMu.Lock() + defer a.capMu.Unlock() + a.capabilities[model] = cap +} diff --git a/src/semantic-router/pkg/selection/elo.go b/src/semantic-router/pkg/selection/elo.go new file mode 100644 index 0000000000..6b6da1f57e --- /dev/null +++ b/src/semantic-router/pkg/selection/elo.go @@ -0,0 +1,494 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "context" + "fmt" + "math" + "sort" + "sync" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" +) + +// DefaultEloRating is the initial Elo rating for new models +const DefaultEloRating = 1500.0 + +// EloKFactor controls how much ratings change per comparison +// Higher values = faster adaptation but more volatility +const EloKFactor = 32.0 + +// EloMinRatingFromScore is the base rating when converting static scores (0-1) to Elo +const EloMinRatingFromScore = 1000.0 + +// EloRatingRange is the rating range for score conversion (scores 0-1 map to 1000-2000) +const EloRatingRange = 1000.0 + +// EloConfig configures the Elo-based model selector +type EloConfig struct { + // InitialRating is the starting Elo rating for new models + InitialRating float64 `yaml:"initial_rating"` + + // KFactor controls rating volatility (higher = more volatile) + KFactor float64 `yaml:"k_factor"` + + // CategoryWeighted enables per-category Elo ratings + CategoryWeighted bool `yaml:"category_weighted"` + + // DecayFactor applies time decay to old comparisons (0-1, 0 = no decay) + DecayFactor float64 `yaml:"decay_factor"` + + // MinComparisons is minimum comparisons before a rating is considered stable + MinComparisons int `yaml:"min_comparisons"` + + // CostScalingFactor scales cost consideration (0 = ignore cost) + CostScalingFactor float64 `yaml:"cost_scaling_factor"` +} + +// DefaultEloConfig returns the default Elo configuration +func DefaultEloConfig() *EloConfig { + return &EloConfig{ + InitialRating: DefaultEloRating, + KFactor: EloKFactor, + CategoryWeighted: true, + DecayFactor: 0.0, + MinComparisons: 5, + CostScalingFactor: 0.0, + } +} + +// ModelRating stores the Elo rating and metadata for a model +type ModelRating struct { + Model string `json:"model"` + Rating float64 `json:"rating"` + Comparisons int `json:"comparisons"` + Wins int `json:"wins"` + Losses int `json:"losses"` + Ties int `json:"ties"` +} + +// EloSelector implements Elo rating-based model selection +// Based on RouteLLM paper (arXiv:2406.18665) using Bradley-Terry model +type EloSelector struct { + config *EloConfig + + // Global ratings (not category-specific) + globalRatings map[string]*ModelRating + globalMu sync.RWMutex + + // Category-specific ratings (decision name -> model -> rating) + categoryRatings map[string]map[string]*ModelRating + categoryMu sync.RWMutex + + // Model costs for cost-aware selection (model -> cost per 1M tokens) + modelCosts map[string]float64 + costMu sync.RWMutex +} + +// NewEloSelector creates a new Elo-based selector +func NewEloSelector(cfg *EloConfig) *EloSelector { + if cfg == nil { + cfg = DefaultEloConfig() + } + return &EloSelector{ + config: cfg, + globalRatings: make(map[string]*ModelRating), + categoryRatings: make(map[string]map[string]*ModelRating), + modelCosts: make(map[string]float64), + } +} + +// Method returns the selection method type +func (e *EloSelector) Method() SelectionMethod { + return MethodElo +} + +// SetModelCost sets the cost per 1M tokens for a model +func (e *EloSelector) SetModelCost(model string, costPer1M float64) { + e.costMu.Lock() + defer e.costMu.Unlock() + e.modelCosts[model] = costPer1M +} + +// InitializeFromConfig sets up initial ratings from model configuration +func (e *EloSelector) InitializeFromConfig(modelConfig map[string]config.ModelParams, categories []config.Category) { + e.globalMu.Lock() + defer e.globalMu.Unlock() + + // Initialize global ratings for all models + for model := range modelConfig { + if _, exists := e.globalRatings[model]; !exists { + e.globalRatings[model] = &ModelRating{ + Model: model, + Rating: e.config.InitialRating, + } + } + } + + // Set costs from config + e.costMu.Lock() + for model, params := range modelConfig { + if params.Pricing.PromptPer1M > 0 { + e.modelCosts[model] = params.Pricing.PromptPer1M + } + } + e.costMu.Unlock() + + // Initialize category ratings from ModelScores if available + if e.config.CategoryWeighted { + e.categoryMu.Lock() + for _, category := range categories { + if e.categoryRatings[category.Name] == nil { + e.categoryRatings[category.Name] = make(map[string]*ModelRating) + } + for _, ms := range category.ModelScores { + // Convert static scores to Elo ratings (scale 0-1 -> 1000-2000) + rating := EloMinRatingFromScore + (ms.Score * EloRatingRange) + e.categoryRatings[category.Name][ms.Model] = &ModelRating{ + Model: ms.Model, + Rating: rating, + } + } + } + e.categoryMu.Unlock() + } +} + +// Select chooses the best model based on Elo ratings +func (e *EloSelector) Select(ctx context.Context, selCtx *SelectionContext) (*SelectionResult, error) { + if len(selCtx.CandidateModels) == 0 { + return nil, fmt.Errorf("no candidate models provided") + } + + allScores := make(map[string]float64) + + // Get ratings for all candidates + ratings := e.getRatingsForCandidates(selCtx.DecisionName, selCtx.CandidateModels) + + logging.Infof("[EloSelector] Evaluating %d candidates for category '%s':", + len(selCtx.CandidateModels), selCtx.DecisionName) + for _, r := range ratings { + logging.Infof("[EloSelector] %s: rating=%.1f (W:%d L:%d T:%d)", + r.Model, r.Rating, r.Wins, r.Losses, r.Ties) + } + + // Calculate selection probability using Bradley-Terry model + // P(model_i wins) = rating_i / sum(all ratings) + totalRating := 0.0 + for _, r := range ratings { + totalRating += math.Pow(10, r.Rating/400.0) // Standard Elo probability scale + } + + if totalRating == 0 { + // Fallback: uniform distribution + for _, r := range ratings { + allScores[r.Model] = 1.0 / float64(len(ratings)) + } + } else { + for _, r := range ratings { + prob := math.Pow(10, r.Rating/400.0) / totalRating + allScores[r.Model] = prob + } + } + + // Apply cost adjustment if enabled + if e.config.CostScalingFactor > 0 && selCtx.CostWeight > 0 { + e.applyCostAdjustment(allScores, selCtx.CostWeight) + } + + // Find best model by combined score + var bestModel *config.ModelRef + var bestScore float64 + var bestRating *ModelRating + + for i := range selCtx.CandidateModels { + model := &selCtx.CandidateModels[i] + score := allScores[model.Model] + + if score > bestScore || bestModel == nil { + bestScore = score + bestModel = model + for _, r := range ratings { + if r.Model == model.Model { + bestRating = r + break + } + } + } + } + + if bestModel == nil { + return nil, fmt.Errorf("could not select a model") + } + + // Calculate confidence based on rating stability + confidence := e.calculateConfidence(bestRating) + + reasoning := fmt.Sprintf("Selected based on Elo rating %.1f (win rate: %d/%d)", + bestRating.Rating, + bestRating.Wins, + bestRating.Wins+bestRating.Losses+bestRating.Ties) + + if e.config.CategoryWeighted && selCtx.DecisionName != "" { + reasoning = fmt.Sprintf("Category '%s': %s", selCtx.DecisionName, reasoning) + } + + logging.Infof("[EloSelector] Selected model %s (rating=%.1f, score=%.4f, confidence=%.2f)", + bestModel.Model, bestRating.Rating, bestScore, confidence) + + return &SelectionResult{ + SelectedModel: bestModel.Model, + LoRAName: bestModel.LoRAName, + Score: bestScore, + Confidence: confidence, + Method: MethodElo, + Reasoning: reasoning, + AllScores: allScores, + }, nil +} + +// UpdateFeedback updates Elo ratings based on user preference feedback +func (e *EloSelector) UpdateFeedback(ctx context.Context, feedback *Feedback) error { + if feedback.WinnerModel == "" { + return fmt.Errorf("winner model is required") + } + + // Update global ratings + e.updateRating(feedback, e.getGlobalRating, e.setGlobalRating) + + // Update category ratings if applicable + if e.config.CategoryWeighted && feedback.DecisionName != "" { + e.updateRating(feedback, + func(model string) *ModelRating { + return e.getCategoryRating(feedback.DecisionName, model) + }, + func(model string, rating *ModelRating) { + e.setCategoryRating(feedback.DecisionName, model, rating) + }) + } + + logging.Infof("[EloSelector] Updated ratings: winner=%s, loser=%s, tie=%v", + feedback.WinnerModel, feedback.LoserModel, feedback.Tie) + + return nil +} + +// updateRating performs the actual Elo rating update +func (e *EloSelector) updateRating(feedback *Feedback, + getRating func(string) *ModelRating, + setRating func(string, *ModelRating), +) { + winnerRating := getRating(feedback.WinnerModel) + if winnerRating == nil { + winnerRating = &ModelRating{Model: feedback.WinnerModel, Rating: e.config.InitialRating} + } + + // Handle single feedback (no loser) + if feedback.LoserModel == "" { + // Just record the comparison without rating change + winnerRating.Comparisons++ + winnerRating.Wins++ + setRating(feedback.WinnerModel, winnerRating) + return + } + + loserRating := getRating(feedback.LoserModel) + if loserRating == nil { + loserRating = &ModelRating{Model: feedback.LoserModel, Rating: e.config.InitialRating} + } + + // Calculate expected scores using Bradley-Terry model + // E_a = 1 / (1 + 10^((R_b - R_a) / 400)) + expectedWinner := 1.0 / (1.0 + math.Pow(10, (loserRating.Rating-winnerRating.Rating)/400.0)) + expectedLoser := 1.0 - expectedWinner + + // Determine actual scores + var actualWinner, actualLoser float64 + if feedback.Tie { + actualWinner = 0.5 + actualLoser = 0.5 + } else { + actualWinner = 1.0 + actualLoser = 0.0 + } + + // Update ratings: R' = R + K * (actual - expected) + winnerRating.Rating += e.config.KFactor * (actualWinner - expectedWinner) + loserRating.Rating += e.config.KFactor * (actualLoser - expectedLoser) + + // Update statistics + winnerRating.Comparisons++ + loserRating.Comparisons++ + if feedback.Tie { + winnerRating.Ties++ + loserRating.Ties++ + } else { + winnerRating.Wins++ + loserRating.Losses++ + } + + setRating(feedback.WinnerModel, winnerRating) + setRating(feedback.LoserModel, loserRating) +} + +// getRatingsForCandidates retrieves ratings for all candidate models +func (e *EloSelector) getRatingsForCandidates(decisionName string, candidates []config.ModelRef) []*ModelRating { + ratings := make([]*ModelRating, 0, len(candidates)) + + for _, c := range candidates { + var rating *ModelRating + + // Try category-specific rating first + if e.config.CategoryWeighted && decisionName != "" { + rating = e.getCategoryRating(decisionName, c.Model) + } + + // Fall back to global rating + if rating == nil { + rating = e.getGlobalRating(c.Model) + } + + // Create default rating if not found + if rating == nil { + rating = &ModelRating{ + Model: c.Model, + Rating: e.config.InitialRating, + } + } + + ratings = append(ratings, rating) + } + + return ratings +} + +// getGlobalRating retrieves the global rating for a model +func (e *EloSelector) getGlobalRating(model string) *ModelRating { + e.globalMu.RLock() + defer e.globalMu.RUnlock() + return e.globalRatings[model] +} + +// setGlobalRating sets the global rating for a model +func (e *EloSelector) setGlobalRating(model string, rating *ModelRating) { + e.globalMu.Lock() + defer e.globalMu.Unlock() + e.globalRatings[model] = rating +} + +// getCategoryRating retrieves the category-specific rating for a model +func (e *EloSelector) getCategoryRating(category, model string) *ModelRating { + e.categoryMu.RLock() + defer e.categoryMu.RUnlock() + if catRatings, ok := e.categoryRatings[category]; ok { + return catRatings[model] + } + return nil +} + +// setCategoryRating sets the category-specific rating for a model +func (e *EloSelector) setCategoryRating(category, model string, rating *ModelRating) { + e.categoryMu.Lock() + defer e.categoryMu.Unlock() + if e.categoryRatings[category] == nil { + e.categoryRatings[category] = make(map[string]*ModelRating) + } + e.categoryRatings[category][model] = rating +} + +// applyCostAdjustment adjusts scores based on model costs +func (e *EloSelector) applyCostAdjustment(scores map[string]float64, costWeight float64) { + e.costMu.RLock() + defer e.costMu.RUnlock() + + if len(e.modelCosts) == 0 { + return + } + + // Find min and max costs for normalization + minCost, maxCost := math.MaxFloat64, 0.0 + for model := range scores { + if cost, ok := e.modelCosts[model]; ok { + if cost < minCost { + minCost = cost + } + if cost > maxCost { + maxCost = cost + } + } + } + + if maxCost == minCost { + return // All same cost, no adjustment needed + } + + // Adjust scores: cheaper models get bonus + for model := range scores { + if cost, ok := e.modelCosts[model]; ok { + // Normalize cost to 0-1 (0 = cheapest, 1 = most expensive) + normalizedCost := (cost - minCost) / (maxCost - minCost) + // Cost penalty: cheaper models get higher bonus + costBonus := (1.0 - normalizedCost) * costWeight * e.config.CostScalingFactor + scores[model] *= (1.0 + costBonus) + } + } +} + +// calculateConfidence returns confidence based on rating stability +func (e *EloSelector) calculateConfidence(rating *ModelRating) float64 { + if rating == nil { + return 0.5 + } + + // Confidence increases with more comparisons + // Sigmoid function: 1 / (1 + e^(-k*(x-threshold))) + k := 0.2 // Steepness + threshold := float64(e.config.MinComparisons) + confidence := 1.0 / (1.0 + math.Exp(-k*(float64(rating.Comparisons)-threshold))) + + return confidence +} + +// GetLeaderboard returns models sorted by rating (for debugging/monitoring) +func (e *EloSelector) GetLeaderboard(category string) []*ModelRating { + var ratings []*ModelRating + + if category != "" && e.config.CategoryWeighted { + e.categoryMu.RLock() + if catRatings, ok := e.categoryRatings[category]; ok { + for _, r := range catRatings { + ratings = append(ratings, r) + } + } + e.categoryMu.RUnlock() + } else { + e.globalMu.RLock() + for _, r := range e.globalRatings { + ratings = append(ratings, r) + } + e.globalMu.RUnlock() + } + + // Sort by rating descending + sort.Slice(ratings, func(i, j int) bool { + return ratings[i].Rating > ratings[j].Rating + }) + + return ratings +} diff --git a/src/semantic-router/pkg/selection/factory.go b/src/semantic-router/pkg/selection/factory.go new file mode 100644 index 0000000000..afb49e597a --- /dev/null +++ b/src/semantic-router/pkg/selection/factory.go @@ -0,0 +1,216 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" +) + +// ModelSelectionConfig represents the configuration for model selection +type ModelSelectionConfig struct { + // Method specifies the selection algorithm to use + Method string `yaml:"method"` + + // Elo configuration (used when method is "elo") + Elo *EloConfig `yaml:"elo,omitempty"` + + // RouterDC configuration (used when method is "router_dc") + RouterDC *RouterDCConfig `yaml:"router_dc,omitempty"` + + // AutoMix configuration (used when method is "automix") + AutoMix *AutoMixConfig `yaml:"automix,omitempty"` + + // Hybrid configuration (used when method is "hybrid") + Hybrid *HybridConfig `yaml:"hybrid,omitempty"` +} + +// DefaultModelSelectionConfig returns the default configuration +func DefaultModelSelectionConfig() *ModelSelectionConfig { + return &ModelSelectionConfig{ + Method: string(MethodStatic), + } +} + +// Factory creates and initializes selectors based on configuration +type Factory struct { + cfg *ModelSelectionConfig + modelConfig map[string]config.ModelParams + categories []config.Category + embeddingFunc func(string) ([]float32, error) +} + +// NewFactory creates a new selector factory +func NewFactory(cfg *ModelSelectionConfig) *Factory { + if cfg == nil { + cfg = DefaultModelSelectionConfig() + } + return &Factory{ + cfg: cfg, + } +} + +// WithModelConfig sets the model configuration +func (f *Factory) WithModelConfig(modelConfig map[string]config.ModelParams) *Factory { + f.modelConfig = modelConfig + return f +} + +// WithCategories sets the category configuration +func (f *Factory) WithCategories(categories []config.Category) *Factory { + f.categories = categories + return f +} + +// WithEmbeddingFunc sets the embedding function for RouterDC +func (f *Factory) WithEmbeddingFunc(fn func(string) ([]float32, error)) *Factory { + f.embeddingFunc = fn + return f +} + +// Create creates and initializes a selector based on the configured method +func (f *Factory) Create() Selector { + method := SelectionMethod(f.cfg.Method) + + var selector Selector + + switch method { + case MethodElo: + eloSelector := NewEloSelector(f.cfg.Elo) + if f.modelConfig != nil { + eloSelector.InitializeFromConfig(f.modelConfig, f.categories) + } + selector = eloSelector + + case MethodRouterDC: + routerDCSelector := NewRouterDCSelector(f.cfg.RouterDC) + if f.embeddingFunc != nil { + routerDCSelector.SetEmbeddingFunc(f.embeddingFunc) + } + selector = routerDCSelector + + case MethodAutoMix: + autoMixSelector := NewAutoMixSelector(f.cfg.AutoMix) + if f.modelConfig != nil { + autoMixSelector.InitializeFromConfig(f.modelConfig) + } + selector = autoMixSelector + + case MethodHybrid: + hybridSelector := NewHybridSelector(f.cfg.Hybrid) + if f.modelConfig != nil { + hybridSelector.InitializeFromConfig(f.modelConfig, f.categories) + } + if f.embeddingFunc != nil && hybridSelector.routerDCSelector != nil { + hybridSelector.routerDCSelector.SetEmbeddingFunc(f.embeddingFunc) + } + selector = hybridSelector + + default: + // Default to static selector + staticSelector := NewStaticSelector(DefaultStaticConfig()) + if f.categories != nil { + staticSelector.InitializeFromConfig(f.categories) + } + selector = staticSelector + } + + logging.Infof("[SelectionFactory] Created selector: method=%s", method) + return selector +} + +// CreateAll creates all available selectors and registers them +func (f *Factory) CreateAll() *Registry { + registry := NewRegistry() + + // Always create static selector + staticSelector := NewStaticSelector(DefaultStaticConfig()) + if f.categories != nil { + staticSelector.InitializeFromConfig(f.categories) + } + registry.Register(MethodStatic, staticSelector) + + // Create Elo selector + eloCfg := f.cfg.Elo + if eloCfg == nil { + eloCfg = DefaultEloConfig() + } + eloSelector := NewEloSelector(eloCfg) + if f.modelConfig != nil { + eloSelector.InitializeFromConfig(f.modelConfig, f.categories) + } + registry.Register(MethodElo, eloSelector) + + // Create RouterDC selector + routerDCCfg := f.cfg.RouterDC + if routerDCCfg == nil { + routerDCCfg = DefaultRouterDCConfig() + } + routerDCSelector := NewRouterDCSelector(routerDCCfg) + if f.embeddingFunc != nil { + routerDCSelector.SetEmbeddingFunc(f.embeddingFunc) + } + registry.Register(MethodRouterDC, routerDCSelector) + + // Create AutoMix selector + autoMixCfg := f.cfg.AutoMix + if autoMixCfg == nil { + autoMixCfg = DefaultAutoMixConfig() + } + autoMixSelector := NewAutoMixSelector(autoMixCfg) + if f.modelConfig != nil { + autoMixSelector.InitializeFromConfig(f.modelConfig) + } + registry.Register(MethodAutoMix, autoMixSelector) + + // Create Hybrid selector with component references + hybridCfg := f.cfg.Hybrid + if hybridCfg == nil { + hybridCfg = DefaultHybridConfig() + } + hybridSelector := NewHybridSelectorWithComponents(hybridCfg, eloSelector, routerDCSelector, autoMixSelector) + if f.modelConfig != nil { + hybridSelector.InitializeFromConfig(f.modelConfig, f.categories) + } + registry.Register(MethodHybrid, hybridSelector) + + logging.Infof("[SelectionFactory] Created all selectors: static, elo, router_dc, automix, hybrid") + return registry +} + +// Initialize sets up the global registry with all selectors +func Initialize(cfg *ModelSelectionConfig, modelConfig map[string]config.ModelParams, categories []config.Category, embeddingFunc func(string) ([]float32, error)) { + factory := NewFactory(cfg). + WithModelConfig(modelConfig). + WithCategories(categories). + WithEmbeddingFunc(embeddingFunc) + + // Create all selectors and register globally + GlobalRegistry = factory.CreateAll() + + logging.Infof("[Selection] Initialized global selector registry") +} + +// GetSelector returns a selector for the specified method from global registry +func GetSelector(method SelectionMethod) Selector { + selector, ok := GlobalRegistry.Get(method) + if !ok { + // Fallback to static + selector, _ = GlobalRegistry.Get(MethodStatic) + } + return selector +} diff --git a/src/semantic-router/pkg/selection/hybrid.go b/src/semantic-router/pkg/selection/hybrid.go new file mode 100644 index 0000000000..cc7c4ffc41 --- /dev/null +++ b/src/semantic-router/pkg/selection/hybrid.go @@ -0,0 +1,466 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "context" + "fmt" + "math" + "strings" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" +) + +// HybridConfig configures the Hybrid selector that combines multiple methods +// Based on arXiv:2404.14618 - Hybrid LLM: Cost-Efficient Quality-Aware Query Routing +type HybridConfig struct { + // EloWeight is the weight for Elo rating contribution (0-1) + EloWeight float64 `yaml:"elo_weight"` + + // RouterDCWeight is the weight for embedding similarity contribution (0-1) + RouterDCWeight float64 `yaml:"router_dc_weight"` + + // AutoMixWeight is the weight for POMDP value contribution (0-1) + AutoMixWeight float64 `yaml:"automix_weight"` + + // CostWeight is the weight for cost consideration (0-1) + CostWeight float64 `yaml:"cost_weight"` + + // QualityGapThreshold triggers escalation to larger models + QualityGapThreshold float64 `yaml:"quality_gap_threshold"` + + // UseMLP enables MLP-based quality gap prediction (advanced) + UseMLP bool `yaml:"use_mlp"` + + // NormalizeScores normalizes component scores before combination + NormalizeScores bool `yaml:"normalize_scores"` +} + +// DefaultHybridConfig returns the default Hybrid configuration +func DefaultHybridConfig() *HybridConfig { + return &HybridConfig{ + EloWeight: 0.3, + RouterDCWeight: 0.3, + AutoMixWeight: 0.2, + CostWeight: 0.2, + QualityGapThreshold: 0.1, + UseMLP: false, + NormalizeScores: true, + } +} + +// HybridSelector combines multiple selection methods for robust routing +// It uses weighted combination of Elo ratings, embedding similarity, +// and POMDP values, with optional cost-aware optimization. +type HybridSelector struct { + config *HybridConfig + + // Component selectors + eloSelector *EloSelector + routerDCSelector *RouterDCSelector + autoMixSelector *AutoMixSelector + + // Model costs for cost-aware selection + modelCosts map[string]float64 +} + +// NewHybridSelector creates a new Hybrid selector +func NewHybridSelector(cfg *HybridConfig) *HybridSelector { + if cfg == nil { + cfg = DefaultHybridConfig() + } + + return &HybridSelector{ + config: cfg, + eloSelector: NewEloSelector(DefaultEloConfig()), + routerDCSelector: NewRouterDCSelector(DefaultRouterDCConfig()), + autoMixSelector: NewAutoMixSelector(DefaultAutoMixConfig()), + modelCosts: make(map[string]float64), + } +} + +// NewHybridSelectorWithComponents creates a Hybrid selector with custom components +func NewHybridSelectorWithComponents( + cfg *HybridConfig, + elo *EloSelector, + routerDC *RouterDCSelector, + autoMix *AutoMixSelector, +) *HybridSelector { + if cfg == nil { + cfg = DefaultHybridConfig() + } + + return &HybridSelector{ + config: cfg, + eloSelector: elo, + routerDCSelector: routerDC, + autoMixSelector: autoMix, + modelCosts: make(map[string]float64), + } +} + +// Method returns the selection method type +func (h *HybridSelector) Method() SelectionMethod { + return MethodHybrid +} + +// SetEloSelector sets the Elo component +func (h *HybridSelector) SetEloSelector(elo *EloSelector) { + h.eloSelector = elo +} + +// SetRouterDCSelector sets the RouterDC component +func (h *HybridSelector) SetRouterDCSelector(routerDC *RouterDCSelector) { + h.routerDCSelector = routerDC +} + +// SetAutoMixSelector sets the AutoMix component +func (h *HybridSelector) SetAutoMixSelector(autoMix *AutoMixSelector) { + h.autoMixSelector = autoMix +} + +// SetModelCost sets the cost for a model +func (h *HybridSelector) SetModelCost(model string, cost float64) { + h.modelCosts[model] = cost +} + +// Select chooses the best model by combining multiple selection methods +func (h *HybridSelector) Select(ctx context.Context, selCtx *SelectionContext) (*SelectionResult, error) { + if len(selCtx.CandidateModels) == 0 { + return nil, fmt.Errorf("no candidate models provided") + } + + // Collect scores from each component + componentScores := make(map[string]map[string]float64) + var componentResults []*SelectionResult + + // Get Elo scores + if h.eloSelector != nil && h.config.EloWeight > 0 { + result, err := h.eloSelector.Select(ctx, selCtx) + if err == nil && result != nil { + componentScores["elo"] = result.AllScores + componentResults = append(componentResults, result) + } + } + + // Get RouterDC scores + if h.routerDCSelector != nil && h.config.RouterDCWeight > 0 { + result, err := h.routerDCSelector.Select(ctx, selCtx) + if err == nil && result != nil { + componentScores["router_dc"] = result.AllScores + componentResults = append(componentResults, result) + } + } + + // Get AutoMix scores + if h.autoMixSelector != nil && h.config.AutoMixWeight > 0 { + result, err := h.autoMixSelector.Select(ctx, selCtx) + if err == nil && result != nil { + componentScores["automix"] = result.AllScores + componentResults = append(componentResults, result) + } + } + + // Normalize scores if enabled + if h.config.NormalizeScores { + for component, scores := range componentScores { + componentScores[component] = h.normalizeScores(scores) + } + } + + // Combine scores with weights + combinedScores := h.combineScores(componentScores, selCtx.CandidateModels) + + // Apply cost adjustment + if h.config.CostWeight > 0 { + h.applyCostAdjustment(combinedScores, selCtx.CostWeight) + } + + logging.Infof("[HybridSelector] Combining scores (weights: elo=%.2f, dc=%.2f, am=%.2f, cost=%.2f):", + h.config.EloWeight, h.config.RouterDCWeight, h.config.AutoMixWeight, h.config.CostWeight) + for _, model := range selCtx.CandidateModels { + var eloScore, dcScore, amScore float64 + if scores, ok := componentScores["elo"]; ok { + eloScore = scores[model.Model] + } + if scores, ok := componentScores["router_dc"]; ok { + dcScore = scores[model.Model] + } + if scores, ok := componentScores["automix"]; ok { + amScore = scores[model.Model] + } + logging.Infof("[HybridSelector] %s: elo=%.4f, dc=%.4f, am=%.4f → combined=%.4f", + model.Model, eloScore, dcScore, amScore, combinedScores[model.Model]) + } + + // Find best model + var bestModel *config.ModelRef + var bestScore float64 + + for i := range selCtx.CandidateModels { + model := &selCtx.CandidateModels[i] + score := combinedScores[model.Model] + + if score > bestScore || bestModel == nil { + bestScore = score + bestModel = model + } + } + + if bestModel == nil { + return nil, fmt.Errorf("could not select a model") + } + + // Calculate confidence from component agreement + confidence := h.calculateConfidence(componentResults, bestModel.Model) + + // Build reasoning + reasoning := h.buildReasoning(componentScores, bestModel.Model) + + logging.Infof("[HybridSelector] Selected model %s (score=%.4f, confidence=%.2f, components=%d)", + bestModel.Model, bestScore, confidence, len(componentResults)) + + return &SelectionResult{ + SelectedModel: bestModel.Model, + LoRAName: bestModel.LoRAName, + Score: bestScore, + Confidence: confidence, + Method: MethodHybrid, + Reasoning: reasoning, + AllScores: combinedScores, + }, nil +} + +// UpdateFeedback propagates feedback to all component selectors +func (h *HybridSelector) UpdateFeedback(ctx context.Context, feedback *Feedback) error { + var errs []error + + if h.eloSelector != nil { + if err := h.eloSelector.UpdateFeedback(ctx, feedback); err != nil { + errs = append(errs, fmt.Errorf("elo: %w", err)) + } + } + + if h.routerDCSelector != nil { + if err := h.routerDCSelector.UpdateFeedback(ctx, feedback); err != nil { + errs = append(errs, fmt.Errorf("router_dc: %w", err)) + } + } + + if h.autoMixSelector != nil { + if err := h.autoMixSelector.UpdateFeedback(ctx, feedback); err != nil { + errs = append(errs, fmt.Errorf("automix: %w", err)) + } + } + + if len(errs) > 0 { + return fmt.Errorf("feedback update errors: %v", errs) + } + + logging.Debugf("[HybridSelector] Propagated feedback to %d components", 3) + return nil +} + +// combineScores combines scores from all components with weights +func (h *HybridSelector) combineScores(componentScores map[string]map[string]float64, candidates []config.ModelRef) map[string]float64 { + result := make(map[string]float64) + + // Initialize with zeros + for _, c := range candidates { + result[c.Model] = 0.0 + } + + // Weight mapping + weights := map[string]float64{ + "elo": h.config.EloWeight, + "router_dc": h.config.RouterDCWeight, + "automix": h.config.AutoMixWeight, + } + + // Calculate total weight for normalization + totalWeight := 0.0 + for component, scores := range componentScores { + if len(scores) > 0 { + totalWeight += weights[component] + } + } + + if totalWeight == 0 { + // No component scores available, use uniform + for model := range result { + result[model] = 1.0 / float64(len(candidates)) + } + return result + } + + // Weighted combination + for component, scores := range componentScores { + weight := weights[component] + for model, score := range scores { + result[model] += (weight / totalWeight) * score + } + } + + return result +} + +// normalizeScores normalizes scores to [0, 1] range using min-max normalization +func (h *HybridSelector) normalizeScores(scores map[string]float64) map[string]float64 { + if len(scores) == 0 { + return scores + } + + minScore := math.Inf(1) + maxScore := math.Inf(-1) + + for _, s := range scores { + if s < minScore { + minScore = s + } + if s > maxScore { + maxScore = s + } + } + + // Avoid division by zero + if maxScore == minScore { + result := make(map[string]float64) + for model := range scores { + result[model] = 0.5 + } + return result + } + + result := make(map[string]float64) + for model, s := range scores { + result[model] = (s - minScore) / (maxScore - minScore) + } + + return result +} + +// applyCostAdjustment applies cost-based score adjustment +func (h *HybridSelector) applyCostAdjustment(scores map[string]float64, costWeight float64) { + if len(h.modelCosts) == 0 || costWeight <= 0 { + return + } + + // Find min and max costs + minCost, maxCost := math.MaxFloat64, 0.0 + for model := range scores { + if cost, ok := h.modelCosts[model]; ok { + if cost < minCost { + minCost = cost + } + if cost > maxCost { + maxCost = cost + } + } + } + + if maxCost == minCost { + return + } + + // Adjust scores: cheaper models get bonus + for model := range scores { + if cost, ok := h.modelCosts[model]; ok { + normalizedCost := (cost - minCost) / (maxCost - minCost) + costBonus := (1.0 - normalizedCost) * costWeight * h.config.CostWeight + scores[model] *= (1.0 + costBonus) + } + } +} + +// calculateConfidence calculates confidence based on component agreement +func (h *HybridSelector) calculateConfidence(results []*SelectionResult, selectedModel string) float64 { + if len(results) == 0 { + return 0.5 + } + + // Count how many components agree on the selected model + agreements := 0 + totalConfidence := 0.0 + + for _, r := range results { + if r.SelectedModel == selectedModel { + agreements++ + } + totalConfidence += r.Confidence + } + + // Agreement ratio + agreementRatio := float64(agreements) / float64(len(results)) + + // Average component confidence + avgConfidence := totalConfidence / float64(len(results)) + + // Combine: higher agreement and confidence = more confident + return (agreementRatio + avgConfidence) / 2.0 +} + +// buildReasoning creates a human-readable explanation +func (h *HybridSelector) buildReasoning(componentScores map[string]map[string]float64, selectedModel string) string { + parts := []string{} + + if scores, ok := componentScores["elo"]; ok { + if score, ok := scores[selectedModel]; ok { + parts = append(parts, fmt.Sprintf("Elo=%.3f", score)) + } + } + + if scores, ok := componentScores["router_dc"]; ok { + if score, ok := scores[selectedModel]; ok { + parts = append(parts, fmt.Sprintf("RouterDC=%.3f", score)) + } + } + + if scores, ok := componentScores["automix"]; ok { + if score, ok := scores[selectedModel]; ok { + parts = append(parts, fmt.Sprintf("AutoMix=%.3f", score)) + } + } + + weightsStr := fmt.Sprintf("weights=[elo:%.2f, dc:%.2f, am:%.2f, cost:%.2f]", + h.config.EloWeight, h.config.RouterDCWeight, h.config.AutoMixWeight, h.config.CostWeight) + + if len(parts) > 0 { + return fmt.Sprintf("Hybrid combination: [%s], %s", strings.Join(parts, " "), weightsStr) + } + return fmt.Sprintf("Hybrid selection with %s", weightsStr) +} + +// InitializeFromConfig initializes all component selectors from configuration +func (h *HybridSelector) InitializeFromConfig(modelConfig map[string]config.ModelParams, categories []config.Category) { + if h.eloSelector != nil { + h.eloSelector.InitializeFromConfig(modelConfig, categories) + } + + if h.autoMixSelector != nil { + h.autoMixSelector.InitializeFromConfig(modelConfig) + } + + // Set costs from config + for model, params := range modelConfig { + if params.Pricing.PromptPer1M > 0 { + h.modelCosts[model] = params.Pricing.PromptPer1M + } + } + + logging.Infof("[HybridSelector] Initialized from config with %d models", len(modelConfig)) +} diff --git a/src/semantic-router/pkg/selection/metrics.go b/src/semantic-router/pkg/selection/metrics.go new file mode 100644 index 0000000000..44863a7bb0 --- /dev/null +++ b/src/semantic-router/pkg/selection/metrics.go @@ -0,0 +1,174 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Prometheus metrics for model selection tracking +var ( + // ModelSelectionTotal tracks the total number of model selections + ModelSelectionTotal *prometheus.CounterVec + + // ModelSelectionDuration tracks the duration of model selection + ModelSelectionDuration *prometheus.HistogramVec + + // ModelSelectionScore tracks the score of selected models + ModelSelectionScore *prometheus.HistogramVec + + // ModelSelectionConfidence tracks confidence of selections + ModelSelectionConfidence *prometheus.HistogramVec + + // ModelEloRating tracks current Elo ratings for models + ModelEloRating *prometheus.GaugeVec + + // ModelFeedbackTotal tracks feedback events + ModelFeedbackTotal *prometheus.CounterVec + + // ComponentAgreement tracks how often components agree on selection + ComponentAgreement *prometheus.HistogramVec + + metricsInitOnce sync.Once +) + +// InitializeMetrics initializes the Prometheus metrics for model selection +func InitializeMetrics() { + metricsInitOnce.Do(func() { + ModelSelectionTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "llm_model_selection_total", + Help: "Total number of model selections by method and selected model", + }, + []string{"method", "model", "decision"}, + ) + + ModelSelectionDuration = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "llm_model_selection_duration_seconds", + Help: "Duration of model selection in seconds", + Buckets: []float64{0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1}, + }, + []string{"method"}, + ) + + ModelSelectionScore = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "llm_model_selection_score", + Help: "Score of selected models", + Buckets: []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, + }, + []string{"method", "model"}, + ) + + ModelSelectionConfidence = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "llm_model_selection_confidence", + Help: "Confidence of model selections", + Buckets: []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, + }, + []string{"method"}, + ) + + ModelEloRating = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "llm_model_elo_rating", + Help: "Current Elo rating for models by category", + }, + []string{"model", "category"}, + ) + + ModelFeedbackTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "llm_model_feedback_total", + Help: "Total feedback events by type", + }, + []string{"winner", "loser", "is_tie"}, + ) + + ComponentAgreement = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "llm_model_selection_component_agreement", + Help: "Agreement ratio between selection components (for hybrid)", + Buckets: []float64{0.0, 0.25, 0.5, 0.75, 1.0}, + }, + []string{}, + ) + }) +} + +// RecordSelection records a model selection event with full metrics +func RecordSelection(method string, decision string, model string, score float64) { + if ModelSelectionTotal == nil { + return // Metrics not initialized + } + + ModelSelectionTotal.WithLabelValues(method, model, decision).Inc() + ModelSelectionScore.WithLabelValues(method, model).Observe(score) +} + +// RecordSelectionFull records a model selection event with all metrics +func RecordSelectionFull(method SelectionMethod, model string, decision string, score, confidence float64, duration time.Duration) { + if ModelSelectionTotal == nil { + return // Metrics not initialized + } + + methodStr := string(method) + + ModelSelectionTotal.WithLabelValues(methodStr, model, decision).Inc() + ModelSelectionDuration.WithLabelValues(methodStr).Observe(duration.Seconds()) + ModelSelectionScore.WithLabelValues(methodStr, model).Observe(score) + ModelSelectionConfidence.WithLabelValues(methodStr).Observe(confidence) +} + +// RecordEloRating records the current Elo rating for a model +func RecordEloRating(model, category string, rating float64) { + if ModelEloRating == nil { + return + } + ModelEloRating.WithLabelValues(model, category).Set(rating) +} + +// RecordFeedback records a feedback event +func RecordFeedback(winner, loser string, isTie bool) { + if ModelFeedbackTotal == nil { + return + } + + tieStr := "false" + if isTie { + tieStr = "true" + } + + if loser == "" { + loser = "none" + } + + ModelFeedbackTotal.WithLabelValues(winner, loser, tieStr).Inc() +} + +// RecordComponentAgreement records the agreement ratio between components +func RecordComponentAgreement(agreementRatio float64) { + if ComponentAgreement == nil { + return + } + ComponentAgreement.WithLabelValues().Observe(agreementRatio) +} diff --git a/src/semantic-router/pkg/selection/router_dc.go b/src/semantic-router/pkg/selection/router_dc.go new file mode 100644 index 0000000000..6eada8b25a --- /dev/null +++ b/src/semantic-router/pkg/selection/router_dc.go @@ -0,0 +1,364 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "context" + "fmt" + "math" + "sync" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" +) + +// RouterDCConfig configures the RouterDC dual-contrastive selector +// Based on arXiv:2409.19886 - Query-Based Router by Dual Contrastive Learning +type RouterDCConfig struct { + // Temperature for softmax scaling (default: 0.07 as per paper) + Temperature float64 `yaml:"temperature"` + + // DimensionSize for embeddings (default: 768) + DimensionSize int `yaml:"dimension_size"` + + // MinSimilarity threshold for valid matches (default: 0.3) + MinSimilarity float64 `yaml:"min_similarity"` + + // UseQueryContrastive enables query-side contrastive learning + UseQueryContrastive bool `yaml:"use_query_contrastive"` + + // UseModelContrastive enables model-side contrastive learning + UseModelContrastive bool `yaml:"use_model_contrastive"` +} + +// DefaultRouterDCConfig returns the default RouterDC configuration +func DefaultRouterDCConfig() *RouterDCConfig { + return &RouterDCConfig{ + Temperature: 0.07, + DimensionSize: 768, + MinSimilarity: 0.3, + UseQueryContrastive: true, + UseModelContrastive: true, + } +} + +// ModelEmbedding represents a model's capability embedding +type ModelEmbedding struct { + Model string `json:"model"` + Embedding []float32 `json:"embedding"` +} + +// RouterDCSelector implements dual-contrastive learning for query-to-model routing +// The approach learns embeddings for both queries and models, then matches them +// using contrastive learning to find the best model for each query type. +type RouterDCSelector struct { + config *RouterDCConfig + + // Model embeddings represent each model's capabilities/strengths + modelEmbeddings map[string][]float32 + embeddingMu sync.RWMutex + + // Query-model affinity matrix for contrastive learning + affinityMatrix map[string]map[string]float64 // query_hash -> model -> affinity + affinityMu sync.RWMutex + + // Embedding provider function (injected dependency) + embeddingFunc func(text string) ([]float32, error) +} + +// NewRouterDCSelector creates a new RouterDC-based selector +func NewRouterDCSelector(cfg *RouterDCConfig) *RouterDCSelector { + if cfg == nil { + cfg = DefaultRouterDCConfig() + } + return &RouterDCSelector{ + config: cfg, + modelEmbeddings: make(map[string][]float32), + affinityMatrix: make(map[string]map[string]float64), + } +} + +// Method returns the selection method type +func (r *RouterDCSelector) Method() SelectionMethod { + return MethodRouterDC +} + +// SetEmbeddingFunc sets the function used to compute embeddings +func (r *RouterDCSelector) SetEmbeddingFunc(f func(text string) ([]float32, error)) { + r.embeddingFunc = f +} + +// InitializeModelEmbeddings sets up model capability embeddings +// Each model is represented by an embedding that captures its strengths +func (r *RouterDCSelector) InitializeModelEmbeddings(modelDescriptions map[string]string) error { + if r.embeddingFunc == nil { + return fmt.Errorf("embedding function not set") + } + + r.embeddingMu.Lock() + defer r.embeddingMu.Unlock() + + for model, description := range modelDescriptions { + embedding, err := r.embeddingFunc(description) + if err != nil { + logging.Warnf("[RouterDC] Failed to embed model %s description: %v", model, err) + continue + } + r.modelEmbeddings[model] = embedding + } + + logging.Infof("[RouterDC] Initialized embeddings for %d models", len(r.modelEmbeddings)) + return nil +} + +// SetModelEmbedding directly sets a model's embedding +func (r *RouterDCSelector) SetModelEmbedding(model string, embedding []float32) { + r.embeddingMu.Lock() + defer r.embeddingMu.Unlock() + r.modelEmbeddings[model] = embedding +} + +// Select chooses the best model using dual-contrastive matching +func (r *RouterDCSelector) Select(ctx context.Context, selCtx *SelectionContext) (*SelectionResult, error) { + if len(selCtx.CandidateModels) == 0 { + return nil, fmt.Errorf("no candidate models provided") + } + + // Get or compute query embedding + queryEmbedding := selCtx.QueryEmbedding + if queryEmbedding == nil { + if r.embeddingFunc == nil { + // Fall back to first candidate if no embedding capability + return r.fallbackSelection(selCtx, "no embedding function available") + } + + var err error + queryEmbedding, err = r.embeddingFunc(selCtx.Query) + if err != nil { + return r.fallbackSelection(selCtx, fmt.Sprintf("embedding error: %v", err)) + } + } + + allScores := make(map[string]float64) + var bestModel *config.ModelRef + var bestScore float64 + + r.embeddingMu.RLock() + defer r.embeddingMu.RUnlock() + + logging.Infof("[RouterDC] Evaluating %d candidates by embedding similarity:", + len(selCtx.CandidateModels)) + + for i := range selCtx.CandidateModels { + model := &selCtx.CandidateModels[i] + modelEmb, exists := r.modelEmbeddings[model.Model] + + if !exists { + // Model has no embedding, assign minimum score + allScores[model.Model] = r.config.MinSimilarity + logging.Infof("[RouterDC] %s: similarity=%.4f (no embedding, using min)", model.Model, r.config.MinSimilarity) + continue + } + + // Compute contrastive similarity + similarity := r.computeContrastiveSimilarity(queryEmbedding, modelEmb) + allScores[model.Model] = similarity + logging.Infof("[RouterDC] %s: similarity=%.4f", model.Model, similarity) + + if similarity > bestScore { + bestScore = similarity + bestModel = model + } + } + + if bestModel == nil || bestScore < r.config.MinSimilarity { + return r.fallbackSelection(selCtx, "no model above similarity threshold") + } + + // Apply softmax with temperature to get calibrated probabilities + softmaxScores := r.applySoftmax(allScores) + + confidence := bestScore // Use raw similarity as confidence + if bestScore > 0.9 { + confidence = 0.95 + } + + reasoning := fmt.Sprintf("Query-model contrastive similarity: %.4f (temperature=%.3f)", + bestScore, r.config.Temperature) + + logging.Infof("[RouterDC] Selected model %s (similarity=%.4f, confidence=%.2f)", + bestModel.Model, bestScore, confidence) + + return &SelectionResult{ + SelectedModel: bestModel.Model, + LoRAName: bestModel.LoRAName, + Score: softmaxScores[bestModel.Model], + Confidence: confidence, + Method: MethodRouterDC, + Reasoning: reasoning, + AllScores: softmaxScores, + }, nil +} + +// UpdateFeedback updates model-query affinity based on feedback +func (r *RouterDCSelector) UpdateFeedback(ctx context.Context, feedback *Feedback) error { + if feedback.WinnerModel == "" { + return fmt.Errorf("winner model is required") + } + + // Create a simple hash for the query to track affinity + queryHash := r.hashQuery(feedback.Query) + + r.affinityMu.Lock() + defer r.affinityMu.Unlock() + + if r.affinityMatrix[queryHash] == nil { + r.affinityMatrix[queryHash] = make(map[string]float64) + } + + // Increase affinity for winner + currentAffinity := r.affinityMatrix[queryHash][feedback.WinnerModel] + r.affinityMatrix[queryHash][feedback.WinnerModel] = currentAffinity + 0.1 + + // Decrease affinity for loser (if present) + if feedback.LoserModel != "" && !feedback.Tie { + loserAffinity := r.affinityMatrix[queryHash][feedback.LoserModel] + r.affinityMatrix[queryHash][feedback.LoserModel] = math.Max(0, loserAffinity-0.05) + } + + logging.Debugf("[RouterDC] Updated affinity for query hash %s: winner=%s (+0.1)", + queryHash[:8], feedback.WinnerModel) + + return nil +} + +// computeContrastiveSimilarity calculates dual-contrastive similarity +func (r *RouterDCSelector) computeContrastiveSimilarity(queryEmb, modelEmb []float32) float64 { + if len(queryEmb) != len(modelEmb) { + // Handle dimension mismatch by using minimum length + minLen := len(queryEmb) + if len(modelEmb) < minLen { + minLen = len(modelEmb) + } + queryEmb = queryEmb[:minLen] + modelEmb = modelEmb[:minLen] + } + + // Compute cosine similarity + similarity := r.cosineSimilarity(queryEmb, modelEmb) + + // Apply temperature scaling for contrastive learning + // Higher temperature = softer distribution + scaledSim := similarity / r.config.Temperature + + // Apply sigmoid to bound the result + return 1.0 / (1.0 + math.Exp(-scaledSim)) +} + +// cosineSimilarity computes cosine similarity between two vectors +func (r *RouterDCSelector) cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float64 + for i := range a { + dotProduct += float64(a[i]) * float64(b[i]) + normA += float64(a[i]) * float64(a[i]) + normB += float64(b[i]) * float64(b[i]) + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +// applySoftmax applies softmax with temperature to scores +func (r *RouterDCSelector) applySoftmax(scores map[string]float64) map[string]float64 { + result := make(map[string]float64) + + // Find max for numerical stability + maxScore := math.Inf(-1) + for _, s := range scores { + if s > maxScore { + maxScore = s + } + } + + // Compute softmax + sum := 0.0 + for model, s := range scores { + exp := math.Exp((s - maxScore) / r.config.Temperature) + result[model] = exp + sum += exp + } + + // Normalize + for model := range result { + result[model] /= sum + } + + return result +} + +// hashQuery creates a simple hash for query tracking +func (r *RouterDCSelector) hashQuery(query string) string { + // Simple hash for query grouping (could use more sophisticated methods) + if len(query) < 32 { + return fmt.Sprintf("%x", query) + } + return fmt.Sprintf("%x", query[:32]) +} + +// fallbackSelection returns a fallback result when embedding-based selection fails +func (r *RouterDCSelector) fallbackSelection(selCtx *SelectionContext, reason string) (*SelectionResult, error) { + if len(selCtx.CandidateModels) == 0 { + return nil, fmt.Errorf("no candidate models") + } + + firstModel := &selCtx.CandidateModels[0] + allScores := make(map[string]float64) + for i := range selCtx.CandidateModels { + allScores[selCtx.CandidateModels[i].Model] = 1.0 / float64(len(selCtx.CandidateModels)) + } + + logging.Warnf("[RouterDC] Fallback selection: %s, using first candidate %s", reason, firstModel.Model) + + return &SelectionResult{ + SelectedModel: firstModel.Model, + LoRAName: firstModel.LoRAName, + Score: allScores[firstModel.Model], + Confidence: 0.5, + Method: MethodRouterDC, + Reasoning: fmt.Sprintf("Fallback selection: %s", reason), + AllScores: allScores, + }, nil +} + +// GetModelEmbeddings returns all model embeddings (for debugging) +func (r *RouterDCSelector) GetModelEmbeddings() map[string][]float32 { + r.embeddingMu.RLock() + defer r.embeddingMu.RUnlock() + + result := make(map[string][]float32) + for k, v := range r.modelEmbeddings { + result[k] = v + } + return result +} diff --git a/src/semantic-router/pkg/selection/selector.go b/src/semantic-router/pkg/selection/selector.go new file mode 100644 index 0000000000..78264d6e2e --- /dev/null +++ b/src/semantic-router/pkg/selection/selector.go @@ -0,0 +1,194 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package selection provides advanced model selection algorithms for intelligent routing. +// It implements multiple selection strategies including Elo rating, RouterDC (dual-contrastive +// learning), AutoMix (POMDP-based), and hybrid approaches that combine multiple techniques. +// +// Reference papers: +// - Elo: RouteLLM (arXiv:2406.18665) - Weighted Elo using Bradley-Terry model +// - RouterDC: Query-Based Router by Dual Contrastive Learning (arXiv:2409.19886) +// - AutoMix: Automatically Mixing Language Models (arXiv:2310.12963) +// - Hybrid LLM: Cost-Efficient Quality-Aware Query Routing (arXiv:2404.14618) +package selection + +import ( + "context" + "sync" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +// SelectionMethod defines the type of model selection algorithm +type SelectionMethod string + +const ( + // MethodElo uses Elo rating system with Bradley-Terry model + // Models are scored based on pairwise comparisons using preference feedback + MethodElo SelectionMethod = "elo" + + // MethodRouterDC uses dual-contrastive learning for query-to-model routing + // Learns query embeddings that match well with specific model capabilities + MethodRouterDC SelectionMethod = "router_dc" + + // MethodAutoMix uses POMDP-based cascaded routing with self-verification + // Routes to smaller models first, escalates based on self-verification confidence + MethodAutoMix SelectionMethod = "automix" + + // MethodHybrid combines multiple selection techniques with configurable weights + // Allows blending Elo, embedding similarity, and cost considerations + MethodHybrid SelectionMethod = "hybrid" + + // MethodStatic uses static scores from configuration (default behavior) + MethodStatic SelectionMethod = "static" +) + +// SelectionContext provides context for model selection decisions +type SelectionContext struct { + // Query is the user's input query text + Query string + + // QueryEmbedding is the precomputed embedding vector for the query (optional) + // If nil, selectors that need embeddings will compute them on demand + QueryEmbedding []float32 + + // ConversationHistory provides prior messages for context-aware selection + ConversationHistory []string + + // DecisionName is the name of the matched decision for category-specific selection + DecisionName string + + // CandidateModels is the list of models to select from + CandidateModels []config.ModelRef + + // CostWeight indicates how much to weight cost in selection (0.0-1.0) + // Higher values prefer cheaper models + CostWeight float64 + + // QualityWeight indicates how much to weight quality/score (0.0-1.0) + // Higher values prefer higher-quality models + QualityWeight float64 +} + +// SelectionResult contains the result of a model selection decision +type SelectionResult struct { + // SelectedModel is the name of the selected model + SelectedModel string + + // LoRAName is the LoRA adapter name to use (if applicable) + LoRAName string + + // Score is the selection score for the chosen model + Score float64 + + // Confidence indicates how confident the selector is in this choice + Confidence float64 + + // Method indicates which selection method was used + Method SelectionMethod + + // Reasoning provides human-readable explanation for the selection + Reasoning string + + // AllScores maps each candidate model to its computed score + AllScores map[string]float64 +} + +// Selector is the interface for model selection algorithms +type Selector interface { + // Select chooses the best model from candidates based on the selection context + Select(ctx context.Context, selCtx *SelectionContext) (*SelectionResult, error) + + // Method returns the selection method type + Method() SelectionMethod + + // UpdateFeedback allows the selector to learn from user feedback + // This is primarily used by Elo and learning-based methods + UpdateFeedback(ctx context.Context, feedback *Feedback) error +} + +// Feedback represents user feedback for model comparison +type Feedback struct { + // Query is the original query that was processed + Query string + + // WinnerModel is the model that was preferred + WinnerModel string + + // LoserModel is the model that was not preferred (can be empty for single feedback) + LoserModel string + + // Tie indicates if both models performed equally + Tie bool + + // DecisionName is the category/decision context + DecisionName string + + // Timestamp is when the feedback was recorded + Timestamp int64 +} + +// Registry maintains available selection methods and their configurations +type Registry struct { + selectors map[SelectionMethod]Selector + mu sync.RWMutex +} + +// NewRegistry creates a new selector registry +func NewRegistry() *Registry { + return &Registry{ + selectors: make(map[SelectionMethod]Selector), + } +} + +// Register adds a selector to the registry +func (r *Registry) Register(method SelectionMethod, selector Selector) { + r.mu.Lock() + defer r.mu.Unlock() + r.selectors[method] = selector +} + +// Get retrieves a selector by method type +func (r *Registry) Get(method SelectionMethod) (Selector, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + s, ok := r.selectors[method] + return s, ok +} + +// GlobalRegistry is the default registry for selection methods +var GlobalRegistry = NewRegistry() + +// Select uses the specified method to select a model +func Select(ctx context.Context, method SelectionMethod, selCtx *SelectionContext) (*SelectionResult, error) { + selector, ok := GlobalRegistry.Get(method) + if !ok { + // Fall back to static selection + selector, _ = GlobalRegistry.Get(MethodStatic) + } + if selector == nil { + // Ultimate fallback: return first candidate + return &SelectionResult{ + SelectedModel: selCtx.CandidateModels[0].Model, + LoRAName: selCtx.CandidateModels[0].LoRAName, + Score: 1.0, + Confidence: 1.0, + Method: MethodStatic, + Reasoning: "No selector available, using first candidate", + }, nil + } + return selector.Select(ctx, selCtx) +} diff --git a/src/semantic-router/pkg/selection/selector_test.go b/src/semantic-router/pkg/selection/selector_test.go new file mode 100644 index 0000000000..c05af4947b --- /dev/null +++ b/src/semantic-router/pkg/selection/selector_test.go @@ -0,0 +1,626 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "context" + "testing" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +// Test helper to create candidate models +func createCandidateModels(names ...string) []config.ModelRef { + models := make([]config.ModelRef, len(names)) + for i, name := range names { + models[i] = config.ModelRef{Model: name} + } + return models +} + +func TestEloSelector_Select(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + candidates []config.ModelRef + setupRatings map[string]float64 + expectedModel string + expectError bool + }{ + { + name: "select highest rated model", + candidates: createCandidateModels("model-a", "model-b", "model-c"), + setupRatings: map[string]float64{"model-a": 1400, "model-b": 1600, "model-c": 1500}, + expectedModel: "model-b", + expectError: false, + }, + { + name: "fallback to default rating", + candidates: createCandidateModels("new-model-1", "new-model-2"), + setupRatings: map[string]float64{}, + expectedModel: "new-model-1", // First model when equal ratings + expectError: false, + }, + { + name: "no candidates", + candidates: []config.ModelRef{}, + setupRatings: map[string]float64{}, + expectedModel: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selector := NewEloSelector(DefaultEloConfig()) + + // Setup ratings + for model, rating := range tt.setupRatings { + selector.setGlobalRating(model, &ModelRating{Model: model, Rating: rating}) + } + + selCtx := &SelectionContext{ + Query: "test query", + CandidateModels: tt.candidates, + } + + result, err := selector.Select(ctx, selCtx) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.SelectedModel != tt.expectedModel { + t.Errorf("expected model %s, got %s", tt.expectedModel, result.SelectedModel) + } + + if result.Method != MethodElo { + t.Errorf("expected method %s, got %s", MethodElo, result.Method) + } + }) + } +} + +func TestEloSelector_UpdateFeedback(t *testing.T) { + ctx := context.Background() + selector := NewEloSelector(DefaultEloConfig()) + + // Initialize ratings + selector.setGlobalRating("model-a", &ModelRating{Model: "model-a", Rating: 1500}) + selector.setGlobalRating("model-b", &ModelRating{Model: "model-b", Rating: 1500}) + + // Submit feedback: model-a wins against model-b + feedback := &Feedback{ + Query: "test query", + WinnerModel: "model-a", + LoserModel: "model-b", + Tie: false, + } + + err := selector.UpdateFeedback(ctx, feedback) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check ratings updated + ratingA := selector.getGlobalRating("model-a") + ratingB := selector.getGlobalRating("model-b") + + if ratingA == nil { + t.Fatal("rating A should not be nil") + return // Explicit return after t.Fatal for staticcheck + } + if ratingB == nil { + t.Fatal("rating B should not be nil") + return // Explicit return after t.Fatal for staticcheck + } + + if ratingA.Rating <= 1500 { + t.Errorf("winner rating should increase, got %f", ratingA.Rating) + } + + if ratingB.Rating >= 1500 { + t.Errorf("loser rating should decrease, got %f", ratingB.Rating) + } + + if ratingA.Wins != 1 { + t.Errorf("winner wins should be 1, got %d", ratingA.Wins) + } + + if ratingB.Losses != 1 { + t.Errorf("loser losses should be 1, got %d", ratingB.Losses) + } +} + +func TestRouterDCSelector_Select(t *testing.T) { + ctx := context.Background() + + selector := NewRouterDCSelector(DefaultRouterDCConfig()) + + // Set up embedding function (mock) + selector.SetEmbeddingFunc(func(text string) ([]float32, error) { + // Return a simple embedding based on text length + embedding := make([]float32, 768) + for i := range embedding { + embedding[i] = float32(len(text)%10) / 10.0 + } + return embedding, nil + }) + + // Set model embeddings + modelAEmb := make([]float32, 768) + modelBEmb := make([]float32, 768) + for i := range modelAEmb { + modelAEmb[i] = 0.5 + modelBEmb[i] = 0.3 + } + selector.SetModelEmbedding("model-a", modelAEmb) + selector.SetModelEmbedding("model-b", modelBEmb) + + selCtx := &SelectionContext{ + Query: "test query", + CandidateModels: createCandidateModels("model-a", "model-b"), + } + + result, err := selector.Select(ctx, selCtx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Method != MethodRouterDC { + t.Errorf("expected method %s, got %s", MethodRouterDC, result.Method) + } + + if result.Score <= 0 { + t.Errorf("expected positive score, got %f", result.Score) + } +} + +func TestAutoMixSelector_Select(t *testing.T) { + ctx := context.Background() + + selector := NewAutoMixSelector(DefaultAutoMixConfig()) + + // Initialize capabilities + modelConfig := map[string]config.ModelParams{ + "small-model": {Pricing: config.ModelPricing{PromptPer1M: 0.5}}, + "large-model": {Pricing: config.ModelPricing{PromptPer1M: 5.0}}, + } + selector.InitializeFromConfig(modelConfig) + + // Set verification probabilities + selector.SetCapability("small-model", &ModelCapability{ + Model: "small-model", + Cost: 0.5, + AvgQuality: 0.7, + VerificationProb: 0.8, + ParamSize: 7.0, + }) + selector.SetCapability("large-model", &ModelCapability{ + Model: "large-model", + Cost: 5.0, + AvgQuality: 0.95, + VerificationProb: 0.95, + ParamSize: 70.0, + }) + + selCtx := &SelectionContext{ + Query: "test query", + CandidateModels: createCandidateModels("small-model", "large-model"), + CostWeight: 0.5, + QualityWeight: 0.5, + } + + result, err := selector.Select(ctx, selCtx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Method != MethodAutoMix { + t.Errorf("expected method %s, got %s", MethodAutoMix, result.Method) + } + + // With cost awareness, cheaper model might be selected + if result.SelectedModel == "" { + t.Error("expected a selected model") + } +} + +func TestHybridSelector_Select(t *testing.T) { + ctx := context.Background() + + cfg := DefaultHybridConfig() + cfg.EloWeight = 0.5 + cfg.RouterDCWeight = 0.0 // Disable RouterDC (no embeddings) + cfg.AutoMixWeight = 0.5 + cfg.CostWeight = 0.0 + + selector := NewHybridSelector(cfg) + + // Initialize Elo component + selector.eloSelector.setGlobalRating("model-a", &ModelRating{Model: "model-a", Rating: 1600}) + selector.eloSelector.setGlobalRating("model-b", &ModelRating{Model: "model-b", Rating: 1400}) + + // Initialize AutoMix component + selector.autoMixSelector.SetCapability("model-a", &ModelCapability{ + Model: "model-a", + AvgQuality: 0.9, + VerificationProb: 0.9, + ParamSize: 70.0, + }) + selector.autoMixSelector.SetCapability("model-b", &ModelCapability{ + Model: "model-b", + AvgQuality: 0.7, + VerificationProb: 0.8, + ParamSize: 7.0, + }) + + selCtx := &SelectionContext{ + Query: "test query", + CandidateModels: createCandidateModels("model-a", "model-b"), + } + + result, err := selector.Select(ctx, selCtx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Method != MethodHybrid { + t.Errorf("expected method %s, got %s", MethodHybrid, result.Method) + } + + // Model-a should win (higher Elo and quality) + if result.SelectedModel != "model-a" { + t.Errorf("expected model-a, got %s", result.SelectedModel) + } +} + +func TestStaticSelector_Select(t *testing.T) { + ctx := context.Background() + + selector := NewStaticSelector(DefaultStaticConfig()) + + // Set up category scores + selector.SetCategoryScore("coding", "code-model", 0.9) + selector.SetCategoryScore("coding", "general-model", 0.5) + + selCtx := &SelectionContext{ + Query: "write python code", + DecisionName: "coding", + CandidateModels: createCandidateModels("code-model", "general-model"), + } + + result, err := selector.Select(ctx, selCtx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Method != MethodStatic { + t.Errorf("expected method %s, got %s", MethodStatic, result.Method) + } + + if result.SelectedModel != "code-model" { + t.Errorf("expected code-model, got %s", result.SelectedModel) + } + + if result.Score != 0.9 { + t.Errorf("expected score 0.9, got %f", result.Score) + } +} + +func TestRegistry(t *testing.T) { + registry := NewRegistry() + + // Register selectors + registry.Register(MethodElo, NewEloSelector(nil)) + registry.Register(MethodStatic, NewStaticSelector(nil)) + + // Get registered selectors + eloSelector, ok := registry.Get(MethodElo) + if !ok || eloSelector == nil { + t.Error("expected Elo selector to be registered") + } + + staticSelector, ok := registry.Get(MethodStatic) + if !ok || staticSelector == nil { + t.Error("expected Static selector to be registered") + } + + // Get unregistered selector + _, ok = registry.Get(MethodRouterDC) + if ok { + t.Error("expected RouterDC to not be registered") + } +} + +func TestFactory_Create(t *testing.T) { + tests := []struct { + name string + method string + expectedMethod SelectionMethod + }{ + { + name: "create elo selector", + method: "elo", + expectedMethod: MethodElo, + }, + { + name: "create router_dc selector", + method: "router_dc", + expectedMethod: MethodRouterDC, + }, + { + name: "create automix selector", + method: "automix", + expectedMethod: MethodAutoMix, + }, + { + name: "create hybrid selector", + method: "hybrid", + expectedMethod: MethodHybrid, + }, + { + name: "create static selector (default)", + method: "static", + expectedMethod: MethodStatic, + }, + { + name: "unknown method defaults to static", + method: "unknown", + expectedMethod: MethodStatic, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &ModelSelectionConfig{Method: tt.method} + factory := NewFactory(cfg) + selector := factory.Create() + + if selector.Method() != tt.expectedMethod { + t.Errorf("expected method %s, got %s", tt.expectedMethod, selector.Method()) + } + }) + } +} + +func TestEloSelector_CategoryRatings(t *testing.T) { + ctx := context.Background() + + cfg := DefaultEloConfig() + cfg.CategoryWeighted = true + selector := NewEloSelector(cfg) + + // Set different ratings for different categories + selector.setCategoryRating("coding", "model-a", &ModelRating{Model: "model-a", Rating: 1700}) + selector.setCategoryRating("coding", "model-b", &ModelRating{Model: "model-b", Rating: 1300}) + selector.setCategoryRating("writing", "model-a", &ModelRating{Model: "model-a", Rating: 1300}) + selector.setCategoryRating("writing", "model-b", &ModelRating{Model: "model-b", Rating: 1700}) + + candidates := createCandidateModels("model-a", "model-b") + + // Test coding category + codingCtx := &SelectionContext{ + Query: "write code", + DecisionName: "coding", + CandidateModels: candidates, + } + result, err := selector.Select(ctx, codingCtx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.SelectedModel != "model-a" { + t.Errorf("expected model-a for coding, got %s", result.SelectedModel) + } + + // Test writing category + writingCtx := &SelectionContext{ + Query: "write essay", + DecisionName: "writing", + CandidateModels: candidates, + } + result, err = selector.Select(ctx, writingCtx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.SelectedModel != "model-b" { + t.Errorf("expected model-b for writing, got %s", result.SelectedModel) + } +} + +func TestEloSelector_GetLeaderboard(t *testing.T) { + selector := NewEloSelector(DefaultEloConfig()) + + // Set up ratings + selector.setGlobalRating("model-c", &ModelRating{Model: "model-c", Rating: 1400}) + selector.setGlobalRating("model-a", &ModelRating{Model: "model-a", Rating: 1600}) + selector.setGlobalRating("model-b", &ModelRating{Model: "model-b", Rating: 1500}) + + leaderboard := selector.GetLeaderboard("") + + if len(leaderboard) != 3 { + t.Errorf("expected 3 models in leaderboard, got %d", len(leaderboard)) + } + + // Should be sorted by rating descending + if leaderboard[0].Model != "model-a" { + t.Errorf("expected model-a first, got %s", leaderboard[0].Model) + } + if leaderboard[1].Model != "model-b" { + t.Errorf("expected model-b second, got %s", leaderboard[1].Model) + } + if leaderboard[2].Model != "model-c" { + t.Errorf("expected model-c third, got %s", leaderboard[2].Model) + } +} + +// TestEloSelector_MultiTurnEvolution tests that Elo ratings evolve correctly +// over multiple feedback rounds, demonstrating convergence and ranking stability. +func TestEloSelector_MultiTurnEvolution(t *testing.T) { + ctx := context.Background() + selector := NewEloSelector(DefaultEloConfig()) + + // Initialize three models with same starting rating + models := []string{"weak-model", "medium-model", "strong-model"} + for _, m := range models { + selector.setGlobalRating(m, &ModelRating{Model: m, Rating: DefaultEloRating}) + } + + // Simulate 10 rounds of feedback where strong > medium > weak + for round := 0; round < 10; round++ { + // Strong beats medium + _ = selector.UpdateFeedback(ctx, &Feedback{ + Query: "test", + WinnerModel: "strong-model", + LoserModel: "medium-model", + }) + + // Medium beats weak + _ = selector.UpdateFeedback(ctx, &Feedback{ + Query: "test", + WinnerModel: "medium-model", + LoserModel: "weak-model", + }) + + // Strong beats weak + _ = selector.UpdateFeedback(ctx, &Feedback{ + Query: "test", + WinnerModel: "strong-model", + LoserModel: "weak-model", + }) + } + + // Verify final rankings + strongRating := selector.getGlobalRating("strong-model") + mediumRating := selector.getGlobalRating("medium-model") + weakRating := selector.getGlobalRating("weak-model") + + if strongRating == nil || mediumRating == nil || weakRating == nil { + t.Fatal("ratings should not be nil") + return + } + + // Strong should have highest rating + if strongRating.Rating <= mediumRating.Rating { + t.Errorf("strong (%f) should beat medium (%f)", strongRating.Rating, mediumRating.Rating) + } + + // Medium should beat weak + if mediumRating.Rating <= weakRating.Rating { + t.Errorf("medium (%f) should beat weak (%f)", mediumRating.Rating, weakRating.Rating) + } + + // Win/loss records should reflect the matches + if strongRating.Wins != 20 { // 10 vs medium + 10 vs weak + t.Errorf("strong should have 20 wins, got %d", strongRating.Wins) + } + if weakRating.Losses != 20 { // 10 vs medium + 10 vs strong + t.Errorf("weak should have 20 losses, got %d", weakRating.Losses) + } + + // Verify leaderboard order + leaderboard := selector.GetLeaderboard("") + if len(leaderboard) < 3 { + t.Fatalf("expected at least 3 models, got %d", len(leaderboard)) + } + if leaderboard[0].Model != "strong-model" { + t.Errorf("strong-model should be first, got %s", leaderboard[0].Model) + } + if leaderboard[1].Model != "medium-model" { + t.Errorf("medium-model should be second, got %s", leaderboard[1].Model) + } + if leaderboard[2].Model != "weak-model" { + t.Errorf("weak-model should be third, got %s", leaderboard[2].Model) + } +} + +// TestEloSelector_TieHandling tests that ties are handled correctly +func TestEloSelector_TieHandling(t *testing.T) { + ctx := context.Background() + selector := NewEloSelector(DefaultEloConfig()) + + selector.setGlobalRating("model-a", &ModelRating{Model: "model-a", Rating: 1500}) + selector.setGlobalRating("model-b", &ModelRating{Model: "model-b", Rating: 1500}) + + // Submit a tie + err := selector.UpdateFeedback(ctx, &Feedback{ + Query: "test", + WinnerModel: "model-a", + LoserModel: "model-b", + Tie: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ratingA := selector.getGlobalRating("model-a") + ratingB := selector.getGlobalRating("model-b") + + if ratingA == nil || ratingB == nil { + t.Fatal("ratings should not be nil") + return + } + + // Both should have a tie recorded + if ratingA.Ties != 1 { + t.Errorf("model-a should have 1 tie, got %d", ratingA.Ties) + } + if ratingB.Ties != 1 { + t.Errorf("model-b should have 1 tie, got %d", ratingB.Ties) + } + + // Ratings should remain close (tie moves both toward each other) + ratingDiff := ratingA.Rating - ratingB.Rating + if ratingDiff < -1 || ratingDiff > 1 { + t.Errorf("ratings should be nearly equal after tie, got diff %f", ratingDiff) + } +} + +// TestEloSelector_SelectionFollowsRatings verifies that Select() respects Elo ratings +func TestEloSelector_SelectionFollowsRatings(t *testing.T) { + ctx := context.Background() + selector := NewEloSelector(DefaultEloConfig()) + + // Set up ratings with clear winner + selector.setGlobalRating("low-rated", &ModelRating{Model: "low-rated", Rating: 1300}) + selector.setGlobalRating("high-rated", &ModelRating{Model: "high-rated", Rating: 1700}) + + selCtx := &SelectionContext{ + Query: "test query", + DecisionName: "test", + CandidateModels: createCandidateModels("low-rated", "high-rated"), + } + + result, err := selector.Select(ctx, selCtx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // High-rated should be selected + if result.SelectedModel != "high-rated" { + t.Errorf("expected high-rated, got %s", result.SelectedModel) + } +} diff --git a/src/semantic-router/pkg/selection/static.go b/src/semantic-router/pkg/selection/static.go new file mode 100644 index 0000000000..fdb36e7241 --- /dev/null +++ b/src/semantic-router/pkg/selection/static.go @@ -0,0 +1,171 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package selection + +import ( + "context" + "fmt" + "sync" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" +) + +// StaticConfig configures the static selector +type StaticConfig struct { + // UseFirstCandidate always selects the first candidate (default behavior) + UseFirstCandidate bool `yaml:"use_first_candidate"` + + // CategoryScores maps category -> model -> score (from config) + CategoryScores map[string]map[string]float64 `yaml:"-"` +} + +// DefaultStaticConfig returns the default Static configuration +func DefaultStaticConfig() *StaticConfig { + return &StaticConfig{ + UseFirstCandidate: true, + CategoryScores: make(map[string]map[string]float64), + } +} + +// StaticSelector implements static model selection based on configuration scores +// This is the baseline selector that uses pre-configured scores without learning. +type StaticSelector struct { + config *StaticConfig + + // Category-specific scores from configuration + categoryScores map[string]map[string]float64 + scoresMu sync.RWMutex +} + +// NewStaticSelector creates a new Static selector +func NewStaticSelector(cfg *StaticConfig) *StaticSelector { + if cfg == nil { + cfg = DefaultStaticConfig() + } + return &StaticSelector{ + config: cfg, + categoryScores: make(map[string]map[string]float64), + } +} + +// Method returns the selection method type +func (s *StaticSelector) Method() SelectionMethod { + return MethodStatic +} + +// InitializeFromConfig sets up static scores from model configuration +func (s *StaticSelector) InitializeFromConfig(categories []config.Category) { + s.scoresMu.Lock() + defer s.scoresMu.Unlock() + + for _, category := range categories { + if s.categoryScores[category.Name] == nil { + s.categoryScores[category.Name] = make(map[string]float64) + } + for _, ms := range category.ModelScores { + s.categoryScores[category.Name][ms.Model] = ms.Score + } + } + + logging.Infof("[StaticSelector] Initialized scores for %d categories", len(s.categoryScores)) +} + +// SetCategoryScore sets a static score for a model in a category +func (s *StaticSelector) SetCategoryScore(category, model string, score float64) { + s.scoresMu.Lock() + defer s.scoresMu.Unlock() + + if s.categoryScores[category] == nil { + s.categoryScores[category] = make(map[string]float64) + } + s.categoryScores[category][model] = score +} + +// Select chooses the best model based on static configuration scores +func (s *StaticSelector) Select(ctx context.Context, selCtx *SelectionContext) (*SelectionResult, error) { + if len(selCtx.CandidateModels) == 0 { + return nil, fmt.Errorf("no candidate models provided") + } + + allScores := make(map[string]float64) + var bestModel *config.ModelRef + var bestScore float64 + + s.scoresMu.RLock() + categoryScores := s.categoryScores[selCtx.DecisionName] + s.scoresMu.RUnlock() + + for i := range selCtx.CandidateModels { + model := &selCtx.CandidateModels[i] + + // Get static score if available + score := 1.0 // Default score + if categoryScores != nil { + if cs, ok := categoryScores[model.Model]; ok { + score = cs + } + } + + allScores[model.Model] = score + + if score > bestScore || bestModel == nil { + bestScore = score + bestModel = model + } + } + + // If no scores found and useFirstCandidate is true, use first + if bestModel == nil || (s.config.UseFirstCandidate && bestScore == 1.0) { + bestModel = &selCtx.CandidateModels[0] + bestScore = allScores[bestModel.Model] + } + + reasoning := "Static selection from configuration" + if selCtx.DecisionName != "" { + reasoning = fmt.Sprintf("Static selection for category '%s'", selCtx.DecisionName) + } + + logging.Infof("[StaticSelector] Candidates: %v → Selected: %s (using first/highest)", + getModelNames(selCtx.CandidateModels), bestModel.Model) + + return &SelectionResult{ + SelectedModel: bestModel.Model, + LoRAName: bestModel.LoRAName, + Score: bestScore, + Confidence: 1.0, // Static selection is always "confident" + Method: MethodStatic, + Reasoning: reasoning, + AllScores: allScores, + }, nil +} + +// UpdateFeedback does nothing for static selector (no learning) +func (s *StaticSelector) UpdateFeedback(ctx context.Context, feedback *Feedback) error { + // Static selector doesn't learn from feedback + logging.Debugf("[StaticSelector] Ignoring feedback (static selector does not learn)") + return nil +} + +// getModelNames extracts model names from ModelRef slice +func getModelNames(models []config.ModelRef) []string { + names := make([]string, len(models)) + for i, m := range models { + names[i] = m.Model + } + return names +}