Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pkg/plugins/filter/by_label_selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestByLabelSelectorFactoryWithJSON(t *testing.T) {
Expand Down Expand Up @@ -298,7 +299,8 @@ func TestByLabelSelectorFiltering(t *testing.T) {
blf, ok := plugin.(*filter.ByLabelSelector)
require.True(t, ok, "plugin should be of type *ByLabelSelector")

ctx := context.Background()
ctx := utils.NewTestContext(t)

filteredPods := blf.Filter(ctx, nil, nil, pods)

var actualPodNames []string
Expand All @@ -322,7 +324,7 @@ func TestByLabelSelectorFilterEdgeCases(t *testing.T) {
blf, ok := plugin.(*filter.ByLabelSelector)
require.True(t, ok)

ctx := context.Background()
ctx := utils.NewTestContext(t)

t.Run("empty pods slice", func(t *testing.T) {
result := blf.Filter(ctx, nil, nil, []types.Pod{})
Expand Down
6 changes: 4 additions & 2 deletions pkg/plugins/filter/by_label_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package filter

import (
"context"
"encoding/json"
"fmt"
"testing"
Expand All @@ -12,6 +11,8 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestByLabelFactory(t *testing.T) {
Expand Down Expand Up @@ -244,7 +245,8 @@ func TestByLabelFiltering(t *testing.T) {
blf, ok := plugin.(*ByLabel)
require.True(t, ok, "plugin should be of type *ByLabel")

ctx := context.Background()
ctx := utils.NewTestContext(t)

filteredPods := blf.Filter(ctx, nil, nil, pods)

var actualPodNames []string
Expand Down
3 changes: 2 additions & 1 deletion pkg/plugins/profile/pd_profile_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestPdProfileHandlerFactory(t *testing.T) {
Expand Down Expand Up @@ -214,7 +215,7 @@ func newMockSchedulerProfile() *framework.SchedulerProfile {
}

func TestPdProfileHandler_Pick(t *testing.T) {
ctx := context.Background()
ctx := utils.NewTestContext(t)
request := &types.LLMRequest{
Body: &types.LLMRequestBody{
Completions: &types.CompletionsRequest{
Expand Down
28 changes: 18 additions & 10 deletions pkg/plugins/scorer/active_request_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package scorer

import (
"context"
"testing"
"time"

Expand All @@ -11,6 +10,8 @@ import (
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestActiveRequestScorer_Score(t *testing.T) {
Expand Down Expand Up @@ -87,10 +88,12 @@ func TestActiveRequestScorer_Score(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scorer := NewActiveRequest(context.Background(), nil)
ctx := utils.NewTestContext(t)

scorer := NewActiveRequest(ctx, nil)
test.setupCache(scorer)

got := scorer.Score(context.Background(), nil, nil, test.input)
got := scorer.Score(ctx, nil, nil, test.input)

if diff := cmp.Diff(test.wantScores, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
Expand All @@ -100,8 +103,7 @@ func TestActiveRequestScorer_Score(t *testing.T) {
}

func TestActiveRequestScorer_PreRequest(t *testing.T) {
ctx := context.Background()

ctx := utils.NewTestContext(t)
scorer := NewActiveRequest(ctx, nil)

podA := &types.PodMetrics{
Expand Down Expand Up @@ -169,7 +171,7 @@ func TestActiveRequestScorer_PreRequest(t *testing.T) {
}

func TestActiveRequestScorer_ResponseComplete(t *testing.T) {
ctx := context.Background()
ctx := utils.NewTestContext(t)

scorer := NewActiveRequest(ctx, nil)

Expand Down Expand Up @@ -225,7 +227,7 @@ func TestActiveRequestScorer_ResponseComplete(t *testing.T) {
}

func TestActiveRequestScorer_TTLExpiration(t *testing.T) {
ctx := context.Background()
ctx := utils.NewTestContext(t)

// Use very short timeout for test
params := &ActiveRequestParameters{RequestTimeout: "1s"}
Expand Down Expand Up @@ -274,8 +276,10 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) {
}

func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) {
ctx := utils.NewTestContext(t)

params := &ActiveRequestParameters{RequestTimeout: "invalid"}
scorer := NewActiveRequest(context.Background(), params)
scorer := NewActiveRequest(ctx, params)

// Should use default timeout when invalid value is provided
if scorer == nil {
Expand All @@ -284,7 +288,9 @@ func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) {
}

func TestActiveRequestScorer_TypedName(t *testing.T) {
scorer := NewActiveRequest(context.Background(), nil)
ctx := utils.NewTestContext(t)

scorer := NewActiveRequest(ctx, nil)

typedName := scorer.TypedName()
if typedName.Type != ActiveRequestType {
Expand All @@ -293,7 +299,9 @@ func TestActiveRequestScorer_TypedName(t *testing.T) {
}

func TestActiveRequestScorer_WithName(t *testing.T) {
scorer := NewActiveRequest(context.Background(), nil)
ctx := utils.NewTestContext(t)

scorer := NewActiveRequest(ctx, nil)
testName := "test-scorer"

scorer = scorer.WithName(testName)
Expand Down
3 changes: 2 additions & 1 deletion pkg/plugins/scorer/load_aware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestLoadBasedScorer(t *testing.T) {
Expand Down Expand Up @@ -44,7 +45,7 @@ func TestLoadBasedScorer(t *testing.T) {
}{
{
name: "load based scorer",
scorer: scorer.NewLoadAware(context.Background(), 10),
scorer: scorer.NewLoadAware(utils.NewTestContext(t), 10),
req: &types.LLMRequest{
TargetModel: "critical",
},
Expand Down
22 changes: 12 additions & 10 deletions pkg/plugins/scorer/no_hit_lru_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

var _ plugins.Handle = &fakeHandle{}
Expand Down Expand Up @@ -76,13 +77,13 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) {
}{
{
name: "missing prefix cache plugin - should work as optimization",
handle: newFakeHandle(context.Background()),
handle: newFakeHandle(utils.NewTestContext(t)),
expectError: false,
},
{
name: "prefix plugin present - should work",
handle: func() *fakeHandle {
h := newFakeHandle(context.Background())
h := newFakeHandle(utils.NewTestContext(t))
h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugins.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}})
return h
}(),
Expand Down Expand Up @@ -146,7 +147,7 @@ func TestNoHitLRUScorer(t *testing.T) {
}{
{
name: "cold request - all pods never used",
scorer: scorer.NewNoHitLRU(context.Background(), nil),
scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil),
req: &types.LLMRequest{
TargetModel: "test-model",
},
Expand All @@ -163,7 +164,7 @@ func TestNoHitLRUScorer(t *testing.T) {
},
{
name: "cache hit - neutral scores",
scorer: scorer.NewNoHitLRU(context.Background(), nil),
scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil),
req: &types.LLMRequest{
TargetModel: "test-model",
},
Expand All @@ -182,7 +183,7 @@ func TestNoHitLRUScorer(t *testing.T) {
},
{
name: "single pod - max score",
scorer: scorer.NewNoHitLRU(context.Background(), nil),
scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil),
req: &types.LLMRequest{
TargetModel: "test-model",
},
Expand All @@ -205,7 +206,7 @@ func TestNoHitLRUScorer(t *testing.T) {
cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), test.prefixState)
}

got := test.scorer.Score(context.Background(), cycleState, test.req, test.input)
got := test.scorer.Score(utils.NewTestContext(t), cycleState, test.req, test.input)

if diff := cmp.Diff(test.wantScores, got); diff != "" {
t.Errorf("%s: Unexpected output (-want +got): %v", test.description, diff)
Expand All @@ -215,7 +216,8 @@ func TestNoHitLRUScorer(t *testing.T) {
}

func TestNoHitLRUBasicFunctionality(t *testing.T) {
ctx := context.Background()
ctx := utils.NewTestContext(t)

scorer := scorer.NewNoHitLRU(ctx, nil)

podA := &types.PodMetrics{
Expand Down Expand Up @@ -257,7 +259,7 @@ func TestNoHitLRUBasicFunctionality(t *testing.T) {
}

func TestNoPrefixCacheStateFound(t *testing.T) {
ctx := context.Background()
ctx := utils.NewTestContext(t)
scorer := scorer.NewNoHitLRU(ctx, nil)

podA := &types.PodMetrics{
Expand All @@ -275,7 +277,7 @@ func TestNoPrefixCacheStateFound(t *testing.T) {
}

func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) {
ctx := context.Background()
ctx := utils.NewTestContext(t)
scorer := scorer.NewNoHitLRU(ctx, nil)

podA := &types.PodMetrics{
Expand Down Expand Up @@ -395,7 +397,7 @@ func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) {
}

func TestNoHitLRUEdgeCases(t *testing.T) {
ctx := context.Background()
ctx := utils.NewTestContext(t)
scorer := scorer.NewNoHitLRU(ctx, nil)

podA := &types.PodMetrics{
Expand Down
4 changes: 3 additions & 1 deletion pkg/plugins/scorer/precise_prefix_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestPrefixCacheTracking_Score(t *testing.T) {
Expand Down Expand Up @@ -556,7 +558,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) {

for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
ctx := t.Context()
ctx := utils.NewTestContext(t)

kvcacheConfig, err := kvcache.NewDefaultConfig()
kvcacheConfig.TokenizersPoolConfig = &tokenization.Config{
Expand Down
3 changes: 2 additions & 1 deletion pkg/plugins/scorer/session_affinity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestSessionAffinity_Score(t *testing.T) {
Expand Down Expand Up @@ -131,7 +132,7 @@ func TestSessionAffinity_ResponseComplete(t *testing.T) {
}

s := scorer.NewSessionAffinity()
ctx := context.Background()
ctx := utils.NewTestContext(t)

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions test/config/prefix_cache_mode_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package config_test

import (
"context"
"fmt"
"os"
"testing"
Expand All @@ -13,6 +12,7 @@ import (

"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
testutils "github.com/llm-d/llm-d-inference-scheduler/test/utils"
)

func TestPrecisePrefixCacheScorer(t *testing.T) {
Expand Down Expand Up @@ -42,7 +42,7 @@ schedulingProfiles:
`,
},
}
ctx := context.Background()
ctx := testutils.NewTestContext(t)
// Register llm-d-inference-scheduler plugins
plugins.RegisterAllPlugins()

Expand Down
23 changes: 23 additions & 0 deletions test/utils/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Package utils contains utilities for testing
//
//revive:disable:var-naming
package utils

//revive:enable:var-naming

import (
"context"
"testing"

"github.com/go-logr/logr/testr"
"sigs.k8s.io/controller-runtime/pkg/log"
)

// NewTestContext creates a new context with a logger associated with the testing.T.
// It simplifies the boilerplate of integrating klog/logr with unit tests.
func NewTestContext(t *testing.T) context.Context {
t.Helper()

logger := testr.New(t)
return log.IntoContext(context.Background(), logger)
}