diff --git a/pkg/plugins/scorer/no_hit_lru.go b/pkg/plugins/scorer/no_hit_lru.go index 417cf05a56..a367caa349 100644 --- a/pkg/plugins/scorer/no_hit_lru.go +++ b/pkg/plugins/scorer/no_hit_lru.go @@ -21,6 +21,12 @@ const ( // defaultLRUSize is the maximum number of pods we'll consider in the cache defaultLRUSize = 1024 + + // defaultPrefillProfile is the name of the prefill profile + // + // This is currently hardcoded until we have a defined proper config interface. + // (See also https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/2104/ ) + defaultPrefillProfile = "prefill" ) // compile-time type assertions @@ -286,19 +292,23 @@ func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, sc return } - // Get the primary profile's target pod - primaryProfile := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] - if primaryProfile == nil || len(primaryProfile.TargetPods) == 0 { - logger.Info("No target pod in primary profile") - return + if targetProfile, ok := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]; ok && targetProfile != nil && len(targetProfile.TargetPods) != 0 { + s.moveTargetPodToFront(ctx, request, targetProfile, schedulingResult.PrimaryProfileName) + } + if targetProfile, ok := schedulingResult.ProfileResults[defaultPrefillProfile]; ok && targetProfile != nil && len(targetProfile.TargetPods) != 0 { + s.moveTargetPodToFront(ctx, request, targetProfile, defaultPrefillProfile) } +} + +func (s *NoHitLRU) moveTargetPodToFront(ctx context.Context, request *types.LLMRequest, targetProfile *types.ProfileRunResult, profileName string) { + logger := log.FromContext(ctx).V(logutil.DEBUG) - targetPod := primaryProfile.TargetPods[0] + targetPod := targetProfile.TargetPods[0] podName := targetPod.GetPod().NamespacedName.String() // Move the pod to the front of the LRU. var present struct{} // dummy value s.lruCache.Add(podName, present) - logger.Info("Updated LRU cache for cold request", "pod", podName, "requestId", request.RequestId) + logger.Info("Updated LRU cache for cold request", "profile", profileName, "pod", podName, "requestId", request.RequestId) } diff --git a/pkg/plugins/scorer/no_hit_lru_test.go b/pkg/plugins/scorer/no_hit_lru_test.go index 6890c998cb..a03dd3aae2 100644 --- a/pkg/plugins/scorer/no_hit_lru_test.go +++ b/pkg/plugins/scorer/no_hit_lru_test.go @@ -455,3 +455,120 @@ func TestNoHitLRUEdgeCases(t *testing.T) { } }) } + +func TestNoHitLRUPrefillDecodeTracking(t *testing.T) { + // Prefill worker pods + prefillPodA := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "prefill-a", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{}, + } + prefillPodB := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "prefill-b", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{}, + } + + // Decode worker pods + decodePodA := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "decode-a", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{}, + } + decodePodB := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "decode-b", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{}, + } + + prefillPods := []types.Pod{prefillPodA, prefillPodB} + decodePods := []types.Pod{decodePodA, decodePodB} + + coldPrefixState := &types.CycleState{} + coldPrefixState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{ + PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request + }) + + ctx := context.Background() + + t.Run("P/D scenario - both profiles tracked separately", func(t *testing.T) { + scorer := scorer.NewNoHitLRU(ctx, nil) + + // First cold request with P/D + req1 := &types.LLMRequest{RequestId: "pd-request-1"} + scorer.Score(ctx, coldPrefixState, req1, append(prefillPods, decodePods...)) + + // Simulate scheduling result with both prefill and decode profiles + pdResult := &types.SchedulingResult{ + PrimaryProfileName: "decode", + ProfileResults: map[string]*types.ProfileRunResult{ + "prefill": { + TargetPods: []types.Pod{prefillPodA}, + }, + "decode": { + TargetPods: []types.Pod{decodePodA}, + }, + }, + } + scorer.PreRequest(ctx, req1, pdResult) + + // Second cold request - both prefillPodB and decodePodB should score higher + // since prefillPodA and decodePodA were just used + req2 := &types.LLMRequest{RequestId: "pd-request-2"} + prefillScores := scorer.Score(ctx, coldPrefixState, req2, prefillPods) + decodeScores := scorer.Score(ctx, coldPrefixState, req2, decodePods) + + if prefillScores[prefillPodB] <= prefillScores[prefillPodA] { + t.Errorf("Expected prefill-b to score higher than prefill-a after prefill-a was used: %+v", prefillScores) + } + + if decodeScores[decodePodB] <= decodeScores[decodePodA] { + t.Errorf("Expected decode-b to score higher than decode-a after decode-a was used: %+v", decodeScores) + } + }) + + t.Run("non-P/D scenario - only primary profile exists", func(t *testing.T) { + req := &types.LLMRequest{RequestId: "non-pd-request"} + scorer := scorer.NewNoHitLRU(ctx, nil) + scorer.Score(ctx, coldPrefixState, req, decodePods) + + // Scheduling result with only decode profile (no prefill) + result := &types.SchedulingResult{ + PrimaryProfileName: "decode", + ProfileResults: map[string]*types.ProfileRunResult{ + "decode": { + TargetPods: []types.Pod{decodePodA}, + }, + // No "prefill" profile in results + }, + } + // Should not panic when prefill profile doesn't exist + scorer.PreRequest(ctx, req, result) + + // Verify decodePodA was tracked + req2 := &types.LLMRequest{RequestId: "non-pd-request-2"} + scores := scorer.Score(ctx, coldPrefixState, req2, decodePods) + + if scores[decodePodB] <= scores[decodePodA] { + t.Errorf("Expected decode-b to score higher than decode-a: %+v", scores) + } + }) + + t.Run("nil scheduling result - graceful handling", func(_ *testing.T) { + req := &types.LLMRequest{RequestId: "nil-result"} + scorer := scorer.NewNoHitLRU(ctx, nil) + scorer.Score(ctx, coldPrefixState, req, decodePods) + + // Should not panic with nil result + scorer.PreRequest(ctx, req, nil) + }) + + t.Run("empty profile results - graceful handling", func(_ *testing.T) { + req := &types.LLMRequest{RequestId: "empty-results"} + scorer := scorer.NewNoHitLRU(ctx, nil) + scorer.Score(ctx, coldPrefixState, req, decodePods) + + result := &types.SchedulingResult{ + PrimaryProfileName: "decode", + ProfileResults: map[string]*types.ProfileRunResult{}, + } + // Should not panic with empty profile results + scorer.PreRequest(ctx, req, result) + }) +}