diff --git a/Makefile b/Makefile index edbd8ec569..19e722ee22 100644 --- a/Makefile +++ b/Makefile @@ -161,6 +161,25 @@ test-unit-%: download-tokenizer install-python-deps check-dependencies ## Run un PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \ CGO_CFLAGS=${$*_CGO_CFLAGS} CGO_LDFLAGS=${$*_CGO_LDFLAGS} go test $($*_LDFLAGS) -v $$($($*_TEST_FILES) | tr '\n' ' ') +.PHONY: test-filter +test-filter: download-tokenizer install-python-deps check-dependencies ## Run filtered unit tests (usage: make test-filter PATTERN=TestName TYPE=epp) + @if [ -z "$(PATTERN)" ]; then \ + echo "ERROR: PATTERN is required. Usage: make test-filter PATTERN=TestName [TYPE=epp|sidecar]"; \ + exit 1; \ + fi + @TEST_TYPE="$(if $(TYPE),$(TYPE),epp)"; \ + printf "\033[33;1m==== Running Filtered Tests (pattern: $(PATTERN), type: $$TEST_TYPE) ====\033[0m\n"; \ + KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache-manager 2>/dev/null || echo ""); \ + if [ "$$TEST_TYPE" = "epp" ]; then \ + PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \ + CGO_CFLAGS=$(epp_CGO_CFLAGS) CGO_LDFLAGS=$(epp_CGO_LDFLAGS) \ + go test $(epp_LDFLAGS) -v -run "$(PATTERN)" $$($(epp_TEST_FILES) | tr '\n' ' '); \ + else \ + PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \ + CGO_CFLAGS=$(sidecar_CGO_CFLAGS) CGO_LDFLAGS=$(sidecar_CGO_LDFLAGS) \ + go test $(sidecar_LDFLAGS) -v -run "$(PATTERN)" $$($(sidecar_TEST_FILES) | tr '\n' ' '); \ + fi + .PHONY: test-integration test-integration: download-tokenizer check-dependencies ## Run integration tests @printf "\033[33;1m==== Running Integration Tests ====\033[0m\n" diff --git a/pkg/plugins/scorer/active_request.go b/pkg/plugins/scorer/active_request.go index 14e2e41692..84947c9921 100644 --- a/pkg/plugins/scorer/active_request.go +++ b/pkg/plugins/scorer/active_request.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "sync" "time" @@ -38,13 +39,13 @@ type ActiveRequestParameters struct { // requestEntry represents a single request in the cache type requestEntry struct { - PodName string + PodNames []string RequestID string } // String returns a string representation of the request entry. -func (r *requestEntry) String() string { - return fmt.Sprintf("%s.%s", r.PodName, r.RequestID) +func (r requestEntry) String() string { + return fmt.Sprintf("%s:%s", r.RequestID, strings.Join(r.PodNames, ".")) } // compile-time type assertion @@ -97,7 +98,9 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act requestCache.OnEviction(func(_ context.Context, reason ttlcache.EvictionReason, item *ttlcache.Item[string, *requestEntry]) { if reason == ttlcache.EvictionReasonExpired { - scorer.decrementPodCount(item.Value().PodName) + for _, podName := range item.Value().PodNames { + scorer.decrementPodCount(podName) + } } }) @@ -166,47 +169,61 @@ func (s *ActiveRequest) Score(ctx context.Context, _ *types.CycleState, _ *types // PreRequest is called before a request is sent to the target pod. // It creates a new request entry in the cache with its own TTL and // increments the pod count for fast lookup. -func (s *ActiveRequest) PreRequest(ctx context.Context, request *types.LLMRequest, - schedulingResult *types.SchedulingResult) { +func (s *ActiveRequest) PreRequest( + ctx context.Context, + request *types.LLMRequest, + schedulingResult *types.SchedulingResult, +) { debugLogger := log.FromContext(ctx).V(logutil.DEBUG) - for _, profileResult := range schedulingResult.ProfileResults { // schedulingResult guaranteed not to be nil - if profileResult == nil || profileResult.TargetPods == nil || len(profileResult.TargetPods) == 0 { + podNames := make([]string, 0, len(schedulingResult.ProfileResults)) + for profileName, profileResult := range schedulingResult.ProfileResults { + if profileResult == nil || len(profileResult.TargetPods) == 0 { continue } - // create request entry for first pod only. TODO: support fallback pods - entry := &requestEntry{ - PodName: profileResult.TargetPods[0].GetPod().NamespacedName.String(), - RequestID: request.RequestId, - } - - // add to request cache with TTL - s.requestCache.Set(entry.String(), entry, 0) // Use default TTL - s.incrementPodCount(entry.PodName) - - debugLogger.Info("Added request to cache", "requestEntry", entry.String()) + podName := profileResult.TargetPods[0].GetPod().NamespacedName.String() + podNames = append(podNames, podName) + s.incrementPodCount(podName) + debugLogger.Info( + "Added request to cache", + "requestId", request.RequestId, + "podName", podName, + "profileName", profileName, + ) } + + // add to request cache + s.requestCache.Set(request.RequestId, &requestEntry{PodNames: podNames, RequestID: request.RequestId}, 0) // Use default TTL } // ResponseComplete is called after a response is sent to the client. // It removes the specific request entry from the cache and decrements // the pod count. -func (s *ActiveRequest) ResponseComplete(ctx context.Context, request *types.LLMRequest, - _ *requestcontrol.Response, targetPod *backend.Pod) { +func (s *ActiveRequest) ResponseComplete( + ctx context.Context, + request *types.LLMRequest, + _ *requestcontrol.Response, + targetPod *backend.Pod, +) { debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequest.ResponseComplete") if targetPod == nil { debugLogger.Info("Skipping ResponseComplete because targetPod is nil") return } - entry := requestEntry{targetPod.NamespacedName.String(), request.RequestId} - - if _, found := s.requestCache.GetAndDelete(entry.String()); found { - s.decrementPodCount(entry.PodName) - debugLogger.Info("Removed request from cache", "requestEntry", entry.String()) + if item, found := s.requestCache.GetAndDelete(request.RequestId); found { + entry := item.Value() + if entry != nil { + for _, podName := range entry.PodNames { + s.decrementPodCount(podName) + } + debugLogger.Info("Removed request from cache", "requestEntry", entry.String()) + } else { + debugLogger.Info("Request entry value is nil", "requestId", request.RequestId) + } } else { - debugLogger.Info("Request not found in cache", "requestEntry", entry.String()) + debugLogger.Info("Request not found in cache", "requestId", request.RequestId) } } diff --git a/pkg/plugins/scorer/active_request_test.go b/pkg/plugins/scorer/active_request_test.go index e7215ce1aa..1009aaa590 100644 --- a/pkg/plugins/scorer/active_request_test.go +++ b/pkg/plugins/scorer/active_request_test.go @@ -4,7 +4,8 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -14,25 +15,52 @@ import ( "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) -func TestActiveRequestScorer_Score(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, +// Test helper functions + +func newTestPod(name string, queueSize int) *types.PodMetrics { + return &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: name, Namespace: "default"}}, MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 2, + WaitingQueueSize: queueSize, }, } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 0, - }, +} + +func newTestRequest(id string) *types.LLMRequest { + return &types.LLMRequest{ + RequestId: id, } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 15, - }, +} + +func newTestSchedulingResult(profilePods map[string]types.Pod) *types.SchedulingResult { + profileResults := make(map[string]*types.ProfileRunResult) + for profile, pod := range profilePods { + profileResults[profile] = &types.ProfileRunResult{ + TargetPods: []types.Pod{pod}, + } + } + return &types.SchedulingResult{ + ProfileResults: profileResults, } +} + +func (s *ActiveRequest) getPodCount(podName string) int { + s.mutex.RLock() + defer s.mutex.RUnlock() + return s.podCounts[podName] +} + +func (s *ActiveRequest) hasPodCount(podName string) bool { + s.mutex.RLock() + defer s.mutex.RUnlock() + _, exists := s.podCounts[podName] + return exists +} + +func TestActiveRequestScorer_Score(t *testing.T) { + podA := newTestPod("pod-a", 2) + podB := newTestPod("pod-b", 0) + podC := newTestPod("pod-c", 15) tests := []struct { name string @@ -95,9 +123,7 @@ func TestActiveRequestScorer_Score(t *testing.T) { got := scorer.Score(ctx, nil, nil, test.input) - if diff := cmp.Diff(test.wantScores, got); diff != "" { - t.Errorf("Unexpected output (-want +got): %v", diff) - } + assert.Equal(t, test.wantScores, got) }) } } @@ -106,124 +132,57 @@ func TestActiveRequestScorer_PreRequest(t *testing.T) { ctx := utils.NewTestContext(t) scorer := NewActiveRequest(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 2, - }, - } + podA := newTestPod("pod-a", 2) + podB := newTestPod("pod-b", 0) - request := &types.LLMRequest{ - RequestId: "test-request-1", - } - - schedulingResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } + testProfile := "test-profile" - // First request - scorer.PreRequest(ctx, request, schedulingResult) + t.Run("First request", func(t *testing.T) { + request := newTestRequest("test-request-1") + schedulingResult := newTestSchedulingResult(map[string]types.Pod{ + testProfile: podA, + }) - // Check cache and pod counts - compositeKey := "default/pod-a.test-request-1" - if !scorer.requestCache.Has(compositeKey) { - t.Errorf("Expected request to be in cache with key %s", compositeKey) - } + scorer.PreRequest(ctx, request, schedulingResult) - scorer.mutex.RLock() - count := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if count != 1 { - t.Errorf("Expected pod-a count to be 1, got %d", count) - } + assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache") + assert.Equal(t, 1, scorer.getPodCount(podA.GetPod().NamespacedName.String())) + }) - // Second request with different ID to same pod - request2 := &types.LLMRequest{ - RequestId: "test-request-2", - } - schedulingResult2 := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } - - scorer.PreRequest(ctx, request2, schedulingResult2) + t.Run("Second request to multiple pods", func(t *testing.T) { + request := newTestRequest("test-request-2") + schedulingResult := newTestSchedulingResult(map[string]types.Pod{ + testProfile: podA, + "prefill": podB, + }) - // Check incremented count - scorer.mutex.RLock() - count = scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if count != 2 { - t.Errorf("Expected pod-a count to be 2, got %d", count) - } + scorer.PreRequest(ctx, request, schedulingResult) - // Check both requests are in cache - compositeKey2 := "default/pod-a.test-request-2" - if !scorer.requestCache.Has(compositeKey2) { - t.Errorf("Expected second request to be in cache with key %s", compositeKey2) - } + assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache") + assert.Equal(t, 2, scorer.getPodCount(podA.GetPod().NamespacedName.String())) + assert.Equal(t, 1, scorer.getPodCount(podB.GetPod().NamespacedName.String())) + }) } func TestActiveRequestScorer_ResponseComplete(t *testing.T) { ctx := utils.NewTestContext(t) - scorer := NewActiveRequest(ctx, nil) - request := &types.LLMRequest{ - RequestId: "test-request-1", - } + podA := newTestPod("pod-a", 2) + request := newTestRequest("test-request-1") - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 2, - }, - } // Setup initial state: add request through PreRequest - schedulingResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } - + schedulingResult := newTestSchedulingResult(map[string]types.Pod{ + "test-profile": podA, + }) scorer.PreRequest(ctx, request, schedulingResult) - // Verify initial state - compositeKey := "default/pod-a.test-request-1" - if !scorer.requestCache.Has(compositeKey) { - t.Fatal("Request should be in cache before ResponseComplete") - } - - scorer.mutex.RLock() - initialCount := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if initialCount != 1 { - t.Fatalf("Expected initial count to be 1, got %d", initialCount) - } - - // Call PostResponse + // Call ResponseComplete scorer.ResponseComplete(ctx, request, &requestcontrol.Response{}, podA.GetPod()) - // Check request is removed from cache - if scorer.requestCache.Has(compositeKey) { - t.Errorf("Request should be removed from cache after ResponseComplete") - } - - // Check pod count is decremented and removed (since it was 1) - scorer.mutex.RLock() - _, exists := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if exists { - t.Errorf("Pod should be removed from podCounts when count reaches 0") - } + assert.False(t, scorer.requestCache.Has(request.RequestId)) + assert.False(t, scorer.hasPodCount(podA.GetPod().NamespacedName.String()), + "Pod count should be removed after decrement to zero") } func TestActiveRequestScorer_TTLExpiration(t *testing.T) { @@ -231,34 +190,19 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { // Use very short timeout for test params := &ActiveRequestParameters{RequestTimeout: "1s"} - scorer := NewActiveRequest(ctx, params) // 1 second timeout - - request := &types.LLMRequest{ - RequestId: "test-request-ttl", - } - - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - } + scorer := NewActiveRequest(ctx, params) - schedulingResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } + podA := newTestPod("pod-a", 0) + request := newTestRequest("test-request-ttl") + schedulingResult := newTestSchedulingResult(map[string]types.Pod{ + "test-profile": podA, + }) // Add request scorer.PreRequest(ctx, request, schedulingResult) // Verify request is added - scorer.mutex.RLock() - initialCount := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if initialCount != 1 { - t.Fatalf("Expected initial count to be 1, got %d", initialCount) - } + require.Equal(t, 1, scorer.getPodCount("default/pod-a"), "Expected initial count to be 1") // Wait for TTL expiration time.Sleep(2 * time.Second) @@ -267,12 +211,8 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { scorer.requestCache.DeleteExpired() // Check that pod count is decremented due to TTL expiration - scorer.mutex.RLock() - _, exists := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if exists { - t.Errorf("Pod should be removed from podCounts after TTL expiration") - } + assert.False(t, scorer.hasPodCount("default/pod-a"), + "Pod should be removed from podCounts after TTL expiration") } func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { @@ -282,9 +222,7 @@ func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { scorer := NewActiveRequest(ctx, params) // Should use default timeout when invalid value is provided - if scorer == nil { - t.Error("Expected scorer to be created even with invalid timeout") - } + assert.NotNil(t, scorer, "Expected scorer to be created even with invalid timeout") } func TestActiveRequestScorer_TypedName(t *testing.T) { @@ -292,10 +230,7 @@ func TestActiveRequestScorer_TypedName(t *testing.T) { scorer := NewActiveRequest(ctx, nil) - typedName := scorer.TypedName() - if typedName.Type != ActiveRequestType { - t.Errorf("Expected type %s, got %s", ActiveRequestType, typedName.Type) - } + assert.Equal(t, ActiveRequestType, scorer.TypedName().Type) } func TestActiveRequestScorer_WithName(t *testing.T) { @@ -306,7 +241,5 @@ func TestActiveRequestScorer_WithName(t *testing.T) { scorer = scorer.WithName(testName) - if scorer.TypedName().Name != testName { - t.Errorf("Expected name %s, got %s", testName, scorer.TypedName().Name) - } + assert.Equal(t, testName, scorer.TypedName().Name) }