diff --git a/pkg/plugins/filter/by_label_selector_test.go b/pkg/plugins/filter/by_label_selector_test.go index 8ba43fafff..3dc16e9d06 100644 --- a/pkg/plugins/filter/by_label_selector_test.go +++ b/pkg/plugins/filter/by_label_selector_test.go @@ -14,6 +14,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestByLabelSelectorFactoryWithJSON(t *testing.T) { @@ -298,7 +299,8 @@ func TestByLabelSelectorFiltering(t *testing.T) { blf, ok := plugin.(*filter.ByLabelSelector) require.True(t, ok, "plugin should be of type *ByLabelSelector") - ctx := context.Background() + ctx := utils.NewTestContext(t) + filteredPods := blf.Filter(ctx, nil, nil, pods) var actualPodNames []string @@ -322,7 +324,7 @@ func TestByLabelSelectorFilterEdgeCases(t *testing.T) { blf, ok := plugin.(*filter.ByLabelSelector) require.True(t, ok) - ctx := context.Background() + ctx := utils.NewTestContext(t) t.Run("empty pods slice", func(t *testing.T) { result := blf.Filter(ctx, nil, nil, []types.Pod{}) diff --git a/pkg/plugins/filter/by_label_test.go b/pkg/plugins/filter/by_label_test.go index fe0ba9ecfa..a3af75a4b2 100644 --- a/pkg/plugins/filter/by_label_test.go +++ b/pkg/plugins/filter/by_label_test.go @@ -1,7 +1,6 @@ package filter import ( - "context" "encoding/json" "fmt" "testing" @@ -12,6 +11,8 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestByLabelFactory(t *testing.T) { @@ -244,7 +245,8 @@ func TestByLabelFiltering(t *testing.T) { blf, ok := plugin.(*ByLabel) require.True(t, ok, "plugin should be of type *ByLabel") - ctx := context.Background() + ctx := utils.NewTestContext(t) + filteredPods := blf.Filter(ctx, nil, nil, pods) var actualPodNames []string diff --git a/pkg/plugins/profile/pd_profile_handler_test.go b/pkg/plugins/profile/pd_profile_handler_test.go index fc78b62121..9d3d3b3285 100644 --- a/pkg/plugins/profile/pd_profile_handler_test.go +++ b/pkg/plugins/profile/pd_profile_handler_test.go @@ -16,6 +16,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestPdProfileHandlerFactory(t *testing.T) { @@ -214,7 +215,7 @@ func newMockSchedulerProfile() *framework.SchedulerProfile { } func TestPdProfileHandler_Pick(t *testing.T) { - ctx := context.Background() + ctx := utils.NewTestContext(t) request := &types.LLMRequest{ Body: &types.LLMRequestBody{ Completions: &types.CompletionsRequest{ diff --git a/pkg/plugins/scorer/active_request_test.go b/pkg/plugins/scorer/active_request_test.go index 065d3f4fe8..e7215ce1aa 100644 --- a/pkg/plugins/scorer/active_request_test.go +++ b/pkg/plugins/scorer/active_request_test.go @@ -1,7 +1,6 @@ package scorer import ( - "context" "testing" "time" @@ -11,6 +10,8 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestActiveRequestScorer_Score(t *testing.T) { @@ -87,10 +88,12 @@ func TestActiveRequestScorer_Score(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scorer := NewActiveRequest(context.Background(), nil) + ctx := utils.NewTestContext(t) + + scorer := NewActiveRequest(ctx, nil) test.setupCache(scorer) - got := scorer.Score(context.Background(), nil, nil, test.input) + got := scorer.Score(ctx, nil, nil, test.input) if diff := cmp.Diff(test.wantScores, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) @@ -100,8 +103,7 @@ func TestActiveRequestScorer_Score(t *testing.T) { } func TestActiveRequestScorer_PreRequest(t *testing.T) { - ctx := context.Background() - + ctx := utils.NewTestContext(t) scorer := NewActiveRequest(ctx, nil) podA := &types.PodMetrics{ @@ -169,7 +171,7 @@ func TestActiveRequestScorer_PreRequest(t *testing.T) { } func TestActiveRequestScorer_ResponseComplete(t *testing.T) { - ctx := context.Background() + ctx := utils.NewTestContext(t) scorer := NewActiveRequest(ctx, nil) @@ -225,7 +227,7 @@ func TestActiveRequestScorer_ResponseComplete(t *testing.T) { } func TestActiveRequestScorer_TTLExpiration(t *testing.T) { - ctx := context.Background() + ctx := utils.NewTestContext(t) // Use very short timeout for test params := &ActiveRequestParameters{RequestTimeout: "1s"} @@ -274,8 +276,10 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { } func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { + ctx := utils.NewTestContext(t) + params := &ActiveRequestParameters{RequestTimeout: "invalid"} - scorer := NewActiveRequest(context.Background(), params) + scorer := NewActiveRequest(ctx, params) // Should use default timeout when invalid value is provided if scorer == nil { @@ -284,7 +288,9 @@ func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { } func TestActiveRequestScorer_TypedName(t *testing.T) { - scorer := NewActiveRequest(context.Background(), nil) + ctx := utils.NewTestContext(t) + + scorer := NewActiveRequest(ctx, nil) typedName := scorer.TypedName() if typedName.Type != ActiveRequestType { @@ -293,7 +299,9 @@ func TestActiveRequestScorer_TypedName(t *testing.T) { } func TestActiveRequestScorer_WithName(t *testing.T) { - scorer := NewActiveRequest(context.Background(), nil) + ctx := utils.NewTestContext(t) + + scorer := NewActiveRequest(ctx, nil) testName := "test-scorer" scorer = scorer.WithName(testName) diff --git a/pkg/plugins/scorer/load_aware_test.go b/pkg/plugins/scorer/load_aware_test.go index e454d22f7e..e693e99b57 100644 --- a/pkg/plugins/scorer/load_aware_test.go +++ b/pkg/plugins/scorer/load_aware_test.go @@ -13,6 +13,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestLoadBasedScorer(t *testing.T) { @@ -44,7 +45,7 @@ func TestLoadBasedScorer(t *testing.T) { }{ { name: "load based scorer", - scorer: scorer.NewLoadAware(context.Background(), 10), + scorer: scorer.NewLoadAware(utils.NewTestContext(t), 10), req: &types.LLMRequest{ TargetModel: "critical", }, diff --git a/pkg/plugins/scorer/no_hit_lru_test.go b/pkg/plugins/scorer/no_hit_lru_test.go index 74eba957f0..a6af9ce20b 100644 --- a/pkg/plugins/scorer/no_hit_lru_test.go +++ b/pkg/plugins/scorer/no_hit_lru_test.go @@ -17,6 +17,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) var _ plugins.Handle = &fakeHandle{} @@ -76,13 +77,13 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) { }{ { name: "missing prefix cache plugin - should work as optimization", - handle: newFakeHandle(context.Background()), + handle: newFakeHandle(utils.NewTestContext(t)), expectError: false, }, { name: "prefix plugin present - should work", handle: func() *fakeHandle { - h := newFakeHandle(context.Background()) + h := newFakeHandle(utils.NewTestContext(t)) h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugins.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}}) return h }(), @@ -146,7 +147,7 @@ func TestNoHitLRUScorer(t *testing.T) { }{ { name: "cold request - all pods never used", - scorer: scorer.NewNoHitLRU(context.Background(), nil), + scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), req: &types.LLMRequest{ TargetModel: "test-model", }, @@ -163,7 +164,7 @@ func TestNoHitLRUScorer(t *testing.T) { }, { name: "cache hit - neutral scores", - scorer: scorer.NewNoHitLRU(context.Background(), nil), + scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), req: &types.LLMRequest{ TargetModel: "test-model", }, @@ -182,7 +183,7 @@ func TestNoHitLRUScorer(t *testing.T) { }, { name: "single pod - max score", - scorer: scorer.NewNoHitLRU(context.Background(), nil), + scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), req: &types.LLMRequest{ TargetModel: "test-model", }, @@ -205,7 +206,7 @@ func TestNoHitLRUScorer(t *testing.T) { cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), test.prefixState) } - got := test.scorer.Score(context.Background(), cycleState, test.req, test.input) + got := test.scorer.Score(utils.NewTestContext(t), cycleState, test.req, test.input) if diff := cmp.Diff(test.wantScores, got); diff != "" { t.Errorf("%s: Unexpected output (-want +got): %v", test.description, diff) @@ -215,7 +216,8 @@ func TestNoHitLRUScorer(t *testing.T) { } func TestNoHitLRUBasicFunctionality(t *testing.T) { - ctx := context.Background() + ctx := utils.NewTestContext(t) + scorer := scorer.NewNoHitLRU(ctx, nil) podA := &types.PodMetrics{ @@ -257,7 +259,7 @@ func TestNoHitLRUBasicFunctionality(t *testing.T) { } func TestNoPrefixCacheStateFound(t *testing.T) { - ctx := context.Background() + ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) podA := &types.PodMetrics{ @@ -275,7 +277,7 @@ func TestNoPrefixCacheStateFound(t *testing.T) { } func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) { - ctx := context.Background() + ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) podA := &types.PodMetrics{ @@ -395,7 +397,7 @@ func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) { } func TestNoHitLRUEdgeCases(t *testing.T) { - ctx := context.Background() + ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) podA := &types.PodMetrics{ diff --git a/pkg/plugins/scorer/precise_prefix_cache_test.go b/pkg/plugins/scorer/precise_prefix_cache_test.go index e033228d90..66dfc87fe0 100644 --- a/pkg/plugins/scorer/precise_prefix_cache_test.go +++ b/pkg/plugins/scorer/precise_prefix_cache_test.go @@ -16,6 +16,8 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestPrefixCacheTracking_Score(t *testing.T) { @@ -556,7 +558,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) { for _, tt := range testcases { t.Run(tt.name, func(t *testing.T) { - ctx := t.Context() + ctx := utils.NewTestContext(t) kvcacheConfig, err := kvcache.NewDefaultConfig() kvcacheConfig.TokenizersPoolConfig = &tokenization.Config{ diff --git a/pkg/plugins/scorer/session_affinity_test.go b/pkg/plugins/scorer/session_affinity_test.go index 481b209a39..943b06eb4b 100644 --- a/pkg/plugins/scorer/session_affinity_test.go +++ b/pkg/plugins/scorer/session_affinity_test.go @@ -14,6 +14,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" + "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestSessionAffinity_Score(t *testing.T) { @@ -131,7 +132,7 @@ func TestSessionAffinity_ResponseComplete(t *testing.T) { } s := scorer.NewSessionAffinity() - ctx := context.Background() + ctx := utils.NewTestContext(t) for _, test := range tests { t.Run(test.name, func(t *testing.T) { diff --git a/test/config/prefix_cache_mode_test.go b/test/config/prefix_cache_mode_test.go index b0fad97fd0..96256eccba 100644 --- a/test/config/prefix_cache_mode_test.go +++ b/test/config/prefix_cache_mode_test.go @@ -1,7 +1,6 @@ package config_test import ( - "context" "fmt" "os" "testing" @@ -13,6 +12,7 @@ import ( "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" + testutils "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestPrecisePrefixCacheScorer(t *testing.T) { @@ -42,7 +42,7 @@ schedulingProfiles: `, }, } - ctx := context.Background() + ctx := testutils.NewTestContext(t) // Register llm-d-inference-scheduler plugins plugins.RegisterAllPlugins() diff --git a/test/utils/context.go b/test/utils/context.go new file mode 100644 index 0000000000..c821db9e0c --- /dev/null +++ b/test/utils/context.go @@ -0,0 +1,23 @@ +// Package utils contains utilities for testing +// +//revive:disable:var-naming +package utils + +//revive:enable:var-naming + +import ( + "context" + "testing" + + "github.com/go-logr/logr/testr" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// NewTestContext creates a new context with a logger associated with the testing.T. +// It simplifies the boilerplate of integrating klog/logr with unit tests. +func NewTestContext(t *testing.T) context.Context { + t.Helper() + + logger := testr.New(t) + return log.IntoContext(context.Background(), logger) +}